koprogo_api/infrastructure/web/
login_rate_limiter.rs1use actix_web::{
2 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
3 Error, HttpResponse,
4};
5use futures_util::future::LocalBoxFuture;
6use std::collections::HashMap;
7use std::future::{ready, Ready};
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone)]
13struct LoginAttempts {
14 count: u32,
15 first_attempt: Instant,
16 last_attempt: Instant,
17}
18
19impl LoginAttempts {
20 fn new() -> Self {
21 let now = Instant::now();
22 Self {
23 count: 1,
24 first_attempt: now,
25 last_attempt: now,
26 }
27 }
28
29 fn increment(&mut self) {
30 self.count += 1;
31 self.last_attempt = Instant::now();
32 }
33
34 fn is_expired(&self, window: Duration) -> bool {
35 self.last_attempt.elapsed() > window
36 }
37
38 fn is_rate_limited(&self, max_attempts: u32, window: Duration) -> bool {
39 if self.first_attempt.elapsed() > window {
40 false
42 } else {
43 self.count >= max_attempts
45 }
46 }
47}
48
49#[derive(Clone)]
55pub struct LoginRateLimiter {
56 store: Arc<Mutex<HashMap<String, LoginAttempts>>>,
57 max_attempts: u32,
58 window_duration: Duration,
59}
60
61impl Default for LoginRateLimiter {
62 fn default() -> Self {
63 Self::new(5, Duration::from_secs(15 * 60)) }
65}
66
67impl LoginRateLimiter {
68 pub fn new(max_attempts: u32, window_duration: Duration) -> Self {
69 let limiter = Self {
70 store: Arc::new(Mutex::new(HashMap::new())),
71 max_attempts,
72 window_duration,
73 };
74
75 limiter.cleanup_expired_entries();
77
78 limiter
79 }
80
81 pub fn check_rate_limit(&self, ip: &str) -> bool {
83 let mut store = self.store.lock().unwrap();
84
85 match store.get_mut(ip) {
86 Some(attempts) => {
87 if attempts.is_expired(self.window_duration) {
88 *attempts = LoginAttempts::new();
90 false
91 } else if attempts.is_rate_limited(self.max_attempts, self.window_duration) {
92 true
94 } else {
95 attempts.increment();
97 false
98 }
99 }
100 None => {
101 store.insert(ip.to_string(), LoginAttempts::new());
103 false
104 }
105 }
106 }
107
108 fn cleanup_expired_entries(&self) {
110 let store = self.store.clone();
111 let window = self.window_duration;
112
113 std::thread::spawn(move || loop {
116 std::thread::sleep(Duration::from_secs(300)); let mut store_lock = store.lock().unwrap();
119 store_lock.retain(|_, attempts| !attempts.is_expired(window));
120
121 log::debug!(
122 "Login rate limiter cleanup: {} active IPs tracked",
123 store_lock.len()
124 );
125 });
126 }
127
128 #[allow(dead_code)]
130 pub fn get_attempt_count(&self, ip: &str) -> u32 {
131 self.store
132 .lock()
133 .unwrap()
134 .get(ip)
135 .map(|a| a.count)
136 .unwrap_or(0)
137 }
138}
139
140impl<S, B> Transform<S, ServiceRequest> for LoginRateLimiter
141where
142 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
143 S::Future: 'static,
144 B: actix_web::body::MessageBody + 'static,
145{
146 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
147 type Error = Error;
148 type InitError = ();
149 type Transform = LoginRateLimiterMiddleware<S>;
150 type Future = Ready<Result<Self::Transform, Self::InitError>>;
151
152 fn new_transform(&self, service: S) -> Self::Future {
153 ready(Ok(LoginRateLimiterMiddleware {
154 service,
155 limiter: self.clone(),
156 }))
157 }
158}
159
160pub struct LoginRateLimiterMiddleware<S> {
161 service: S,
162 limiter: LoginRateLimiter,
163}
164
165impl<S, B> Service<ServiceRequest> for LoginRateLimiterMiddleware<S>
166where
167 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
168 S::Future: 'static,
169 B: actix_web::body::MessageBody + 'static,
170{
171 type Response = ServiceResponse<actix_web::body::EitherBody<B>>;
172 type Error = Error;
173 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
174
175 forward_ready!(service);
176
177 fn call(&self, req: ServiceRequest) -> Self::Future {
178 let path = req.path();
180 let is_login_endpoint = path == "/api/v1/auth/login" || path.ends_with("/login");
181
182 if !is_login_endpoint {
183 let fut = self.service.call(req);
185 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
186 }
187
188 let ip = req
190 .connection_info()
191 .realip_remote_addr()
192 .unwrap_or("unknown")
193 .to_string();
194
195 let is_limited = self.limiter.check_rate_limit(&ip);
197
198 if is_limited {
199 log::warn!("Login rate limit exceeded for IP: {}", ip);
200
201 let response = HttpResponse::TooManyRequests()
203 .insert_header(("Retry-After", "900"))
204 .json(serde_json::json!({
205 "error": "Too many login attempts. Please try again in 15 minutes.",
206 "retry_after": 900 }));
208
209 return Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) });
210 }
211
212 let fut = self.service.call(req);
214 Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_login_rate_limiter() {
224 let limiter = LoginRateLimiter::new(5, Duration::from_secs(60));
225 let ip = "192.168.1.1";
226
227 for i in 1..=5 {
229 assert!(
230 !limiter.check_rate_limit(ip),
231 "Attempt {} should be allowed",
232 i
233 );
234 }
235
236 assert!(
238 limiter.check_rate_limit(ip),
239 "Attempt 6 should be rate limited"
240 );
241
242 assert_eq!(limiter.get_attempt_count(ip), 5);
244 }
245
246 #[test]
247 fn test_rate_limiter_expiration() {
248 let limiter = LoginRateLimiter::new(2, Duration::from_millis(100));
249 let ip = "192.168.1.2";
250
251 limiter.check_rate_limit(ip);
253 limiter.check_rate_limit(ip);
254 assert!(limiter.check_rate_limit(ip), "Should be rate limited");
255
256 std::thread::sleep(Duration::from_millis(150));
258
259 assert!(
261 !limiter.check_rate_limit(ip),
262 "Should be allowed after expiration"
263 );
264 }
265
266 #[test]
267 fn test_different_ips_independent() {
268 let limiter = LoginRateLimiter::new(2, Duration::from_secs(60));
269
270 let ip1 = "192.168.1.1";
271 let ip2 = "192.168.1.2";
272
273 limiter.check_rate_limit(ip1);
275 limiter.check_rate_limit(ip1);
276
277 assert!(!limiter.check_rate_limit(ip2), "IP2 should be independent");
279 }
280}