koprogo_api/infrastructure/database/repositories/
expense_repository_impl.rs

1use crate::application::dto::{ExpenseFilters, PageRequest};
2use crate::application::ports::ExpenseRepository;
3use crate::domain::entities::{ApprovalStatus, Expense, ExpenseCategory, PaymentStatus};
4use crate::infrastructure::database::pool::DbPool;
5use async_trait::async_trait;
6use sqlx::Row;
7use uuid::Uuid;
8
9pub struct PostgresExpenseRepository {
10    pool: DbPool,
11}
12
13impl PostgresExpenseRepository {
14    pub fn new(pool: DbPool) -> Self {
15        Self { pool }
16    }
17}
18
19#[async_trait]
20impl ExpenseRepository for PostgresExpenseRepository {
21    async fn create(&self, expense: &Expense) -> Result<Expense, String> {
22        let category_str = match expense.category {
23            ExpenseCategory::Maintenance => "maintenance",
24            ExpenseCategory::Repairs => "repairs",
25            ExpenseCategory::Insurance => "insurance",
26            ExpenseCategory::Utilities => "utilities",
27            ExpenseCategory::Cleaning => "cleaning",
28            ExpenseCategory::Administration => "administration",
29            ExpenseCategory::Works => "works",
30            ExpenseCategory::Other => "other",
31        };
32
33        let status_str = match expense.payment_status {
34            PaymentStatus::Pending => "pending",
35            PaymentStatus::Paid => "paid",
36            PaymentStatus::Overdue => "overdue",
37            PaymentStatus::Cancelled => "cancelled",
38        };
39
40        sqlx::query(
41            r#"
42            INSERT INTO expenses (id, organization_id, building_id, category, description, amount, expense_date, payment_status, supplier, invoice_number, account_code, contractor_report_id, created_at, updated_at)
43            VALUES ($1, $2, $3, CAST($4 AS expense_category), $5, $6, $7, CAST($8 AS payment_status), $9, $10, $11, $12, $13, $14)
44            "#,
45        )
46        .bind(expense.id)
47        .bind(expense.organization_id)
48        .bind(expense.building_id)
49        .bind(category_str)
50        .bind(&expense.description)
51        .bind(expense.amount)
52        .bind(expense.expense_date)
53        .bind(status_str)
54        .bind(&expense.supplier)
55        .bind(&expense.invoice_number)
56        .bind(&expense.account_code)
57        .bind(expense.contractor_report_id)
58        .bind(expense.created_at)
59        .bind(expense.updated_at)
60        .execute(&self.pool)
61        .await
62        .map_err(|e| format!("Database error: {}", e))?;
63
64        Ok(expense.clone())
65    }
66
67    async fn find_by_id(&self, id: Uuid) -> Result<Option<Expense>, String> {
68        let row = sqlx::query(
69            r#"
70            SELECT id, organization_id, building_id,
71                   category::text AS category, description, amount, expense_date,
72                   payment_status::text AS payment_status, approval_status::text AS approval_status,
73                   submitted_at, approved_by, approved_at, rejection_reason, paid_date,
74                   supplier, invoice_number, account_code, contractor_report_id, created_at, updated_at
75            FROM expenses
76            WHERE id = $1
77            "#,
78        )
79        .bind(id)
80        .fetch_optional(&self.pool)
81        .await
82        .map_err(|e| format!("Database error: {}", e))?;
83
84        Ok(row.map(|row| {
85            let category_str: String = row.get("category");
86            let category = match category_str.as_str() {
87                "maintenance" => ExpenseCategory::Maintenance,
88                "repairs" => ExpenseCategory::Repairs,
89                "insurance" => ExpenseCategory::Insurance,
90                "utilities" => ExpenseCategory::Utilities,
91                "cleaning" => ExpenseCategory::Cleaning,
92                "administration" => ExpenseCategory::Administration,
93                "works" => ExpenseCategory::Works,
94                _ => ExpenseCategory::Other,
95            };
96
97            let status_str: String = row.get("payment_status");
98            let payment_status = match status_str.as_str() {
99                "paid" => PaymentStatus::Paid,
100                "overdue" => PaymentStatus::Overdue,
101                "cancelled" => PaymentStatus::Cancelled,
102                _ => PaymentStatus::Pending,
103            };
104
105            let approval_status_str: String = row.get("approval_status");
106            let approval_status = match approval_status_str.as_str() {
107                "pending_approval" => ApprovalStatus::PendingApproval,
108                "approved" => ApprovalStatus::Approved,
109                "rejected" => ApprovalStatus::Rejected,
110                _ => ApprovalStatus::Draft,
111            };
112
113            Expense {
114                id: row.get("id"),
115                organization_id: row.get("organization_id"),
116                building_id: row.get("building_id"),
117                category,
118                description: row.get("description"),
119                amount: row.get("amount"),
120                amount_excl_vat: None,
121                vat_rate: None,
122                vat_amount: None,
123                amount_incl_vat: Some(row.get("amount")),
124                expense_date: row.get("expense_date"),
125                invoice_date: None,
126                due_date: None,
127                paid_date: row.try_get("paid_date").ok(),
128                approval_status,
129                submitted_at: row.try_get("submitted_at").ok(),
130                approved_by: row.try_get("approved_by").ok(),
131                approved_at: row.try_get("approved_at").ok(),
132                rejection_reason: row.try_get("rejection_reason").ok(),
133                payment_status,
134                supplier: row.get("supplier"),
135                invoice_number: row.get("invoice_number"),
136                account_code: row.get("account_code"),
137                contractor_report_id: row.try_get("contractor_report_id").ok(),
138                created_at: row.get("created_at"),
139                updated_at: row.get("updated_at"),
140            }
141        }))
142    }
143
144    async fn find_by_building(&self, building_id: Uuid) -> Result<Vec<Expense>, String> {
145        let rows = sqlx::query(
146            r#"
147            SELECT id, organization_id, building_id,
148                   category::text AS category, description, amount, expense_date,
149                   payment_status::text AS payment_status, approval_status::text AS approval_status,
150                   submitted_at, approved_by, approved_at, rejection_reason, paid_date,
151                   supplier, invoice_number, account_code, contractor_report_id, created_at, updated_at
152            FROM expenses
153            WHERE building_id = $1
154            ORDER BY expense_date DESC
155            "#,
156        )
157        .bind(building_id)
158        .fetch_all(&self.pool)
159        .await
160        .map_err(|e| format!("Database error: {}", e))?;
161
162        Ok(rows
163            .iter()
164            .map(|row| {
165                let category_str: String = row.get("category");
166                let category = match category_str.as_str() {
167                    "maintenance" => ExpenseCategory::Maintenance,
168                    "repairs" => ExpenseCategory::Repairs,
169                    "insurance" => ExpenseCategory::Insurance,
170                    "utilities" => ExpenseCategory::Utilities,
171                    "cleaning" => ExpenseCategory::Cleaning,
172                    "administration" => ExpenseCategory::Administration,
173                    "works" => ExpenseCategory::Works,
174                    _ => ExpenseCategory::Other,
175                };
176
177                let status_str: String = row.get("payment_status");
178                let payment_status = match status_str.as_str() {
179                    "paid" => PaymentStatus::Paid,
180                    "overdue" => PaymentStatus::Overdue,
181                    "cancelled" => PaymentStatus::Cancelled,
182                    _ => PaymentStatus::Pending,
183                };
184
185                let approval_status_str: String = row.get("approval_status");
186                let approval_status = match approval_status_str.as_str() {
187                    "pending_approval" => ApprovalStatus::PendingApproval,
188                    "approved" => ApprovalStatus::Approved,
189                    "rejected" => ApprovalStatus::Rejected,
190                    _ => ApprovalStatus::Draft,
191                };
192
193                Expense {
194                    id: row.get("id"),
195                    organization_id: row.get("organization_id"),
196                    building_id: row.get("building_id"),
197                    category,
198                    description: row.get("description"),
199                    amount: row.get("amount"),
200                    amount_excl_vat: None,
201                    vat_rate: None,
202                    vat_amount: None,
203                    amount_incl_vat: Some(row.get("amount")),
204                    expense_date: row.get("expense_date"),
205                    invoice_date: None,
206                    due_date: None,
207                    paid_date: row.try_get("paid_date").ok(),
208                    approval_status,
209                    submitted_at: row.try_get("submitted_at").ok(),
210                    approved_by: row.try_get("approved_by").ok(),
211                    approved_at: row.try_get("approved_at").ok(),
212                    rejection_reason: row.try_get("rejection_reason").ok(),
213                    payment_status,
214                    supplier: row.get("supplier"),
215                    invoice_number: row.get("invoice_number"),
216                    account_code: row.get("account_code"),
217                    contractor_report_id: row.try_get("contractor_report_id").ok(),
218                    created_at: row.get("created_at"),
219                    updated_at: row.get("updated_at"),
220                }
221            })
222            .collect())
223    }
224
225    async fn find_all_paginated(
226        &self,
227        page_request: &PageRequest,
228        filters: &ExpenseFilters,
229    ) -> Result<(Vec<Expense>, i64), String> {
230        // Validate page request
231        page_request.validate()?;
232
233        // Build WHERE clause dynamically
234        let mut where_clauses = Vec::new();
235        let mut param_count = 0;
236
237        if filters.organization_id.is_some() {
238            param_count += 1;
239            where_clauses.push(format!("organization_id = ${}", param_count));
240        }
241
242        if filters.building_id.is_some() {
243            param_count += 1;
244            where_clauses.push(format!("building_id = ${}", param_count));
245        }
246
247        if filters.category.is_some() {
248            param_count += 1;
249            where_clauses.push(format!("category = ${}", param_count));
250        }
251
252        if filters.status.is_some() {
253            param_count += 1;
254            where_clauses.push(format!("payment_status = ${}", param_count));
255        }
256
257        if filters.date_from.is_some() {
258            param_count += 1;
259            where_clauses.push(format!("expense_date >= ${}", param_count));
260        }
261
262        if filters.date_to.is_some() {
263            param_count += 1;
264            where_clauses.push(format!("expense_date <= ${}", param_count));
265        }
266
267        if filters.min_amount.is_some() {
268            param_count += 1;
269            where_clauses.push(format!("amount >= ${}", param_count));
270        }
271
272        if filters.max_amount.is_some() {
273            param_count += 1;
274            where_clauses.push(format!("amount <= ${}", param_count));
275        }
276
277        if filters.approval_status.is_some() {
278            param_count += 1;
279            where_clauses.push(format!("approval_status::text = ${}", param_count));
280        }
281
282        let where_clause = if where_clauses.is_empty() {
283            String::new()
284        } else {
285            format!("WHERE {}", where_clauses.join(" AND "))
286        };
287
288        // Validate sort column (whitelist)
289        let allowed_columns = ["expense_date", "amount", "created_at", "payment_status"];
290        let sort_column = page_request.sort_by.as_deref().unwrap_or("expense_date");
291
292        if !allowed_columns.contains(&sort_column) {
293            return Err(format!("Invalid sort column: {}", sort_column));
294        }
295
296        // Count total items
297        let count_query = format!("SELECT COUNT(*) FROM expenses {}", where_clause);
298        let mut count_query = sqlx::query_scalar::<_, i64>(&count_query);
299
300        if let Some(organization_id) = filters.organization_id {
301            count_query = count_query.bind(organization_id);
302        }
303        if let Some(building_id) = filters.building_id {
304            count_query = count_query.bind(building_id);
305        }
306        if let Some(category) = &filters.category {
307            count_query = count_query.bind(category);
308        }
309        if let Some(status) = &filters.status {
310            count_query = count_query.bind(status);
311        }
312        if let Some(date_from) = filters.date_from {
313            count_query = count_query.bind(date_from);
314        }
315        if let Some(date_to) = filters.date_to {
316            count_query = count_query.bind(date_to);
317        }
318        if let Some(min_amount) = filters.min_amount {
319            count_query = count_query.bind(min_amount);
320        }
321        if let Some(max_amount) = filters.max_amount {
322            count_query = count_query.bind(max_amount);
323        }
324        if let Some(ref approval_status) = filters.approval_status {
325            let status_str = match approval_status {
326                ApprovalStatus::Draft => "draft",
327                ApprovalStatus::PendingApproval => "pending_approval",
328                ApprovalStatus::Approved => "approved",
329                ApprovalStatus::Rejected => "rejected",
330            };
331            count_query = count_query.bind(status_str);
332        }
333
334        let total_items = count_query
335            .fetch_one(&self.pool)
336            .await
337            .map_err(|e| format!("Database error: {}", e))?;
338
339        // Fetch paginated data
340        param_count += 1;
341        let limit_param = param_count;
342        param_count += 1;
343        let offset_param = param_count;
344
345        let data_query = format!(
346            "SELECT id, organization_id, building_id, category::text AS category, description, amount, expense_date, payment_status::text AS payment_status, approval_status::text AS approval_status, submitted_at, approved_by, approved_at, rejection_reason, paid_date, supplier, invoice_number, account_code, contractor_report_id, created_at, updated_at \
347             FROM expenses {} ORDER BY {} {} LIMIT ${} OFFSET ${}",
348            where_clause,
349            sort_column,
350            page_request.order.to_sql(),
351            limit_param,
352            offset_param
353        );
354
355        let mut data_query = sqlx::query(&data_query);
356
357        if let Some(organization_id) = filters.organization_id {
358            data_query = data_query.bind(organization_id);
359        }
360        if let Some(building_id) = filters.building_id {
361            data_query = data_query.bind(building_id);
362        }
363        if let Some(category) = &filters.category {
364            data_query = data_query.bind(category);
365        }
366        if let Some(status) = &filters.status {
367            data_query = data_query.bind(status);
368        }
369        if let Some(date_from) = filters.date_from {
370            data_query = data_query.bind(date_from);
371        }
372        if let Some(date_to) = filters.date_to {
373            data_query = data_query.bind(date_to);
374        }
375        if let Some(min_amount) = filters.min_amount {
376            data_query = data_query.bind(min_amount);
377        }
378        if let Some(max_amount) = filters.max_amount {
379            data_query = data_query.bind(max_amount);
380        }
381        if let Some(ref approval_status) = filters.approval_status {
382            let status_str = match approval_status {
383                ApprovalStatus::Draft => "draft",
384                ApprovalStatus::PendingApproval => "pending_approval",
385                ApprovalStatus::Approved => "approved",
386                ApprovalStatus::Rejected => "rejected",
387            };
388            data_query = data_query.bind(status_str);
389        }
390
391        data_query = data_query
392            .bind(page_request.limit())
393            .bind(page_request.offset());
394
395        let rows = data_query
396            .fetch_all(&self.pool)
397            .await
398            .map_err(|e| format!("Database error: {}", e))?;
399
400        let expenses: Vec<Expense> = rows
401            .iter()
402            .map(|row| {
403                let category_str: String = row.get("category");
404                let category = match category_str.as_str() {
405                    "maintenance" => ExpenseCategory::Maintenance,
406                    "repairs" => ExpenseCategory::Repairs,
407                    "insurance" => ExpenseCategory::Insurance,
408                    "utilities" => ExpenseCategory::Utilities,
409                    "cleaning" => ExpenseCategory::Cleaning,
410                    "administration" => ExpenseCategory::Administration,
411                    "works" => ExpenseCategory::Works,
412                    _ => ExpenseCategory::Other,
413                };
414
415                let status_str: String = row.get("payment_status");
416                let payment_status = match status_str.as_str() {
417                    "paid" => PaymentStatus::Paid,
418                    "overdue" => PaymentStatus::Overdue,
419                    "cancelled" => PaymentStatus::Cancelled,
420                    _ => PaymentStatus::Pending,
421                };
422
423                let approval_status_str: String = row.get("approval_status");
424                let approval_status = match approval_status_str.as_str() {
425                    "pending_approval" => ApprovalStatus::PendingApproval,
426                    "approved" => ApprovalStatus::Approved,
427                    "rejected" => ApprovalStatus::Rejected,
428                    _ => ApprovalStatus::Draft,
429                };
430
431                Expense {
432                    id: row.get("id"),
433                    organization_id: row.get("organization_id"),
434                    building_id: row.get("building_id"),
435                    category,
436                    description: row.get("description"),
437                    amount: row.get("amount"),
438                    amount_excl_vat: None,
439                    vat_rate: None,
440                    vat_amount: None,
441                    amount_incl_vat: Some(row.get("amount")),
442                    expense_date: row.get("expense_date"),
443                    invoice_date: None,
444                    due_date: None,
445                    paid_date: row.try_get("paid_date").ok(),
446                    approval_status,
447                    submitted_at: row.try_get("submitted_at").ok(),
448                    approved_by: row.try_get("approved_by").ok(),
449                    approved_at: row.try_get("approved_at").ok(),
450                    rejection_reason: row.try_get("rejection_reason").ok(),
451                    payment_status,
452                    supplier: row.get("supplier"),
453                    invoice_number: row.get("invoice_number"),
454                    account_code: row.get("account_code"),
455                    contractor_report_id: row.try_get("contractor_report_id").ok(),
456                    created_at: row.get("created_at"),
457                    updated_at: row.get("updated_at"),
458                }
459            })
460            .collect();
461
462        Ok((expenses, total_items))
463    }
464
465    async fn update(&self, expense: &Expense) -> Result<Expense, String> {
466        let payment_status_str = match expense.payment_status {
467            PaymentStatus::Pending => "pending",
468            PaymentStatus::Paid => "paid",
469            PaymentStatus::Overdue => "overdue",
470            PaymentStatus::Cancelled => "cancelled",
471        };
472
473        let approval_status_str = match expense.approval_status {
474            ApprovalStatus::Draft => "draft",
475            ApprovalStatus::PendingApproval => "pending_approval",
476            ApprovalStatus::Approved => "approved",
477            ApprovalStatus::Rejected => "rejected",
478        };
479
480        sqlx::query(
481            r#"
482            UPDATE expenses
483            SET
484                payment_status = CAST($2 AS payment_status),
485                approval_status = CAST($3 AS approval_status),
486                submitted_at = $4,
487                approved_by = $5,
488                approved_at = $6,
489                rejection_reason = $7,
490                paid_date = $8,
491                updated_at = $9
492            WHERE id = $1
493            "#,
494        )
495        .bind(expense.id)
496        .bind(payment_status_str)
497        .bind(approval_status_str)
498        .bind(expense.submitted_at)
499        .bind(expense.approved_by)
500        .bind(expense.approved_at)
501        .bind(expense.rejection_reason.as_deref())
502        .bind(expense.paid_date)
503        .bind(expense.updated_at)
504        .execute(&self.pool)
505        .await
506        .map_err(|e| format!("Database error: {}", e))?;
507
508        Ok(expense.clone())
509    }
510
511    async fn delete(&self, id: Uuid) -> Result<bool, String> {
512        let result = sqlx::query("DELETE FROM expenses WHERE id = $1")
513            .bind(id)
514            .execute(&self.pool)
515            .await
516            .map_err(|e| format!("Database error: {}", e))?;
517
518        Ok(result.rows_affected() > 0)
519    }
520}