1use std::{error::Error, fmt};
2
3use cast;
4use rust_decimal::RoundingStrategy::MidpointAwayFromZero;
5
6use super::{TimeUnit, Value, ValueRef};
7
8#[derive(Debug)]
10#[non_exhaustive]
11pub enum FromSqlError {
12 InvalidType,
15
16 OutOfRange(i128),
19
20 #[cfg(feature = "uuid")]
23 InvalidUuidSize(usize),
24
25 Other(Box<dyn Error + Send + Sync + 'static>),
27}
28
29impl PartialEq for FromSqlError {
30 fn eq(&self, other: &Self) -> bool {
31 match (self, other) {
32 (Self::InvalidType, Self::InvalidType) => true,
33 (Self::OutOfRange(n1), Self::OutOfRange(n2)) => n1 == n2,
34 #[cfg(feature = "uuid")]
35 (Self::InvalidUuidSize(s1), Self::InvalidUuidSize(s2)) => s1 == s2,
36 (..) => false,
37 }
38 }
39}
40
41impl fmt::Display for FromSqlError {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match *self {
44 Self::InvalidType => write!(f, "Invalid type"),
45 Self::OutOfRange(i) => write!(f, "Value {i} out of range"),
46 #[cfg(feature = "uuid")]
47 Self::InvalidUuidSize(s) => {
48 write!(f, "Cannot read UUID value out of {s} byte blob")
49 }
50 Self::Other(ref err) => err.fmt(f),
51 }
52 }
53}
54
55impl Error for FromSqlError {
56 fn source(&self) -> Option<&(dyn Error + 'static)> {
57 if let Self::Other(ref err) = self {
58 Some(&**err)
59 } else {
60 None
61 }
62 }
63}
64
65pub type FromSqlResult<T> = Result<T, FromSqlError>;
67
68pub trait FromSql: Sized {
70 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self>;
72}
73
74macro_rules! from_sql_integral(
75 ($t:ident) => (
76 impl FromSql for $t {
77 #[inline]
78 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
79 match value {
80 ValueRef::TinyInt(i) => <$t as cast::From<i8>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
81 ValueRef::SmallInt(i) => <$t as cast::From<i16>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
82 ValueRef::Int(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
83 ValueRef::BigInt(i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
84 ValueRef::HugeInt(i) => <$t as cast::From<i128>>::cast(i).into_result(FromSqlError::OutOfRange(i)),
85
86 ValueRef::UTinyInt(i) => <$t as cast::From<u8>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
87 ValueRef::USmallInt(i) => <$t as cast::From<u16>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
88 ValueRef::UInt(i) => <$t as cast::From<u32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
89 ValueRef::UBigInt(i) => <$t as cast::From<u64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
90
91 ValueRef::Float(i) => <$t as cast::From<f32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
92 ValueRef::Double(i) => <$t as cast::From<f64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
93
94 ValueRef::Decimal(d) => {
95 let rounded = d.round_dp_with_strategy(0, MidpointAwayFromZero);
97 <$t as cast::From<i128>>::cast(rounded.mantissa()).into_result(FromSqlError::OutOfRange(d.mantissa()))
98 }
99
100 ValueRef::Timestamp(_, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
101 ValueRef::Date32(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
102 ValueRef::Time64(TimeUnit::Microsecond, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
103 ValueRef::Text(_) => {
104 let s = value.as_str()?;
105 s.parse::<$t>().or_else(|_| {
106 s.parse::<i128>()
107 .map_err(|_| FromSqlError::InvalidType)
108 .and_then(|i| Err(FromSqlError::OutOfRange(i)))
109 })
110 }
111 _ => Err(FromSqlError::InvalidType),
112 }
113 }
114 }
115 )
116);
117
118trait IntoResult {
121 type Value;
122 fn into_result<E>(self, err: E) -> Result<Self::Value, E>;
123}
124
125macro_rules! into_result_integral(
127 ($type_name:ident) => (
128 impl IntoResult for $type_name {
129 type Value = $type_name;
130
131 #[inline]
132 fn into_result<E>(self, _err: E) -> Result<Self::Value, E> {
133 Ok(self)
134 }
135 }
136 )
137);
138
139into_result_integral!(i8);
140into_result_integral!(i16);
141into_result_integral!(i32);
142into_result_integral!(i64);
143into_result_integral!(i128);
144into_result_integral!(isize);
145into_result_integral!(u8);
146into_result_integral!(u16);
147into_result_integral!(u32);
148into_result_integral!(u64);
149into_result_integral!(usize);
150into_result_integral!(f32);
151into_result_integral!(f64);
152
153impl<T, E> IntoResult for Result<T, E> {
154 type Value = T;
155
156 #[inline]
157 fn into_result<E2>(self, err: E2) -> Result<Self::Value, E2> {
158 self.map_err(|_| err)
159 }
160}
161
162from_sql_integral!(i8);
163from_sql_integral!(i16);
164from_sql_integral!(i32);
165from_sql_integral!(i64);
166from_sql_integral!(i128);
167from_sql_integral!(isize);
168from_sql_integral!(u8);
169from_sql_integral!(u16);
170from_sql_integral!(u32);
171from_sql_integral!(u64);
172from_sql_integral!(usize);
173from_sql_integral!(f32);
174from_sql_integral!(f64);
175
176impl FromSql for bool {
177 #[inline]
178 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
179 match value {
180 ValueRef::Boolean(b) => Ok(b),
181 _ => i8::column_result(value).map(|i| i != 0),
182 }
183 }
184}
185
186impl FromSql for String {
187 #[inline]
188 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
189 match value {
190 #[cfg(feature = "chrono")]
191 ValueRef::Date32(_) => Ok(chrono::NaiveDate::column_result(value)?.format("%F").to_string()),
192 #[cfg(feature = "chrono")]
193 ValueRef::Time64(..) => Ok(chrono::NaiveTime::column_result(value)?.format("%T%.f").to_string()),
194 #[cfg(feature = "chrono")]
195 ValueRef::Timestamp(..) => Ok(chrono::NaiveDateTime::column_result(value)?
196 .format("%F %T%.f")
197 .to_string()),
198 _ => value.as_str().map(ToString::to_string),
199 }
200 }
201}
202
203impl FromSql for Box<str> {
204 #[inline]
205 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
206 value.as_str().map(Into::into)
207 }
208}
209
210impl FromSql for std::rc::Rc<str> {
211 #[inline]
212 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
213 value.as_str().map(Into::into)
214 }
215}
216
217impl FromSql for std::sync::Arc<str> {
218 #[inline]
219 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
220 value.as_str().map(Into::into)
221 }
222}
223
224impl FromSql for Vec<u8> {
225 #[inline]
226 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
227 value.as_blob().map(|b| b.to_vec())
228 }
229}
230
231#[cfg(feature = "uuid")]
232impl FromSql for uuid::Uuid {
233 #[inline]
234 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
235 match value {
236 ValueRef::Text(..) => value
237 .as_str()
238 .and_then(|s| Self::parse_str(s).map_err(|_| FromSqlError::InvalidUuidSize(s.len()))),
239 ValueRef::Blob(..) => value
240 .as_blob()
241 .and_then(|bytes| {
242 uuid::Builder::from_slice(bytes).map_err(|_| FromSqlError::InvalidUuidSize(bytes.len()))
243 })
244 .map(|builder| builder.into_uuid()),
245 _ => Err(FromSqlError::InvalidType),
246 }
247 }
248}
249
250impl<T: FromSql> FromSql for Option<T> {
251 #[inline]
252 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
253 match value {
254 ValueRef::Null => Ok(None),
255 _ => FromSql::column_result(value).map(Some),
256 }
257 }
258}
259
260impl FromSql for Value {
261 #[inline]
262 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
263 Ok(value.into())
264 }
265}
266
267#[cfg(test)]
268mod test {
269 use super::FromSql;
270 use crate::{Connection, Error, Result};
271
272 #[test]
273 fn test_timestamp_raw() -> Result<()> {
274 let db = Connection::open_in_memory()?;
275 let sql = "BEGIN;
276 CREATE TABLE ts (sec TIMESTAMP_S, milli TIMESTAMP_MS, micro TIMESTAMP_US, nano TIMESTAMP_NS );
277 INSERT INTO ts VALUES (NULL,NULL,NULL,NULL );
278 INSERT INTO ts VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01.594','2008-01-01 00:00:01.88926','2008-01-01 00:00:01.889268000' );
279 -- INSERT INTO ts VALUES (NULL,NULL,NULL,1199145601889268321 );
280 END;";
281 db.execute_batch(sql)?;
282 let v = db.query_row(
283 "SELECT sec, milli, micro, nano FROM ts WHERE sec is not null",
284 [],
285 |row| <(i64, i64, i64, i64)>::try_from(row),
286 )?;
287 assert_eq!(v, (1199145601, 1199145601594, 1199145601889260, 1199145601889268000));
288 Ok(())
289 }
290
291 #[test]
292 fn test_time64_raw() -> Result<()> {
293 let db = Connection::open_in_memory()?;
294 let sql = "BEGIN;
295 CREATE TABLE time64 (t time);
296 INSERT INTO time64 VALUES ('20:08:10.998');
297 END;";
298 db.execute_batch(sql)?;
299 let v = db.query_row("SELECT * FROM time64", [], |row| <(i64,)>::try_from(row))?;
300 assert_eq!(v, (72490998000,));
301 Ok(())
302 }
303
304 #[test]
305 fn test_date32_raw() -> Result<()> {
306 let db = Connection::open_in_memory()?;
307 let sql = "BEGIN;
308 CREATE TABLE date32 (d date);
309 INSERT INTO date32 VALUES ('2008-01-01');
310 END;";
311 db.execute_batch(sql)?;
312 let v = db.query_row("SELECT * FROM date32", [], |row| <(i32,)>::try_from(row))?;
313 assert_eq!(v, (13879,));
314 Ok(())
315 }
316
317 #[test]
318 fn test_unsigned_integer() -> Result<()> {
319 let db = Connection::open_in_memory()?;
320 let sql = "BEGIN;
321 CREATE TABLE unsigned_int (u1 utinyint, u2 usmallint, u4 uinteger, u8 ubigint);
322 INSERT INTO unsigned_int VALUES (255, 65535, 4294967295, 18446744073709551615);
323 END;";
324 db.execute_batch(sql)?;
325 let v = db.query_row("SELECT * FROM unsigned_int", [], |row| {
326 <(u8, u16, u32, u64)>::try_from(row)
327 })?;
328 assert_eq!(v, (255, 65535, 4294967295, 18446744073709551615));
329 Ok(())
330 }
331
332 #[test]
334 fn test_hugeint_max_min() -> Result<()> {
335 let db = Connection::open_in_memory()?;
336 db.execute("CREATE TABLE huge_int (u1 hugeint, u2 hugeint);", [])?;
337 let i128max: i128 = i128::MAX;
339 let i128min: i128 = i128::MIN + 1;
340 db.execute("INSERT INTO huge_int VALUES (?, ?);", [&i128max, &i128min])?;
341 let v = db.query_row("SELECT * FROM huge_int", [], |row| <(i128, i128)>::try_from(row))?;
342 assert_eq!(v, (i128max, i128min));
343 Ok(())
344 }
345
346 #[test]
347 fn test_integral_ranges() -> Result<()> {
348 let db = Connection::open_in_memory()?;
349
350 fn check_ranges<T>(db: &Connection, out_of_range: &[i128], in_range: &[i128])
351 where
352 T: Into<i128> + FromSql + ::std::fmt::Debug,
353 {
354 for n in out_of_range {
355 let err = db.query_row("SELECT ?", [n], |r| r.get::<_, T>(0)).unwrap_err();
356 match err {
357 Error::IntegralValueOutOfRange(_, value) => assert_eq!(*n, value),
358 _ => panic!("unexpected error: {err}"),
359 }
360 }
361 for n in in_range {
362 assert_eq!(*n, db.query_row("SELECT ?", [n], |r| r.get::<_, T>(0)).unwrap().into());
363 }
364 }
365
366 check_ranges::<i8>(&db, &[-129, 128], &[-128, 0, 1, 127]);
367 check_ranges::<i16>(&db, &[-32769, 32768], &[-32768, -1, 0, 1, 32767]);
368 check_ranges::<i32>(
369 &db,
370 &[-2_147_483_649, 2_147_483_648],
371 &[-2_147_483_648, -1, 0, 1, 2_147_483_647],
372 );
373 check_ranges::<u8>(&db, &[-2, -1, 256], &[0, 1, 255]);
374 check_ranges::<u16>(&db, &[-2, -1, 65536], &[0, 1, 65535]);
375 check_ranges::<u32>(&db, &[-2, -1, 4_294_967_296], &[0, 1, 4_294_967_295]);
376 Ok(())
377 }
378
379 #[test]
381 fn test_uuid_string() -> Result<()> {
382 let db = Connection::open_in_memory()?;
383 let sql = "BEGIN;
384 CREATE TABLE uuid (u uuid);
385 INSERT INTO uuid VALUES ('10203040-5060-7080-0102-030405060708'),(NULL),('47183823-2574-4bfd-b411-99ed177d3e43');
386 END;";
387 db.execute_batch(sql)?;
388 let v = db.query_row("SELECT u FROM uuid order by u desc nulls last limit 1", [], |row| {
389 <(String,)>::try_from(row)
390 })?;
391 assert_eq!(v, ("47183823-2574-4bfd-b411-99ed177d3e43".to_string(),));
392 let v = db.query_row(
393 "SELECT u FROM uuid where u>?::UUID",
394 ["10203040-5060-7080-0102-030405060708"],
395 |row| <(String,)>::try_from(row),
396 )?;
397 assert_eq!(v, ("47183823-2574-4bfd-b411-99ed177d3e43".to_string(),));
398 Ok(())
399 }
400
401 #[cfg(feature = "uuid")]
402 #[test]
403 fn test_uuid_from_string() -> crate::Result<()> {
404 let db = Connection::open_in_memory()?;
405 let sql = "BEGIN;
406 CREATE TABLE uuid (u uuid);
407 INSERT INTO uuid VALUES ('10203040-5060-7080-0102-030405060708'),(NULL),('47183823-2574-4bfd-b411-99ed177d3e43');
408 END;";
409 db.execute_batch(sql)?;
410 let v = db.query_row("SELECT u FROM uuid order by u desc nulls last limit 1", [], |row| {
411 <(uuid::Uuid,)>::try_from(row)
412 })?;
413 assert_eq!(v.0.to_string(), "47183823-2574-4bfd-b411-99ed177d3e43");
414 Ok(())
415 }
416
417 #[test]
418 fn test_decimal_to_integer() -> Result<()> {
419 let db = Connection::open_in_memory()?;
420
421 assert_eq!(
422 db.query_row("SELECT 0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
423 0
424 );
425 assert_eq!(
426 db.query_row("SELECT 0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
427 0
428 );
429 assert_eq!(
430 db.query_row("SELECT 0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
431 1
432 );
433 assert_eq!(
434 db.query_row("SELECT 0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
435 1
436 );
437 assert_eq!(
438 db.query_row("SELECT 0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
439 1
440 );
441
442 assert_eq!(
443 db.query_row("SELECT 1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
444 2
445 );
446 assert_eq!(
447 db.query_row("SELECT 2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
448 3
449 );
450 assert_eq!(
451 db.query_row("SELECT 3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
452 4
453 );
454 assert_eq!(
455 db.query_row("SELECT 4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
456 5
457 );
458 assert_eq!(
459 db.query_row("SELECT 5.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
460 6
461 );
462 assert_eq!(
463 db.query_row("SELECT 10.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
464 11
465 );
466 assert_eq!(
467 db.query_row("SELECT 99.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
468 100
469 );
470
471 assert_eq!(
472 db.query_row("SELECT -0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
473 -1
474 );
475 assert_eq!(
476 db.query_row("SELECT -1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
477 -2
478 );
479 assert_eq!(
480 db.query_row("SELECT -2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
481 -3
482 );
483 assert_eq!(
484 db.query_row("SELECT -3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
485 -4
486 );
487 assert_eq!(
488 db.query_row("SELECT -4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
489 -5
490 );
491
492 assert_eq!(
493 db.query_row("SELECT -0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
494 0
495 );
496 assert_eq!(
497 db.query_row("SELECT -0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
498 0
499 );
500 assert_eq!(
501 db.query_row("SELECT -0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
502 -1
503 );
504 assert_eq!(
505 db.query_row("SELECT -0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
506 -1
507 );
508
509 assert_eq!(
510 db.query_row("SELECT 999.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
511 999
512 );
513 assert_eq!(
514 db.query_row("SELECT 999.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
515 1000
516 );
517 assert_eq!(
518 db.query_row("SELECT 999.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
519 1000
520 );
521
522 assert_eq!(
523 db.query_row("SELECT 123456.49::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
524 123456
525 );
526 assert_eq!(
527 db.query_row("SELECT 123456.50::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
528 123457
529 );
530 assert_eq!(
531 db.query_row("SELECT 123456.51::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
532 123457
533 );
534
535 assert_eq!(
536 db.query_row("SELECT 0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
537 0
538 );
539 assert_eq!(
540 db.query_row("SELECT 0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
541 1
542 );
543 assert_eq!(
544 db.query_row("SELECT 0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
545 1
546 );
547 assert_eq!(
548 db.query_row("SELECT -0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
549 0
550 );
551 assert_eq!(
552 db.query_row("SELECT -0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
553 -1
554 );
555 assert_eq!(
556 db.query_row("SELECT -0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
557 -1
558 );
559
560 assert_eq!(
561 db.query_row("SELECT 126.4::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
562 126
563 );
564 assert_eq!(
565 db.query_row("SELECT 126.6::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
566 127
567 );
568
569 let err = db
570 .query_row("SELECT 999::DECIMAL(10,0)", [], |row| row.get::<_, i8>(0))
571 .unwrap_err();
572 match err {
573 Error::IntegralValueOutOfRange(_, _) => {} _ => panic!("Expected IntegralValueOutOfRange error, got: {err}"),
575 }
576
577 Ok(())
578 }
579}