1use std::error::Error as StdError;
2use std::fmt;
3use std::iter;
4use std::num;
5use std::str;
6
7use serde::de::value::BorrowedBytesDeserializer;
8use serde::de::{
9 Deserialize, DeserializeSeed, Deserializer, EnumAccess,
10 Error as SerdeError, IntoDeserializer, MapAccess, SeqAccess, Unexpected,
11 VariantAccess, Visitor,
12};
13use serde::serde_if_integer128;
14
15use crate::byte_record::{ByteRecord, ByteRecordIter};
16use crate::error::{Error, ErrorKind};
17use crate::string_record::{StringRecord, StringRecordIter};
18
19use self::DeserializeErrorKind as DEK;
20
21pub fn deserialize_string_record<'de, D: Deserialize<'de>>(
22 record: &'de StringRecord,
23 headers: Option<&'de StringRecord>,
24) -> Result<D, Error> {
25 let mut deser = DeRecordWrap(DeStringRecord {
26 it: record.iter().peekable(),
27 headers: headers.map(|r| r.iter()),
28 field: 0,
29 });
30 D::deserialize(&mut deser).map_err(|err| {
31 Error::new(ErrorKind::Deserialize {
32 pos: record.position().map(Clone::clone),
33 err,
34 })
35 })
36}
37
38pub fn deserialize_byte_record<'de, D: Deserialize<'de>>(
39 record: &'de ByteRecord,
40 headers: Option<&'de ByteRecord>,
41) -> Result<D, Error> {
42 let mut deser = DeRecordWrap(DeByteRecord {
43 it: record.iter().peekable(),
44 headers: headers.map(|r| r.iter()),
45 field: 0,
46 });
47 D::deserialize(&mut deser).map_err(|err| {
48 Error::new(ErrorKind::Deserialize {
49 pos: record.position().map(Clone::clone),
50 err,
51 })
52 })
53}
54
55trait DeRecord<'r> {
72 fn has_headers(&self) -> bool;
74
75 fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>;
77
78 fn next_header_bytes(
80 &mut self,
81 ) -> Result<Option<&'r [u8]>, DeserializeError>;
82
83 fn next_field(&mut self) -> Result<&'r str, DeserializeError>;
85
86 fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>;
88
89 fn peek_field(&mut self) -> Option<&'r [u8]>;
91
92 fn error(&self, kind: DeserializeErrorKind) -> DeserializeError;
94
95 fn infer_deserialize<'de, V: Visitor<'de>>(
97 &mut self,
98 visitor: V,
99 ) -> Result<V::Value, DeserializeError>;
100}
101
102struct DeRecordWrap<T>(T);
103
104impl<'r, T: DeRecord<'r>> DeRecord<'r> for DeRecordWrap<T> {
105 #[inline]
106 fn has_headers(&self) -> bool {
107 self.0.has_headers()
108 }
109
110 #[inline]
111 fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
112 self.0.next_header()
113 }
114
115 #[inline]
116 fn next_header_bytes(
117 &mut self,
118 ) -> Result<Option<&'r [u8]>, DeserializeError> {
119 self.0.next_header_bytes()
120 }
121
122 #[inline]
123 fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
124 self.0.next_field()
125 }
126
127 #[inline]
128 fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
129 self.0.next_field_bytes()
130 }
131
132 #[inline]
133 fn peek_field(&mut self) -> Option<&'r [u8]> {
134 self.0.peek_field()
135 }
136
137 #[inline]
138 fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
139 self.0.error(kind)
140 }
141
142 #[inline]
143 fn infer_deserialize<'de, V: Visitor<'de>>(
144 &mut self,
145 visitor: V,
146 ) -> Result<V::Value, DeserializeError> {
147 self.0.infer_deserialize(visitor)
148 }
149}
150
151struct DeStringRecord<'r> {
152 it: iter::Peekable<StringRecordIter<'r>>,
153 headers: Option<StringRecordIter<'r>>,
154 field: u64,
155}
156
157impl<'r> DeRecord<'r> for DeStringRecord<'r> {
158 #[inline]
159 fn has_headers(&self) -> bool {
160 self.headers.is_some()
161 }
162
163 #[inline]
164 fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
165 Ok(self.headers.as_mut().and_then(|it| it.next()))
166 }
167
168 #[inline]
169 fn next_header_bytes(
170 &mut self,
171 ) -> Result<Option<&'r [u8]>, DeserializeError> {
172 Ok(self.next_header()?.map(|s| s.as_bytes()))
173 }
174
175 #[inline]
176 fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
177 match self.it.next() {
178 Some(field) => {
179 self.field += 1;
180 Ok(field)
181 }
182 None => Err(DeserializeError {
183 field: None,
184 kind: DEK::UnexpectedEndOfRow,
185 }),
186 }
187 }
188
189 #[inline]
190 fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
191 self.next_field().map(|s| s.as_bytes())
192 }
193
194 #[inline]
195 fn peek_field(&mut self) -> Option<&'r [u8]> {
196 self.it.peek().map(|s| s.as_bytes())
197 }
198
199 fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
200 DeserializeError {
201 field: Some(self.field.saturating_sub(1)),
202 kind,
203 }
204 }
205
206 fn infer_deserialize<'de, V: Visitor<'de>>(
207 &mut self,
208 visitor: V,
209 ) -> Result<V::Value, DeserializeError> {
210 let x = self.next_field()?;
211 if x == "true" {
212 return visitor.visit_bool(true);
213 } else if x == "false" {
214 return visitor.visit_bool(false);
215 } else if let Some(n) = try_positive_integer64(x) {
216 return visitor.visit_u64(n);
217 } else if let Some(n) = try_negative_integer64(x) {
218 return visitor.visit_i64(n);
219 }
220 serde_if_integer128! {
221 if let Some(n) = try_positive_integer128(x) {
222 return visitor.visit_u128(n);
223 } else if let Some(n) = try_negative_integer128(x) {
224 return visitor.visit_i128(n);
225 }
226 }
227 if let Some(n) = try_float(x) {
228 visitor.visit_f64(n)
229 } else {
230 visitor.visit_str(x)
231 }
232 }
233}
234
235struct DeByteRecord<'r> {
236 it: iter::Peekable<ByteRecordIter<'r>>,
237 headers: Option<ByteRecordIter<'r>>,
238 field: u64,
239}
240
241impl<'r> DeRecord<'r> for DeByteRecord<'r> {
242 #[inline]
243 fn has_headers(&self) -> bool {
244 self.headers.is_some()
245 }
246
247 #[inline]
248 fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
249 match self.next_header_bytes() {
250 Ok(Some(field)) => Ok(Some(
251 str::from_utf8(field)
252 .map_err(|err| self.error(DEK::InvalidUtf8(err)))?,
253 )),
254 Ok(None) => Ok(None),
255 Err(err) => Err(err),
256 }
257 }
258
259 #[inline]
260 fn next_header_bytes(
261 &mut self,
262 ) -> Result<Option<&'r [u8]>, DeserializeError> {
263 Ok(self.headers.as_mut().and_then(|it| it.next()))
264 }
265
266 #[inline]
267 fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
268 self.next_field_bytes().and_then(|field| {
269 str::from_utf8(field)
270 .map_err(|err| self.error(DEK::InvalidUtf8(err)))
271 })
272 }
273
274 #[inline]
275 fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
276 match self.it.next() {
277 Some(field) => {
278 self.field += 1;
279 Ok(field)
280 }
281 None => Err(DeserializeError {
282 field: None,
283 kind: DEK::UnexpectedEndOfRow,
284 }),
285 }
286 }
287
288 #[inline]
289 fn peek_field(&mut self) -> Option<&'r [u8]> {
290 self.it.peek().copied()
291 }
292
293 fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
294 DeserializeError {
295 field: Some(self.field.saturating_sub(1)),
296 kind,
297 }
298 }
299
300 fn infer_deserialize<'de, V: Visitor<'de>>(
301 &mut self,
302 visitor: V,
303 ) -> Result<V::Value, DeserializeError> {
304 let x = self.next_field_bytes()?;
305 if x == b"true" {
306 return visitor.visit_bool(true);
307 } else if x == b"false" {
308 return visitor.visit_bool(false);
309 } else if let Some(n) = try_positive_integer64_bytes(x) {
310 return visitor.visit_u64(n);
311 } else if let Some(n) = try_negative_integer64_bytes(x) {
312 return visitor.visit_i64(n);
313 }
314 serde_if_integer128! {
315 if let Some(n) = try_positive_integer128_bytes(x) {
316 return visitor.visit_u128(n);
317 } else if let Some(n) = try_negative_integer128_bytes(x) {
318 return visitor.visit_i128(n);
319 }
320 }
321 if let Some(n) = try_float_bytes(x) {
322 visitor.visit_f64(n)
323 } else if let Ok(s) = str::from_utf8(x) {
324 visitor.visit_str(s)
325 } else {
326 visitor.visit_bytes(x)
327 }
328 }
329}
330
331macro_rules! deserialize_int {
332 ($method:ident, $visit:ident, $inttype:ty) => {
333 fn $method<V: Visitor<'de>>(
334 self,
335 visitor: V,
336 ) -> Result<V::Value, Self::Error> {
337 let field = self.next_field()?;
338 let num = if let Some(stripped) = field.strip_prefix("0x") {
339 <$inttype>::from_str_radix(stripped, 16)
340 } else {
341 field.parse()
342 };
343 visitor.$visit(num.map_err(|err| self.error(DEK::ParseInt(err)))?)
344 }
345 };
346}
347
348impl<'a, 'de: 'a, T: DeRecord<'de>> Deserializer<'de>
349 for &'a mut DeRecordWrap<T>
350{
351 type Error = DeserializeError;
352
353 fn deserialize_any<V: Visitor<'de>>(
354 self,
355 visitor: V,
356 ) -> Result<V::Value, Self::Error> {
357 self.infer_deserialize(visitor)
358 .map_err(|err| self.error(err.kind))
359 }
360
361 fn deserialize_bool<V: Visitor<'de>>(
362 self,
363 visitor: V,
364 ) -> Result<V::Value, Self::Error> {
365 visitor.visit_bool(
366 self.next_field()?
367 .parse()
368 .map_err(|err| self.error(DEK::ParseBool(err)))?,
369 )
370 }
371
372 deserialize_int!(deserialize_u8, visit_u8, u8);
373 deserialize_int!(deserialize_u16, visit_u16, u16);
374 deserialize_int!(deserialize_u32, visit_u32, u32);
375 deserialize_int!(deserialize_u64, visit_u64, u64);
376 serde_if_integer128! {
377 deserialize_int!(deserialize_u128, visit_u128, u128);
378 }
379 deserialize_int!(deserialize_i8, visit_i8, i8);
380 deserialize_int!(deserialize_i16, visit_i16, i16);
381 deserialize_int!(deserialize_i32, visit_i32, i32);
382 deserialize_int!(deserialize_i64, visit_i64, i64);
383 serde_if_integer128! {
384 deserialize_int!(deserialize_i128, visit_i128, i128);
385 }
386
387 fn deserialize_f32<V: Visitor<'de>>(
388 self,
389 visitor: V,
390 ) -> Result<V::Value, Self::Error> {
391 visitor.visit_f32(
392 self.next_field()?
393 .parse()
394 .map_err(|err| self.error(DEK::ParseFloat(err)))?,
395 )
396 }
397
398 fn deserialize_f64<V: Visitor<'de>>(
399 self,
400 visitor: V,
401 ) -> Result<V::Value, Self::Error> {
402 visitor.visit_f64(
403 self.next_field()?
404 .parse()
405 .map_err(|err| self.error(DEK::ParseFloat(err)))?,
406 )
407 }
408
409 fn deserialize_char<V: Visitor<'de>>(
410 self,
411 visitor: V,
412 ) -> Result<V::Value, Self::Error> {
413 let field = self.next_field()?;
414 let len = field.chars().count();
415 if len != 1 {
416 return Err(self.error(DEK::Message(format!(
417 "expected single character but got {} characters in '{}'",
418 len, field
419 ))));
420 }
421 visitor.visit_char(field.chars().next().unwrap())
422 }
423
424 fn deserialize_str<V: Visitor<'de>>(
425 self,
426 visitor: V,
427 ) -> Result<V::Value, Self::Error> {
428 self.next_field().and_then(|f| visitor.visit_borrowed_str(f))
429 .map_err(|err| self.error(err.kind))
430 }
431
432 fn deserialize_string<V: Visitor<'de>>(
433 self,
434 visitor: V,
435 ) -> Result<V::Value, Self::Error> {
436 self.next_field().and_then(|f| visitor.visit_str(f))
437 .map_err(|err| self.error(err.kind))
438 }
439
440 fn deserialize_bytes<V: Visitor<'de>>(
441 self,
442 visitor: V,
443 ) -> Result<V::Value, Self::Error> {
444 self.next_field_bytes().and_then(|f| visitor.visit_borrowed_bytes(f))
445 .map_err(|err| self.error(err.kind))
446 }
447
448 fn deserialize_byte_buf<V: Visitor<'de>>(
449 self,
450 visitor: V,
451 ) -> Result<V::Value, Self::Error> {
452 self.next_field_bytes()
453 .and_then(|f| visitor.visit_byte_buf(f.to_vec()))
454 .map_err(|err| self.error(err.kind))
455 }
456
457 fn deserialize_option<V: Visitor<'de>>(
458 self,
459 visitor: V,
460 ) -> Result<V::Value, Self::Error> {
461 match self.peek_field() {
462 None => visitor.visit_none(),
463 Some([]) => {
464 self.next_field().expect("empty field");
465 visitor.visit_none()
466 }
467 Some(_) => visitor.visit_some(self),
468 }
469 }
470
471 fn deserialize_unit<V: Visitor<'de>>(
472 self,
473 visitor: V,
474 ) -> Result<V::Value, Self::Error> {
475 visitor.visit_unit()
476 }
477
478 fn deserialize_unit_struct<V: Visitor<'de>>(
479 self,
480 _name: &'static str,
481 visitor: V,
482 ) -> Result<V::Value, Self::Error> {
483 visitor.visit_unit()
484 }
485
486 fn deserialize_newtype_struct<V: Visitor<'de>>(
487 self,
488 _name: &'static str,
489 visitor: V,
490 ) -> Result<V::Value, Self::Error> {
491 visitor.visit_newtype_struct(self)
492 }
493
494 fn deserialize_seq<V: Visitor<'de>>(
495 self,
496 visitor: V,
497 ) -> Result<V::Value, Self::Error> {
498 visitor.visit_seq(self)
499 }
500
501 fn deserialize_tuple<V: Visitor<'de>>(
502 self,
503 _len: usize,
504 visitor: V,
505 ) -> Result<V::Value, Self::Error> {
506 visitor.visit_seq(self)
507 }
508
509 fn deserialize_tuple_struct<V: Visitor<'de>>(
510 self,
511 _name: &'static str,
512 _len: usize,
513 visitor: V,
514 ) -> Result<V::Value, Self::Error> {
515 visitor.visit_seq(self)
516 }
517
518 fn deserialize_map<V: Visitor<'de>>(
519 self,
520 visitor: V,
521 ) -> Result<V::Value, Self::Error> {
522 if !self.has_headers() {
523 visitor.visit_seq(self)
524 } else {
525 visitor.visit_map(self)
526 }
527 }
528
529 fn deserialize_struct<V: Visitor<'de>>(
530 self,
531 _name: &'static str,
532 _fields: &'static [&'static str],
533 visitor: V,
534 ) -> Result<V::Value, Self::Error> {
535 if !self.has_headers() {
536 visitor.visit_seq(self)
537 } else {
538 visitor.visit_map(self)
539 }
540 }
541
542 fn deserialize_identifier<V: Visitor<'de>>(
543 self,
544 _visitor: V,
545 ) -> Result<V::Value, Self::Error> {
546 Err(self.error(DEK::Unsupported("deserialize_identifier".into())))
547 }
548
549 fn deserialize_enum<V: Visitor<'de>>(
550 self,
551 _name: &'static str,
552 _variants: &'static [&'static str],
553 visitor: V,
554 ) -> Result<V::Value, Self::Error> {
555 visitor.visit_enum(self)
556 }
557
558 fn deserialize_ignored_any<V: Visitor<'de>>(
559 self,
560 visitor: V,
561 ) -> Result<V::Value, Self::Error> {
562 let _ = self.next_field_bytes()?;
566 visitor.visit_unit()
567 }
568}
569
570impl<'a, 'de: 'a, T: DeRecord<'de>> EnumAccess<'de>
571 for &'a mut DeRecordWrap<T>
572{
573 type Error = DeserializeError;
574 type Variant = Self;
575
576 fn variant_seed<V: DeserializeSeed<'de>>(
577 self,
578 seed: V,
579 ) -> Result<(V::Value, Self::Variant), Self::Error> {
580 let variant_name = self.next_field()?;
581 seed.deserialize(variant_name.into_deserializer()).map(|v| (v, self))
582 }
583}
584
585impl<'a, 'de: 'a, T: DeRecord<'de>> VariantAccess<'de>
586 for &'a mut DeRecordWrap<T>
587{
588 type Error = DeserializeError;
589
590 fn unit_variant(self) -> Result<(), Self::Error> {
591 Ok(())
592 }
593
594 fn newtype_variant_seed<U: DeserializeSeed<'de>>(
595 self,
596 _seed: U,
597 ) -> Result<U::Value, Self::Error> {
598 let unexp = Unexpected::UnitVariant;
599 Err(DeserializeError::invalid_type(unexp, &"newtype variant"))
600 }
601
602 fn tuple_variant<V: Visitor<'de>>(
603 self,
604 _len: usize,
605 _visitor: V,
606 ) -> Result<V::Value, Self::Error> {
607 let unexp = Unexpected::UnitVariant;
608 Err(DeserializeError::invalid_type(unexp, &"tuple variant"))
609 }
610
611 fn struct_variant<V: Visitor<'de>>(
612 self,
613 _fields: &'static [&'static str],
614 _visitor: V,
615 ) -> Result<V::Value, Self::Error> {
616 let unexp = Unexpected::UnitVariant;
617 Err(DeserializeError::invalid_type(unexp, &"struct variant"))
618 }
619}
620
621impl<'a, 'de: 'a, T: DeRecord<'de>> SeqAccess<'de>
622 for &'a mut DeRecordWrap<T>
623{
624 type Error = DeserializeError;
625
626 fn next_element_seed<U: DeserializeSeed<'de>>(
627 &mut self,
628 seed: U,
629 ) -> Result<Option<U::Value>, Self::Error> {
630 if self.peek_field().is_none() {
631 Ok(None)
632 } else {
633 seed.deserialize(&mut **self).map(Some)
634 }
635 }
636}
637
638impl<'a, 'de: 'a, T: DeRecord<'de>> MapAccess<'de>
639 for &'a mut DeRecordWrap<T>
640{
641 type Error = DeserializeError;
642
643 fn next_key_seed<K: DeserializeSeed<'de>>(
644 &mut self,
645 seed: K,
646 ) -> Result<Option<K::Value>, Self::Error> {
647 assert!(self.has_headers());
648 let field = match self.next_header_bytes()? {
649 None => return Ok(None),
650 Some(field) => field,
651 };
652 seed.deserialize(BorrowedBytesDeserializer::new(field)).map(Some)
653 }
654
655 fn next_value_seed<K: DeserializeSeed<'de>>(
656 &mut self,
657 seed: K,
658 ) -> Result<K::Value, Self::Error> {
659 seed.deserialize(&mut **self)
660 }
661}
662
663#[derive(Clone, Debug, Eq, PartialEq)]
665pub struct DeserializeError {
666 field: Option<u64>,
668 kind: DeserializeErrorKind,
670}
671
672#[derive(Clone, Debug, Eq, PartialEq)]
674pub enum DeserializeErrorKind {
675 Message(String),
677 Unsupported(String),
679 UnexpectedEndOfRow,
682 InvalidUtf8(str::Utf8Error),
686 ParseBool(str::ParseBoolError),
688 ParseInt(num::ParseIntError),
690 ParseFloat(num::ParseFloatError),
692}
693
694impl SerdeError for DeserializeError {
695 fn custom<T: fmt::Display>(msg: T) -> DeserializeError {
696 DeserializeError { field: None, kind: DEK::Message(msg.to_string()) }
697 }
698}
699
700impl StdError for DeserializeError {
701 fn description(&self) -> &str {
702 self.kind.description()
703 }
704}
705
706impl fmt::Display for DeserializeError {
707 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
708 if let Some(field) = self.field {
709 write!(f, "field {}: {}", field + 1, self.kind)
710 } else {
711 write!(f, "{}", self.kind)
712 }
713 }
714}
715
716impl fmt::Display for DeserializeErrorKind {
717 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
718 use self::DeserializeErrorKind::*;
719
720 match *self {
721 Message(ref msg) => write!(f, "{}", msg),
722 Unsupported(ref which) => {
723 write!(f, "unsupported deserializer method: {}", which)
724 }
725 UnexpectedEndOfRow => write!(f, "{}", self.description()),
726 InvalidUtf8(ref err) => err.fmt(f),
727 ParseBool(ref err) => err.fmt(f),
728 ParseInt(ref err) => err.fmt(f),
729 ParseFloat(ref err) => err.fmt(f),
730 }
731 }
732}
733
734impl DeserializeError {
735 pub fn field(&self) -> Option<u64> {
737 self.field
738 }
739
740 pub fn kind(&self) -> &DeserializeErrorKind {
742 &self.kind
743 }
744}
745
746impl DeserializeErrorKind {
747 #[allow(deprecated)]
748 fn description(&self) -> &str {
749 use self::DeserializeErrorKind::*;
750
751 match *self {
752 Message(_) => "deserialization error",
753 Unsupported(_) => "unsupported deserializer method",
754 UnexpectedEndOfRow => "expected field, but got end of row",
755 InvalidUtf8(ref err) => err.description(),
756 ParseBool(ref err) => err.description(),
757 ParseInt(ref err) => err.description(),
758 ParseFloat(ref err) => err.description(),
759 }
760 }
761}
762
763serde_if_integer128! {
764 fn try_positive_integer128(s: &str) -> Option<u128> {
765 s.parse().ok()
766 }
767
768 fn try_negative_integer128(s: &str) -> Option<i128> {
769 s.parse().ok()
770 }
771}
772
773fn try_positive_integer64(s: &str) -> Option<u64> {
774 s.parse().ok()
775}
776
777fn try_negative_integer64(s: &str) -> Option<i64> {
778 s.parse().ok()
779}
780
781fn try_float(s: &str) -> Option<f64> {
782 s.parse().ok()
783}
784
785fn try_positive_integer64_bytes(s: &[u8]) -> Option<u64> {
786 str::from_utf8(s).ok().and_then(|s| s.parse().ok())
787}
788
789fn try_negative_integer64_bytes(s: &[u8]) -> Option<i64> {
790 str::from_utf8(s).ok().and_then(|s| s.parse().ok())
791}
792
793serde_if_integer128! {
794 fn try_positive_integer128_bytes(s: &[u8]) -> Option<u128> {
795 str::from_utf8(s).ok().and_then(|s| s.parse().ok())
796 }
797
798 fn try_negative_integer128_bytes(s: &[u8]) -> Option<i128> {
799 str::from_utf8(s).ok().and_then(|s| s.parse().ok())
800 }
801}
802
803fn try_float_bytes(s: &[u8]) -> Option<f64> {
804 str::from_utf8(s).ok().and_then(|s| s.parse().ok())
805}
806
807#[cfg(test)]
808mod tests {
809 use std::collections::HashMap;
810
811 use bstr::BString;
812 use serde::{de::DeserializeOwned, serde_if_integer128, Deserialize};
813
814 use super::{deserialize_byte_record, deserialize_string_record};
815 use crate::byte_record::ByteRecord;
816 use crate::error::{Error, ErrorKind};
817 use crate::string_record::StringRecord;
818
819 fn de<D: DeserializeOwned>(fields: &[&str]) -> Result<D, Error> {
820 let record = StringRecord::from(fields);
821 deserialize_string_record(&record, None)
822 }
823
824 fn de_headers<D: DeserializeOwned>(
825 headers: &[&str],
826 fields: &[&str],
827 ) -> Result<D, Error> {
828 let headers = StringRecord::from(headers);
829 let record = StringRecord::from(fields);
830 deserialize_string_record(&record, Some(&headers))
831 }
832
833 fn b<'a, T: AsRef<[u8]> + ?Sized>(bytes: &'a T) -> &'a [u8] {
834 bytes.as_ref()
835 }
836
837 #[test]
838 fn with_header() {
839 #[derive(Deserialize, Debug, PartialEq)]
840 struct Foo {
841 z: f64,
842 y: i32,
843 x: String,
844 }
845
846 let got: Foo =
847 de_headers(&["x", "y", "z"], &["hi", "42", "1.3"]).unwrap();
848 assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
849 }
850
851 #[test]
852 fn with_header_unknown() {
853 #[derive(Deserialize, Debug, PartialEq)]
854 #[serde(deny_unknown_fields)]
855 struct Foo {
856 z: f64,
857 y: i32,
858 x: String,
859 }
860 assert!(de_headers::<Foo>(
861 &["a", "x", "y", "z"],
862 &["foo", "hi", "42", "1.3"],
863 )
864 .is_err());
865 }
866
867 #[test]
868 fn with_header_missing() {
869 #[derive(Deserialize, Debug, PartialEq)]
870 struct Foo {
871 z: f64,
872 y: i32,
873 x: String,
874 }
875 let got = de_headers::<Foo>(&["y", "z"], &["42", "1.3"],);
876 assert!(got.is_err());
877 let got = got.unwrap_err();
878 assert!(match got.kind() { ErrorKind::Deserialize {..} => true, _ => false });
879 assert!(got.to_string().starts_with("CSV deserialize error:"));
881 }
882
883 #[test]
884 fn with_header_missing_ok() {
885 #[derive(Deserialize, Debug, PartialEq)]
886 struct Foo {
887 z: f64,
888 y: i32,
889 x: Option<String>,
890 }
891
892 let got: Foo = de_headers(&["y", "z"], &["42", "1.3"]).unwrap();
893 assert_eq!(got, Foo { x: None, y: 42, z: 1.3 });
894 }
895
896 #[test]
897 fn with_header_no_fields() {
898 #[derive(Deserialize, Debug, PartialEq)]
899 struct Foo {
900 z: f64,
901 y: i32,
902 x: Option<String>,
903 }
904
905 let got = de_headers::<Foo>(&["y", "z"], &[]);
906 assert!(got.is_err());
907 }
908
909 #[test]
910 fn with_header_empty() {
911 #[derive(Deserialize, Debug, PartialEq)]
912 struct Foo {
913 z: f64,
914 y: i32,
915 x: Option<String>,
916 }
917
918 let got = de_headers::<Foo>(&[], &[]);
919 assert!(got.is_err());
920 }
921
922 #[test]
923 fn with_header_empty_ok() {
924 #[derive(Deserialize, Debug, PartialEq)]
925 struct Foo;
926
927 #[derive(Deserialize, Debug, PartialEq)]
928 struct Bar {}
929
930 let got = de_headers::<Foo>(&[], &[]);
931 assert_eq!(got.unwrap(), Foo);
932
933 let got = de_headers::<Bar>(&[], &[]);
934 assert_eq!(got.unwrap(), Bar {});
935
936 let got = de_headers::<()>(&[], &[]);
937 assert_eq!(got.unwrap(), ());
938 }
939
940 #[test]
941 fn without_header() {
942 #[derive(Deserialize, Debug, PartialEq)]
943 struct Foo {
944 z: f64,
945 y: i32,
946 x: String,
947 }
948
949 let got: Foo = de(&["1.3", "42", "hi"]).unwrap();
950 assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
951 }
952
953 #[test]
954 fn no_fields() {
955 assert!(de::<String>(&[]).is_err());
956 }
957
958 #[test]
959 fn one_field() {
960 let got: i32 = de(&["42"]).unwrap();
961 assert_eq!(got, 42);
962 }
963
964 serde_if_integer128! {
965 #[test]
966 fn one_field_128() {
967 let got: i128 = de(&["2010223372036854775808"]).unwrap();
968 assert_eq!(got, 2010223372036854775808);
969 }
970 }
971
972 #[test]
973 fn two_fields() {
974 let got: (i32, bool) = de(&["42", "true"]).unwrap();
975 assert_eq!(got, (42, true));
976
977 #[derive(Deserialize, Debug, PartialEq)]
978 struct Foo(i32, bool);
979
980 let got: Foo = de(&["42", "true"]).unwrap();
981 assert_eq!(got, Foo(42, true));
982 }
983
984 #[test]
985 fn two_fields_too_many() {
986 let got: (i32, bool) = de(&["42", "true", "z", "z"]).unwrap();
987 assert_eq!(got, (42, true));
988 }
989
990 #[test]
991 fn two_fields_too_few() {
992 assert!(de::<(i32, bool)>(&["42"]).is_err());
993 }
994
995 #[test]
996 fn one_char() {
997 let got: char = de(&["a"]).unwrap();
998 assert_eq!(got, 'a');
999 }
1000
1001 #[test]
1002 fn no_chars() {
1003 assert!(de::<char>(&[""]).is_err());
1004 }
1005
1006 #[test]
1007 fn too_many_chars() {
1008 assert!(de::<char>(&["ab"]).is_err());
1009 }
1010
1011 #[test]
1012 fn simple_seq() {
1013 let got: Vec<i32> = de(&["1", "5", "10"]).unwrap();
1014 assert_eq!(got, vec![1, 5, 10]);
1015 }
1016
1017 #[test]
1018 fn simple_hex_seq() {
1019 let got: Vec<i32> = de(&["0x7F", "0xA9", "0x10"]).unwrap();
1020 assert_eq!(got, vec![0x7F, 0xA9, 0x10]);
1021 }
1022
1023 #[test]
1024 fn mixed_hex_seq() {
1025 let got: Vec<i32> = de(&["0x7F", "0xA9", "10"]).unwrap();
1026 assert_eq!(got, vec![0x7F, 0xA9, 10]);
1027 }
1028
1029 #[test]
1030 fn bad_hex_seq() {
1031 assert!(de::<Vec<u8>>(&["7F", "0xA9", "10"]).is_err());
1032 }
1033
1034 #[test]
1035 fn seq_in_struct() {
1036 #[derive(Deserialize, Debug, PartialEq)]
1037 struct Foo {
1038 xs: Vec<i32>,
1039 }
1040 let got: Foo = de(&["1", "5", "10"]).unwrap();
1041 assert_eq!(got, Foo { xs: vec![1, 5, 10] });
1042 }
1043
1044 #[test]
1045 fn seq_in_struct_tail() {
1046 #[derive(Deserialize, Debug, PartialEq)]
1047 struct Foo {
1048 label: String,
1049 xs: Vec<i32>,
1050 }
1051 let got: Foo = de(&["foo", "1", "5", "10"]).unwrap();
1052 assert_eq!(got, Foo { label: "foo".into(), xs: vec![1, 5, 10] });
1053 }
1054
1055 #[test]
1056 fn map_headers() {
1057 let got: HashMap<String, i32> =
1058 de_headers(&["a", "b", "c"], &["1", "5", "10"]).unwrap();
1059 assert_eq!(got.len(), 3);
1060 assert_eq!(got["a"], 1);
1061 assert_eq!(got["b"], 5);
1062 assert_eq!(got["c"], 10);
1063 }
1064
1065 #[test]
1066 fn map_no_headers() {
1067 let got = de::<HashMap<String, i32>>(&["1", "5", "10"]);
1068 assert!(got.is_err());
1069 }
1070
1071 #[test]
1072 fn bytes() {
1073 let got: Vec<u8> = de::<BString>(&["foobar"]).unwrap().into();
1074 assert_eq!(got, b"foobar".to_vec());
1075 }
1076
1077 #[test]
1078 fn adjacent_fixed_arrays() {
1079 let got: ([u32; 2], [u32; 2]) = de(&["1", "5", "10", "15"]).unwrap();
1080 assert_eq!(got, ([1, 5], [10, 15]));
1081 }
1082
1083 #[test]
1084 fn enum_label_simple_tagged() {
1085 #[derive(Deserialize, Debug, PartialEq)]
1086 struct Row {
1087 label: Label,
1088 x: f64,
1089 }
1090
1091 #[derive(Deserialize, Debug, PartialEq)]
1092 #[serde(rename_all = "snake_case")]
1093 enum Label {
1094 Foo,
1095 Bar,
1096 Baz,
1097 }
1098
1099 let got: Row = de_headers(&["label", "x"], &["bar", "5"]).unwrap();
1100 assert_eq!(got, Row { label: Label::Bar, x: 5.0 });
1101 }
1102
1103 #[test]
1104 fn enum_untagged() {
1105 #[derive(Deserialize, Debug, PartialEq)]
1106 struct Row {
1107 x: Boolish,
1108 y: Boolish,
1109 z: Boolish,
1110 }
1111
1112 #[derive(Deserialize, Debug, PartialEq)]
1113 #[serde(rename_all = "snake_case")]
1114 #[serde(untagged)]
1115 enum Boolish {
1116 Bool(bool),
1117 Number(i64),
1118 String(String),
1119 }
1120
1121 let got: Row =
1122 de_headers(&["x", "y", "z"], &["true", "null", "1"]).unwrap();
1123 assert_eq!(
1124 got,
1125 Row {
1126 x: Boolish::Bool(true),
1127 y: Boolish::String("null".into()),
1128 z: Boolish::Number(1),
1129 }
1130 );
1131 }
1132
1133 #[test]
1134 fn option_empty_field() {
1135 #[derive(Deserialize, Debug, PartialEq)]
1136 struct Foo {
1137 a: Option<i32>,
1138 b: String,
1139 c: Option<i32>,
1140 }
1141
1142 let got: Foo =
1143 de_headers(&["a", "b", "c"], &["", "foo", "5"]).unwrap();
1144 assert_eq!(got, Foo { a: None, b: "foo".into(), c: Some(5) });
1145 }
1146
1147 #[test]
1165 fn borrowed() {
1166 #[derive(Deserialize, Debug, PartialEq)]
1167 struct Foo<'a, 'c> {
1168 a: &'a str,
1169 b: i32,
1170 c: &'c str,
1171 }
1172
1173 let headers = StringRecord::from(vec!["a", "b", "c"]);
1174 let record = StringRecord::from(vec!["foo", "5", "bar"]);
1175 let got: Foo =
1176 deserialize_string_record(&record, Some(&headers)).unwrap();
1177 assert_eq!(got, Foo { a: "foo", b: 5, c: "bar" });
1178 }
1179
1180 #[test]
1181 fn borrowed_map() {
1182 use std::collections::HashMap;
1183
1184 let headers = StringRecord::from(vec!["a", "b", "c"]);
1185 let record = StringRecord::from(vec!["aardvark", "bee", "cat"]);
1186 let got: HashMap<&str, &str> =
1187 deserialize_string_record(&record, Some(&headers)).unwrap();
1188
1189 let expected: HashMap<&str, &str> =
1190 headers.iter().zip(&record).collect();
1191 assert_eq!(got, expected);
1192 }
1193
1194 #[test]
1195 fn borrowed_map_bytes() {
1196 use std::collections::HashMap;
1197
1198 let headers = ByteRecord::from(vec![b"a", b"\xFF", b"c"]);
1199 let record = ByteRecord::from(vec!["aardvark", "bee", "cat"]);
1200 let got: HashMap<&[u8], &[u8]> =
1201 deserialize_byte_record(&record, Some(&headers)).unwrap();
1202
1203 let expected: HashMap<&[u8], &[u8]> =
1204 headers.iter().zip(&record).collect();
1205 assert_eq!(got, expected);
1206 }
1207
1208 #[test]
1209 fn flatten() {
1210 #[derive(Deserialize, Debug, PartialEq)]
1211 struct Input {
1212 x: f64,
1213 y: f64,
1214 }
1215
1216 #[derive(Deserialize, Debug, PartialEq)]
1217 struct Properties {
1218 prop1: f64,
1219 prop2: f64,
1220 }
1221
1222 #[derive(Deserialize, Debug, PartialEq)]
1223 struct Row {
1224 #[serde(flatten)]
1225 input: Input,
1226 #[serde(flatten)]
1227 properties: Properties,
1228 }
1229
1230 let header = StringRecord::from(vec!["x", "y", "prop1", "prop2"]);
1231 let record = StringRecord::from(vec!["1", "2", "3", "4"]);
1232 let got: Row = record.deserialize(Some(&header)).unwrap();
1233 assert_eq!(
1234 got,
1235 Row {
1236 input: Input { x: 1.0, y: 2.0 },
1237 properties: Properties { prop1: 3.0, prop2: 4.0 },
1238 }
1239 );
1240 }
1241
1242 #[test]
1243 fn partially_invalid_utf8() {
1244 #[derive(Debug, Deserialize, PartialEq)]
1245 struct Row {
1246 h1: String,
1247 h2: BString,
1248 h3: String,
1249 }
1250
1251 let headers = ByteRecord::from(vec![b"h1", b"h2", b"h3"]);
1252 let record =
1253 ByteRecord::from(vec![b(b"baz"), b(b"foo\xFFbar"), b(b"quux")]);
1254 let got: Row =
1255 deserialize_byte_record(&record, Some(&headers)).unwrap();
1256 assert_eq!(
1257 got,
1258 Row {
1259 h1: "baz".to_string(),
1260 h2: BString::from(b"foo\xFFbar".to_vec()),
1261 h3: "quux".to_string(),
1262 }
1263 );
1264 }
1265}