koprogo_api/infrastructure/database/repositories/
service_provider_repository_impl.rs

1use crate::application::ports::service_provider_repository::ServiceProviderRepository;
2use crate::domain::entities::service_provider::{ServiceProvider, TradeCategory};
3use crate::infrastructure::database::pool::DbPool;
4use async_trait::async_trait;
5use sqlx::Row;
6use uuid::Uuid;
7
8pub struct PostgresServiceProviderRepository {
9    pool: DbPool,
10}
11
12impl PostgresServiceProviderRepository {
13    pub fn new(pool: DbPool) -> Self {
14        Self { pool }
15    }
16}
17
18#[async_trait]
19impl ServiceProviderRepository for PostgresServiceProviderRepository {
20    async fn create(&self, provider: &ServiceProvider) -> Result<ServiceProvider, String> {
21        sqlx::query(
22            r#"
23            INSERT INTO service_providers (
24                id, organization_id, company_name, trade_category,
25                specializations, service_zone_postal_codes, certifications,
26                ipi_registration, bce_number, rating_avg, reviews_count,
27                is_verified, public_profile_slug, created_at, updated_at
28            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
29            "#,
30        )
31        .bind(provider.id)
32        .bind(provider.organization_id)
33        .bind(&provider.company_name)
34        .bind(provider.trade_category.to_sql())
35        .bind(&provider.specializations)
36        .bind(&provider.service_zone_postal_codes)
37        .bind(&provider.certifications)
38        .bind(&provider.ipi_registration)
39        .bind(&provider.bce_number)
40        .bind(provider.rating_avg)
41        .bind(provider.reviews_count)
42        .bind(provider.is_verified)
43        .bind(&provider.public_profile_slug)
44        .bind(provider.created_at)
45        .bind(provider.updated_at)
46        .execute(&self.pool)
47        .await
48        .map_err(|e| format!("Database error creating service provider: {}", e))?;
49
50        Ok(provider.clone())
51    }
52
53    async fn find_by_id(&self, id: Uuid) -> Result<Option<ServiceProvider>, String> {
54        let row = sqlx::query(
55            r#"
56            SELECT id, organization_id, company_name, trade_category,
57                   specializations, service_zone_postal_codes, certifications,
58                   ipi_registration, bce_number, rating_avg, reviews_count,
59                   is_verified, public_profile_slug, created_at, updated_at
60            FROM service_providers
61            WHERE id = $1
62            "#,
63        )
64        .bind(id)
65        .fetch_optional(&self.pool)
66        .await
67        .map_err(|e| format!("Database error: {}", e))?;
68
69        Ok(row.map(|row| ServiceProvider {
70            id: row.get("id"),
71            organization_id: row.get("organization_id"),
72            company_name: row.get("company_name"),
73            trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
74                .unwrap_or(TradeCategory::Syndic),
75            specializations: row.get("specializations"),
76            service_zone_postal_codes: row.get("service_zone_postal_codes"),
77            certifications: row.get("certifications"),
78            ipi_registration: row.get("ipi_registration"),
79            bce_number: row.get("bce_number"),
80            rating_avg: row.get("rating_avg"),
81            reviews_count: row.get("reviews_count"),
82            is_verified: row.get("is_verified"),
83            public_profile_slug: row.get("public_profile_slug"),
84            created_at: row.get("created_at"),
85            updated_at: row.get("updated_at"),
86        }))
87    }
88
89    async fn find_by_slug(&self, slug: &str) -> Result<Option<ServiceProvider>, String> {
90        let row = sqlx::query(
91            r#"
92            SELECT id, organization_id, company_name, trade_category,
93                   specializations, service_zone_postal_codes, certifications,
94                   ipi_registration, bce_number, rating_avg, reviews_count,
95                   is_verified, public_profile_slug, created_at, updated_at
96            FROM service_providers
97            WHERE public_profile_slug = $1
98            "#,
99        )
100        .bind(slug)
101        .fetch_optional(&self.pool)
102        .await
103        .map_err(|e| format!("Database error: {}", e))?;
104
105        Ok(row.map(|row| ServiceProvider {
106            id: row.get("id"),
107            organization_id: row.get("organization_id"),
108            company_name: row.get("company_name"),
109            trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
110                .unwrap_or(TradeCategory::Syndic),
111            specializations: row.get("specializations"),
112            service_zone_postal_codes: row.get("service_zone_postal_codes"),
113            certifications: row.get("certifications"),
114            ipi_registration: row.get("ipi_registration"),
115            bce_number: row.get("bce_number"),
116            rating_avg: row.get("rating_avg"),
117            reviews_count: row.get("reviews_count"),
118            is_verified: row.get("is_verified"),
119            public_profile_slug: row.get("public_profile_slug"),
120            created_at: row.get("created_at"),
121            updated_at: row.get("updated_at"),
122        }))
123    }
124
125    async fn find_all(
126        &self,
127        organization_id: Option<Uuid>,
128        page: i64,
129        per_page: i64,
130    ) -> Result<Vec<ServiceProvider>, String> {
131        if page < 1 || per_page < 1 {
132            return Err("Page and per_page must be >= 1".to_string());
133        }
134
135        let offset = (page - 1) * per_page;
136
137        let query = if let Some(org_id) = organization_id {
138            sqlx::query(
139                r#"
140                SELECT id, organization_id, company_name, trade_category,
141                       specializations, service_zone_postal_codes, certifications,
142                       ipi_registration, bce_number, rating_avg, reviews_count,
143                       is_verified, public_profile_slug, created_at, updated_at
144                FROM service_providers
145                WHERE organization_id = $1
146                ORDER BY created_at DESC
147                LIMIT $2 OFFSET $3
148                "#,
149            )
150            .bind(org_id)
151            .bind(per_page)
152            .bind(offset)
153            .fetch_all(&self.pool)
154            .await
155        } else {
156            sqlx::query(
157                r#"
158                SELECT id, organization_id, company_name, trade_category,
159                       specializations, service_zone_postal_codes, certifications,
160                       ipi_registration, bce_number, rating_avg, reviews_count,
161                       is_verified, public_profile_slug, created_at, updated_at
162                FROM service_providers
163                ORDER BY created_at DESC
164                LIMIT $1 OFFSET $2
165                "#,
166            )
167            .bind(per_page)
168            .bind(offset)
169            .fetch_all(&self.pool)
170            .await
171        };
172
173        let rows = query.map_err(|e| format!("Database error: {}", e))?;
174
175        Ok(rows
176            .iter()
177            .map(|row| ServiceProvider {
178                id: row.get("id"),
179                organization_id: row.get("organization_id"),
180                company_name: row.get("company_name"),
181                trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
182                    .unwrap_or(TradeCategory::Syndic),
183                specializations: row.get("specializations"),
184                service_zone_postal_codes: row.get("service_zone_postal_codes"),
185                certifications: row.get("certifications"),
186                ipi_registration: row.get("ipi_registration"),
187                bce_number: row.get("bce_number"),
188                rating_avg: row.get("rating_avg"),
189                reviews_count: row.get("reviews_count"),
190                is_verified: row.get("is_verified"),
191                public_profile_slug: row.get("public_profile_slug"),
192                created_at: row.get("created_at"),
193                updated_at: row.get("updated_at"),
194            })
195            .collect())
196    }
197
198    async fn find_by_trade_category(
199        &self,
200        category: TradeCategory,
201        page: i64,
202        per_page: i64,
203    ) -> Result<Vec<ServiceProvider>, String> {
204        if page < 1 || per_page < 1 {
205            return Err("Page and per_page must be >= 1".to_string());
206        }
207
208        let offset = (page - 1) * per_page;
209        let category_str = category.to_sql();
210
211        let rows = sqlx::query(
212            r#"
213            SELECT id, organization_id, company_name, trade_category,
214                   specializations, service_zone_postal_codes, certifications,
215                   ipi_registration, bce_number, rating_avg, reviews_count,
216                   is_verified, public_profile_slug, created_at, updated_at
217            FROM service_providers
218            WHERE trade_category = $1
219            ORDER BY rating_avg DESC NULLS LAST, created_at DESC
220            LIMIT $2 OFFSET $3
221            "#,
222        )
223        .bind(category_str)
224        .bind(per_page)
225        .bind(offset)
226        .fetch_all(&self.pool)
227        .await
228        .map_err(|e| format!("Database error: {}", e))?;
229
230        Ok(rows
231            .iter()
232            .map(|row| ServiceProvider {
233                id: row.get("id"),
234                organization_id: row.get("organization_id"),
235                company_name: row.get("company_name"),
236                trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
237                    .unwrap_or(TradeCategory::Syndic),
238                specializations: row.get("specializations"),
239                service_zone_postal_codes: row.get("service_zone_postal_codes"),
240                certifications: row.get("certifications"),
241                ipi_registration: row.get("ipi_registration"),
242                bce_number: row.get("bce_number"),
243                rating_avg: row.get("rating_avg"),
244                reviews_count: row.get("reviews_count"),
245                is_verified: row.get("is_verified"),
246                public_profile_slug: row.get("public_profile_slug"),
247                created_at: row.get("created_at"),
248                updated_at: row.get("updated_at"),
249            })
250            .collect())
251    }
252
253    async fn search(
254        &self,
255        query: &str,
256        postal_code: Option<&str>,
257        page: i64,
258        per_page: i64,
259    ) -> Result<Vec<ServiceProvider>, String> {
260        if page < 1 || per_page < 1 {
261            return Err("Page and per_page must be >= 1".to_string());
262        }
263
264        let offset = (page - 1) * per_page;
265        let search_query = format!("%{}%", query);
266
267        let rows = if let Some(postal) = postal_code {
268            sqlx::query(
269                r#"
270                SELECT id, organization_id, company_name, trade_category,
271                       specializations, service_zone_postal_codes, certifications,
272                       ipi_registration, bce_number, rating_avg, reviews_count,
273                       is_verified, public_profile_slug, created_at, updated_at
274                FROM service_providers
275                WHERE (company_name ILIKE $1 OR specializations::text ILIKE $1)
276                  AND service_zone_postal_codes @> ARRAY[$2]
277                ORDER BY rating_avg DESC NULLS LAST, created_at DESC
278                LIMIT $3 OFFSET $4
279                "#,
280            )
281            .bind(&search_query)
282            .bind(postal)
283            .bind(per_page)
284            .bind(offset)
285            .fetch_all(&self.pool)
286            .await
287        } else {
288            sqlx::query(
289                r#"
290                SELECT id, organization_id, company_name, trade_category,
291                       specializations, service_zone_postal_codes, certifications,
292                       ipi_registration, bce_number, rating_avg, reviews_count,
293                       is_verified, public_profile_slug, created_at, updated_at
294                FROM service_providers
295                WHERE company_name ILIKE $1 OR specializations::text ILIKE $1
296                ORDER BY rating_avg DESC NULLS LAST, created_at DESC
297                LIMIT $2 OFFSET $3
298                "#,
299            )
300            .bind(&search_query)
301            .bind(per_page)
302            .bind(offset)
303            .fetch_all(&self.pool)
304            .await
305        };
306
307        let rows = rows.map_err(|e| format!("Database error: {}", e))?;
308
309        Ok(rows
310            .iter()
311            .map(|row| ServiceProvider {
312                id: row.get("id"),
313                organization_id: row.get("organization_id"),
314                company_name: row.get("company_name"),
315                trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
316                    .unwrap_or(TradeCategory::Syndic),
317                specializations: row.get("specializations"),
318                service_zone_postal_codes: row.get("service_zone_postal_codes"),
319                certifications: row.get("certifications"),
320                ipi_registration: row.get("ipi_registration"),
321                bce_number: row.get("bce_number"),
322                rating_avg: row.get("rating_avg"),
323                reviews_count: row.get("reviews_count"),
324                is_verified: row.get("is_verified"),
325                public_profile_slug: row.get("public_profile_slug"),
326                created_at: row.get("created_at"),
327                updated_at: row.get("updated_at"),
328            })
329            .collect())
330    }
331
332    async fn update(&self, provider: &ServiceProvider) -> Result<ServiceProvider, String> {
333        sqlx::query(
334            r#"
335            UPDATE service_providers
336            SET company_name = $1,
337                specializations = $2,
338                service_zone_postal_codes = $3,
339                certifications = $4,
340                ipi_registration = $5,
341                bce_number = $6,
342                rating_avg = $7,
343                reviews_count = $8,
344                is_verified = $9,
345                updated_at = $10
346            WHERE id = $11
347            "#,
348        )
349        .bind(&provider.company_name)
350        .bind(&provider.specializations)
351        .bind(&provider.service_zone_postal_codes)
352        .bind(&provider.certifications)
353        .bind(&provider.ipi_registration)
354        .bind(&provider.bce_number)
355        .bind(provider.rating_avg)
356        .bind(provider.reviews_count)
357        .bind(provider.is_verified)
358        .bind(provider.updated_at)
359        .bind(provider.id)
360        .execute(&self.pool)
361        .await
362        .map_err(|e| format!("Database error updating service provider: {}", e))?;
363
364        Ok(provider.clone())
365    }
366
367    async fn delete(&self, id: Uuid) -> Result<(), String> {
368        sqlx::query("DELETE FROM service_providers WHERE id = $1")
369            .bind(id)
370            .execute(&self.pool)
371            .await
372            .map_err(|e| format!("Database error deleting service provider: {}", e))?;
373
374        Ok(())
375    }
376
377    async fn update_rating(
378        &self,
379        id: Uuid,
380        rating_avg: f64,
381        reviews_count: i32,
382    ) -> Result<(), String> {
383        sqlx::query(
384            r#"
385            UPDATE service_providers
386            SET rating_avg = $1, reviews_count = $2, updated_at = NOW()
387            WHERE id = $3
388            "#,
389        )
390        .bind(rating_avg)
391        .bind(reviews_count)
392        .bind(id)
393        .execute(&self.pool)
394        .await
395        .map_err(|e| format!("Database error updating rating: {}", e))?;
396
397        Ok(())
398    }
399
400    async fn count_by_organization(&self, organization_id: Uuid) -> Result<i64, String> {
401        let row = sqlx::query(
402            "SELECT COUNT(*) as count FROM service_providers WHERE organization_id = $1",
403        )
404        .bind(organization_id)
405        .fetch_one(&self.pool)
406        .await
407        .map_err(|e| format!("Database error: {}", e))?;
408
409        Ok(row.get::<i64, _>("count"))
410    }
411}