koprogo_api/infrastructure/database/repositories/
refresh_token_repository_impl.rs

1use crate::application::ports::RefreshTokenRepository;
2use crate::domain::entities::RefreshToken;
3use crate::infrastructure::database::pool::DbPool;
4use async_trait::async_trait;
5use sqlx::Row;
6use uuid::Uuid;
7
8pub struct PostgresRefreshTokenRepository {
9    pool: DbPool,
10}
11
12impl PostgresRefreshTokenRepository {
13    pub fn new(pool: DbPool) -> Self {
14        Self { pool }
15    }
16}
17
18#[async_trait]
19impl RefreshTokenRepository for PostgresRefreshTokenRepository {
20    async fn create(&self, refresh_token: &RefreshToken) -> Result<RefreshToken, String> {
21        sqlx::query(
22            r#"
23            INSERT INTO refresh_tokens (id, user_id, token, expires_at, revoked, created_at, updated_at)
24            VALUES ($1, $2, $3, $4, $5, $6, $7)
25            "#,
26        )
27        .bind(refresh_token.id)
28        .bind(refresh_token.user_id)
29        .bind(&refresh_token.token)
30        .bind(refresh_token.expires_at)
31        .bind(refresh_token.revoked)
32        .bind(refresh_token.created_at)
33        .bind(refresh_token.updated_at)
34        .execute(&self.pool)
35        .await
36        .map_err(|e| format!("Database error: {}", e))?;
37
38        Ok(refresh_token.clone())
39    }
40
41    async fn find_by_token(&self, token: &str) -> Result<Option<RefreshToken>, String> {
42        let row = sqlx::query(
43            r#"
44            SELECT id, user_id, token, expires_at, revoked, created_at, updated_at
45            FROM refresh_tokens
46            WHERE token = $1
47            "#,
48        )
49        .bind(token)
50        .fetch_optional(&self.pool)
51        .await
52        .map_err(|e| format!("Database error: {}", e))?;
53
54        Ok(row.map(|row| RefreshToken {
55            id: row.get("id"),
56            user_id: row.get("user_id"),
57            token: row.get("token"),
58            expires_at: row.get("expires_at"),
59            revoked: row.get("revoked"),
60            created_at: row.get("created_at"),
61            updated_at: row.get("updated_at"),
62        }))
63    }
64
65    async fn find_by_user_id(&self, user_id: Uuid) -> Result<Vec<RefreshToken>, String> {
66        let rows = sqlx::query(
67            r#"
68            SELECT id, user_id, token, expires_at, revoked, created_at, updated_at
69            FROM refresh_tokens
70            WHERE user_id = $1
71            ORDER BY created_at DESC
72            "#,
73        )
74        .bind(user_id)
75        .fetch_all(&self.pool)
76        .await
77        .map_err(|e| format!("Database error: {}", e))?;
78
79        Ok(rows
80            .iter()
81            .map(|row| RefreshToken {
82                id: row.get("id"),
83                user_id: row.get("user_id"),
84                token: row.get("token"),
85                expires_at: row.get("expires_at"),
86                revoked: row.get("revoked"),
87                created_at: row.get("created_at"),
88                updated_at: row.get("updated_at"),
89            })
90            .collect())
91    }
92
93    async fn revoke(&self, token: &str) -> Result<bool, String> {
94        let result = sqlx::query(
95            r#"
96            UPDATE refresh_tokens
97            SET revoked = true, updated_at = NOW()
98            WHERE token = $1
99            "#,
100        )
101        .bind(token)
102        .execute(&self.pool)
103        .await
104        .map_err(|e| format!("Database error: {}", e))?;
105
106        Ok(result.rows_affected() > 0)
107    }
108
109    async fn revoke_all_for_user(&self, user_id: Uuid) -> Result<u64, String> {
110        let result = sqlx::query(
111            r#"
112            UPDATE refresh_tokens
113            SET revoked = true, updated_at = NOW()
114            WHERE user_id = $1
115            "#,
116        )
117        .bind(user_id)
118        .execute(&self.pool)
119        .await
120        .map_err(|e| format!("Database error: {}", e))?;
121
122        Ok(result.rows_affected())
123    }
124
125    async fn delete_expired(&self) -> Result<u64, String> {
126        let result = sqlx::query(
127            r#"
128            DELETE FROM refresh_tokens
129            WHERE expires_at < NOW() OR revoked = true
130            "#,
131        )
132        .execute(&self.pool)
133        .await
134        .map_err(|e| format!("Database error: {}", e))?;
135
136        Ok(result.rows_affected())
137    }
138}