koprogo_api/infrastructure/totp/
totp_generator.rs1use aes_gcm::{
2 aead::{Aead, KeyInit},
3 Aes256Gcm, Nonce,
4};
5use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
6use image::Luma;
7use qrcode::QrCode;
8use rand::Rng;
9use totp_lite::{totp_custom, Sha1};
10
11pub struct TotpGenerator;
24
25impl TotpGenerator {
26 pub fn generate_secret() -> String {
36 let mut rng = rand::rng();
37 let bytes: Vec<u8> = (0..32).map(|_| rng.random()).collect();
38 Self::base32_encode(&bytes)
39 }
40
41 pub fn generate_qr_code(
56 secret: &str,
57 issuer: &str,
58 account_name: &str,
59 ) -> Result<String, String> {
60 let uri = format!(
62 "otpauth://totp/{}:{}?secret={}&issuer={}",
63 urlencoding::encode(issuer),
64 urlencoding::encode(account_name),
65 secret,
66 urlencoding::encode(issuer)
67 );
68
69 let code = QrCode::new(uri.as_bytes())
71 .map_err(|e| format!("Failed to generate QR code: {}", e))?;
72
73 let image = code.render::<Luma<u8>>().build();
75 let mut png_bytes = Vec::new();
76 image
77 .write_to(
78 &mut std::io::Cursor::new(&mut png_bytes),
79 image::ImageFormat::Png,
80 )
81 .map_err(|e| format!("Failed to encode PNG: {}", e))?;
82
83 let base64_image = BASE64.encode(&png_bytes);
85 Ok(format!("data:image/png;base64,{}", base64_image))
86 }
87
88 pub fn verify_code(secret: &str, code: &str) -> Result<bool, String> {
107 if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
109 return Ok(false);
110 }
111
112 let secret_bytes = Self::base32_decode(secret)?;
114
115 let now = std::time::SystemTime::now()
117 .duration_since(std::time::UNIX_EPOCH)
118 .map_err(|e| format!("System time error: {}", e))?
119 .as_secs();
120
121 for time_offset in [-1, 0, 1] {
123 let time_step = (now as i64 + time_offset * 30) as u64;
124 let expected_code = totp_custom::<Sha1>(30, 6, &secret_bytes, time_step);
125
126 if code == expected_code {
127 return Ok(true);
128 }
129 }
130
131 Ok(false)
132 }
133
134 pub fn generate_backup_codes() -> Vec<String> {
149 let mut rng = rand::rng();
150 const CHARSET: &[u8] = b"23456789ABCDEFGHJKMNPQRSTUVWXYZ"; const CODE_LENGTH: usize = 8;
152
153 (0..10)
154 .map(|_| {
155 let code: String = (0..CODE_LENGTH)
156 .map(|_| {
157 let idx = rng.random_range(0..CHARSET.len());
158 CHARSET[idx] as char
159 })
160 .collect();
161
162 format!("{}-{}", &code[0..4], &code[4..8])
164 })
165 .collect()
166 }
167
168 pub fn hash_backup_code(code: &str) -> Result<String, String> {
179 bcrypt::hash(code, 12).map_err(|e| format!("Failed to hash backup code: {}", e))
180 }
181
182 pub fn verify_backup_code(code: &str, hash: &str) -> Result<bool, String> {
191 bcrypt::verify(code, hash).map_err(|e| format!("Failed to verify backup code: {}", e))
192 }
193
194 pub fn encrypt_secret(secret: &str, key: &[u8; 32]) -> Result<String, String> {
211 let cipher = Aes256Gcm::new(key.into());
212
213 let mut nonce_bytes = [0u8; 12];
215 rand::rng().fill(&mut nonce_bytes);
216 #[allow(deprecated)]
217 let nonce = Nonce::from_slice(&nonce_bytes);
218
219 let ciphertext = cipher
221 .encrypt(nonce, secret.as_bytes())
222 .map_err(|e| format!("Encryption failed: {}", e))?;
223
224 let mut encrypted = nonce_bytes.to_vec();
226 encrypted.extend_from_slice(&ciphertext);
227 Ok(BASE64.encode(&encrypted))
228 }
229
230 pub fn decrypt_secret(encrypted: &str, key: &[u8; 32]) -> Result<String, String> {
242 let cipher = Aes256Gcm::new(key.into());
243
244 let encrypted_bytes = BASE64
246 .decode(encrypted)
247 .map_err(|e| format!("Invalid Base64: {}", e))?;
248
249 if encrypted_bytes.len() < 12 {
250 return Err("Encrypted data too short".to_string());
251 }
252
253 let (nonce_bytes, ciphertext) = encrypted_bytes.split_at(12);
255 #[allow(deprecated)]
256 let nonce = Nonce::from_slice(nonce_bytes);
257
258 let plaintext = cipher
260 .decrypt(nonce, ciphertext)
261 .map_err(|e| format!("Decryption failed: {}", e))?;
262
263 String::from_utf8(plaintext).map_err(|e| format!("Invalid UTF-8: {}", e))
264 }
265
266 fn base32_encode(bytes: &[u8]) -> String {
272 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
273 let mut result = String::new();
274 let mut bits = 0u32;
275 let mut bit_count = 0;
276
277 for &byte in bytes {
278 bits = (bits << 8) | byte as u32;
279 bit_count += 8;
280
281 while bit_count >= 5 {
282 bit_count -= 5;
283 let index = ((bits >> bit_count) & 0x1F) as usize;
284 result.push(ALPHABET[index] as char);
285 }
286 }
287
288 if bit_count > 0 {
289 let index = ((bits << (5 - bit_count)) & 0x1F) as usize;
290 result.push(ALPHABET[index] as char);
291 }
292
293 result
294 }
295
296 fn base32_decode(encoded: &str) -> Result<Vec<u8>, String> {
298 let encoded = encoded.to_uppercase();
299 let mut result = Vec::new();
300 let mut bits = 0u32;
301 let mut bit_count = 0;
302
303 for ch in encoded.chars() {
304 if ch == '=' {
305 break; }
307
308 let value = match ch {
309 'A'..='Z' => (ch as u32) - ('A' as u32),
310 '2'..='7' => 26 + (ch as u32) - ('2' as u32),
311 _ => return Err(format!("Invalid Base32 character: {}", ch)),
312 };
313
314 bits = (bits << 5) | value;
315 bit_count += 5;
316
317 if bit_count >= 8 {
318 bit_count -= 8;
319 result.push((bits >> bit_count) as u8);
320 bits &= (1 << bit_count) - 1;
321 }
322 }
323
324 Ok(result)
325 }
326
327 #[cfg(test)]
329 pub fn generate_current_code(secret: &str) -> Result<String, String> {
330 let secret_bytes = Self::base32_decode(secret)?;
331 let now = std::time::SystemTime::now()
332 .duration_since(std::time::UNIX_EPOCH)
333 .unwrap()
334 .as_secs();
335 Ok(totp_custom::<Sha1>(30, 6, &secret_bytes, now))
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_generate_secret() {
345 let secret = TotpGenerator::generate_secret();
346 assert_eq!(secret.len(), 52); assert!(secret
348 .chars()
349 .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit()));
350 }
351
352 #[test]
353 fn test_verify_code_valid() {
354 let secret = TotpGenerator::generate_secret();
355 let code = TotpGenerator::generate_current_code(&secret).unwrap();
356 assert!(TotpGenerator::verify_code(&secret, &code).unwrap());
357 }
358
359 #[test]
360 fn test_verify_code_invalid_format() {
361 let secret = TotpGenerator::generate_secret();
362 assert!(!TotpGenerator::verify_code(&secret, "12345").unwrap()); assert!(!TotpGenerator::verify_code(&secret, "1234567").unwrap()); assert!(!TotpGenerator::verify_code(&secret, "ABCDEF").unwrap()); }
366
367 #[test]
368 fn test_generate_backup_codes() {
369 let codes = TotpGenerator::generate_backup_codes();
370 assert_eq!(codes.len(), 10);
371
372 for code in &codes {
373 assert_eq!(code.len(), 9); assert!(code.contains('-'));
375 let parts: Vec<&str> = code.split('-').collect();
376 assert_eq!(parts.len(), 2);
377 assert_eq!(parts[0].len(), 4);
378 assert_eq!(parts[1].len(), 4);
379
380 assert!(!code.contains('0'));
382 assert!(!code.contains('O'));
383 assert!(!code.contains('1'));
384 assert!(!code.contains('I'));
385 assert!(!code.contains('L'));
386 }
387
388 let unique_codes: std::collections::HashSet<_> = codes.iter().collect();
390 assert_eq!(unique_codes.len(), 10);
391 }
392
393 #[test]
394 fn test_hash_and_verify_backup_code() {
395 let code = "ABCD-EFGH";
396 let hash = TotpGenerator::hash_backup_code(code).unwrap();
397 assert_eq!(hash.len(), 60); assert!(TotpGenerator::verify_backup_code(code, &hash).unwrap());
401
402 assert!(!TotpGenerator::verify_backup_code("WXYZ-1234", &hash).unwrap());
404 }
405
406 #[test]
407 fn test_encrypt_decrypt_secret() {
408 let secret = TotpGenerator::generate_secret();
409 let key: [u8; 32] = rand::rng().random();
410
411 let encrypted = TotpGenerator::encrypt_secret(&secret, &key).unwrap();
412 assert_ne!(encrypted, secret); let decrypted = TotpGenerator::decrypt_secret(&encrypted, &key).unwrap();
415 assert_eq!(decrypted, secret);
416 }
417
418 #[test]
419 fn test_decrypt_with_wrong_key() {
420 let secret = TotpGenerator::generate_secret();
421 let key1: [u8; 32] = rand::rng().random();
422 let key2: [u8; 32] = rand::rng().random();
423
424 let encrypted = TotpGenerator::encrypt_secret(&secret, &key1).unwrap();
425 let result = TotpGenerator::decrypt_secret(&encrypted, &key2);
426 assert!(result.is_err()); }
428
429 #[test]
430 fn test_base32_encode_decode() {
431 let bytes = vec![0x48, 0x65, 0x6C, 0x6C, 0x6F]; let encoded = TotpGenerator::base32_encode(&bytes);
433 assert_eq!(encoded, "JBSWY3DP");
434
435 let decoded = TotpGenerator::base32_decode(&encoded).unwrap();
436 assert_eq!(decoded, bytes);
437 }
438
439 #[test]
440 fn test_generate_qr_code() {
441 let secret = TotpGenerator::generate_secret();
442 let qr_data_url =
443 TotpGenerator::generate_qr_code(&secret, "KoproGo", "user@example.com").unwrap();
444
445 assert!(qr_data_url.starts_with("data:image/png;base64,"));
446 assert!(qr_data_url.len() > 100); }
448}