koprogo_api/infrastructure/web/
middleware.rs

1use crate::infrastructure::web::app_state::AppState;
2// Note: Rate limiting is configured in main.rs using actix_governor
3// The actix_governor imports are kept in main.rs, not here
4use actix_web::{
5    body::MessageBody,
6    dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
7    error::ErrorUnauthorized,
8    http::StatusCode,
9    web, Error, FromRequest, HttpRequest, HttpResponse,
10};
11use std::collections::HashMap;
12use std::future::{ready, Future, Ready};
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use uuid::Uuid;
17
18/// Authenticated user claims extracted from JWT token
19///
20/// This struct automatically extracts and validates JWT tokens from the Authorization header.
21/// Use it as a parameter in your handler functions to require authentication:
22///
23/// ```rust,ignore
24/// use actix_web::Responder;
25/// use koprogo_api::infrastructure::web::middleware::AuthenticatedUser;
26///
27/// async fn protected_handler(claims: AuthenticatedUser) -> impl Responder {
28///     // claims.user_id and claims.organization_id are now available
29/// }
30/// ```
31#[derive(Debug, Clone)]
32pub struct AuthenticatedUser {
33    pub user_id: Uuid,
34    pub email: String,
35    pub role: String,
36    pub role_id: Option<Uuid>,
37    pub organization_id: Option<Uuid>,
38}
39
40impl AuthenticatedUser {
41    /// Get the organization_id or return an error if not present
42    pub fn require_organization(&self) -> Result<Uuid, Error> {
43        self.organization_id
44            .ok_or_else(|| ErrorUnauthorized("User does not belong to an organization"))
45    }
46
47    /// Check if user is superadmin (can access all organizations)
48    pub fn is_superadmin(&self) -> bool {
49        self.role == "superadmin"
50    }
51
52    /// Get effective organization_id for filtering:
53    /// - SuperAdmin: None (sees everything)
54    /// - Others: Some(org_id)
55    pub fn effective_org_filter(&self) -> Option<Uuid> {
56        if self.is_superadmin() {
57            None
58        } else {
59            self.organization_id
60        }
61    }
62
63    /// Verify that a resource's organization matches the user's organization.
64    /// SuperAdmin bypasses this check.
65    /// Returns Ok(()) if access is allowed, Err(message) if denied.
66    pub fn verify_org_access(&self, resource_org_id: Uuid) -> Result<(), String> {
67        if self.is_superadmin() {
68            return Ok(());
69        }
70        match self.organization_id {
71            Some(user_org_id) if user_org_id == resource_org_id => Ok(()),
72            Some(_) => Err("Access denied: resource belongs to another organization".to_string()),
73            None => Err("User does not belong to an organization".to_string()),
74        }
75    }
76}
77
78impl FromRequest for AuthenticatedUser {
79    type Error = Error;
80    type Future = Ready<Result<Self, Self::Error>>;
81
82    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
83        // Get AppState from request
84        let app_state = match req.app_data::<web::Data<AppState>>() {
85            Some(state) => state,
86            None => return ready(Err(ErrorUnauthorized("Internal server error"))),
87        };
88
89        // Extract Authorization header
90        let auth_header = match req.headers().get("Authorization") {
91            Some(header) => match header.to_str() {
92                Ok(s) => s,
93                Err(_) => return ready(Err(ErrorUnauthorized("Invalid authorization header"))),
94            },
95            None => return ready(Err(ErrorUnauthorized("Missing authorization header"))),
96        };
97
98        // Extract token from "Bearer <token>"
99        let token = auth_header.trim_start_matches("Bearer ").trim();
100
101        // Verify token and extract claims
102        match app_state.auth_use_cases.verify_token(token) {
103            Ok(claims) => {
104                // Parse user_id from claims.sub
105                match Uuid::parse_str(&claims.sub) {
106                    Ok(user_id) => ready(Ok(AuthenticatedUser {
107                        user_id,
108                        email: claims.email,
109                        role: claims.role,
110                        role_id: claims.role_id,
111                        organization_id: claims.organization_id,
112                    })),
113                    Err(_) => ready(Err(ErrorUnauthorized("Invalid user ID in token"))),
114                }
115            }
116            Err(e) => ready(Err(ErrorUnauthorized(e))),
117        }
118    }
119}
120
121/// Organization ID extracted from authenticated user's JWT token
122///
123/// This extractor requires that the user belongs to an organization.
124/// Use it when you need to enforce organization-scoped operations:
125///
126/// ```rust,ignore
127/// use actix_web::{Responder, web};
128/// use koprogo_api::application::dto::CreateBuildingDto;
129/// use koprogo_api::infrastructure::web::middleware::OrganizationId;
130///
131/// async fn create_building(
132///     organization: OrganizationId,
133///     dto: web::Json<CreateBuildingDto>
134/// ) -> impl Responder {
135///     // organization.0 contains the Uuid
136/// }
137/// ```
138#[derive(Debug, Clone, Copy)]
139pub struct OrganizationId(pub Uuid);
140
141impl FromRequest for OrganizationId {
142    type Error = Error;
143    type Future = Ready<Result<Self, Self::Error>>;
144
145    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
146        // First extract AuthenticatedUser
147        let user_future = AuthenticatedUser::from_request(req, payload);
148
149        // Get the result
150        match user_future.into_inner() {
151            Ok(user) => match user.organization_id {
152                Some(org_id) => ready(Ok(OrganizationId(org_id))),
153                None => ready(Err(ErrorUnauthorized(
154                    "User does not belong to an organization",
155                ))),
156            },
157            Err(e) => ready(Err(e)),
158        }
159    }
160}
161
162// ========================================
163// GDPR Rate Limiting Middleware
164// ========================================
165
166/// Configuration for GDPR rate limiting
167#[derive(Clone, Debug)]
168pub struct GdprRateLimitConfig {
169    /// Maximum number of requests allowed per window
170    pub max_requests: usize,
171    /// Duration of the rate limit window
172    pub window_duration: Duration,
173}
174
175impl Default for GdprRateLimitConfig {
176    fn default() -> Self {
177        Self {
178            max_requests: 10,
179            window_duration: Duration::from_secs(3600), // 1 hour
180        }
181    }
182}
183
184/// Rate limit state tracking
185#[derive(Clone)]
186pub struct GdprRateLimitState {
187    state: Arc<Mutex<HashMap<String, (usize, Instant)>>>,
188    config: GdprRateLimitConfig,
189}
190
191impl GdprRateLimitState {
192    pub fn new(config: GdprRateLimitConfig) -> Self {
193        Self {
194            state: Arc::new(Mutex::new(HashMap::new())),
195            config,
196        }
197    }
198
199    /// Check if user has exceeded rate limit
200    pub fn check_rate_limit(&self, user_id: &str) -> Result<(), String> {
201        let mut state = self.state.lock().unwrap();
202        let now = Instant::now();
203        let entry = state.entry(user_id.to_string()).or_insert((0, now));
204        let (count, window_start) = entry;
205
206        // Reset window if expired
207        if now.duration_since(*window_start) > self.config.window_duration {
208            *count = 0;
209            *window_start = now;
210        }
211
212        // Check limit
213        if *count >= self.config.max_requests {
214            let reset_in = self
215                .config
216                .window_duration
217                .saturating_sub(now.duration_since(*window_start));
218            return Err(format!(
219                "Rate limit exceeded. Try again in {} seconds.",
220                reset_in.as_secs()
221            ));
222        }
223
224        *count += 1;
225        Ok(())
226    }
227}
228
229/// GDPR-specific rate limiting middleware
230///
231/// Only applies rate limits to GDPR-related endpoints:
232/// - `/api/v1/gdpr/*`
233/// - `/api/v1/admin/gdpr/*`
234#[derive(Clone)]
235pub struct GdprRateLimit {
236    state: GdprRateLimitState,
237}
238
239impl GdprRateLimit {
240    pub fn new(config: GdprRateLimitConfig) -> Self {
241        Self {
242            state: GdprRateLimitState::new(config),
243        }
244    }
245}
246
247impl<S, B> Transform<S, ServiceRequest> for GdprRateLimit
248where
249    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
250    S::Future: 'static,
251    B: MessageBody + 'static,
252{
253    type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
254    type Error = Error;
255    type InitError = ();
256    type Transform = GdprRateLimitMiddleware<S>;
257    type Future = Ready<Result<Self::Transform, Self::InitError>>;
258
259    fn new_transform(&self, service: S) -> Self::Future {
260        ready(Ok(GdprRateLimitMiddleware {
261            service: Arc::new(service),
262            state: self.state.clone(),
263        }))
264    }
265}
266
267pub struct GdprRateLimitMiddleware<S> {
268    service: Arc<S>,
269    state: GdprRateLimitState,
270}
271
272impl<S, B> Service<ServiceRequest> for GdprRateLimitMiddleware<S>
273where
274    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
275    S::Future: 'static,
276    B: MessageBody + 'static,
277{
278    type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
279    type Error = Error;
280    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
281
282    forward_ready!(service);
283
284    fn call(&self, req: ServiceRequest) -> Self::Future {
285        let path = req.path().to_string();
286
287        // Only apply rate limiting to GDPR endpoints
288        let is_gdpr_endpoint =
289            path.starts_with("/api/v1/gdpr") || path.starts_with("/api/v1/admin/gdpr");
290
291        if !is_gdpr_endpoint {
292            let fut = self.service.call(req);
293            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
294        }
295
296        // Extract user_id from AuthenticatedUser
297        let user_id = match req.app_data::<web::Data<AppState>>() {
298            Some(app_state) => {
299                // Extract Authorization header
300                let auth_header = match req.headers().get("Authorization") {
301                    Some(header) => match header.to_str() {
302                        Ok(s) => s.to_string(),
303                        Err(_) => {
304                            // Let the handler deal with invalid auth
305                            let fut = self.service.call(req);
306                            return Box::pin(async move {
307                                fut.await.map(|res| res.map_into_left_body())
308                            });
309                        }
310                    },
311                    None => {
312                        // Let the handler deal with missing auth
313                        let fut = self.service.call(req);
314                        return Box::pin(
315                            async move { fut.await.map(|res| res.map_into_left_body()) },
316                        );
317                    }
318                };
319
320                let token = auth_header.trim_start_matches("Bearer ").trim();
321
322                match app_state.auth_use_cases.verify_token(token) {
323                    Ok(claims) => claims.sub,
324                    Err(_) => {
325                        // Let the handler deal with invalid token
326                        let fut = self.service.call(req);
327                        return Box::pin(
328                            async move { fut.await.map(|res| res.map_into_left_body()) },
329                        );
330                    }
331                }
332            }
333            None => {
334                let fut = self.service.call(req);
335                return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
336            }
337        };
338
339        // Check rate limit
340        let state = self.state.clone();
341        let service = self.service.clone();
342
343        Box::pin(async move {
344            match state.check_rate_limit(&user_id) {
345                Ok(_) => {
346                    // Rate limit not exceeded, proceed with request
347                    service.call(req).await.map(|res| res.map_into_left_body())
348                }
349                Err(msg) => {
350                    // Rate limit exceeded, return 429
351                    let retry_after = state.config.window_duration.as_secs().to_string();
352                    let response = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
353                        .insert_header(("Retry-After", retry_after.clone()))
354                        .json(serde_json::json!({
355                            "error": msg,
356                            "retry_after_seconds": state.config.window_duration.as_secs()
357                        }));
358
359                    Ok(req.into_response(response).map_into_right_body())
360                }
361            }
362        })
363    }
364}
365
366// ========================================
367// Global Rate Limiting (Issue #78)
368// ========================================
369//
370// Rate limiting is configured directly in main.rs using GovernorConfigBuilder.
371// Three-tier strategy:
372// 1. Public endpoints: 100 req/min per IP (DDoS prevention)
373// 2. Authenticated endpoints: 1000 req/min per IP (higher trust, still IP-based for simplicity)
374// 3. Login endpoint: 5 attempts per 15min per IP (brute-force prevention via LoginRateLimiter)
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_authenticated_user_require_organization() {
382        let user_with_org = AuthenticatedUser {
383            user_id: Uuid::new_v4(),
384            email: "test@example.com".to_string(),
385            role: "admin".to_string(),
386            role_id: None,
387            organization_id: Some(Uuid::new_v4()),
388        };
389
390        assert!(user_with_org.require_organization().is_ok());
391
392        let user_without_org = AuthenticatedUser {
393            user_id: Uuid::new_v4(),
394            email: "test@example.com".to_string(),
395            role: "admin".to_string(),
396            role_id: None,
397            organization_id: None,
398        };
399
400        assert!(user_without_org.require_organization().is_err());
401    }
402
403    #[test]
404    fn test_gdpr_rate_limit_config_default() {
405        let config = GdprRateLimitConfig::default();
406        assert_eq!(config.max_requests, 10);
407        assert_eq!(config.window_duration, Duration::from_secs(3600));
408    }
409
410    #[test]
411    fn test_gdpr_rate_limit_state_allows_within_limit() {
412        let config = GdprRateLimitConfig {
413            max_requests: 3,
414            window_duration: Duration::from_secs(60),
415        };
416        let state = GdprRateLimitState::new(config);
417
418        assert!(state.check_rate_limit("user1").is_ok());
419        assert!(state.check_rate_limit("user1").is_ok());
420        assert!(state.check_rate_limit("user1").is_ok());
421    }
422
423    #[test]
424    fn test_gdpr_rate_limit_state_blocks_exceeding_limit() {
425        let config = GdprRateLimitConfig {
426            max_requests: 2,
427            window_duration: Duration::from_secs(60),
428        };
429        let state = GdprRateLimitState::new(config);
430
431        assert!(state.check_rate_limit("user1").is_ok());
432        assert!(state.check_rate_limit("user1").is_ok());
433        let result = state.check_rate_limit("user1");
434        assert!(result.is_err());
435        assert!(result
436            .unwrap_err()
437            .contains("Rate limit exceeded. Try again in"));
438    }
439}