koprogo_api/infrastructure/web/
middleware.rs

1use crate::infrastructure::web::app_state::AppState;
2// Note: Rate limiting is configured in main.rs using actix_governor
3// The actix_governor imports are kept in main.rs, not here
4use actix_web::{
5    body::MessageBody,
6    dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
7    error::ErrorUnauthorized,
8    http::StatusCode,
9    web, Error, FromRequest, HttpRequest, HttpResponse,
10};
11use std::collections::HashMap;
12use std::future::{ready, Future, Ready};
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use uuid::Uuid;
17
18/// Authenticated user claims extracted from JWT token
19///
20/// This struct automatically extracts and validates JWT tokens from the Authorization header.
21/// Use it as a parameter in your handler functions to require authentication:
22///
23/// ```rust,ignore
24/// use actix_web::Responder;
25/// use koprogo_api::infrastructure::web::middleware::AuthenticatedUser;
26///
27/// async fn protected_handler(claims: AuthenticatedUser) -> impl Responder {
28///     // claims.user_id and claims.organization_id are now available
29/// }
30/// ```
31#[derive(Debug, Clone)]
32pub struct AuthenticatedUser {
33    pub user_id: Uuid,
34    pub email: String,
35    pub role: String,
36    pub role_id: Option<Uuid>,
37    pub organization_id: Option<Uuid>,
38}
39
40impl AuthenticatedUser {
41    /// Get the organization_id or return an error if not present
42    pub fn require_organization(&self) -> Result<Uuid, Error> {
43        self.organization_id
44            .ok_or_else(|| ErrorUnauthorized("User does not belong to an organization"))
45    }
46}
47
48impl FromRequest for AuthenticatedUser {
49    type Error = Error;
50    type Future = Ready<Result<Self, Self::Error>>;
51
52    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
53        // Get AppState from request
54        let app_state = match req.app_data::<web::Data<AppState>>() {
55            Some(state) => state,
56            None => return ready(Err(ErrorUnauthorized("Internal server error"))),
57        };
58
59        // Extract Authorization header
60        let auth_header = match req.headers().get("Authorization") {
61            Some(header) => match header.to_str() {
62                Ok(s) => s,
63                Err(_) => return ready(Err(ErrorUnauthorized("Invalid authorization header"))),
64            },
65            None => return ready(Err(ErrorUnauthorized("Missing authorization header"))),
66        };
67
68        // Extract token from "Bearer <token>"
69        let token = auth_header.trim_start_matches("Bearer ").trim();
70
71        // Verify token and extract claims
72        match app_state.auth_use_cases.verify_token(token) {
73            Ok(claims) => {
74                // Parse user_id from claims.sub
75                match Uuid::parse_str(&claims.sub) {
76                    Ok(user_id) => ready(Ok(AuthenticatedUser {
77                        user_id,
78                        email: claims.email,
79                        role: claims.role,
80                        role_id: claims.role_id,
81                        organization_id: claims.organization_id,
82                    })),
83                    Err(_) => ready(Err(ErrorUnauthorized("Invalid user ID in token"))),
84                }
85            }
86            Err(e) => ready(Err(ErrorUnauthorized(e))),
87        }
88    }
89}
90
91/// Organization ID extracted from authenticated user's JWT token
92///
93/// This extractor requires that the user belongs to an organization.
94/// Use it when you need to enforce organization-scoped operations:
95///
96/// ```rust,ignore
97/// use actix_web::{Responder, web};
98/// use koprogo_api::application::dto::CreateBuildingDto;
99/// use koprogo_api::infrastructure::web::middleware::OrganizationId;
100///
101/// async fn create_building(
102///     organization: OrganizationId,
103///     dto: web::Json<CreateBuildingDto>
104/// ) -> impl Responder {
105///     // organization.0 contains the Uuid
106/// }
107/// ```
108#[derive(Debug, Clone, Copy)]
109pub struct OrganizationId(pub Uuid);
110
111impl FromRequest for OrganizationId {
112    type Error = Error;
113    type Future = Ready<Result<Self, Self::Error>>;
114
115    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
116        // First extract AuthenticatedUser
117        let user_future = AuthenticatedUser::from_request(req, payload);
118
119        // Get the result
120        match user_future.into_inner() {
121            Ok(user) => match user.organization_id {
122                Some(org_id) => ready(Ok(OrganizationId(org_id))),
123                None => ready(Err(ErrorUnauthorized(
124                    "User does not belong to an organization",
125                ))),
126            },
127            Err(e) => ready(Err(e)),
128        }
129    }
130}
131
132// ========================================
133// GDPR Rate Limiting Middleware
134// ========================================
135
136/// Configuration for GDPR rate limiting
137#[derive(Clone, Debug)]
138pub struct GdprRateLimitConfig {
139    /// Maximum number of requests allowed per window
140    pub max_requests: usize,
141    /// Duration of the rate limit window
142    pub window_duration: Duration,
143}
144
145impl Default for GdprRateLimitConfig {
146    fn default() -> Self {
147        Self {
148            max_requests: 10,
149            window_duration: Duration::from_secs(3600), // 1 hour
150        }
151    }
152}
153
154/// Rate limit state tracking
155#[derive(Clone)]
156pub struct GdprRateLimitState {
157    state: Arc<Mutex<HashMap<String, (usize, Instant)>>>,
158    config: GdprRateLimitConfig,
159}
160
161impl GdprRateLimitState {
162    pub fn new(config: GdprRateLimitConfig) -> Self {
163        Self {
164            state: Arc::new(Mutex::new(HashMap::new())),
165            config,
166        }
167    }
168
169    /// Check if user has exceeded rate limit
170    pub fn check_rate_limit(&self, user_id: &str) -> Result<(), String> {
171        let mut state = self.state.lock().unwrap();
172        let now = Instant::now();
173        let entry = state.entry(user_id.to_string()).or_insert((0, now));
174        let (count, window_start) = entry;
175
176        // Reset window if expired
177        if now.duration_since(*window_start) > self.config.window_duration {
178            *count = 0;
179            *window_start = now;
180        }
181
182        // Check limit
183        if *count >= self.config.max_requests {
184            let reset_in = self
185                .config
186                .window_duration
187                .saturating_sub(now.duration_since(*window_start));
188            return Err(format!(
189                "Rate limit exceeded. Try again in {} seconds.",
190                reset_in.as_secs()
191            ));
192        }
193
194        *count += 1;
195        Ok(())
196    }
197}
198
199/// GDPR-specific rate limiting middleware
200///
201/// Only applies rate limits to GDPR-related endpoints:
202/// - `/api/v1/gdpr/*`
203/// - `/api/v1/admin/gdpr/*`
204#[derive(Clone)]
205pub struct GdprRateLimit {
206    state: GdprRateLimitState,
207}
208
209impl GdprRateLimit {
210    pub fn new(config: GdprRateLimitConfig) -> Self {
211        Self {
212            state: GdprRateLimitState::new(config),
213        }
214    }
215}
216
217impl<S, B> Transform<S, ServiceRequest> for GdprRateLimit
218where
219    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
220    S::Future: 'static,
221    B: MessageBody + 'static,
222{
223    type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
224    type Error = Error;
225    type InitError = ();
226    type Transform = GdprRateLimitMiddleware<S>;
227    type Future = Ready<Result<Self::Transform, Self::InitError>>;
228
229    fn new_transform(&self, service: S) -> Self::Future {
230        ready(Ok(GdprRateLimitMiddleware {
231            service: Arc::new(service),
232            state: self.state.clone(),
233        }))
234    }
235}
236
237pub struct GdprRateLimitMiddleware<S> {
238    service: Arc<S>,
239    state: GdprRateLimitState,
240}
241
242impl<S, B> Service<ServiceRequest> for GdprRateLimitMiddleware<S>
243where
244    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
245    S::Future: 'static,
246    B: MessageBody + 'static,
247{
248    type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
249    type Error = Error;
250    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
251
252    forward_ready!(service);
253
254    fn call(&self, req: ServiceRequest) -> Self::Future {
255        let path = req.path().to_string();
256
257        // Only apply rate limiting to GDPR endpoints
258        let is_gdpr_endpoint =
259            path.starts_with("/api/v1/gdpr") || path.starts_with("/api/v1/admin/gdpr");
260
261        if !is_gdpr_endpoint {
262            let fut = self.service.call(req);
263            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
264        }
265
266        // Extract user_id from AuthenticatedUser
267        let user_id = match req.app_data::<web::Data<AppState>>() {
268            Some(app_state) => {
269                // Extract Authorization header
270                let auth_header = match req.headers().get("Authorization") {
271                    Some(header) => match header.to_str() {
272                        Ok(s) => s.to_string(),
273                        Err(_) => {
274                            // Let the handler deal with invalid auth
275                            let fut = self.service.call(req);
276                            return Box::pin(async move {
277                                fut.await.map(|res| res.map_into_left_body())
278                            });
279                        }
280                    },
281                    None => {
282                        // Let the handler deal with missing auth
283                        let fut = self.service.call(req);
284                        return Box::pin(
285                            async move { fut.await.map(|res| res.map_into_left_body()) },
286                        );
287                    }
288                };
289
290                let token = auth_header.trim_start_matches("Bearer ").trim();
291
292                match app_state.auth_use_cases.verify_token(token) {
293                    Ok(claims) => claims.sub,
294                    Err(_) => {
295                        // Let the handler deal with invalid token
296                        let fut = self.service.call(req);
297                        return Box::pin(
298                            async move { fut.await.map(|res| res.map_into_left_body()) },
299                        );
300                    }
301                }
302            }
303            None => {
304                let fut = self.service.call(req);
305                return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
306            }
307        };
308
309        // Check rate limit
310        let state = self.state.clone();
311        let service = self.service.clone();
312
313        Box::pin(async move {
314            match state.check_rate_limit(&user_id) {
315                Ok(_) => {
316                    // Rate limit not exceeded, proceed with request
317                    service.call(req).await.map(|res| res.map_into_left_body())
318                }
319                Err(msg) => {
320                    // Rate limit exceeded, return 429
321                    let retry_after = state.config.window_duration.as_secs().to_string();
322                    let response = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
323                        .insert_header(("Retry-After", retry_after.clone()))
324                        .json(serde_json::json!({
325                            "error": msg,
326                            "retry_after_seconds": state.config.window_duration.as_secs()
327                        }));
328
329                    Ok(req.into_response(response).map_into_right_body())
330                }
331            }
332        })
333    }
334}
335
336// ========================================
337// Global Rate Limiting (Issue #78)
338// ========================================
339//
340// Rate limiting is configured directly in main.rs using GovernorConfigBuilder.
341// Three-tier strategy:
342// 1. Public endpoints: 100 req/min per IP (DDoS prevention)
343// 2. Authenticated endpoints: 1000 req/min per IP (higher trust, still IP-based for simplicity)
344// 3. Login endpoint: 5 attempts per 15min per IP (brute-force prevention via LoginRateLimiter)
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_authenticated_user_require_organization() {
352        let user_with_org = AuthenticatedUser {
353            user_id: Uuid::new_v4(),
354            email: "test@example.com".to_string(),
355            role: "admin".to_string(),
356            role_id: None,
357            organization_id: Some(Uuid::new_v4()),
358        };
359
360        assert!(user_with_org.require_organization().is_ok());
361
362        let user_without_org = AuthenticatedUser {
363            user_id: Uuid::new_v4(),
364            email: "test@example.com".to_string(),
365            role: "admin".to_string(),
366            role_id: None,
367            organization_id: None,
368        };
369
370        assert!(user_without_org.require_organization().is_err());
371    }
372
373    #[test]
374    fn test_gdpr_rate_limit_config_default() {
375        let config = GdprRateLimitConfig::default();
376        assert_eq!(config.max_requests, 10);
377        assert_eq!(config.window_duration, Duration::from_secs(3600));
378    }
379
380    #[test]
381    fn test_gdpr_rate_limit_state_allows_within_limit() {
382        let config = GdprRateLimitConfig {
383            max_requests: 3,
384            window_duration: Duration::from_secs(60),
385        };
386        let state = GdprRateLimitState::new(config);
387
388        assert!(state.check_rate_limit("user1").is_ok());
389        assert!(state.check_rate_limit("user1").is_ok());
390        assert!(state.check_rate_limit("user1").is_ok());
391    }
392
393    #[test]
394    fn test_gdpr_rate_limit_state_blocks_exceeding_limit() {
395        let config = GdprRateLimitConfig {
396            max_requests: 2,
397            window_duration: Duration::from_secs(60),
398        };
399        let state = GdprRateLimitState::new(config);
400
401        assert!(state.check_rate_limit("user1").is_ok());
402        assert!(state.check_rate_limit("user1").is_ok());
403        let result = state.check_rate_limit("user1");
404        assert!(result.is_err());
405        assert!(result
406            .unwrap_err()
407            .contains("Rate limit exceeded. Try again in"));
408    }
409}