koprogo_api/infrastructure/database/repositories/
security_incident_repository_impl.rs

1use crate::application::ports::security_incident_repository::{
2    SecurityIncidentFilters, SecurityIncidentRepository,
3};
4use crate::domain::entities::SecurityIncident;
5use crate::infrastructure::database::pool::DbPool;
6use async_trait::async_trait;
7use sqlx::Row;
8use uuid::Uuid;
9
10pub struct PostgresSecurityIncidentRepository {
11    pool: DbPool,
12}
13
14impl PostgresSecurityIncidentRepository {
15    pub fn new(pool: DbPool) -> Self {
16        Self { pool }
17    }
18
19    fn row_to_incident(row: &sqlx::postgres::PgRow) -> SecurityIncident {
20        let data_categories: Option<Vec<String>> =
21            row.try_get("data_categories_affected").ok().unwrap_or(None);
22
23        SecurityIncident {
24            id: row.get("id"),
25            organization_id: row.try_get("organization_id").ok().flatten(),
26            severity: row.get("severity"),
27            incident_type: row.get("incident_type"),
28            title: row.get("title"),
29            description: row.get("description"),
30            data_categories_affected: data_categories.unwrap_or_default(),
31            affected_subjects_count: row.try_get("affected_subjects_count").ok().flatten(),
32            discovery_at: row.get("discovery_at"),
33            notification_at: row.try_get("notification_at").ok().flatten(),
34            apd_reference_number: row.try_get("apd_reference_number").ok().flatten(),
35            status: row.get("status"),
36            reported_by: row.get("reported_by"),
37            investigation_notes: row.try_get("investigation_notes").ok().flatten(),
38            root_cause: row.try_get("root_cause").ok().flatten(),
39            remediation_steps: row.try_get("remediation_steps").ok().flatten(),
40            created_at: row.get("created_at"),
41            updated_at: row.get("updated_at"),
42        }
43    }
44}
45
46#[async_trait]
47impl SecurityIncidentRepository for PostgresSecurityIncidentRepository {
48    async fn create(&self, incident: &SecurityIncident) -> Result<SecurityIncident, String> {
49        let row = sqlx::query(
50            r#"
51            INSERT INTO security_incidents (
52                id, organization_id, severity, incident_type, title, description,
53                data_categories_affected, affected_subjects_count,
54                discovery_at, status, reported_by, created_at, updated_at
55            )
56            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
57            RETURNING
58                id, organization_id, severity, incident_type, title, description,
59                data_categories_affected, affected_subjects_count, discovery_at,
60                notification_at, apd_reference_number, status, reported_by,
61                investigation_notes, root_cause, remediation_steps, created_at, updated_at
62            "#,
63        )
64        .bind(incident.id)
65        .bind(incident.organization_id)
66        .bind(&incident.severity)
67        .bind(&incident.incident_type)
68        .bind(&incident.title)
69        .bind(&incident.description)
70        .bind(&incident.data_categories_affected)
71        .bind(incident.affected_subjects_count)
72        .bind(incident.discovery_at)
73        .bind(&incident.status)
74        .bind(incident.reported_by)
75        .bind(incident.created_at)
76        .bind(incident.updated_at)
77        .fetch_one(&self.pool)
78        .await
79        .map_err(|e| format!("Failed to create security incident: {}", e))?;
80
81        Ok(Self::row_to_incident(&row))
82    }
83
84    async fn find_by_id(
85        &self,
86        id: Uuid,
87        organization_id: Option<Uuid>,
88    ) -> Result<Option<SecurityIncident>, String> {
89        let row = if let Some(org_id) = organization_id {
90            sqlx::query(
91                r#"
92                SELECT id, organization_id, severity, incident_type, title, description,
93                    data_categories_affected, affected_subjects_count, discovery_at,
94                    notification_at, apd_reference_number, status, reported_by,
95                    investigation_notes, root_cause, remediation_steps, created_at, updated_at
96                FROM security_incidents
97                WHERE id = $1 AND organization_id = $2
98                "#,
99            )
100            .bind(id)
101            .bind(org_id)
102            .fetch_optional(&self.pool)
103            .await
104        } else {
105            sqlx::query(
106                r#"
107                SELECT id, organization_id, severity, incident_type, title, description,
108                    data_categories_affected, affected_subjects_count, discovery_at,
109                    notification_at, apd_reference_number, status, reported_by,
110                    investigation_notes, root_cause, remediation_steps, created_at, updated_at
111                FROM security_incidents
112                WHERE id = $1
113                "#,
114            )
115            .bind(id)
116            .fetch_optional(&self.pool)
117            .await
118        }
119        .map_err(|e| format!("Failed to fetch security incident: {}", e))?;
120
121        Ok(row.map(|r| Self::row_to_incident(&r)))
122    }
123
124    async fn find_all(
125        &self,
126        organization_id: Option<Uuid>,
127        filters: SecurityIncidentFilters,
128    ) -> Result<(Vec<SecurityIncident>, i64), String> {
129        let offset = (filters.page - 1) * filters.per_page;
130
131        let rows = match (organization_id, &filters.severity, &filters.status) {
132            (Some(org_id), Some(sev), Some(st)) => {
133                sqlx::query(
134                    r#"
135                    SELECT id, organization_id, severity, incident_type, title, description,
136                        data_categories_affected, affected_subjects_count, discovery_at,
137                        notification_at, apd_reference_number, status, reported_by,
138                        investigation_notes, root_cause, remediation_steps, created_at, updated_at
139                    FROM security_incidents
140                    WHERE organization_id = $1 AND severity = $2 AND status = $3
141                    ORDER BY discovery_at DESC
142                    LIMIT $4 OFFSET $5
143                    "#,
144                )
145                .bind(org_id)
146                .bind(sev)
147                .bind(st)
148                .bind(filters.per_page)
149                .bind(offset)
150                .fetch_all(&self.pool)
151                .await
152            }
153            (Some(org_id), Some(sev), None) => {
154                sqlx::query(
155                    r#"
156                    SELECT id, organization_id, severity, incident_type, title, description,
157                        data_categories_affected, affected_subjects_count, discovery_at,
158                        notification_at, apd_reference_number, status, reported_by,
159                        investigation_notes, root_cause, remediation_steps, created_at, updated_at
160                    FROM security_incidents
161                    WHERE organization_id = $1 AND severity = $2
162                    ORDER BY discovery_at DESC
163                    LIMIT $3 OFFSET $4
164                    "#,
165                )
166                .bind(org_id)
167                .bind(sev)
168                .bind(filters.per_page)
169                .bind(offset)
170                .fetch_all(&self.pool)
171                .await
172            }
173            (Some(org_id), None, Some(st)) => {
174                sqlx::query(
175                    r#"
176                    SELECT id, organization_id, severity, incident_type, title, description,
177                        data_categories_affected, affected_subjects_count, discovery_at,
178                        notification_at, apd_reference_number, status, reported_by,
179                        investigation_notes, root_cause, remediation_steps, created_at, updated_at
180                    FROM security_incidents
181                    WHERE organization_id = $1 AND status = $2
182                    ORDER BY discovery_at DESC
183                    LIMIT $3 OFFSET $4
184                    "#,
185                )
186                .bind(org_id)
187                .bind(st)
188                .bind(filters.per_page)
189                .bind(offset)
190                .fetch_all(&self.pool)
191                .await
192            }
193            (Some(org_id), None, None) => {
194                sqlx::query(
195                    r#"
196                    SELECT id, organization_id, severity, incident_type, title, description,
197                        data_categories_affected, affected_subjects_count, discovery_at,
198                        notification_at, apd_reference_number, status, reported_by,
199                        investigation_notes, root_cause, remediation_steps, created_at, updated_at
200                    FROM security_incidents
201                    WHERE organization_id = $1
202                    ORDER BY discovery_at DESC
203                    LIMIT $2 OFFSET $3
204                    "#,
205                )
206                .bind(org_id)
207                .bind(filters.per_page)
208                .bind(offset)
209                .fetch_all(&self.pool)
210                .await
211            }
212            _ => {
213                sqlx::query(
214                    r#"
215                    SELECT id, organization_id, severity, incident_type, title, description,
216                        data_categories_affected, affected_subjects_count, discovery_at,
217                        notification_at, apd_reference_number, status, reported_by,
218                        investigation_notes, root_cause, remediation_steps, created_at, updated_at
219                    FROM security_incidents
220                    ORDER BY discovery_at DESC
221                    LIMIT $1 OFFSET $2
222                    "#,
223                )
224                .bind(filters.per_page)
225                .bind(offset)
226                .fetch_all(&self.pool)
227                .await
228            }
229        }
230        .map_err(|e| format!("Failed to list security incidents: {}", e))?;
231
232        let total: i64 = if let Some(org_id) = organization_id {
233            sqlx::query_scalar("SELECT COUNT(*) FROM security_incidents WHERE organization_id = $1")
234                .bind(org_id)
235                .fetch_one(&self.pool)
236                .await
237                .map_err(|e| format!("Failed to count incidents: {}", e))?
238        } else {
239            sqlx::query_scalar("SELECT COUNT(*) FROM security_incidents")
240                .fetch_one(&self.pool)
241                .await
242                .map_err(|e| format!("Failed to count incidents: {}", e))?
243        };
244
245        let incidents = rows.iter().map(Self::row_to_incident).collect();
246        Ok((incidents, total))
247    }
248
249    async fn report_to_apd(
250        &self,
251        id: Uuid,
252        organization_id: Option<Uuid>,
253        apd_reference_number: String,
254        investigation_notes: Option<String>,
255    ) -> Result<Option<SecurityIncident>, String> {
256        let row = if let Some(org_id) = organization_id {
257            sqlx::query(
258                r#"
259                UPDATE security_incidents
260                SET notification_at = now(),
261                    apd_reference_number = $1,
262                    status = 'reported',
263                    investigation_notes = $2,
264                    updated_at = now()
265                WHERE id = $3 AND organization_id = $4
266                RETURNING
267                    id, organization_id, severity, incident_type, title, description,
268                    data_categories_affected, affected_subjects_count, discovery_at,
269                    notification_at, apd_reference_number, status, reported_by,
270                    investigation_notes, root_cause, remediation_steps, created_at, updated_at
271                "#,
272            )
273            .bind(&apd_reference_number)
274            .bind(&investigation_notes)
275            .bind(id)
276            .bind(org_id)
277            .fetch_optional(&self.pool)
278            .await
279        } else {
280            sqlx::query(
281                r#"
282                UPDATE security_incidents
283                SET notification_at = now(),
284                    apd_reference_number = $1,
285                    status = 'reported',
286                    investigation_notes = $2,
287                    updated_at = now()
288                WHERE id = $3
289                RETURNING
290                    id, organization_id, severity, incident_type, title, description,
291                    data_categories_affected, affected_subjects_count, discovery_at,
292                    notification_at, apd_reference_number, status, reported_by,
293                    investigation_notes, root_cause, remediation_steps, created_at, updated_at
294                "#,
295            )
296            .bind(&apd_reference_number)
297            .bind(&investigation_notes)
298            .bind(id)
299            .fetch_optional(&self.pool)
300            .await
301        }
302        .map_err(|e| format!("Failed to report incident to APD: {}", e))?;
303
304        Ok(row.map(|r| Self::row_to_incident(&r)))
305    }
306
307    async fn find_overdue(
308        &self,
309        organization_id: Option<Uuid>,
310    ) -> Result<Vec<SecurityIncident>, String> {
311        let rows = if let Some(org_id) = organization_id {
312            sqlx::query(
313                r#"
314                SELECT id, organization_id, severity, incident_type, title, description,
315                    data_categories_affected, affected_subjects_count, discovery_at,
316                    notification_at, apd_reference_number, status, reported_by,
317                    investigation_notes, root_cause, remediation_steps, created_at, updated_at
318                FROM security_incidents
319                WHERE organization_id = $1
320                  AND notification_at IS NULL
321                  AND status IN ('detected', 'investigating', 'contained')
322                  AND discovery_at < (NOW() - INTERVAL '72 hours')
323                ORDER BY discovery_at ASC
324                "#,
325            )
326            .bind(org_id)
327            .fetch_all(&self.pool)
328            .await
329        } else {
330            sqlx::query(
331                r#"
332                SELECT id, organization_id, severity, incident_type, title, description,
333                    data_categories_affected, affected_subjects_count, discovery_at,
334                    notification_at, apd_reference_number, status, reported_by,
335                    investigation_notes, root_cause, remediation_steps, created_at, updated_at
336                FROM security_incidents
337                WHERE notification_at IS NULL
338                  AND status IN ('detected', 'investigating', 'contained')
339                  AND discovery_at < (NOW() - INTERVAL '72 hours')
340                ORDER BY discovery_at ASC
341                "#,
342            )
343            .fetch_all(&self.pool)
344            .await
345        }
346        .map_err(|e| format!("Failed to fetch overdue incidents: {}", e))?;
347
348        Ok(rows.iter().map(Self::row_to_incident).collect())
349    }
350}