duckdb/types/
from_sql.rs

1use std::{error::Error, fmt};
2
3use cast;
4use rust_decimal::RoundingStrategy::MidpointAwayFromZero;
5
6use super::{TimeUnit, Value, ValueRef};
7
8/// Enum listing possible errors from [`FromSql`] trait.
9#[derive(Debug)]
10#[non_exhaustive]
11pub enum FromSqlError {
12    /// Error when an DuckDB value is requested, but the type of the result
13    /// cannot be converted to the requested Rust type.
14    InvalidType,
15
16    /// Error when the value returned by DuckDB cannot be stored into the
17    /// requested type.
18    OutOfRange(i128),
19
20    /// `feature = "uuid"` Error returned when reading a `uuid` from a blob with
21    /// a size other than 16. Only available when the `uuid` feature is enabled.
22    #[cfg(feature = "uuid")]
23    InvalidUuidSize(usize),
24
25    /// An error case available for implementors of the [`FromSql`] trait.
26    Other(Box<dyn Error + Send + Sync + 'static>),
27}
28
29impl PartialEq for FromSqlError {
30    fn eq(&self, other: &Self) -> bool {
31        match (self, other) {
32            (Self::InvalidType, Self::InvalidType) => true,
33            (Self::OutOfRange(n1), Self::OutOfRange(n2)) => n1 == n2,
34            #[cfg(feature = "uuid")]
35            (Self::InvalidUuidSize(s1), Self::InvalidUuidSize(s2)) => s1 == s2,
36            (..) => false,
37        }
38    }
39}
40
41impl fmt::Display for FromSqlError {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match *self {
44            Self::InvalidType => write!(f, "Invalid type"),
45            Self::OutOfRange(i) => write!(f, "Value {i} out of range"),
46            #[cfg(feature = "uuid")]
47            Self::InvalidUuidSize(s) => {
48                write!(f, "Cannot read UUID value out of {s} byte blob")
49            }
50            Self::Other(ref err) => err.fmt(f),
51        }
52    }
53}
54
55impl Error for FromSqlError {
56    fn source(&self) -> Option<&(dyn Error + 'static)> {
57        if let Self::Other(ref err) = self {
58            Some(&**err)
59        } else {
60            None
61        }
62    }
63}
64
65/// Result type for implementors of the [`FromSql`] trait.
66pub type FromSqlResult<T> = Result<T, FromSqlError>;
67
68/// A trait for types that can be created from a DuckDB value.
69pub trait FromSql: Sized {
70    /// Converts DuckDB value into Rust value.
71    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self>;
72}
73
74macro_rules! from_sql_integral(
75    ($t:ident) => (
76        impl FromSql for $t {
77            #[inline]
78            fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
79                match value {
80                    ValueRef::TinyInt(i) => <$t as cast::From<i8>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
81                    ValueRef::SmallInt(i) => <$t as cast::From<i16>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
82                    ValueRef::Int(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
83                    ValueRef::BigInt(i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
84                    ValueRef::HugeInt(i) => <$t as cast::From<i128>>::cast(i).into_result(FromSqlError::OutOfRange(i)),
85
86                    ValueRef::UTinyInt(i) => <$t as cast::From<u8>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
87                    ValueRef::USmallInt(i) => <$t as cast::From<u16>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
88                    ValueRef::UInt(i) => <$t as cast::From<u32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
89                    ValueRef::UBigInt(i) => <$t as cast::From<u64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
90
91                    ValueRef::Float(i) => <$t as cast::From<f32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
92                    ValueRef::Double(i) => <$t as cast::From<f64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
93
94                    ValueRef::Decimal(d) => {
95                         // DuckDB rounds DECIMAL to INTEGER (following PostgreSQL behavior)
96                        let rounded = d.round_dp_with_strategy(0, MidpointAwayFromZero);
97                        <$t as cast::From<i128>>::cast(rounded.mantissa()).into_result(FromSqlError::OutOfRange(d.mantissa()))
98                    }
99
100                    ValueRef::Timestamp(_, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
101                    ValueRef::Date32(i) => <$t as cast::From<i32>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
102                    ValueRef::Time64(TimeUnit::Microsecond, i) => <$t as cast::From<i64>>::cast(i).into_result(FromSqlError::OutOfRange(i as i128)),
103                    ValueRef::Text(_) => {
104                        let s = value.as_str()?;
105                        s.parse::<$t>().or_else(|_| {
106                            s.parse::<i128>()
107                                .map_err(|_| FromSqlError::InvalidType)
108                                .and_then(|i| Err(FromSqlError::OutOfRange(i)))
109                        })
110                    }
111                    _ => Err(FromSqlError::InvalidType),
112                }
113            }
114        }
115    )
116);
117
118/// A trait to provide ok_or method for both Result and primitive types
119/// cast::From trait returns Result or the primitive, depending on the types
120trait IntoResult {
121    type Value;
122    fn into_result<E>(self, err: E) -> Result<Self::Value, E>;
123}
124
125/// A macro to implement the IntoResult trait for all integral types
126macro_rules! into_result_integral(
127    ($type_name:ident) => (
128        impl IntoResult for $type_name {
129            type Value = $type_name;
130
131            #[inline]
132            fn into_result<E>(self, _err: E) -> Result<Self::Value, E> {
133                Ok(self)
134            }
135        }
136    )
137);
138
139into_result_integral!(i8);
140into_result_integral!(i16);
141into_result_integral!(i32);
142into_result_integral!(i64);
143into_result_integral!(i128);
144into_result_integral!(isize);
145into_result_integral!(u8);
146into_result_integral!(u16);
147into_result_integral!(u32);
148into_result_integral!(u64);
149into_result_integral!(usize);
150into_result_integral!(f32);
151into_result_integral!(f64);
152
153impl<T, E> IntoResult for Result<T, E> {
154    type Value = T;
155
156    #[inline]
157    fn into_result<E2>(self, err: E2) -> Result<Self::Value, E2> {
158        self.map_err(|_| err)
159    }
160}
161
162from_sql_integral!(i8);
163from_sql_integral!(i16);
164from_sql_integral!(i32);
165from_sql_integral!(i64);
166from_sql_integral!(i128);
167from_sql_integral!(isize);
168from_sql_integral!(u8);
169from_sql_integral!(u16);
170from_sql_integral!(u32);
171from_sql_integral!(u64);
172from_sql_integral!(usize);
173from_sql_integral!(f32);
174from_sql_integral!(f64);
175
176impl FromSql for bool {
177    #[inline]
178    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
179        match value {
180            ValueRef::Boolean(b) => Ok(b),
181            _ => i8::column_result(value).map(|i| i != 0),
182        }
183    }
184}
185
186impl FromSql for String {
187    #[inline]
188    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
189        match value {
190            #[cfg(feature = "chrono")]
191            ValueRef::Date32(_) => Ok(chrono::NaiveDate::column_result(value)?.format("%F").to_string()),
192            #[cfg(feature = "chrono")]
193            ValueRef::Time64(..) => Ok(chrono::NaiveTime::column_result(value)?.format("%T%.f").to_string()),
194            #[cfg(feature = "chrono")]
195            ValueRef::Timestamp(..) => Ok(chrono::NaiveDateTime::column_result(value)?
196                .format("%F %T%.f")
197                .to_string()),
198            _ => value.as_str().map(ToString::to_string),
199        }
200    }
201}
202
203impl FromSql for Box<str> {
204    #[inline]
205    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
206        value.as_str().map(Into::into)
207    }
208}
209
210impl FromSql for std::rc::Rc<str> {
211    #[inline]
212    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
213        value.as_str().map(Into::into)
214    }
215}
216
217impl FromSql for std::sync::Arc<str> {
218    #[inline]
219    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
220        value.as_str().map(Into::into)
221    }
222}
223
224impl FromSql for Vec<u8> {
225    #[inline]
226    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
227        value.as_blob().map(|b| b.to_vec())
228    }
229}
230
231#[cfg(feature = "uuid")]
232impl FromSql for uuid::Uuid {
233    #[inline]
234    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
235        match value {
236            ValueRef::Text(..) => value
237                .as_str()
238                .and_then(|s| Self::parse_str(s).map_err(|_| FromSqlError::InvalidUuidSize(s.len()))),
239            ValueRef::Blob(..) => value
240                .as_blob()
241                .and_then(|bytes| {
242                    uuid::Builder::from_slice(bytes).map_err(|_| FromSqlError::InvalidUuidSize(bytes.len()))
243                })
244                .map(|builder| builder.into_uuid()),
245            _ => Err(FromSqlError::InvalidType),
246        }
247    }
248}
249
250impl<T: FromSql> FromSql for Option<T> {
251    #[inline]
252    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
253        match value {
254            ValueRef::Null => Ok(None),
255            _ => FromSql::column_result(value).map(Some),
256        }
257    }
258}
259
260impl FromSql for Value {
261    #[inline]
262    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
263        Ok(value.into())
264    }
265}
266
267#[cfg(test)]
268mod test {
269    use super::FromSql;
270    use crate::{Connection, Error, Result};
271
272    #[test]
273    fn test_timestamp_raw() -> Result<()> {
274        let db = Connection::open_in_memory()?;
275        let sql = "BEGIN;
276                   CREATE TABLE ts (sec TIMESTAMP_S, milli TIMESTAMP_MS, micro TIMESTAMP_US, nano TIMESTAMP_NS );
277                   INSERT INTO ts VALUES (NULL,NULL,NULL,NULL );
278                   INSERT INTO ts VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01.594','2008-01-01 00:00:01.88926','2008-01-01 00:00:01.889268000' );
279                   -- INSERT INTO ts VALUES (NULL,NULL,NULL,1199145601889268321 );
280                   END;";
281        db.execute_batch(sql)?;
282        let v = db.query_row(
283            "SELECT sec, milli, micro, nano FROM ts WHERE sec is not null",
284            [],
285            |row| <(i64, i64, i64, i64)>::try_from(row),
286        )?;
287        assert_eq!(v, (1199145601, 1199145601594, 1199145601889260, 1199145601889268000));
288        Ok(())
289    }
290
291    #[test]
292    fn test_time64_raw() -> Result<()> {
293        let db = Connection::open_in_memory()?;
294        let sql = "BEGIN;
295                   CREATE TABLE time64 (t time);
296                   INSERT INTO time64 VALUES ('20:08:10.998');
297                   END;";
298        db.execute_batch(sql)?;
299        let v = db.query_row("SELECT * FROM time64", [], |row| <(i64,)>::try_from(row))?;
300        assert_eq!(v, (72490998000,));
301        Ok(())
302    }
303
304    #[test]
305    fn test_date32_raw() -> Result<()> {
306        let db = Connection::open_in_memory()?;
307        let sql = "BEGIN;
308                   CREATE TABLE date32 (d date);
309                   INSERT INTO date32 VALUES ('2008-01-01');
310                   END;";
311        db.execute_batch(sql)?;
312        let v = db.query_row("SELECT * FROM date32", [], |row| <(i32,)>::try_from(row))?;
313        assert_eq!(v, (13879,));
314        Ok(())
315    }
316
317    #[test]
318    fn test_unsigned_integer() -> Result<()> {
319        let db = Connection::open_in_memory()?;
320        let sql = "BEGIN;
321                   CREATE TABLE unsigned_int (u1 utinyint, u2 usmallint, u4 uinteger, u8 ubigint);
322                   INSERT INTO unsigned_int VALUES (255, 65535, 4294967295, 18446744073709551615);
323                   END;";
324        db.execute_batch(sql)?;
325        let v = db.query_row("SELECT * FROM unsigned_int", [], |row| {
326            <(u8, u16, u32, u64)>::try_from(row)
327        })?;
328        assert_eq!(v, (255, 65535, 4294967295, 18446744073709551615));
329        Ok(())
330    }
331
332    // This test asserts that i128s above/below the i64 max/min can written and retrieved properly.
333    #[test]
334    fn test_hugeint_max_min() -> Result<()> {
335        let db = Connection::open_in_memory()?;
336        db.execute("CREATE TABLE huge_int (u1 hugeint, u2 hugeint);", [])?;
337        // Min/Max value defined in here: https://duckdb.org/docs/sql/data_types/numeric
338        let i128max: i128 = i128::MAX;
339        let i128min: i128 = i128::MIN + 1;
340        db.execute("INSERT INTO huge_int VALUES (?, ?);", [&i128max, &i128min])?;
341        let v = db.query_row("SELECT * FROM huge_int", [], |row| <(i128, i128)>::try_from(row))?;
342        assert_eq!(v, (i128max, i128min));
343        Ok(())
344    }
345
346    #[test]
347    fn test_integral_ranges() -> Result<()> {
348        let db = Connection::open_in_memory()?;
349
350        fn check_ranges<T>(db: &Connection, out_of_range: &[i128], in_range: &[i128])
351        where
352            T: Into<i128> + FromSql + ::std::fmt::Debug,
353        {
354            for n in out_of_range {
355                let err = db.query_row("SELECT ?", [n], |r| r.get::<_, T>(0)).unwrap_err();
356                match err {
357                    Error::IntegralValueOutOfRange(_, value) => assert_eq!(*n, value),
358                    _ => panic!("unexpected error: {err}"),
359                }
360            }
361            for n in in_range {
362                assert_eq!(*n, db.query_row("SELECT ?", [n], |r| r.get::<_, T>(0)).unwrap().into());
363            }
364        }
365
366        check_ranges::<i8>(&db, &[-129, 128], &[-128, 0, 1, 127]);
367        check_ranges::<i16>(&db, &[-32769, 32768], &[-32768, -1, 0, 1, 32767]);
368        check_ranges::<i32>(
369            &db,
370            &[-2_147_483_649, 2_147_483_648],
371            &[-2_147_483_648, -1, 0, 1, 2_147_483_647],
372        );
373        check_ranges::<u8>(&db, &[-2, -1, 256], &[0, 1, 255]);
374        check_ranges::<u16>(&db, &[-2, -1, 65536], &[0, 1, 65535]);
375        check_ranges::<u32>(&db, &[-2, -1, 4_294_967_296], &[0, 1, 4_294_967_295]);
376        Ok(())
377    }
378
379    // Don't need uuid crate if we only care about the string value of uuid
380    #[test]
381    fn test_uuid_string() -> Result<()> {
382        let db = Connection::open_in_memory()?;
383        let sql = "BEGIN;
384                   CREATE TABLE uuid (u uuid);
385                   INSERT INTO uuid VALUES ('10203040-5060-7080-0102-030405060708'),(NULL),('47183823-2574-4bfd-b411-99ed177d3e43');
386                   END;";
387        db.execute_batch(sql)?;
388        let v = db.query_row("SELECT u FROM uuid order by u desc nulls last limit 1", [], |row| {
389            <(String,)>::try_from(row)
390        })?;
391        assert_eq!(v, ("47183823-2574-4bfd-b411-99ed177d3e43".to_string(),));
392        let v = db.query_row(
393            "SELECT u FROM uuid where u>?::UUID",
394            ["10203040-5060-7080-0102-030405060708"],
395            |row| <(String,)>::try_from(row),
396        )?;
397        assert_eq!(v, ("47183823-2574-4bfd-b411-99ed177d3e43".to_string(),));
398        Ok(())
399    }
400
401    #[cfg(feature = "uuid")]
402    #[test]
403    fn test_uuid_from_string() -> crate::Result<()> {
404        let db = Connection::open_in_memory()?;
405        let sql = "BEGIN;
406                   CREATE TABLE uuid (u uuid);
407                   INSERT INTO uuid VALUES ('10203040-5060-7080-0102-030405060708'),(NULL),('47183823-2574-4bfd-b411-99ed177d3e43');
408                   END;";
409        db.execute_batch(sql)?;
410        let v = db.query_row("SELECT u FROM uuid order by u desc nulls last limit 1", [], |row| {
411            <(uuid::Uuid,)>::try_from(row)
412        })?;
413        assert_eq!(v.0.to_string(), "47183823-2574-4bfd-b411-99ed177d3e43");
414        Ok(())
415    }
416
417    #[test]
418    fn test_decimal_to_integer() -> Result<()> {
419        let db = Connection::open_in_memory()?;
420
421        assert_eq!(
422            db.query_row("SELECT 0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
423            0
424        );
425        assert_eq!(
426            db.query_row("SELECT 0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
427            0
428        );
429        assert_eq!(
430            db.query_row("SELECT 0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
431            1
432        );
433        assert_eq!(
434            db.query_row("SELECT 0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
435            1
436        );
437        assert_eq!(
438            db.query_row("SELECT 0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
439            1
440        );
441
442        assert_eq!(
443            db.query_row("SELECT 1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
444            2
445        );
446        assert_eq!(
447            db.query_row("SELECT 2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
448            3
449        );
450        assert_eq!(
451            db.query_row("SELECT 3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
452            4
453        );
454        assert_eq!(
455            db.query_row("SELECT 4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
456            5
457        );
458        assert_eq!(
459            db.query_row("SELECT 5.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
460            6
461        );
462        assert_eq!(
463            db.query_row("SELECT 10.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
464            11
465        );
466        assert_eq!(
467            db.query_row("SELECT 99.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
468            100
469        );
470
471        assert_eq!(
472            db.query_row("SELECT -0.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
473            -1
474        );
475        assert_eq!(
476            db.query_row("SELECT -1.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
477            -2
478        );
479        assert_eq!(
480            db.query_row("SELECT -2.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
481            -3
482        );
483        assert_eq!(
484            db.query_row("SELECT -3.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
485            -4
486        );
487        assert_eq!(
488            db.query_row("SELECT -4.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
489            -5
490        );
491
492        assert_eq!(
493            db.query_row("SELECT -0.1::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
494            0
495        );
496        assert_eq!(
497            db.query_row("SELECT -0.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
498            0
499        );
500        assert_eq!(
501            db.query_row("SELECT -0.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
502            -1
503        );
504        assert_eq!(
505            db.query_row("SELECT -0.9::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
506            -1
507        );
508
509        assert_eq!(
510            db.query_row("SELECT 999.4::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
511            999
512        );
513        assert_eq!(
514            db.query_row("SELECT 999.5::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
515            1000
516        );
517        assert_eq!(
518            db.query_row("SELECT 999.6::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
519            1000
520        );
521
522        assert_eq!(
523            db.query_row("SELECT 123456.49::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
524            123456
525        );
526        assert_eq!(
527            db.query_row("SELECT 123456.50::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
528            123457
529        );
530        assert_eq!(
531            db.query_row("SELECT 123456.51::DECIMAL(18,3)", [], |row| row.get::<_, i64>(0))?,
532            123457
533        );
534
535        assert_eq!(
536            db.query_row("SELECT 0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
537            0
538        );
539        assert_eq!(
540            db.query_row("SELECT 0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
541            1
542        );
543        assert_eq!(
544            db.query_row("SELECT 0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
545            1
546        );
547        assert_eq!(
548            db.query_row("SELECT -0.49::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
549            0
550        );
551        assert_eq!(
552            db.query_row("SELECT -0.50::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
553            -1
554        );
555        assert_eq!(
556            db.query_row("SELECT -0.51::DECIMAL(10,2)", [], |row| row.get::<_, i32>(0))?,
557            -1
558        );
559
560        assert_eq!(
561            db.query_row("SELECT 126.4::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
562            126
563        );
564        assert_eq!(
565            db.query_row("SELECT 126.6::DECIMAL(5,1)", [], |row| row.get::<_, i8>(0))?,
566            127
567        );
568
569        let err = db
570            .query_row("SELECT 999::DECIMAL(10,0)", [], |row| row.get::<_, i8>(0))
571            .unwrap_err();
572        match err {
573            Error::IntegralValueOutOfRange(_, _) => {} // Expected
574            _ => panic!("Expected IntegralValueOutOfRange error, got: {err}"),
575        }
576
577        Ok(())
578    }
579}