koprogo_api/infrastructure/web/
middleware.rs1use crate::infrastructure::web::app_state::AppState;
2use actix_web::{
5 body::MessageBody,
6 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
7 error::ErrorUnauthorized,
8 http::StatusCode,
9 web, Error, FromRequest, HttpRequest, HttpResponse,
10};
11use std::collections::HashMap;
12use std::future::{ready, Future, Ready};
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use uuid::Uuid;
17
18#[derive(Debug, Clone)]
32pub struct AuthenticatedUser {
33 pub user_id: Uuid,
34 pub email: String,
35 pub role: String,
36 pub role_id: Option<Uuid>,
37 pub organization_id: Option<Uuid>,
38}
39
40impl AuthenticatedUser {
41 pub fn require_organization(&self) -> Result<Uuid, Error> {
43 self.organization_id
44 .ok_or_else(|| ErrorUnauthorized("User does not belong to an organization"))
45 }
46
47 pub fn is_superadmin(&self) -> bool {
49 self.role == "superadmin"
50 }
51
52 pub fn effective_org_filter(&self) -> Option<Uuid> {
56 if self.is_superadmin() {
57 None
58 } else {
59 self.organization_id
60 }
61 }
62
63 pub fn verify_org_access(&self, resource_org_id: Uuid) -> Result<(), String> {
67 if self.is_superadmin() {
68 return Ok(());
69 }
70 match self.organization_id {
71 Some(user_org_id) if user_org_id == resource_org_id => Ok(()),
72 Some(_) => Err("Access denied: resource belongs to another organization".to_string()),
73 None => Err("User does not belong to an organization".to_string()),
74 }
75 }
76}
77
78impl FromRequest for AuthenticatedUser {
79 type Error = Error;
80 type Future = Ready<Result<Self, Self::Error>>;
81
82 fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
83 let app_state = match req.app_data::<web::Data<AppState>>() {
85 Some(state) => state,
86 None => return ready(Err(ErrorUnauthorized("Internal server error"))),
87 };
88
89 let auth_header = match req.headers().get("Authorization") {
91 Some(header) => match header.to_str() {
92 Ok(s) => s,
93 Err(_) => return ready(Err(ErrorUnauthorized("Invalid authorization header"))),
94 },
95 None => return ready(Err(ErrorUnauthorized("Missing authorization header"))),
96 };
97
98 let token = auth_header.trim_start_matches("Bearer ").trim();
100
101 match app_state.auth_use_cases.verify_token(token) {
103 Ok(claims) => {
104 match Uuid::parse_str(&claims.sub) {
106 Ok(user_id) => ready(Ok(AuthenticatedUser {
107 user_id,
108 email: claims.email,
109 role: claims.role,
110 role_id: claims.role_id,
111 organization_id: claims.organization_id,
112 })),
113 Err(_) => ready(Err(ErrorUnauthorized("Invalid user ID in token"))),
114 }
115 }
116 Err(e) => ready(Err(ErrorUnauthorized(e))),
117 }
118 }
119}
120
121#[derive(Debug, Clone, Copy)]
139pub struct OrganizationId(pub Uuid);
140
141impl FromRequest for OrganizationId {
142 type Error = Error;
143 type Future = Ready<Result<Self, Self::Error>>;
144
145 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
146 let user_future = AuthenticatedUser::from_request(req, payload);
148
149 match user_future.into_inner() {
151 Ok(user) => match user.organization_id {
152 Some(org_id) => ready(Ok(OrganizationId(org_id))),
153 None => ready(Err(ErrorUnauthorized(
154 "User does not belong to an organization",
155 ))),
156 },
157 Err(e) => ready(Err(e)),
158 }
159 }
160}
161
162#[derive(Clone, Debug)]
168pub struct GdprRateLimitConfig {
169 pub max_requests: usize,
171 pub window_duration: Duration,
173}
174
175impl Default for GdprRateLimitConfig {
176 fn default() -> Self {
177 Self {
178 max_requests: 10,
179 window_duration: Duration::from_secs(3600), }
181 }
182}
183
184#[derive(Clone)]
186pub struct GdprRateLimitState {
187 state: Arc<Mutex<HashMap<String, (usize, Instant)>>>,
188 config: GdprRateLimitConfig,
189}
190
191impl GdprRateLimitState {
192 pub fn new(config: GdprRateLimitConfig) -> Self {
193 Self {
194 state: Arc::new(Mutex::new(HashMap::new())),
195 config,
196 }
197 }
198
199 pub fn check_rate_limit(&self, user_id: &str) -> Result<(), String> {
201 let mut state = self.state.lock().unwrap();
202 let now = Instant::now();
203 let entry = state.entry(user_id.to_string()).or_insert((0, now));
204 let (count, window_start) = entry;
205
206 if now.duration_since(*window_start) > self.config.window_duration {
208 *count = 0;
209 *window_start = now;
210 }
211
212 if *count >= self.config.max_requests {
214 let reset_in = self
215 .config
216 .window_duration
217 .saturating_sub(now.duration_since(*window_start));
218 return Err(format!(
219 "Rate limit exceeded. Try again in {} seconds.",
220 reset_in.as_secs()
221 ));
222 }
223
224 *count += 1;
225 Ok(())
226 }
227}
228
229#[derive(Clone)]
235pub struct GdprRateLimit {
236 state: GdprRateLimitState,
237}
238
239impl GdprRateLimit {
240 pub fn new(config: GdprRateLimitConfig) -> Self {
241 Self {
242 state: GdprRateLimitState::new(config),
243 }
244 }
245}
246
247impl<S, B> Transform<S, ServiceRequest> for GdprRateLimit
248where
249 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
250 S::Future: 'static,
251 B: MessageBody + 'static,
252{
253 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
254 type Error = Error;
255 type InitError = ();
256 type Transform = GdprRateLimitMiddleware<S>;
257 type Future = Ready<Result<Self::Transform, Self::InitError>>;
258
259 fn new_transform(&self, service: S) -> Self::Future {
260 ready(Ok(GdprRateLimitMiddleware {
261 service: Arc::new(service),
262 state: self.state.clone(),
263 }))
264 }
265}
266
267pub struct GdprRateLimitMiddleware<S> {
268 service: Arc<S>,
269 state: GdprRateLimitState,
270}
271
272impl<S, B> Service<ServiceRequest> for GdprRateLimitMiddleware<S>
273where
274 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
275 S::Future: 'static,
276 B: MessageBody + 'static,
277{
278 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
279 type Error = Error;
280 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
281
282 forward_ready!(service);
283
284 fn call(&self, req: ServiceRequest) -> Self::Future {
285 let path = req.path().to_string();
286
287 let is_gdpr_endpoint =
289 path.starts_with("/api/v1/gdpr") || path.starts_with("/api/v1/admin/gdpr");
290
291 if !is_gdpr_endpoint {
292 let fut = self.service.call(req);
293 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
294 }
295
296 let user_id = match req.app_data::<web::Data<AppState>>() {
298 Some(app_state) => {
299 let auth_header = match req.headers().get("Authorization") {
301 Some(header) => match header.to_str() {
302 Ok(s) => s.to_string(),
303 Err(_) => {
304 let fut = self.service.call(req);
306 return Box::pin(async move {
307 fut.await.map(|res| res.map_into_left_body())
308 });
309 }
310 },
311 None => {
312 let fut = self.service.call(req);
314 return Box::pin(
315 async move { fut.await.map(|res| res.map_into_left_body()) },
316 );
317 }
318 };
319
320 let token = auth_header.trim_start_matches("Bearer ").trim();
321
322 match app_state.auth_use_cases.verify_token(token) {
323 Ok(claims) => claims.sub,
324 Err(_) => {
325 let fut = self.service.call(req);
327 return Box::pin(
328 async move { fut.await.map(|res| res.map_into_left_body()) },
329 );
330 }
331 }
332 }
333 None => {
334 let fut = self.service.call(req);
335 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
336 }
337 };
338
339 let state = self.state.clone();
341 let service = self.service.clone();
342
343 Box::pin(async move {
344 match state.check_rate_limit(&user_id) {
345 Ok(_) => {
346 service.call(req).await.map(|res| res.map_into_left_body())
348 }
349 Err(msg) => {
350 let retry_after = state.config.window_duration.as_secs().to_string();
352 let response = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
353 .insert_header(("Retry-After", retry_after.clone()))
354 .json(serde_json::json!({
355 "error": msg,
356 "retry_after_seconds": state.config.window_duration.as_secs()
357 }));
358
359 Ok(req.into_response(response).map_into_right_body())
360 }
361 }
362 })
363 }
364}
365
366#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_authenticated_user_require_organization() {
382 let user_with_org = AuthenticatedUser {
383 user_id: Uuid::new_v4(),
384 email: "test@example.com".to_string(),
385 role: "admin".to_string(),
386 role_id: None,
387 organization_id: Some(Uuid::new_v4()),
388 };
389
390 assert!(user_with_org.require_organization().is_ok());
391
392 let user_without_org = AuthenticatedUser {
393 user_id: Uuid::new_v4(),
394 email: "test@example.com".to_string(),
395 role: "admin".to_string(),
396 role_id: None,
397 organization_id: None,
398 };
399
400 assert!(user_without_org.require_organization().is_err());
401 }
402
403 #[test]
404 fn test_gdpr_rate_limit_config_default() {
405 let config = GdprRateLimitConfig::default();
406 assert_eq!(config.max_requests, 10);
407 assert_eq!(config.window_duration, Duration::from_secs(3600));
408 }
409
410 #[test]
411 fn test_gdpr_rate_limit_state_allows_within_limit() {
412 let config = GdprRateLimitConfig {
413 max_requests: 3,
414 window_duration: Duration::from_secs(60),
415 };
416 let state = GdprRateLimitState::new(config);
417
418 assert!(state.check_rate_limit("user1").is_ok());
419 assert!(state.check_rate_limit("user1").is_ok());
420 assert!(state.check_rate_limit("user1").is_ok());
421 }
422
423 #[test]
424 fn test_gdpr_rate_limit_state_blocks_exceeding_limit() {
425 let config = GdprRateLimitConfig {
426 max_requests: 2,
427 window_duration: Duration::from_secs(60),
428 };
429 let state = GdprRateLimitState::new(config);
430
431 assert!(state.check_rate_limit("user1").is_ok());
432 assert!(state.check_rate_limit("user1").is_ok());
433 let result = state.check_rate_limit("user1");
434 assert!(result.is_err());
435 assert!(result
436 .unwrap_err()
437 .contains("Rate limit exceeded. Try again in"));
438 }
439}