koprogo_api/infrastructure/web/
middleware.rs1use crate::infrastructure::web::app_state::AppState;
2use actix_web::{
5 body::MessageBody,
6 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
7 error::ErrorUnauthorized,
8 http::StatusCode,
9 web, Error, FromRequest, HttpRequest, HttpResponse,
10};
11use std::collections::HashMap;
12use std::future::{ready, Future, Ready};
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use uuid::Uuid;
17
18#[derive(Debug, Clone)]
32pub struct AuthenticatedUser {
33 pub user_id: Uuid,
34 pub email: String,
35 pub role: String,
36 pub role_id: Option<Uuid>,
37 pub organization_id: Option<Uuid>,
38}
39
40impl AuthenticatedUser {
41 pub fn require_organization(&self) -> Result<Uuid, Error> {
43 self.organization_id
44 .ok_or_else(|| ErrorUnauthorized("User does not belong to an organization"))
45 }
46}
47
48impl FromRequest for AuthenticatedUser {
49 type Error = Error;
50 type Future = Ready<Result<Self, Self::Error>>;
51
52 fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
53 let app_state = match req.app_data::<web::Data<AppState>>() {
55 Some(state) => state,
56 None => return ready(Err(ErrorUnauthorized("Internal server error"))),
57 };
58
59 let auth_header = match req.headers().get("Authorization") {
61 Some(header) => match header.to_str() {
62 Ok(s) => s,
63 Err(_) => return ready(Err(ErrorUnauthorized("Invalid authorization header"))),
64 },
65 None => return ready(Err(ErrorUnauthorized("Missing authorization header"))),
66 };
67
68 let token = auth_header.trim_start_matches("Bearer ").trim();
70
71 match app_state.auth_use_cases.verify_token(token) {
73 Ok(claims) => {
74 match Uuid::parse_str(&claims.sub) {
76 Ok(user_id) => ready(Ok(AuthenticatedUser {
77 user_id,
78 email: claims.email,
79 role: claims.role,
80 role_id: claims.role_id,
81 organization_id: claims.organization_id,
82 })),
83 Err(_) => ready(Err(ErrorUnauthorized("Invalid user ID in token"))),
84 }
85 }
86 Err(e) => ready(Err(ErrorUnauthorized(e))),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy)]
109pub struct OrganizationId(pub Uuid);
110
111impl FromRequest for OrganizationId {
112 type Error = Error;
113 type Future = Ready<Result<Self, Self::Error>>;
114
115 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
116 let user_future = AuthenticatedUser::from_request(req, payload);
118
119 match user_future.into_inner() {
121 Ok(user) => match user.organization_id {
122 Some(org_id) => ready(Ok(OrganizationId(org_id))),
123 None => ready(Err(ErrorUnauthorized(
124 "User does not belong to an organization",
125 ))),
126 },
127 Err(e) => ready(Err(e)),
128 }
129 }
130}
131
132#[derive(Clone, Debug)]
138pub struct GdprRateLimitConfig {
139 pub max_requests: usize,
141 pub window_duration: Duration,
143}
144
145impl Default for GdprRateLimitConfig {
146 fn default() -> Self {
147 Self {
148 max_requests: 10,
149 window_duration: Duration::from_secs(3600), }
151 }
152}
153
154#[derive(Clone)]
156pub struct GdprRateLimitState {
157 state: Arc<Mutex<HashMap<String, (usize, Instant)>>>,
158 config: GdprRateLimitConfig,
159}
160
161impl GdprRateLimitState {
162 pub fn new(config: GdprRateLimitConfig) -> Self {
163 Self {
164 state: Arc::new(Mutex::new(HashMap::new())),
165 config,
166 }
167 }
168
169 pub fn check_rate_limit(&self, user_id: &str) -> Result<(), String> {
171 let mut state = self.state.lock().unwrap();
172 let now = Instant::now();
173 let entry = state.entry(user_id.to_string()).or_insert((0, now));
174 let (count, window_start) = entry;
175
176 if now.duration_since(*window_start) > self.config.window_duration {
178 *count = 0;
179 *window_start = now;
180 }
181
182 if *count >= self.config.max_requests {
184 let reset_in = self
185 .config
186 .window_duration
187 .saturating_sub(now.duration_since(*window_start));
188 return Err(format!(
189 "Rate limit exceeded. Try again in {} seconds.",
190 reset_in.as_secs()
191 ));
192 }
193
194 *count += 1;
195 Ok(())
196 }
197}
198
199#[derive(Clone)]
205pub struct GdprRateLimit {
206 state: GdprRateLimitState,
207}
208
209impl GdprRateLimit {
210 pub fn new(config: GdprRateLimitConfig) -> Self {
211 Self {
212 state: GdprRateLimitState::new(config),
213 }
214 }
215}
216
217impl<S, B> Transform<S, ServiceRequest> for GdprRateLimit
218where
219 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
220 S::Future: 'static,
221 B: MessageBody + 'static,
222{
223 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
224 type Error = Error;
225 type InitError = ();
226 type Transform = GdprRateLimitMiddleware<S>;
227 type Future = Ready<Result<Self::Transform, Self::InitError>>;
228
229 fn new_transform(&self, service: S) -> Self::Future {
230 ready(Ok(GdprRateLimitMiddleware {
231 service: Arc::new(service),
232 state: self.state.clone(),
233 }))
234 }
235}
236
237pub struct GdprRateLimitMiddleware<S> {
238 service: Arc<S>,
239 state: GdprRateLimitState,
240}
241
242impl<S, B> Service<ServiceRequest> for GdprRateLimitMiddleware<S>
243where
244 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
245 S::Future: 'static,
246 B: MessageBody + 'static,
247{
248 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
249 type Error = Error;
250 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
251
252 forward_ready!(service);
253
254 fn call(&self, req: ServiceRequest) -> Self::Future {
255 let path = req.path().to_string();
256
257 let is_gdpr_endpoint =
259 path.starts_with("/api/v1/gdpr") || path.starts_with("/api/v1/admin/gdpr");
260
261 if !is_gdpr_endpoint {
262 let fut = self.service.call(req);
263 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
264 }
265
266 let user_id = match req.app_data::<web::Data<AppState>>() {
268 Some(app_state) => {
269 let auth_header = match req.headers().get("Authorization") {
271 Some(header) => match header.to_str() {
272 Ok(s) => s.to_string(),
273 Err(_) => {
274 let fut = self.service.call(req);
276 return Box::pin(async move {
277 fut.await.map(|res| res.map_into_left_body())
278 });
279 }
280 },
281 None => {
282 let fut = self.service.call(req);
284 return Box::pin(
285 async move { fut.await.map(|res| res.map_into_left_body()) },
286 );
287 }
288 };
289
290 let token = auth_header.trim_start_matches("Bearer ").trim();
291
292 match app_state.auth_use_cases.verify_token(token) {
293 Ok(claims) => claims.sub,
294 Err(_) => {
295 let fut = self.service.call(req);
297 return Box::pin(
298 async move { fut.await.map(|res| res.map_into_left_body()) },
299 );
300 }
301 }
302 }
303 None => {
304 let fut = self.service.call(req);
305 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
306 }
307 };
308
309 let state = self.state.clone();
311 let service = self.service.clone();
312
313 Box::pin(async move {
314 match state.check_rate_limit(&user_id) {
315 Ok(_) => {
316 service.call(req).await.map(|res| res.map_into_left_body())
318 }
319 Err(msg) => {
320 let retry_after = state.config.window_duration.as_secs().to_string();
322 let response = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
323 .insert_header(("Retry-After", retry_after.clone()))
324 .json(serde_json::json!({
325 "error": msg,
326 "retry_after_seconds": state.config.window_duration.as_secs()
327 }));
328
329 Ok(req.into_response(response).map_into_right_body())
330 }
331 }
332 })
333 }
334}
335
336#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_authenticated_user_require_organization() {
352 let user_with_org = AuthenticatedUser {
353 user_id: Uuid::new_v4(),
354 email: "test@example.com".to_string(),
355 role: "admin".to_string(),
356 role_id: None,
357 organization_id: Some(Uuid::new_v4()),
358 };
359
360 assert!(user_with_org.require_organization().is_ok());
361
362 let user_without_org = AuthenticatedUser {
363 user_id: Uuid::new_v4(),
364 email: "test@example.com".to_string(),
365 role: "admin".to_string(),
366 role_id: None,
367 organization_id: None,
368 };
369
370 assert!(user_without_org.require_organization().is_err());
371 }
372
373 #[test]
374 fn test_gdpr_rate_limit_config_default() {
375 let config = GdprRateLimitConfig::default();
376 assert_eq!(config.max_requests, 10);
377 assert_eq!(config.window_duration, Duration::from_secs(3600));
378 }
379
380 #[test]
381 fn test_gdpr_rate_limit_state_allows_within_limit() {
382 let config = GdprRateLimitConfig {
383 max_requests: 3,
384 window_duration: Duration::from_secs(60),
385 };
386 let state = GdprRateLimitState::new(config);
387
388 assert!(state.check_rate_limit("user1").is_ok());
389 assert!(state.check_rate_limit("user1").is_ok());
390 assert!(state.check_rate_limit("user1").is_ok());
391 }
392
393 #[test]
394 fn test_gdpr_rate_limit_state_blocks_exceeding_limit() {
395 let config = GdprRateLimitConfig {
396 max_requests: 2,
397 window_duration: Duration::from_secs(60),
398 };
399 let state = GdprRateLimitState::new(config);
400
401 assert!(state.check_rate_limit("user1").is_ok());
402 assert!(state.check_rate_limit("user1").is_ok());
403 let result = state.check_rate_limit("user1");
404 assert!(result.is_err());
405 assert!(result
406 .unwrap_err()
407 .contains("Rate limit exceeded. Try again in"));
408 }
409}