1use crate::application::ports::service_provider_repository::ServiceProviderRepository;
2use crate::domain::entities::service_provider::{ServiceProvider, TradeCategory};
3use crate::infrastructure::database::pool::DbPool;
4use async_trait::async_trait;
5use sqlx::Row;
6use uuid::Uuid;
7
8pub struct PostgresServiceProviderRepository {
9 pool: DbPool,
10}
11
12impl PostgresServiceProviderRepository {
13 pub fn new(pool: DbPool) -> Self {
14 Self { pool }
15 }
16}
17
18#[async_trait]
19impl ServiceProviderRepository for PostgresServiceProviderRepository {
20 async fn create(&self, provider: &ServiceProvider) -> Result<ServiceProvider, String> {
21 sqlx::query(
22 r#"
23 INSERT INTO service_providers (
24 id, organization_id, company_name, trade_category,
25 specializations, service_zone_postal_codes, certifications,
26 ipi_registration, bce_number, rating_avg, reviews_count,
27 is_verified, public_profile_slug, created_at, updated_at
28 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
29 "#,
30 )
31 .bind(provider.id)
32 .bind(provider.organization_id)
33 .bind(&provider.company_name)
34 .bind(provider.trade_category.to_sql())
35 .bind(&provider.specializations)
36 .bind(&provider.service_zone_postal_codes)
37 .bind(&provider.certifications)
38 .bind(&provider.ipi_registration)
39 .bind(&provider.bce_number)
40 .bind(provider.rating_avg)
41 .bind(provider.reviews_count)
42 .bind(provider.is_verified)
43 .bind(&provider.public_profile_slug)
44 .bind(provider.created_at)
45 .bind(provider.updated_at)
46 .execute(&self.pool)
47 .await
48 .map_err(|e| format!("Database error creating service provider: {}", e))?;
49
50 Ok(provider.clone())
51 }
52
53 async fn find_by_id(&self, id: Uuid) -> Result<Option<ServiceProvider>, String> {
54 let row = sqlx::query(
55 r#"
56 SELECT id, organization_id, company_name, trade_category,
57 specializations, service_zone_postal_codes, certifications,
58 ipi_registration, bce_number, rating_avg, reviews_count,
59 is_verified, public_profile_slug, created_at, updated_at
60 FROM service_providers
61 WHERE id = $1
62 "#,
63 )
64 .bind(id)
65 .fetch_optional(&self.pool)
66 .await
67 .map_err(|e| format!("Database error: {}", e))?;
68
69 Ok(row.map(|row| ServiceProvider {
70 id: row.get("id"),
71 organization_id: row.get("organization_id"),
72 company_name: row.get("company_name"),
73 trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
74 .unwrap_or(TradeCategory::Syndic),
75 specializations: row.get("specializations"),
76 service_zone_postal_codes: row.get("service_zone_postal_codes"),
77 certifications: row.get("certifications"),
78 ipi_registration: row.get("ipi_registration"),
79 bce_number: row.get("bce_number"),
80 rating_avg: row.get("rating_avg"),
81 reviews_count: row.get("reviews_count"),
82 is_verified: row.get("is_verified"),
83 public_profile_slug: row.get("public_profile_slug"),
84 created_at: row.get("created_at"),
85 updated_at: row.get("updated_at"),
86 }))
87 }
88
89 async fn find_by_slug(&self, slug: &str) -> Result<Option<ServiceProvider>, String> {
90 let row = sqlx::query(
91 r#"
92 SELECT id, organization_id, company_name, trade_category,
93 specializations, service_zone_postal_codes, certifications,
94 ipi_registration, bce_number, rating_avg, reviews_count,
95 is_verified, public_profile_slug, created_at, updated_at
96 FROM service_providers
97 WHERE public_profile_slug = $1
98 "#,
99 )
100 .bind(slug)
101 .fetch_optional(&self.pool)
102 .await
103 .map_err(|e| format!("Database error: {}", e))?;
104
105 Ok(row.map(|row| ServiceProvider {
106 id: row.get("id"),
107 organization_id: row.get("organization_id"),
108 company_name: row.get("company_name"),
109 trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
110 .unwrap_or(TradeCategory::Syndic),
111 specializations: row.get("specializations"),
112 service_zone_postal_codes: row.get("service_zone_postal_codes"),
113 certifications: row.get("certifications"),
114 ipi_registration: row.get("ipi_registration"),
115 bce_number: row.get("bce_number"),
116 rating_avg: row.get("rating_avg"),
117 reviews_count: row.get("reviews_count"),
118 is_verified: row.get("is_verified"),
119 public_profile_slug: row.get("public_profile_slug"),
120 created_at: row.get("created_at"),
121 updated_at: row.get("updated_at"),
122 }))
123 }
124
125 async fn find_all(
126 &self,
127 organization_id: Option<Uuid>,
128 page: i64,
129 per_page: i64,
130 ) -> Result<Vec<ServiceProvider>, String> {
131 if page < 1 || per_page < 1 {
132 return Err("Page and per_page must be >= 1".to_string());
133 }
134
135 let offset = (page - 1) * per_page;
136
137 let query = if let Some(org_id) = organization_id {
138 sqlx::query(
139 r#"
140 SELECT id, organization_id, company_name, trade_category,
141 specializations, service_zone_postal_codes, certifications,
142 ipi_registration, bce_number, rating_avg, reviews_count,
143 is_verified, public_profile_slug, created_at, updated_at
144 FROM service_providers
145 WHERE organization_id = $1
146 ORDER BY created_at DESC
147 LIMIT $2 OFFSET $3
148 "#,
149 )
150 .bind(org_id)
151 .bind(per_page)
152 .bind(offset)
153 .fetch_all(&self.pool)
154 .await
155 } else {
156 sqlx::query(
157 r#"
158 SELECT id, organization_id, company_name, trade_category,
159 specializations, service_zone_postal_codes, certifications,
160 ipi_registration, bce_number, rating_avg, reviews_count,
161 is_verified, public_profile_slug, created_at, updated_at
162 FROM service_providers
163 ORDER BY created_at DESC
164 LIMIT $1 OFFSET $2
165 "#,
166 )
167 .bind(per_page)
168 .bind(offset)
169 .fetch_all(&self.pool)
170 .await
171 };
172
173 let rows = query.map_err(|e| format!("Database error: {}", e))?;
174
175 Ok(rows
176 .iter()
177 .map(|row| ServiceProvider {
178 id: row.get("id"),
179 organization_id: row.get("organization_id"),
180 company_name: row.get("company_name"),
181 trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
182 .unwrap_or(TradeCategory::Syndic),
183 specializations: row.get("specializations"),
184 service_zone_postal_codes: row.get("service_zone_postal_codes"),
185 certifications: row.get("certifications"),
186 ipi_registration: row.get("ipi_registration"),
187 bce_number: row.get("bce_number"),
188 rating_avg: row.get("rating_avg"),
189 reviews_count: row.get("reviews_count"),
190 is_verified: row.get("is_verified"),
191 public_profile_slug: row.get("public_profile_slug"),
192 created_at: row.get("created_at"),
193 updated_at: row.get("updated_at"),
194 })
195 .collect())
196 }
197
198 async fn find_by_trade_category(
199 &self,
200 category: TradeCategory,
201 page: i64,
202 per_page: i64,
203 ) -> Result<Vec<ServiceProvider>, String> {
204 if page < 1 || per_page < 1 {
205 return Err("Page and per_page must be >= 1".to_string());
206 }
207
208 let offset = (page - 1) * per_page;
209 let category_str = category.to_sql();
210
211 let rows = sqlx::query(
212 r#"
213 SELECT id, organization_id, company_name, trade_category,
214 specializations, service_zone_postal_codes, certifications,
215 ipi_registration, bce_number, rating_avg, reviews_count,
216 is_verified, public_profile_slug, created_at, updated_at
217 FROM service_providers
218 WHERE trade_category = $1
219 ORDER BY rating_avg DESC NULLS LAST, created_at DESC
220 LIMIT $2 OFFSET $3
221 "#,
222 )
223 .bind(category_str)
224 .bind(per_page)
225 .bind(offset)
226 .fetch_all(&self.pool)
227 .await
228 .map_err(|e| format!("Database error: {}", e))?;
229
230 Ok(rows
231 .iter()
232 .map(|row| ServiceProvider {
233 id: row.get("id"),
234 organization_id: row.get("organization_id"),
235 company_name: row.get("company_name"),
236 trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
237 .unwrap_or(TradeCategory::Syndic),
238 specializations: row.get("specializations"),
239 service_zone_postal_codes: row.get("service_zone_postal_codes"),
240 certifications: row.get("certifications"),
241 ipi_registration: row.get("ipi_registration"),
242 bce_number: row.get("bce_number"),
243 rating_avg: row.get("rating_avg"),
244 reviews_count: row.get("reviews_count"),
245 is_verified: row.get("is_verified"),
246 public_profile_slug: row.get("public_profile_slug"),
247 created_at: row.get("created_at"),
248 updated_at: row.get("updated_at"),
249 })
250 .collect())
251 }
252
253 async fn search(
254 &self,
255 query: &str,
256 postal_code: Option<&str>,
257 page: i64,
258 per_page: i64,
259 ) -> Result<Vec<ServiceProvider>, String> {
260 if page < 1 || per_page < 1 {
261 return Err("Page and per_page must be >= 1".to_string());
262 }
263
264 let offset = (page - 1) * per_page;
265 let search_query = format!("%{}%", query);
266
267 let rows = if let Some(postal) = postal_code {
268 sqlx::query(
269 r#"
270 SELECT id, organization_id, company_name, trade_category,
271 specializations, service_zone_postal_codes, certifications,
272 ipi_registration, bce_number, rating_avg, reviews_count,
273 is_verified, public_profile_slug, created_at, updated_at
274 FROM service_providers
275 WHERE (company_name ILIKE $1 OR specializations::text ILIKE $1)
276 AND service_zone_postal_codes @> ARRAY[$2]
277 ORDER BY rating_avg DESC NULLS LAST, created_at DESC
278 LIMIT $3 OFFSET $4
279 "#,
280 )
281 .bind(&search_query)
282 .bind(postal)
283 .bind(per_page)
284 .bind(offset)
285 .fetch_all(&self.pool)
286 .await
287 } else {
288 sqlx::query(
289 r#"
290 SELECT id, organization_id, company_name, trade_category,
291 specializations, service_zone_postal_codes, certifications,
292 ipi_registration, bce_number, rating_avg, reviews_count,
293 is_verified, public_profile_slug, created_at, updated_at
294 FROM service_providers
295 WHERE company_name ILIKE $1 OR specializations::text ILIKE $1
296 ORDER BY rating_avg DESC NULLS LAST, created_at DESC
297 LIMIT $2 OFFSET $3
298 "#,
299 )
300 .bind(&search_query)
301 .bind(per_page)
302 .bind(offset)
303 .fetch_all(&self.pool)
304 .await
305 };
306
307 let rows = rows.map_err(|e| format!("Database error: {}", e))?;
308
309 Ok(rows
310 .iter()
311 .map(|row| ServiceProvider {
312 id: row.get("id"),
313 organization_id: row.get("organization_id"),
314 company_name: row.get("company_name"),
315 trade_category: TradeCategory::from_sql(&row.get::<String, _>("trade_category"))
316 .unwrap_or(TradeCategory::Syndic),
317 specializations: row.get("specializations"),
318 service_zone_postal_codes: row.get("service_zone_postal_codes"),
319 certifications: row.get("certifications"),
320 ipi_registration: row.get("ipi_registration"),
321 bce_number: row.get("bce_number"),
322 rating_avg: row.get("rating_avg"),
323 reviews_count: row.get("reviews_count"),
324 is_verified: row.get("is_verified"),
325 public_profile_slug: row.get("public_profile_slug"),
326 created_at: row.get("created_at"),
327 updated_at: row.get("updated_at"),
328 })
329 .collect())
330 }
331
332 async fn update(&self, provider: &ServiceProvider) -> Result<ServiceProvider, String> {
333 sqlx::query(
334 r#"
335 UPDATE service_providers
336 SET company_name = $1,
337 specializations = $2,
338 service_zone_postal_codes = $3,
339 certifications = $4,
340 ipi_registration = $5,
341 bce_number = $6,
342 rating_avg = $7,
343 reviews_count = $8,
344 is_verified = $9,
345 updated_at = $10
346 WHERE id = $11
347 "#,
348 )
349 .bind(&provider.company_name)
350 .bind(&provider.specializations)
351 .bind(&provider.service_zone_postal_codes)
352 .bind(&provider.certifications)
353 .bind(&provider.ipi_registration)
354 .bind(&provider.bce_number)
355 .bind(provider.rating_avg)
356 .bind(provider.reviews_count)
357 .bind(provider.is_verified)
358 .bind(provider.updated_at)
359 .bind(provider.id)
360 .execute(&self.pool)
361 .await
362 .map_err(|e| format!("Database error updating service provider: {}", e))?;
363
364 Ok(provider.clone())
365 }
366
367 async fn delete(&self, id: Uuid) -> Result<(), String> {
368 sqlx::query("DELETE FROM service_providers WHERE id = $1")
369 .bind(id)
370 .execute(&self.pool)
371 .await
372 .map_err(|e| format!("Database error deleting service provider: {}", e))?;
373
374 Ok(())
375 }
376
377 async fn update_rating(
378 &self,
379 id: Uuid,
380 rating_avg: f64,
381 reviews_count: i32,
382 ) -> Result<(), String> {
383 sqlx::query(
384 r#"
385 UPDATE service_providers
386 SET rating_avg = $1, reviews_count = $2, updated_at = NOW()
387 WHERE id = $3
388 "#,
389 )
390 .bind(rating_avg)
391 .bind(reviews_count)
392 .bind(id)
393 .execute(&self.pool)
394 .await
395 .map_err(|e| format!("Database error updating rating: {}", e))?;
396
397 Ok(())
398 }
399
400 async fn count_by_organization(&self, organization_id: Uuid) -> Result<i64, String> {
401 let row = sqlx::query(
402 "SELECT COUNT(*) as count FROM service_providers WHERE organization_id = $1",
403 )
404 .bind(organization_id)
405 .fetch_one(&self.pool)
406 .await
407 .map_err(|e| format!("Database error: {}", e))?;
408
409 Ok(row.get::<i64, _>("count"))
410 }
411}