koprogo_api/infrastructure/database/repositories/
poll_repository_impl.rs

1use crate::application::dto::{PageRequest, PollFilters};
2use crate::application::ports::{PollRepository, PollStatistics};
3use crate::domain::entities::{Poll, PollOption, PollStatus, PollType};
4use crate::infrastructure::database::pool::DbPool;
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use serde_json;
8use sqlx::Row;
9use uuid::Uuid;
10
11pub struct PostgresPollRepository {
12    pool: DbPool,
13}
14
15impl PostgresPollRepository {
16    pub fn new(pool: DbPool) -> Self {
17        Self { pool }
18    }
19
20    /// Convert PollType to database string format
21    fn poll_type_to_string(poll_type: &PollType) -> &'static str {
22        match poll_type {
23            PollType::YesNo => "yes_no",
24            PollType::MultipleChoice => "multiple_choice",
25            PollType::Rating => "rating",
26            PollType::OpenEnded => "open_ended",
27        }
28    }
29
30    /// Parse PollType from database string
31    fn parse_poll_type(s: &str) -> PollType {
32        match s {
33            "yes_no" => PollType::YesNo,
34            "multiple_choice" => PollType::MultipleChoice,
35            "rating" => PollType::Rating,
36            "open_ended" => PollType::OpenEnded,
37            _ => PollType::YesNo, // Default fallback
38        }
39    }
40
41    /// Convert PollStatus to database string format
42    fn poll_status_to_string(status: &PollStatus) -> &'static str {
43        match status {
44            PollStatus::Draft => "draft",
45            PollStatus::Active => "active",
46            PollStatus::Closed => "closed",
47            PollStatus::Cancelled => "cancelled",
48        }
49    }
50
51    /// Parse PollStatus from database string
52    fn parse_poll_status(s: &str) -> PollStatus {
53        match s {
54            "draft" => PollStatus::Draft,
55            "active" => PollStatus::Active,
56            "closed" => PollStatus::Closed,
57            "cancelled" => PollStatus::Cancelled,
58            _ => PollStatus::Draft, // Default fallback
59        }
60    }
61
62    /// Map database row to Poll entity
63    fn row_to_poll(&self, row: &sqlx::postgres::PgRow) -> Result<Poll, String> {
64        // Parse options from JSONB
65        let options_json: serde_json::Value = row
66            .try_get("options")
67            .map_err(|e| format!("Failed to get options: {}", e))?;
68
69        let options: Vec<PollOption> = serde_json::from_value(options_json)
70            .map_err(|e| format!("Failed to deserialize options: {}", e))?;
71
72        Ok(Poll {
73            id: row
74                .try_get("id")
75                .map_err(|e| format!("Failed to get id: {}", e))?,
76            building_id: row
77                .try_get("building_id")
78                .map_err(|e| format!("Failed to get building_id: {}", e))?,
79            created_by: row
80                .try_get("created_by")
81                .map_err(|e| format!("Failed to get created_by: {}", e))?,
82            title: row
83                .try_get("title")
84                .map_err(|e| format!("Failed to get title: {}", e))?,
85            description: row
86                .try_get("description")
87                .map_err(|e| format!("Failed to get description: {}", e))?,
88            poll_type: Self::parse_poll_type(
89                row.try_get("poll_type")
90                    .map_err(|e| format!("Failed to get poll_type: {}", e))?,
91            ),
92            options,
93            is_anonymous: row
94                .try_get("is_anonymous")
95                .map_err(|e| format!("Failed to get is_anonymous: {}", e))?,
96            allow_multiple_votes: row
97                .try_get("allow_multiple_votes")
98                .map_err(|e| format!("Failed to get allow_multiple_votes: {}", e))?,
99            require_all_owners: row
100                .try_get("require_all_owners")
101                .map_err(|e| format!("Failed to get require_all_owners: {}", e))?,
102            starts_at: row
103                .try_get("starts_at")
104                .map_err(|e| format!("Failed to get starts_at: {}", e))?,
105            ends_at: row
106                .try_get("ends_at")
107                .map_err(|e| format!("Failed to get ends_at: {}", e))?,
108            status: Self::parse_poll_status(
109                row.try_get("status")
110                    .map_err(|e| format!("Failed to get status: {}", e))?,
111            ),
112            total_eligible_voters: row
113                .try_get("total_eligible_voters")
114                .map_err(|e| format!("Failed to get total_eligible_voters: {}", e))?,
115            total_votes_cast: row
116                .try_get("total_votes_cast")
117                .map_err(|e| format!("Failed to get total_votes_cast: {}", e))?,
118            created_at: row
119                .try_get("created_at")
120                .map_err(|e| format!("Failed to get created_at: {}", e))?,
121            updated_at: row
122                .try_get("updated_at")
123                .map_err(|e| format!("Failed to get updated_at: {}", e))?,
124        })
125    }
126}
127
128#[async_trait]
129impl PollRepository for PostgresPollRepository {
130    async fn create(&self, poll: &Poll) -> Result<Poll, String> {
131        let poll_type_str = Self::poll_type_to_string(&poll.poll_type);
132        let status_str = Self::poll_status_to_string(&poll.status);
133
134        // Serialize options to JSON
135        let options_json = serde_json::to_value(&poll.options)
136            .map_err(|e| format!("Failed to serialize options: {}", e))?;
137
138        sqlx::query(
139            r#"
140            INSERT INTO polls (
141                id, building_id, created_by, title, description, poll_type, options,
142                is_anonymous, allow_multiple_votes, require_all_owners,
143                starts_at, ends_at, status, total_eligible_voters, total_votes_cast,
144                created_at, updated_at
145            )
146            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
147            "#,
148        )
149        .bind(poll.id)
150        .bind(poll.building_id)
151        .bind(poll.created_by)
152        .bind(&poll.title)
153        .bind(&poll.description)
154        .bind(poll_type_str)
155        .bind(options_json)
156        .bind(poll.is_anonymous)
157        .bind(poll.allow_multiple_votes)
158        .bind(poll.require_all_owners)
159        .bind(poll.starts_at)
160        .bind(poll.ends_at)
161        .bind(status_str)
162        .bind(poll.total_eligible_voters)
163        .bind(poll.total_votes_cast)
164        .bind(poll.created_at)
165        .bind(poll.updated_at)
166        .execute(&self.pool)
167        .await
168        .map_err(|e| format!("Failed to insert poll: {}", e))?;
169
170        Ok(poll.clone())
171    }
172
173    async fn find_by_id(&self, id: Uuid) -> Result<Option<Poll>, String> {
174        let row = sqlx::query(
175            r#"
176            SELECT * FROM polls WHERE id = $1
177            "#,
178        )
179        .bind(id)
180        .fetch_optional(&self.pool)
181        .await
182        .map_err(|e| format!("Failed to fetch poll: {}", e))?;
183
184        match row {
185            Some(r) => Ok(Some(self.row_to_poll(&r)?)),
186            None => Ok(None),
187        }
188    }
189
190    async fn find_by_building(&self, building_id: Uuid) -> Result<Vec<Poll>, String> {
191        let rows = sqlx::query(
192            r#"
193            SELECT * FROM polls 
194            WHERE building_id = $1
195            ORDER BY created_at DESC
196            "#,
197        )
198        .bind(building_id)
199        .fetch_all(&self.pool)
200        .await
201        .map_err(|e| format!("Failed to fetch polls: {}", e))?;
202
203        rows.iter()
204            .map(|row| self.row_to_poll(row))
205            .collect::<Result<Vec<Poll>, String>>()
206    }
207
208    async fn find_by_created_by(&self, created_by: Uuid) -> Result<Vec<Poll>, String> {
209        let rows = sqlx::query(
210            r#"
211            SELECT * FROM polls 
212            WHERE created_by = $1
213            ORDER BY created_at DESC
214            "#,
215        )
216        .bind(created_by)
217        .fetch_all(&self.pool)
218        .await
219        .map_err(|e| format!("Failed to fetch polls: {}", e))?;
220
221        rows.iter()
222            .map(|row| self.row_to_poll(row))
223            .collect::<Result<Vec<Poll>, String>>()
224    }
225
226    async fn find_all_paginated(
227        &self,
228        page_request: &PageRequest,
229        filters: &PollFilters,
230    ) -> Result<(Vec<Poll>, i64), String> {
231        let offset = (page_request.page - 1) * page_request.per_page;
232
233        // Build WHERE clause dynamically
234        let mut where_clauses = Vec::new();
235        let mut bind_index = 1;
236
237        if filters.building_id.is_some() {
238            where_clauses.push(format!("building_id = ${}", bind_index));
239            bind_index += 1;
240        }
241
242        if filters.created_by.is_some() {
243            where_clauses.push(format!("created_by = ${}", bind_index));
244            bind_index += 1;
245        }
246
247        if filters.status.is_some() {
248            where_clauses.push(format!("status = ${}", bind_index));
249            bind_index += 1;
250        }
251
252        if filters.poll_type.is_some() {
253            where_clauses.push(format!("poll_type = ${}", bind_index));
254            bind_index += 1;
255        }
256
257        if filters.ends_before.is_some() {
258            where_clauses.push(format!("ends_at < ${}", bind_index));
259            bind_index += 1;
260        }
261
262        if filters.ends_after.is_some() {
263            where_clauses.push(format!("ends_at > ${}", bind_index));
264            bind_index += 1;
265        }
266
267        let where_sql = if where_clauses.is_empty() {
268            String::new()
269        } else {
270            format!("WHERE {}", where_clauses.join(" AND "))
271        };
272
273        // Count total
274        let count_query = format!("SELECT COUNT(*) as count FROM polls {}", where_sql);
275        let mut count_query_builder = sqlx::query(&count_query);
276
277        if let Some(ref building_id) = filters.building_id {
278            let id = Uuid::parse_str(building_id)
279                .map_err(|_| "Invalid building_id format".to_string())?;
280            count_query_builder = count_query_builder.bind(id);
281        }
282        if let Some(ref created_by) = filters.created_by {
283            let id =
284                Uuid::parse_str(created_by).map_err(|_| "Invalid created_by format".to_string())?;
285            count_query_builder = count_query_builder.bind(id);
286        }
287        if let Some(ref status) = filters.status {
288            count_query_builder = count_query_builder.bind(Self::poll_status_to_string(status));
289        }
290        if let Some(ref poll_type) = filters.poll_type {
291            count_query_builder = count_query_builder.bind(Self::poll_type_to_string(poll_type));
292        }
293        if let Some(ref ends_before) = filters.ends_before {
294            let date = DateTime::parse_from_rfc3339(ends_before)
295                .map_err(|_| "Invalid ends_before format".to_string())?
296                .with_timezone(&Utc);
297            count_query_builder = count_query_builder.bind(date);
298        }
299        if let Some(ref ends_after) = filters.ends_after {
300            let date = DateTime::parse_from_rfc3339(ends_after)
301                .map_err(|_| "Invalid ends_after format".to_string())?
302                .with_timezone(&Utc);
303            count_query_builder = count_query_builder.bind(date);
304        }
305
306        let count_row = count_query_builder
307            .fetch_one(&self.pool)
308            .await
309            .map_err(|e| format!("Failed to count polls: {}", e))?;
310        let total: i64 = count_row
311            .try_get("count")
312            .map_err(|e| format!("Failed to get count: {}", e))?;
313
314        // Fetch paginated results
315        let data_query = format!(
316            "SELECT * FROM polls {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
317            where_sql,
318            bind_index,
319            bind_index + 1
320        );
321        let mut data_query_builder = sqlx::query(&data_query);
322
323        if let Some(ref building_id) = filters.building_id {
324            let id = Uuid::parse_str(building_id)
325                .map_err(|_| "Invalid building_id format".to_string())?;
326            data_query_builder = data_query_builder.bind(id);
327        }
328        if let Some(ref created_by) = filters.created_by {
329            let id =
330                Uuid::parse_str(created_by).map_err(|_| "Invalid created_by format".to_string())?;
331            data_query_builder = data_query_builder.bind(id);
332        }
333        if let Some(ref status) = filters.status {
334            data_query_builder = data_query_builder.bind(Self::poll_status_to_string(status));
335        }
336        if let Some(ref poll_type) = filters.poll_type {
337            data_query_builder = data_query_builder.bind(Self::poll_type_to_string(poll_type));
338        }
339        if let Some(ref ends_before) = filters.ends_before {
340            let date = DateTime::parse_from_rfc3339(ends_before)
341                .map_err(|_| "Invalid ends_before format".to_string())?
342                .with_timezone(&Utc);
343            data_query_builder = data_query_builder.bind(date);
344        }
345        if let Some(ref ends_after) = filters.ends_after {
346            let date = DateTime::parse_from_rfc3339(ends_after)
347                .map_err(|_| "Invalid ends_after format".to_string())?
348                .with_timezone(&Utc);
349            data_query_builder = data_query_builder.bind(date);
350        }
351
352        data_query_builder = data_query_builder.bind(page_request.per_page).bind(offset);
353
354        let rows = data_query_builder
355            .fetch_all(&self.pool)
356            .await
357            .map_err(|e| format!("Failed to fetch polls: {}", e))?;
358
359        let polls = rows
360            .iter()
361            .map(|row| self.row_to_poll(row))
362            .collect::<Result<Vec<Poll>, String>>()?;
363
364        Ok((polls, total))
365    }
366
367    async fn find_active(&self, building_id: Uuid) -> Result<Vec<Poll>, String> {
368        let now = Utc::now();
369
370        let rows = sqlx::query(
371            r#"
372            SELECT * FROM polls 
373            WHERE building_id = $1 
374              AND status = 'active'
375              AND starts_at <= $2
376              AND ends_at > $2
377            ORDER BY created_at DESC
378            "#,
379        )
380        .bind(building_id)
381        .bind(now)
382        .fetch_all(&self.pool)
383        .await
384        .map_err(|e| format!("Failed to fetch active polls: {}", e))?;
385
386        rows.iter()
387            .map(|row| self.row_to_poll(row))
388            .collect::<Result<Vec<Poll>, String>>()
389    }
390
391    async fn find_by_status(&self, building_id: Uuid, status: &str) -> Result<Vec<Poll>, String> {
392        let rows = sqlx::query(
393            r#"
394            SELECT * FROM polls 
395            WHERE building_id = $1 AND status = $2
396            ORDER BY created_at DESC
397            "#,
398        )
399        .bind(building_id)
400        .bind(status)
401        .fetch_all(&self.pool)
402        .await
403        .map_err(|e| format!("Failed to fetch polls by status: {}", e))?;
404
405        rows.iter()
406            .map(|row| self.row_to_poll(row))
407            .collect::<Result<Vec<Poll>, String>>()
408    }
409
410    async fn find_expired_active(&self) -> Result<Vec<Poll>, String> {
411        let now = Utc::now();
412
413        let rows = sqlx::query(
414            r#"
415            SELECT * FROM polls 
416            WHERE status = 'active' AND ends_at <= $1
417            ORDER BY ends_at ASC
418            "#,
419        )
420        .bind(now)
421        .fetch_all(&self.pool)
422        .await
423        .map_err(|e| format!("Failed to fetch expired polls: {}", e))?;
424
425        rows.iter()
426            .map(|row| self.row_to_poll(row))
427            .collect::<Result<Vec<Poll>, String>>()
428    }
429
430    async fn update(&self, poll: &Poll) -> Result<Poll, String> {
431        let poll_type_str = Self::poll_type_to_string(&poll.poll_type);
432        let status_str = Self::poll_status_to_string(&poll.status);
433
434        // Serialize options to JSON
435        let options_json = serde_json::to_value(&poll.options)
436            .map_err(|e| format!("Failed to serialize options: {}", e))?;
437
438        sqlx::query(
439            r#"
440            UPDATE polls SET
441                title = $2,
442                description = $3,
443                poll_type = $4,
444                options = $5,
445                is_anonymous = $6,
446                allow_multiple_votes = $7,
447                require_all_owners = $8,
448                starts_at = $9,
449                ends_at = $10,
450                status = $11,
451                total_eligible_voters = $12,
452                total_votes_cast = $13,
453                updated_at = $14
454            WHERE id = $1
455            "#,
456        )
457        .bind(poll.id)
458        .bind(&poll.title)
459        .bind(&poll.description)
460        .bind(poll_type_str)
461        .bind(options_json)
462        .bind(poll.is_anonymous)
463        .bind(poll.allow_multiple_votes)
464        .bind(poll.require_all_owners)
465        .bind(poll.starts_at)
466        .bind(poll.ends_at)
467        .bind(status_str)
468        .bind(poll.total_eligible_voters)
469        .bind(poll.total_votes_cast)
470        .bind(poll.updated_at)
471        .execute(&self.pool)
472        .await
473        .map_err(|e| format!("Failed to update poll: {}", e))?;
474
475        Ok(poll.clone())
476    }
477
478    async fn delete(&self, id: Uuid) -> Result<bool, String> {
479        let result = sqlx::query(
480            r#"
481            DELETE FROM polls WHERE id = $1
482            "#,
483        )
484        .bind(id)
485        .execute(&self.pool)
486        .await
487        .map_err(|e| format!("Failed to delete poll: {}", e))?;
488
489        Ok(result.rows_affected() > 0)
490    }
491
492    async fn get_building_statistics(&self, building_id: Uuid) -> Result<PollStatistics, String> {
493        let row = sqlx::query(
494            r#"
495            SELECT 
496                COUNT(*) as total_polls,
497                COUNT(*) FILTER (WHERE status = 'active') as active_polls,
498                COUNT(*) FILTER (WHERE status = 'closed') as closed_polls,
499                COALESCE(AVG(CASE 
500                    WHEN total_eligible_voters > 0 
501                    THEN (total_votes_cast::float / total_eligible_voters::float) * 100.0
502                    ELSE 0
503                END), 0.0) as avg_participation
504            FROM polls
505            WHERE building_id = $1
506            "#,
507        )
508        .bind(building_id)
509        .fetch_one(&self.pool)
510        .await
511        .map_err(|e| format!("Failed to fetch poll statistics: {}", e))?;
512
513        Ok(PollStatistics {
514            total_polls: row
515                .try_get("total_polls")
516                .map_err(|e| format!("Failed to get total_polls: {}", e))?,
517            active_polls: row
518                .try_get("active_polls")
519                .map_err(|e| format!("Failed to get active_polls: {}", e))?,
520            closed_polls: row
521                .try_get("closed_polls")
522                .map_err(|e| format!("Failed to get closed_polls: {}", e))?,
523            average_participation_rate: row
524                .try_get("avg_participation")
525                .map_err(|e| format!("Failed to get avg_participation: {}", e))?,
526        })
527    }
528}