koprogo_api/infrastructure/web/
login_rate_limiter.rs

1use actix_web::{
2    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
3    Error, HttpResponse,
4};
5use futures_util::future::LocalBoxFuture;
6use std::collections::HashMap;
7use std::future::{ready, Ready};
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11/// Login attempt tracking per IP address
12#[derive(Debug, Clone)]
13struct LoginAttempts {
14    count: u32,
15    first_attempt: Instant,
16    last_attempt: Instant,
17}
18
19impl LoginAttempts {
20    fn new() -> Self {
21        let now = Instant::now();
22        Self {
23            count: 1,
24            first_attempt: now,
25            last_attempt: now,
26        }
27    }
28
29    fn increment(&mut self) {
30        self.count += 1;
31        self.last_attempt = Instant::now();
32    }
33
34    fn is_expired(&self, window: Duration) -> bool {
35        self.last_attempt.elapsed() > window
36    }
37
38    fn is_rate_limited(&self, max_attempts: u32, window: Duration) -> bool {
39        if self.first_attempt.elapsed() > window {
40            // Time window expired, reset allowed
41            false
42        } else {
43            // Within time window, check count
44            self.count >= max_attempts
45        }
46    }
47}
48
49/// Login rate limiter to prevent brute-force attacks
50///
51/// Default configuration:
52/// - 5 login attempts per 15 minutes per IP
53/// - Automatic cleanup of expired entries every 5 minutes
54#[derive(Clone)]
55pub struct LoginRateLimiter {
56    store: Arc<Mutex<HashMap<String, LoginAttempts>>>,
57    max_attempts: u32,
58    window_duration: Duration,
59}
60
61impl Default for LoginRateLimiter {
62    fn default() -> Self {
63        Self::new(5, Duration::from_secs(15 * 60)) // 5 attempts per 15 minutes
64    }
65}
66
67impl LoginRateLimiter {
68    pub fn new(max_attempts: u32, window_duration: Duration) -> Self {
69        let limiter = Self {
70            store: Arc::new(Mutex::new(HashMap::new())),
71            max_attempts,
72            window_duration,
73        };
74
75        // Spawn cleanup task (simplified - in production use tokio::spawn with proper cleanup)
76        limiter.cleanup_expired_entries();
77
78        limiter
79    }
80
81    /// Check if IP is rate limited
82    pub fn check_rate_limit(&self, ip: &str) -> bool {
83        let mut store = self.store.lock().unwrap();
84
85        match store.get_mut(ip) {
86            Some(attempts) => {
87                if attempts.is_expired(self.window_duration) {
88                    // Expired, reset
89                    *attempts = LoginAttempts::new();
90                    false
91                } else if attempts.is_rate_limited(self.max_attempts, self.window_duration) {
92                    // Rate limited
93                    true
94                } else {
95                    // Within limits, increment
96                    attempts.increment();
97                    false
98                }
99            }
100            None => {
101                // First attempt
102                store.insert(ip.to_string(), LoginAttempts::new());
103                false
104            }
105        }
106    }
107
108    /// Cleanup expired entries (call periodically)
109    fn cleanup_expired_entries(&self) {
110        let store = self.store.clone();
111        let window = self.window_duration;
112
113        // In a real implementation, use tokio::spawn for async cleanup
114        // For MVP, rely on per-request cleanup in check_rate_limit
115        std::thread::spawn(move || loop {
116            std::thread::sleep(Duration::from_secs(300)); // Clean every 5 minutes
117
118            let mut store_lock = store.lock().unwrap();
119            store_lock.retain(|_, attempts| !attempts.is_expired(window));
120
121            log::debug!(
122                "Login rate limiter cleanup: {} active IPs tracked",
123                store_lock.len()
124            );
125        });
126    }
127
128    /// Get current attempt count for IP (for testing/monitoring)
129    #[allow(dead_code)]
130    pub fn get_attempt_count(&self, ip: &str) -> u32 {
131        self.store
132            .lock()
133            .unwrap()
134            .get(ip)
135            .map(|a| a.count)
136            .unwrap_or(0)
137    }
138}
139
140impl<S, B> Transform<S, ServiceRequest> for LoginRateLimiter
141where
142    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
143    S::Future: 'static,
144    B: actix_web::body::MessageBody + 'static,
145{
146    type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
147    type Error = Error;
148    type InitError = ();
149    type Transform = LoginRateLimiterMiddleware<S>;
150    type Future = Ready<Result<Self::Transform, Self::InitError>>;
151
152    fn new_transform(&self, service: S) -> Self::Future {
153        ready(Ok(LoginRateLimiterMiddleware {
154            service,
155            limiter: self.clone(),
156        }))
157    }
158}
159
160pub struct LoginRateLimiterMiddleware<S> {
161    service: S,
162    limiter: LoginRateLimiter,
163}
164
165impl<S, B> Service<ServiceRequest> for LoginRateLimiterMiddleware<S>
166where
167    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
168    S::Future: 'static,
169    B: actix_web::body::MessageBody + 'static,
170{
171    type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
172    type Error = Error;
173    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
174
175    forward_ready!(service);
176
177    fn call(&self, req: ServiceRequest) -> Self::Future {
178        // Only apply rate limiting to login endpoint
179        let path = req.path();
180        let is_login_endpoint = path == "/api/v1/auth/login" || path.ends_with("/login");
181
182        if !is_login_endpoint {
183            // Not a login endpoint, skip rate limiting
184            let fut = self.service.call(req);
185            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
186        }
187
188        // Extract IP address
189        let ip = req
190            .connection_info()
191            .realip_remote_addr()
192            .unwrap_or("unknown")
193            .to_string();
194
195        // Check rate limit
196        let is_limited = self.limiter.check_rate_limit(&ip);
197
198        if is_limited {
199            log::warn!("Login rate limit exceeded for IP: {}", ip);
200
201            // Return 429 Too Many Requests
202            let response = HttpResponse::TooManyRequests()
203                .insert_header(("Retry-After", "900"))
204                .json(serde_json::json!({
205                    "error": "Too many login attempts. Please try again in 15 minutes.",
206                    "retry_after": 900 // 15 minutes in seconds
207                }));
208
209            return Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) });
210        }
211
212        // Not rate limited, proceed
213        let fut = self.service.call(req);
214        Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_login_rate_limiter() {
224        let limiter = LoginRateLimiter::new(5, Duration::from_secs(60));
225        let ip = "192.168.1.1";
226
227        // First 5 attempts should be allowed
228        for i in 1..=5 {
229            assert!(
230                !limiter.check_rate_limit(ip),
231                "Attempt {} should be allowed",
232                i
233            );
234        }
235
236        // 6th attempt should be rate limited
237        assert!(
238            limiter.check_rate_limit(ip),
239            "Attempt 6 should be rate limited"
240        );
241
242        // Verify count
243        assert_eq!(limiter.get_attempt_count(ip), 5);
244    }
245
246    #[test]
247    fn test_rate_limiter_expiration() {
248        let limiter = LoginRateLimiter::new(2, Duration::from_millis(100));
249        let ip = "192.168.1.2";
250
251        // Use up the limit
252        limiter.check_rate_limit(ip);
253        limiter.check_rate_limit(ip);
254        assert!(limiter.check_rate_limit(ip), "Should be rate limited");
255
256        // Wait for expiration
257        std::thread::sleep(Duration::from_millis(150));
258
259        // Should be allowed again
260        assert!(
261            !limiter.check_rate_limit(ip),
262            "Should be allowed after expiration"
263        );
264    }
265
266    #[test]
267    fn test_different_ips_independent() {
268        let limiter = LoginRateLimiter::new(2, Duration::from_secs(60));
269
270        let ip1 = "192.168.1.1";
271        let ip2 = "192.168.1.2";
272
273        // IP1 uses limit
274        limiter.check_rate_limit(ip1);
275        limiter.check_rate_limit(ip1);
276
277        // IP2 should still have its own limit
278        assert!(!limiter.check_rate_limit(ip2), "IP2 should be independent");
279    }
280}