koprogo_api/infrastructure/database/repositories/
poll_repository_impl.rs

1use crate::application::dto::{PageRequest, PollFilters};
2use crate::application::ports::{PollRepository, PollStatistics};
3use crate::domain::entities::{Poll, PollOption, PollStatus, PollType};
4use crate::infrastructure::database::pool::DbPool;
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use serde_json;
8use sqlx::Row;
9use uuid::Uuid;
10
11/// Explicit column list for SELECT queries, casting PostgreSQL custom enums to text
12const POLL_COLUMNS: &str = r#"
13    id, building_id, created_by, title, description,
14    poll_type::text as poll_type, options, is_anonymous,
15    allow_multiple_votes, require_all_owners, starts_at, ends_at,
16    status::text as status, total_eligible_voters, total_votes_cast,
17    created_at, updated_at
18"#;
19
20pub struct PostgresPollRepository {
21    pool: DbPool,
22}
23
24impl PostgresPollRepository {
25    pub fn new(pool: DbPool) -> Self {
26        Self { pool }
27    }
28
29    /// Convert PollType to database string format
30    fn poll_type_to_string(poll_type: &PollType) -> &'static str {
31        match poll_type {
32            PollType::YesNo => "yes_no",
33            PollType::MultipleChoice => "multiple_choice",
34            PollType::Rating => "rating",
35            PollType::OpenEnded => "open_ended",
36        }
37    }
38
39    /// Parse PollType from database string
40    fn parse_poll_type(s: &str) -> PollType {
41        match s {
42            "yes_no" => PollType::YesNo,
43            "multiple_choice" => PollType::MultipleChoice,
44            "rating" => PollType::Rating,
45            "open_ended" => PollType::OpenEnded,
46            _ => PollType::YesNo, // Default fallback
47        }
48    }
49
50    /// Convert PollStatus to database string format
51    fn poll_status_to_string(status: &PollStatus) -> &'static str {
52        match status {
53            PollStatus::Draft => "draft",
54            PollStatus::Active => "active",
55            PollStatus::Closed => "closed",
56            PollStatus::Cancelled => "cancelled",
57        }
58    }
59
60    /// Parse PollStatus from database string
61    fn parse_poll_status(s: &str) -> PollStatus {
62        match s {
63            "draft" => PollStatus::Draft,
64            "active" => PollStatus::Active,
65            "closed" => PollStatus::Closed,
66            "cancelled" => PollStatus::Cancelled,
67            _ => PollStatus::Draft, // Default fallback
68        }
69    }
70
71    /// Map database row to Poll entity
72    fn row_to_poll(&self, row: &sqlx::postgres::PgRow) -> Result<Poll, String> {
73        // Parse options from JSONB
74        let options_json: serde_json::Value = row
75            .try_get("options")
76            .map_err(|e| format!("Failed to get options: {}", e))?;
77
78        let options: Vec<PollOption> = serde_json::from_value(options_json)
79            .map_err(|e| format!("Failed to deserialize options: {}", e))?;
80
81        Ok(Poll {
82            id: row
83                .try_get("id")
84                .map_err(|e| format!("Failed to get id: {}", e))?,
85            building_id: row
86                .try_get("building_id")
87                .map_err(|e| format!("Failed to get building_id: {}", e))?,
88            created_by: row
89                .try_get("created_by")
90                .map_err(|e| format!("Failed to get created_by: {}", e))?,
91            title: row
92                .try_get("title")
93                .map_err(|e| format!("Failed to get title: {}", e))?,
94            description: row
95                .try_get("description")
96                .map_err(|e| format!("Failed to get description: {}", e))?,
97            poll_type: Self::parse_poll_type(
98                row.try_get("poll_type")
99                    .map_err(|e| format!("Failed to get poll_type: {}", e))?,
100            ),
101            options,
102            is_anonymous: row
103                .try_get("is_anonymous")
104                .map_err(|e| format!("Failed to get is_anonymous: {}", e))?,
105            allow_multiple_votes: row
106                .try_get("allow_multiple_votes")
107                .map_err(|e| format!("Failed to get allow_multiple_votes: {}", e))?,
108            require_all_owners: row
109                .try_get("require_all_owners")
110                .map_err(|e| format!("Failed to get require_all_owners: {}", e))?,
111            starts_at: row
112                .try_get("starts_at")
113                .map_err(|e| format!("Failed to get starts_at: {}", e))?,
114            ends_at: row
115                .try_get("ends_at")
116                .map_err(|e| format!("Failed to get ends_at: {}", e))?,
117            status: Self::parse_poll_status(
118                row.try_get("status")
119                    .map_err(|e| format!("Failed to get status: {}", e))?,
120            ),
121            total_eligible_voters: row
122                .try_get("total_eligible_voters")
123                .map_err(|e| format!("Failed to get total_eligible_voters: {}", e))?,
124            total_votes_cast: row
125                .try_get("total_votes_cast")
126                .map_err(|e| format!("Failed to get total_votes_cast: {}", e))?,
127            created_at: row
128                .try_get("created_at")
129                .map_err(|e| format!("Failed to get created_at: {}", e))?,
130            updated_at: row
131                .try_get("updated_at")
132                .map_err(|e| format!("Failed to get updated_at: {}", e))?,
133        })
134    }
135}
136
137#[async_trait]
138impl PollRepository for PostgresPollRepository {
139    async fn create(&self, poll: &Poll) -> Result<Poll, String> {
140        let poll_type_str = Self::poll_type_to_string(&poll.poll_type);
141        let status_str = Self::poll_status_to_string(&poll.status);
142
143        // Serialize options to JSON
144        let options_json = serde_json::to_value(&poll.options)
145            .map_err(|e| format!("Failed to serialize options: {}", e))?;
146
147        sqlx::query(
148            r#"
149            INSERT INTO polls (
150                id, building_id, created_by, title, description, poll_type, options,
151                is_anonymous, allow_multiple_votes, require_all_owners,
152                starts_at, ends_at, status, total_eligible_voters, total_votes_cast,
153                created_at, updated_at
154            )
155            VALUES ($1, $2, $3, $4, $5, $6::poll_type, $7, $8, $9, $10, $11, $12, $13::poll_status, $14, $15, $16, $17)
156            "#,
157        )
158        .bind(poll.id)
159        .bind(poll.building_id)
160        .bind(poll.created_by)
161        .bind(&poll.title)
162        .bind(&poll.description)
163        .bind(poll_type_str)
164        .bind(options_json)
165        .bind(poll.is_anonymous)
166        .bind(poll.allow_multiple_votes)
167        .bind(poll.require_all_owners)
168        .bind(poll.starts_at)
169        .bind(poll.ends_at)
170        .bind(status_str)
171        .bind(poll.total_eligible_voters)
172        .bind(poll.total_votes_cast)
173        .bind(poll.created_at)
174        .bind(poll.updated_at)
175        .execute(&self.pool)
176        .await
177        .map_err(|e| format!("Failed to insert poll: {}", e))?;
178
179        Ok(poll.clone())
180    }
181
182    async fn find_by_id(&self, id: Uuid) -> Result<Option<Poll>, String> {
183        let row = sqlx::query(&format!("SELECT {} FROM polls WHERE id = $1", POLL_COLUMNS))
184            .bind(id)
185            .fetch_optional(&self.pool)
186            .await
187            .map_err(|e| format!("Failed to fetch poll: {}", e))?;
188
189        match row {
190            Some(r) => Ok(Some(self.row_to_poll(&r)?)),
191            None => Ok(None),
192        }
193    }
194
195    async fn find_by_building(&self, building_id: Uuid) -> Result<Vec<Poll>, String> {
196        let rows = sqlx::query(&format!(
197            "SELECT {} FROM polls WHERE building_id = $1 ORDER BY created_at DESC",
198            POLL_COLUMNS
199        ))
200        .bind(building_id)
201        .fetch_all(&self.pool)
202        .await
203        .map_err(|e| format!("Failed to fetch polls: {}", e))?;
204
205        rows.iter()
206            .map(|row| self.row_to_poll(row))
207            .collect::<Result<Vec<Poll>, String>>()
208    }
209
210    async fn find_by_created_by(&self, created_by: Uuid) -> Result<Vec<Poll>, String> {
211        let rows = sqlx::query(&format!(
212            "SELECT {} FROM polls WHERE created_by = $1 ORDER BY created_at DESC",
213            POLL_COLUMNS
214        ))
215        .bind(created_by)
216        .fetch_all(&self.pool)
217        .await
218        .map_err(|e| format!("Failed to fetch polls: {}", e))?;
219
220        rows.iter()
221            .map(|row| self.row_to_poll(row))
222            .collect::<Result<Vec<Poll>, String>>()
223    }
224
225    async fn find_all_paginated(
226        &self,
227        page_request: &PageRequest,
228        filters: &PollFilters,
229    ) -> Result<(Vec<Poll>, i64), String> {
230        let offset = (page_request.page - 1) * page_request.per_page;
231
232        // Build WHERE clause dynamically
233        let mut where_clauses = Vec::new();
234        let mut bind_index = 1;
235
236        if filters.building_id.is_some() {
237            where_clauses.push(format!("building_id = ${}", bind_index));
238            bind_index += 1;
239        }
240
241        if filters.created_by.is_some() {
242            where_clauses.push(format!("created_by = ${}", bind_index));
243            bind_index += 1;
244        }
245
246        if filters.status.is_some() {
247            where_clauses.push(format!("status = ${}::poll_status", bind_index));
248            bind_index += 1;
249        }
250
251        if filters.poll_type.is_some() {
252            where_clauses.push(format!("poll_type = ${}::poll_type", bind_index));
253            bind_index += 1;
254        }
255
256        if filters.ends_before.is_some() {
257            where_clauses.push(format!("ends_at < ${}", bind_index));
258            bind_index += 1;
259        }
260
261        if filters.ends_after.is_some() {
262            where_clauses.push(format!("ends_at > ${}", bind_index));
263            bind_index += 1;
264        }
265
266        let where_sql = if where_clauses.is_empty() {
267            String::new()
268        } else {
269            format!("WHERE {}", where_clauses.join(" AND "))
270        };
271
272        // Count total
273        let count_query = format!("SELECT COUNT(*) as count FROM polls {}", where_sql);
274        let mut count_query_builder = sqlx::query(&count_query);
275
276        if let Some(ref building_id) = filters.building_id {
277            let id = Uuid::parse_str(building_id)
278                .map_err(|_| "Invalid building_id format".to_string())?;
279            count_query_builder = count_query_builder.bind(id);
280        }
281        if let Some(ref created_by) = filters.created_by {
282            let id =
283                Uuid::parse_str(created_by).map_err(|_| "Invalid created_by format".to_string())?;
284            count_query_builder = count_query_builder.bind(id);
285        }
286        if let Some(ref status) = filters.status {
287            count_query_builder = count_query_builder.bind(Self::poll_status_to_string(status));
288        }
289        if let Some(ref poll_type) = filters.poll_type {
290            count_query_builder = count_query_builder.bind(Self::poll_type_to_string(poll_type));
291        }
292        if let Some(ref ends_before) = filters.ends_before {
293            let date = DateTime::parse_from_rfc3339(ends_before)
294                .map_err(|_| "Invalid ends_before format".to_string())?
295                .with_timezone(&Utc);
296            count_query_builder = count_query_builder.bind(date);
297        }
298        if let Some(ref ends_after) = filters.ends_after {
299            let date = DateTime::parse_from_rfc3339(ends_after)
300                .map_err(|_| "Invalid ends_after format".to_string())?
301                .with_timezone(&Utc);
302            count_query_builder = count_query_builder.bind(date);
303        }
304
305        let count_row = count_query_builder
306            .fetch_one(&self.pool)
307            .await
308            .map_err(|e| format!("Failed to count polls: {}", e))?;
309        let total: i64 = count_row
310            .try_get("count")
311            .map_err(|e| format!("Failed to get count: {}", e))?;
312
313        // Fetch paginated results
314        let data_query = format!(
315            "SELECT {} FROM polls {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
316            POLL_COLUMNS,
317            where_sql,
318            bind_index,
319            bind_index + 1
320        );
321        let mut data_query_builder = sqlx::query(&data_query);
322
323        if let Some(ref building_id) = filters.building_id {
324            let id = Uuid::parse_str(building_id)
325                .map_err(|_| "Invalid building_id format".to_string())?;
326            data_query_builder = data_query_builder.bind(id);
327        }
328        if let Some(ref created_by) = filters.created_by {
329            let id =
330                Uuid::parse_str(created_by).map_err(|_| "Invalid created_by format".to_string())?;
331            data_query_builder = data_query_builder.bind(id);
332        }
333        if let Some(ref status) = filters.status {
334            data_query_builder = data_query_builder.bind(Self::poll_status_to_string(status));
335        }
336        if let Some(ref poll_type) = filters.poll_type {
337            data_query_builder = data_query_builder.bind(Self::poll_type_to_string(poll_type));
338        }
339        if let Some(ref ends_before) = filters.ends_before {
340            let date = DateTime::parse_from_rfc3339(ends_before)
341                .map_err(|_| "Invalid ends_before format".to_string())?
342                .with_timezone(&Utc);
343            data_query_builder = data_query_builder.bind(date);
344        }
345        if let Some(ref ends_after) = filters.ends_after {
346            let date = DateTime::parse_from_rfc3339(ends_after)
347                .map_err(|_| "Invalid ends_after format".to_string())?
348                .with_timezone(&Utc);
349            data_query_builder = data_query_builder.bind(date);
350        }
351
352        data_query_builder = data_query_builder.bind(page_request.per_page).bind(offset);
353
354        let rows = data_query_builder
355            .fetch_all(&self.pool)
356            .await
357            .map_err(|e| format!("Failed to fetch polls: {}", e))?;
358
359        let polls = rows
360            .iter()
361            .map(|row| self.row_to_poll(row))
362            .collect::<Result<Vec<Poll>, String>>()?;
363
364        Ok((polls, total))
365    }
366
367    async fn find_active(&self, building_id: Uuid) -> Result<Vec<Poll>, String> {
368        let now = Utc::now();
369
370        let rows = sqlx::query(&format!(
371            "SELECT {} FROM polls WHERE building_id = $1 AND status = 'active' AND starts_at <= $2 AND ends_at > $2 ORDER BY created_at DESC",
372            POLL_COLUMNS
373        ))
374        .bind(building_id)
375        .bind(now)
376        .fetch_all(&self.pool)
377        .await
378        .map_err(|e| format!("Failed to fetch active polls: {}", e))?;
379
380        rows.iter()
381            .map(|row| self.row_to_poll(row))
382            .collect::<Result<Vec<Poll>, String>>()
383    }
384
385    async fn find_by_status(&self, building_id: Uuid, status: &str) -> Result<Vec<Poll>, String> {
386        let rows = sqlx::query(&format!(
387            "SELECT {} FROM polls WHERE building_id = $1 AND status = $2::poll_status ORDER BY created_at DESC",
388            POLL_COLUMNS
389        ))
390        .bind(building_id)
391        .bind(status)
392        .fetch_all(&self.pool)
393        .await
394        .map_err(|e| format!("Failed to fetch polls by status: {}", e))?;
395
396        rows.iter()
397            .map(|row| self.row_to_poll(row))
398            .collect::<Result<Vec<Poll>, String>>()
399    }
400
401    async fn find_expired_active(&self) -> Result<Vec<Poll>, String> {
402        let now = Utc::now();
403
404        let rows = sqlx::query(&format!(
405            "SELECT {} FROM polls WHERE status = 'active' AND ends_at <= $1 ORDER BY ends_at ASC",
406            POLL_COLUMNS
407        ))
408        .bind(now)
409        .fetch_all(&self.pool)
410        .await
411        .map_err(|e| format!("Failed to fetch expired polls: {}", e))?;
412
413        rows.iter()
414            .map(|row| self.row_to_poll(row))
415            .collect::<Result<Vec<Poll>, String>>()
416    }
417
418    async fn update(&self, poll: &Poll) -> Result<Poll, String> {
419        let poll_type_str = Self::poll_type_to_string(&poll.poll_type);
420        let status_str = Self::poll_status_to_string(&poll.status);
421
422        // Serialize options to JSON
423        let options_json = serde_json::to_value(&poll.options)
424            .map_err(|e| format!("Failed to serialize options: {}", e))?;
425
426        sqlx::query(
427            r#"
428            UPDATE polls SET
429                title = $2,
430                description = $3,
431                poll_type = $4::poll_type,
432                options = $5,
433                is_anonymous = $6,
434                allow_multiple_votes = $7,
435                require_all_owners = $8,
436                starts_at = $9,
437                ends_at = $10,
438                status = $11::poll_status,
439                total_eligible_voters = $12,
440                total_votes_cast = $13,
441                updated_at = $14
442            WHERE id = $1
443            "#,
444        )
445        .bind(poll.id)
446        .bind(&poll.title)
447        .bind(&poll.description)
448        .bind(poll_type_str)
449        .bind(options_json)
450        .bind(poll.is_anonymous)
451        .bind(poll.allow_multiple_votes)
452        .bind(poll.require_all_owners)
453        .bind(poll.starts_at)
454        .bind(poll.ends_at)
455        .bind(status_str)
456        .bind(poll.total_eligible_voters)
457        .bind(poll.total_votes_cast)
458        .bind(poll.updated_at)
459        .execute(&self.pool)
460        .await
461        .map_err(|e| format!("Failed to update poll: {}", e))?;
462
463        Ok(poll.clone())
464    }
465
466    async fn delete(&self, id: Uuid) -> Result<bool, String> {
467        let result = sqlx::query(
468            r#"
469            DELETE FROM polls WHERE id = $1
470            "#,
471        )
472        .bind(id)
473        .execute(&self.pool)
474        .await
475        .map_err(|e| format!("Failed to delete poll: {}", e))?;
476
477        Ok(result.rows_affected() > 0)
478    }
479
480    async fn get_building_statistics(&self, building_id: Uuid) -> Result<PollStatistics, String> {
481        let row = sqlx::query(
482            r#"
483            SELECT 
484                COUNT(*) as total_polls,
485                COUNT(*) FILTER (WHERE status = 'active') as active_polls,
486                COUNT(*) FILTER (WHERE status = 'closed') as closed_polls,
487                COALESCE(AVG(CASE 
488                    WHEN total_eligible_voters > 0 
489                    THEN (total_votes_cast::float / total_eligible_voters::float) * 100.0
490                    ELSE 0
491                END), 0.0) as avg_participation
492            FROM polls
493            WHERE building_id = $1
494            "#,
495        )
496        .bind(building_id)
497        .fetch_one(&self.pool)
498        .await
499        .map_err(|e| format!("Failed to fetch poll statistics: {}", e))?;
500
501        Ok(PollStatistics {
502            total_polls: row
503                .try_get("total_polls")
504                .map_err(|e| format!("Failed to get total_polls: {}", e))?,
505            active_polls: row
506                .try_get("active_polls")
507                .map_err(|e| format!("Failed to get active_polls: {}", e))?,
508            closed_polls: row
509                .try_get("closed_polls")
510                .map_err(|e| format!("Failed to get closed_polls: {}", e))?,
511            average_participation_rate: row
512                .try_get("avg_participation")
513                .map_err(|e| format!("Failed to get avg_participation: {}", e))?,
514        })
515    }
516}