csv/
deserializer.rs

1use std::{error::Error as StdError, fmt, iter, num, str};
2
3use serde::{
4    de::value::BorrowedBytesDeserializer,
5    de::{
6        Deserialize, DeserializeSeed, Deserializer, EnumAccess,
7        Error as SerdeError, IntoDeserializer, MapAccess, SeqAccess,
8        Unexpected, VariantAccess, Visitor,
9    },
10    serde_if_integer128,
11};
12
13use crate::{
14    byte_record::{ByteRecord, ByteRecordIter},
15    error::{Error, ErrorKind},
16    string_record::{StringRecord, StringRecordIter},
17};
18
19use self::DeserializeErrorKind as DEK;
20
21pub fn deserialize_string_record<'de, D: Deserialize<'de>>(
22    record: &'de StringRecord,
23    headers: Option<&'de StringRecord>,
24) -> Result<D, Error> {
25    let mut deser = DeRecordWrap(DeStringRecord {
26        it: record.iter().peekable(),
27        headers: headers.map(|r| r.iter()),
28        field: 0,
29    });
30    D::deserialize(&mut deser).map_err(|err| {
31        Error::new(ErrorKind::Deserialize {
32            pos: record.position().map(Clone::clone),
33            err,
34        })
35    })
36}
37
38pub fn deserialize_byte_record<'de, D: Deserialize<'de>>(
39    record: &'de ByteRecord,
40    headers: Option<&'de ByteRecord>,
41) -> Result<D, Error> {
42    let mut deser = DeRecordWrap(DeByteRecord {
43        it: record.iter().peekable(),
44        headers: headers.map(|r| r.iter()),
45        field: 0,
46    });
47    D::deserialize(&mut deser).map_err(|err| {
48        Error::new(ErrorKind::Deserialize {
49            pos: record.position().map(Clone::clone),
50            err,
51        })
52    })
53}
54
55/// An over-engineered internal trait that permits writing a single Serde
56/// deserializer that works on both ByteRecord and StringRecord.
57///
58/// We *could* implement a single deserializer on `ByteRecord` and simply
59/// convert `StringRecord`s to `ByteRecord`s, but then the implementation
60/// would be required to redo UTF-8 validation checks in certain places.
61///
62/// How does this work? We create a new `DeRecordWrap` type that wraps
63/// either a `StringRecord` or a `ByteRecord`. We then implement
64/// `DeRecord` for `DeRecordWrap<ByteRecord>` and `DeRecordWrap<StringRecord>`.
65/// Finally, we impl `serde::Deserialize` for `DeRecordWrap<T>` where
66/// `T: DeRecord`. That is, the `DeRecord` type corresponds to the differences
67/// between deserializing into a `ByteRecord` and deserializing into a
68/// `StringRecord`.
69///
70/// The lifetime `'r` refers to the lifetime of the underlying record.
71trait DeRecord<'r> {
72    /// Returns true if and only if this deserialize has access to headers.
73    fn has_headers(&self) -> bool;
74
75    /// Extracts the next string header value from the underlying record.
76    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>;
77
78    /// Extracts the next raw byte header value from the underlying record.
79    fn next_header_bytes(
80        &mut self,
81    ) -> Result<Option<&'r [u8]>, DeserializeError>;
82
83    /// Extracts the next string field from the underlying record.
84    fn next_field(&mut self) -> Result<&'r str, DeserializeError>;
85
86    /// Extracts the next raw byte field from the underlying record.
87    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>;
88
89    /// Peeks at the next field from the underlying record.
90    fn peek_field(&mut self) -> Option<&'r [u8]>;
91
92    /// Returns an error corresponding to the most recently extracted field.
93    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError;
94
95    /// Infer the type of the next field and deserialize it.
96    fn infer_deserialize<'de, V: Visitor<'de>>(
97        &mut self,
98        visitor: V,
99    ) -> Result<V::Value, DeserializeError>;
100}
101
102struct DeRecordWrap<T>(T);
103
104impl<'r, T: DeRecord<'r>> DeRecord<'r> for DeRecordWrap<T> {
105    #[inline]
106    fn has_headers(&self) -> bool {
107        self.0.has_headers()
108    }
109
110    #[inline]
111    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
112        self.0.next_header()
113    }
114
115    #[inline]
116    fn next_header_bytes(
117        &mut self,
118    ) -> Result<Option<&'r [u8]>, DeserializeError> {
119        self.0.next_header_bytes()
120    }
121
122    #[inline]
123    fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
124        self.0.next_field()
125    }
126
127    #[inline]
128    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
129        self.0.next_field_bytes()
130    }
131
132    #[inline]
133    fn peek_field(&mut self) -> Option<&'r [u8]> {
134        self.0.peek_field()
135    }
136
137    #[inline]
138    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
139        self.0.error(kind)
140    }
141
142    #[inline]
143    fn infer_deserialize<'de, V: Visitor<'de>>(
144        &mut self,
145        visitor: V,
146    ) -> Result<V::Value, DeserializeError> {
147        self.0.infer_deserialize(visitor)
148    }
149}
150
151struct DeStringRecord<'r> {
152    it: iter::Peekable<StringRecordIter<'r>>,
153    headers: Option<StringRecordIter<'r>>,
154    field: u64,
155}
156
157impl<'r> DeRecord<'r> for DeStringRecord<'r> {
158    #[inline]
159    fn has_headers(&self) -> bool {
160        self.headers.is_some()
161    }
162
163    #[inline]
164    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
165        Ok(self.headers.as_mut().and_then(|it| it.next()))
166    }
167
168    #[inline]
169    fn next_header_bytes(
170        &mut self,
171    ) -> Result<Option<&'r [u8]>, DeserializeError> {
172        Ok(self.next_header()?.map(|s| s.as_bytes()))
173    }
174
175    #[inline]
176    fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
177        match self.it.next() {
178            Some(field) => {
179                self.field += 1;
180                Ok(field)
181            }
182            None => Err(DeserializeError {
183                field: None,
184                kind: DEK::UnexpectedEndOfRow,
185            }),
186        }
187    }
188
189    #[inline]
190    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
191        self.next_field().map(|s| s.as_bytes())
192    }
193
194    #[inline]
195    fn peek_field(&mut self) -> Option<&'r [u8]> {
196        self.it.peek().map(|s| s.as_bytes())
197    }
198
199    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
200        DeserializeError { field: Some(self.field.saturating_sub(1)), kind }
201    }
202
203    fn infer_deserialize<'de, V: Visitor<'de>>(
204        &mut self,
205        visitor: V,
206    ) -> Result<V::Value, DeserializeError> {
207        let x = self.next_field()?;
208        if x == "true" {
209            return visitor.visit_bool(true);
210        } else if x == "false" {
211            return visitor.visit_bool(false);
212        } else if let Some(n) = try_positive_integer64(x) {
213            return visitor.visit_u64(n);
214        } else if let Some(n) = try_negative_integer64(x) {
215            return visitor.visit_i64(n);
216        }
217        serde_if_integer128! {
218            if let Some(n) = try_positive_integer128(x) {
219                return visitor.visit_u128(n);
220            } else if let Some(n) = try_negative_integer128(x) {
221                return visitor.visit_i128(n);
222            }
223        }
224        if let Some(n) = try_float(x) {
225            visitor.visit_f64(n)
226        } else {
227            visitor.visit_str(x)
228        }
229    }
230}
231
232struct DeByteRecord<'r> {
233    it: iter::Peekable<ByteRecordIter<'r>>,
234    headers: Option<ByteRecordIter<'r>>,
235    field: u64,
236}
237
238impl<'r> DeRecord<'r> for DeByteRecord<'r> {
239    #[inline]
240    fn has_headers(&self) -> bool {
241        self.headers.is_some()
242    }
243
244    #[inline]
245    fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
246        match self.next_header_bytes() {
247            Ok(Some(field)) => Ok(Some(
248                str::from_utf8(field)
249                    .map_err(|err| self.error(DEK::InvalidUtf8(err)))?,
250            )),
251            Ok(None) => Ok(None),
252            Err(err) => Err(err),
253        }
254    }
255
256    #[inline]
257    fn next_header_bytes(
258        &mut self,
259    ) -> Result<Option<&'r [u8]>, DeserializeError> {
260        Ok(self.headers.as_mut().and_then(|it| it.next()))
261    }
262
263    #[inline]
264    fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
265        self.next_field_bytes().and_then(|field| {
266            str::from_utf8(field)
267                .map_err(|err| self.error(DEK::InvalidUtf8(err)))
268        })
269    }
270
271    #[inline]
272    fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
273        match self.it.next() {
274            Some(field) => {
275                self.field += 1;
276                Ok(field)
277            }
278            None => Err(DeserializeError {
279                field: None,
280                kind: DEK::UnexpectedEndOfRow,
281            }),
282        }
283    }
284
285    #[inline]
286    fn peek_field(&mut self) -> Option<&'r [u8]> {
287        self.it.peek().map(|s| *s)
288    }
289
290    fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
291        DeserializeError { field: Some(self.field.saturating_sub(1)), kind }
292    }
293
294    fn infer_deserialize<'de, V: Visitor<'de>>(
295        &mut self,
296        visitor: V,
297    ) -> Result<V::Value, DeserializeError> {
298        let x = self.next_field_bytes()?;
299        if x == b"true" {
300            return visitor.visit_bool(true);
301        } else if x == b"false" {
302            return visitor.visit_bool(false);
303        } else if let Some(n) = try_positive_integer64_bytes(x) {
304            return visitor.visit_u64(n);
305        } else if let Some(n) = try_negative_integer64_bytes(x) {
306            return visitor.visit_i64(n);
307        }
308        serde_if_integer128! {
309            if let Some(n) = try_positive_integer128_bytes(x) {
310                return visitor.visit_u128(n);
311            } else if let Some(n) = try_negative_integer128_bytes(x) {
312                return visitor.visit_i128(n);
313            }
314        }
315        if let Some(n) = try_float_bytes(x) {
316            visitor.visit_f64(n)
317        } else if let Ok(s) = str::from_utf8(x) {
318            visitor.visit_str(s)
319        } else {
320            visitor.visit_bytes(x)
321        }
322    }
323}
324
325macro_rules! deserialize_int {
326    ($method:ident, $visit:ident, $inttype:ty) => {
327        fn $method<V: Visitor<'de>>(
328            self,
329            visitor: V,
330        ) -> Result<V::Value, Self::Error> {
331            let field = self.next_field()?;
332            let num = if field.starts_with("0x") {
333                <$inttype>::from_str_radix(&field[2..], 16)
334            } else {
335                field.parse()
336            };
337            visitor.$visit(num.map_err(|err| self.error(DEK::ParseInt(err)))?)
338        }
339    };
340}
341
342impl<'a, 'de: 'a, T: DeRecord<'de>> Deserializer<'de>
343    for &'a mut DeRecordWrap<T>
344{
345    type Error = DeserializeError;
346
347    fn deserialize_any<V: Visitor<'de>>(
348        self,
349        visitor: V,
350    ) -> Result<V::Value, Self::Error> {
351        self.infer_deserialize(visitor)
352    }
353
354    fn deserialize_bool<V: Visitor<'de>>(
355        self,
356        visitor: V,
357    ) -> Result<V::Value, Self::Error> {
358        visitor.visit_bool(
359            self.next_field()?
360                .parse()
361                .map_err(|err| self.error(DEK::ParseBool(err)))?,
362        )
363    }
364
365    deserialize_int!(deserialize_u8, visit_u8, u8);
366    deserialize_int!(deserialize_u16, visit_u16, u16);
367    deserialize_int!(deserialize_u32, visit_u32, u32);
368    deserialize_int!(deserialize_u64, visit_u64, u64);
369    serde_if_integer128! {
370        deserialize_int!(deserialize_u128, visit_u128, u128);
371    }
372    deserialize_int!(deserialize_i8, visit_i8, i8);
373    deserialize_int!(deserialize_i16, visit_i16, i16);
374    deserialize_int!(deserialize_i32, visit_i32, i32);
375    deserialize_int!(deserialize_i64, visit_i64, i64);
376    serde_if_integer128! {
377        deserialize_int!(deserialize_i128, visit_i128, i128);
378    }
379
380    fn deserialize_f32<V: Visitor<'de>>(
381        self,
382        visitor: V,
383    ) -> Result<V::Value, Self::Error> {
384        visitor.visit_f32(
385            self.next_field()?
386                .parse()
387                .map_err(|err| self.error(DEK::ParseFloat(err)))?,
388        )
389    }
390
391    fn deserialize_f64<V: Visitor<'de>>(
392        self,
393        visitor: V,
394    ) -> Result<V::Value, Self::Error> {
395        visitor.visit_f64(
396            self.next_field()?
397                .parse()
398                .map_err(|err| self.error(DEK::ParseFloat(err)))?,
399        )
400    }
401
402    fn deserialize_char<V: Visitor<'de>>(
403        self,
404        visitor: V,
405    ) -> Result<V::Value, Self::Error> {
406        let field = self.next_field()?;
407        let len = field.chars().count();
408        if len != 1 {
409            return Err(self.error(DEK::Message(format!(
410                "expected single character but got {} characters in '{}'",
411                len, field
412            ))));
413        }
414        visitor.visit_char(field.chars().next().unwrap())
415    }
416
417    fn deserialize_str<V: Visitor<'de>>(
418        self,
419        visitor: V,
420    ) -> Result<V::Value, Self::Error> {
421        self.next_field().and_then(|f| visitor.visit_borrowed_str(f))
422    }
423
424    fn deserialize_string<V: Visitor<'de>>(
425        self,
426        visitor: V,
427    ) -> Result<V::Value, Self::Error> {
428        self.next_field().and_then(|f| visitor.visit_str(f.into()))
429    }
430
431    fn deserialize_bytes<V: Visitor<'de>>(
432        self,
433        visitor: V,
434    ) -> Result<V::Value, Self::Error> {
435        self.next_field_bytes().and_then(|f| visitor.visit_borrowed_bytes(f))
436    }
437
438    fn deserialize_byte_buf<V: Visitor<'de>>(
439        self,
440        visitor: V,
441    ) -> Result<V::Value, Self::Error> {
442        self.next_field_bytes()
443            .and_then(|f| visitor.visit_byte_buf(f.to_vec()))
444    }
445
446    fn deserialize_option<V: Visitor<'de>>(
447        self,
448        visitor: V,
449    ) -> Result<V::Value, Self::Error> {
450        match self.peek_field() {
451            None => visitor.visit_none(),
452            Some(f) if f.is_empty() => {
453                self.next_field().expect("empty field");
454                visitor.visit_none()
455            }
456            Some(_) => visitor.visit_some(self),
457        }
458    }
459
460    fn deserialize_unit<V: Visitor<'de>>(
461        self,
462        visitor: V,
463    ) -> Result<V::Value, Self::Error> {
464        visitor.visit_unit()
465    }
466
467    fn deserialize_unit_struct<V: Visitor<'de>>(
468        self,
469        _name: &'static str,
470        visitor: V,
471    ) -> Result<V::Value, Self::Error> {
472        visitor.visit_unit()
473    }
474
475    fn deserialize_newtype_struct<V: Visitor<'de>>(
476        self,
477        _name: &'static str,
478        visitor: V,
479    ) -> Result<V::Value, Self::Error> {
480        visitor.visit_newtype_struct(self)
481    }
482
483    fn deserialize_seq<V: Visitor<'de>>(
484        self,
485        visitor: V,
486    ) -> Result<V::Value, Self::Error> {
487        visitor.visit_seq(self)
488    }
489
490    fn deserialize_tuple<V: Visitor<'de>>(
491        self,
492        _len: usize,
493        visitor: V,
494    ) -> Result<V::Value, Self::Error> {
495        visitor.visit_seq(self)
496    }
497
498    fn deserialize_tuple_struct<V: Visitor<'de>>(
499        self,
500        _name: &'static str,
501        _len: usize,
502        visitor: V,
503    ) -> Result<V::Value, Self::Error> {
504        visitor.visit_seq(self)
505    }
506
507    fn deserialize_map<V: Visitor<'de>>(
508        self,
509        visitor: V,
510    ) -> Result<V::Value, Self::Error> {
511        if !self.has_headers() {
512            visitor.visit_seq(self)
513        } else {
514            visitor.visit_map(self)
515        }
516    }
517
518    fn deserialize_struct<V: Visitor<'de>>(
519        self,
520        _name: &'static str,
521        _fields: &'static [&'static str],
522        visitor: V,
523    ) -> Result<V::Value, Self::Error> {
524        if !self.has_headers() {
525            visitor.visit_seq(self)
526        } else {
527            visitor.visit_map(self)
528        }
529    }
530
531    fn deserialize_identifier<V: Visitor<'de>>(
532        self,
533        _visitor: V,
534    ) -> Result<V::Value, Self::Error> {
535        Err(self.error(DEK::Unsupported("deserialize_identifier".into())))
536    }
537
538    fn deserialize_enum<V: Visitor<'de>>(
539        self,
540        _name: &'static str,
541        _variants: &'static [&'static str],
542        visitor: V,
543    ) -> Result<V::Value, Self::Error> {
544        visitor.visit_enum(self)
545    }
546
547    fn deserialize_ignored_any<V: Visitor<'de>>(
548        self,
549        visitor: V,
550    ) -> Result<V::Value, Self::Error> {
551        // Read and drop the next field.
552        // This code is reached, e.g., when trying to deserialize a header
553        // that doesn't exist in the destination struct.
554        let _ = self.next_field_bytes()?;
555        visitor.visit_unit()
556    }
557}
558
559impl<'a, 'de: 'a, T: DeRecord<'de>> EnumAccess<'de>
560    for &'a mut DeRecordWrap<T>
561{
562    type Error = DeserializeError;
563    type Variant = Self;
564
565    fn variant_seed<V: DeserializeSeed<'de>>(
566        self,
567        seed: V,
568    ) -> Result<(V::Value, Self::Variant), Self::Error> {
569        let variant_name = self.next_field()?;
570        seed.deserialize(variant_name.into_deserializer()).map(|v| (v, self))
571    }
572}
573
574impl<'a, 'de: 'a, T: DeRecord<'de>> VariantAccess<'de>
575    for &'a mut DeRecordWrap<T>
576{
577    type Error = DeserializeError;
578
579    fn unit_variant(self) -> Result<(), Self::Error> {
580        Ok(())
581    }
582
583    fn newtype_variant_seed<U: DeserializeSeed<'de>>(
584        self,
585        _seed: U,
586    ) -> Result<U::Value, Self::Error> {
587        let unexp = Unexpected::UnitVariant;
588        Err(DeserializeError::invalid_type(unexp, &"newtype variant"))
589    }
590
591    fn tuple_variant<V: Visitor<'de>>(
592        self,
593        _len: usize,
594        _visitor: V,
595    ) -> Result<V::Value, Self::Error> {
596        let unexp = Unexpected::UnitVariant;
597        Err(DeserializeError::invalid_type(unexp, &"tuple variant"))
598    }
599
600    fn struct_variant<V: Visitor<'de>>(
601        self,
602        _fields: &'static [&'static str],
603        _visitor: V,
604    ) -> Result<V::Value, Self::Error> {
605        let unexp = Unexpected::UnitVariant;
606        Err(DeserializeError::invalid_type(unexp, &"struct variant"))
607    }
608}
609
610impl<'a, 'de: 'a, T: DeRecord<'de>> SeqAccess<'de>
611    for &'a mut DeRecordWrap<T>
612{
613    type Error = DeserializeError;
614
615    fn next_element_seed<U: DeserializeSeed<'de>>(
616        &mut self,
617        seed: U,
618    ) -> Result<Option<U::Value>, Self::Error> {
619        if self.peek_field().is_none() {
620            Ok(None)
621        } else {
622            seed.deserialize(&mut **self).map(Some)
623        }
624    }
625}
626
627impl<'a, 'de: 'a, T: DeRecord<'de>> MapAccess<'de>
628    for &'a mut DeRecordWrap<T>
629{
630    type Error = DeserializeError;
631
632    fn next_key_seed<K: DeserializeSeed<'de>>(
633        &mut self,
634        seed: K,
635    ) -> Result<Option<K::Value>, Self::Error> {
636        assert!(self.has_headers());
637        let field = match self.next_header_bytes()? {
638            None => return Ok(None),
639            Some(field) => field,
640        };
641        seed.deserialize(BorrowedBytesDeserializer::new(field)).map(Some)
642    }
643
644    fn next_value_seed<K: DeserializeSeed<'de>>(
645        &mut self,
646        seed: K,
647    ) -> Result<K::Value, Self::Error> {
648        seed.deserialize(&mut **self)
649    }
650}
651
652/// An Serde deserialization error.
653#[derive(Clone, Debug, Eq, PartialEq)]
654pub struct DeserializeError {
655    field: Option<u64>,
656    kind: DeserializeErrorKind,
657}
658
659/// The type of a Serde deserialization error.
660#[derive(Clone, Debug, Eq, PartialEq)]
661pub enum DeserializeErrorKind {
662    /// A generic Serde deserialization error.
663    Message(String),
664    /// A generic Serde unsupported error.
665    Unsupported(String),
666    /// This error occurs when a Rust type expects to decode another field
667    /// from a row, but no more fields exist.
668    UnexpectedEndOfRow,
669    /// This error occurs when UTF-8 validation on a field fails. UTF-8
670    /// validation is only performed when the Rust type requires it (e.g.,
671    /// a `String` or `&str` type).
672    InvalidUtf8(str::Utf8Error),
673    /// This error occurs when a boolean value fails to parse.
674    ParseBool(str::ParseBoolError),
675    /// This error occurs when an integer value fails to parse.
676    ParseInt(num::ParseIntError),
677    /// This error occurs when a float value fails to parse.
678    ParseFloat(num::ParseFloatError),
679}
680
681impl SerdeError for DeserializeError {
682    fn custom<T: fmt::Display>(msg: T) -> DeserializeError {
683        DeserializeError { field: None, kind: DEK::Message(msg.to_string()) }
684    }
685}
686
687impl StdError for DeserializeError {
688    fn description(&self) -> &str {
689        self.kind.description()
690    }
691}
692
693impl fmt::Display for DeserializeError {
694    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
695        if let Some(field) = self.field {
696            write!(f, "field {}: {}", field, self.kind)
697        } else {
698            write!(f, "{}", self.kind)
699        }
700    }
701}
702
703impl fmt::Display for DeserializeErrorKind {
704    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
705        use self::DeserializeErrorKind::*;
706
707        match *self {
708            Message(ref msg) => write!(f, "{}", msg),
709            Unsupported(ref which) => {
710                write!(f, "unsupported deserializer method: {}", which)
711            }
712            UnexpectedEndOfRow => write!(f, "{}", self.description()),
713            InvalidUtf8(ref err) => err.fmt(f),
714            ParseBool(ref err) => err.fmt(f),
715            ParseInt(ref err) => err.fmt(f),
716            ParseFloat(ref err) => err.fmt(f),
717        }
718    }
719}
720
721impl DeserializeError {
722    /// Return the field index (starting at 0) of this error, if available.
723    pub fn field(&self) -> Option<u64> {
724        self.field
725    }
726
727    /// Return the underlying error kind.
728    pub fn kind(&self) -> &DeserializeErrorKind {
729        &self.kind
730    }
731}
732
733impl DeserializeErrorKind {
734    #[allow(deprecated)]
735    fn description(&self) -> &str {
736        use self::DeserializeErrorKind::*;
737
738        match *self {
739            Message(_) => "deserialization error",
740            Unsupported(_) => "unsupported deserializer method",
741            UnexpectedEndOfRow => "expected field, but got end of row",
742            InvalidUtf8(ref err) => err.description(),
743            ParseBool(ref err) => err.description(),
744            ParseInt(ref err) => err.description(),
745            ParseFloat(ref err) => err.description(),
746        }
747    }
748}
749
750serde_if_integer128! {
751    fn try_positive_integer128(s: &str) -> Option<u128> {
752        s.parse().ok()
753    }
754
755    fn try_negative_integer128(s: &str) -> Option<i128> {
756        s.parse().ok()
757    }
758}
759
760fn try_positive_integer64(s: &str) -> Option<u64> {
761    s.parse().ok()
762}
763
764fn try_negative_integer64(s: &str) -> Option<i64> {
765    s.parse().ok()
766}
767
768fn try_float(s: &str) -> Option<f64> {
769    s.parse().ok()
770}
771
772fn try_positive_integer64_bytes(s: &[u8]) -> Option<u64> {
773    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
774}
775
776fn try_negative_integer64_bytes(s: &[u8]) -> Option<i64> {
777    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
778}
779
780serde_if_integer128! {
781    fn try_positive_integer128_bytes(s: &[u8]) -> Option<u128> {
782        str::from_utf8(s).ok().and_then(|s| s.parse().ok())
783    }
784
785    fn try_negative_integer128_bytes(s: &[u8]) -> Option<i128> {
786        str::from_utf8(s).ok().and_then(|s| s.parse().ok())
787    }
788}
789
790fn try_float_bytes(s: &[u8]) -> Option<f64> {
791    str::from_utf8(s).ok().and_then(|s| s.parse().ok())
792}
793
794#[cfg(test)]
795mod tests {
796    use std::collections::HashMap;
797
798    use {
799        bstr::BString,
800        serde::{de::DeserializeOwned, serde_if_integer128, Deserialize},
801    };
802
803    use crate::{
804        byte_record::ByteRecord, error::Error, string_record::StringRecord,
805    };
806
807    use super::{deserialize_byte_record, deserialize_string_record};
808
809    fn de<D: DeserializeOwned>(fields: &[&str]) -> Result<D, Error> {
810        let record = StringRecord::from(fields);
811        deserialize_string_record(&record, None)
812    }
813
814    fn de_headers<D: DeserializeOwned>(
815        headers: &[&str],
816        fields: &[&str],
817    ) -> Result<D, Error> {
818        let headers = StringRecord::from(headers);
819        let record = StringRecord::from(fields);
820        deserialize_string_record(&record, Some(&headers))
821    }
822
823    fn b<'a, T: AsRef<[u8]> + ?Sized>(bytes: &'a T) -> &'a [u8] {
824        bytes.as_ref()
825    }
826
827    #[test]
828    fn with_header() {
829        #[derive(Deserialize, Debug, PartialEq)]
830        struct Foo {
831            z: f64,
832            y: i32,
833            x: String,
834        }
835
836        let got: Foo =
837            de_headers(&["x", "y", "z"], &["hi", "42", "1.3"]).unwrap();
838        assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
839    }
840
841    #[test]
842    fn with_header_unknown() {
843        #[derive(Deserialize, Debug, PartialEq)]
844        #[serde(deny_unknown_fields)]
845        struct Foo {
846            z: f64,
847            y: i32,
848            x: String,
849        }
850        assert!(de_headers::<Foo>(
851            &["a", "x", "y", "z"],
852            &["foo", "hi", "42", "1.3"],
853        )
854        .is_err());
855    }
856
857    #[test]
858    fn with_header_missing() {
859        #[derive(Deserialize, Debug, PartialEq)]
860        struct Foo {
861            z: f64,
862            y: i32,
863            x: String,
864        }
865        assert!(de_headers::<Foo>(&["y", "z"], &["42", "1.3"],).is_err());
866    }
867
868    #[test]
869    fn with_header_missing_ok() {
870        #[derive(Deserialize, Debug, PartialEq)]
871        struct Foo {
872            z: f64,
873            y: i32,
874            x: Option<String>,
875        }
876
877        let got: Foo = de_headers(&["y", "z"], &["42", "1.3"]).unwrap();
878        assert_eq!(got, Foo { x: None, y: 42, z: 1.3 });
879    }
880
881    #[test]
882    fn with_header_no_fields() {
883        #[derive(Deserialize, Debug, PartialEq)]
884        struct Foo {
885            z: f64,
886            y: i32,
887            x: Option<String>,
888        }
889
890        let got = de_headers::<Foo>(&["y", "z"], &[]);
891        assert!(got.is_err());
892    }
893
894    #[test]
895    fn with_header_empty() {
896        #[derive(Deserialize, Debug, PartialEq)]
897        struct Foo {
898            z: f64,
899            y: i32,
900            x: Option<String>,
901        }
902
903        let got = de_headers::<Foo>(&[], &[]);
904        assert!(got.is_err());
905    }
906
907    #[test]
908    fn with_header_empty_ok() {
909        #[derive(Deserialize, Debug, PartialEq)]
910        struct Foo;
911
912        #[derive(Deserialize, Debug, PartialEq)]
913        struct Bar {}
914
915        let got = de_headers::<Foo>(&[], &[]);
916        assert_eq!(got.unwrap(), Foo);
917
918        let got = de_headers::<Bar>(&[], &[]);
919        assert_eq!(got.unwrap(), Bar {});
920
921        let got = de_headers::<()>(&[], &[]);
922        assert_eq!(got.unwrap(), ());
923    }
924
925    #[test]
926    fn without_header() {
927        #[derive(Deserialize, Debug, PartialEq)]
928        struct Foo {
929            z: f64,
930            y: i32,
931            x: String,
932        }
933
934        let got: Foo = de(&["1.3", "42", "hi"]).unwrap();
935        assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
936    }
937
938    #[test]
939    fn no_fields() {
940        assert!(de::<String>(&[]).is_err());
941    }
942
943    #[test]
944    fn one_field() {
945        let got: i32 = de(&["42"]).unwrap();
946        assert_eq!(got, 42);
947    }
948
949    serde_if_integer128! {
950        #[test]
951        fn one_field_128() {
952            let got: i128 = de(&["2010223372036854775808"]).unwrap();
953            assert_eq!(got, 2010223372036854775808);
954        }
955    }
956
957    #[test]
958    fn two_fields() {
959        let got: (i32, bool) = de(&["42", "true"]).unwrap();
960        assert_eq!(got, (42, true));
961
962        #[derive(Deserialize, Debug, PartialEq)]
963        struct Foo(i32, bool);
964
965        let got: Foo = de(&["42", "true"]).unwrap();
966        assert_eq!(got, Foo(42, true));
967    }
968
969    #[test]
970    fn two_fields_too_many() {
971        let got: (i32, bool) = de(&["42", "true", "z", "z"]).unwrap();
972        assert_eq!(got, (42, true));
973    }
974
975    #[test]
976    fn two_fields_too_few() {
977        assert!(de::<(i32, bool)>(&["42"]).is_err());
978    }
979
980    #[test]
981    fn one_char() {
982        let got: char = de(&["a"]).unwrap();
983        assert_eq!(got, 'a');
984    }
985
986    #[test]
987    fn no_chars() {
988        assert!(de::<char>(&[""]).is_err());
989    }
990
991    #[test]
992    fn too_many_chars() {
993        assert!(de::<char>(&["ab"]).is_err());
994    }
995
996    #[test]
997    fn simple_seq() {
998        let got: Vec<i32> = de(&["1", "5", "10"]).unwrap();
999        assert_eq!(got, vec![1, 5, 10]);
1000    }
1001
1002    #[test]
1003    fn simple_hex_seq() {
1004        let got: Vec<i32> = de(&["0x7F", "0xA9", "0x10"]).unwrap();
1005        assert_eq!(got, vec![0x7F, 0xA9, 0x10]);
1006    }
1007
1008    #[test]
1009    fn mixed_hex_seq() {
1010        let got: Vec<i32> = de(&["0x7F", "0xA9", "10"]).unwrap();
1011        assert_eq!(got, vec![0x7F, 0xA9, 10]);
1012    }
1013
1014    #[test]
1015    fn bad_hex_seq() {
1016        assert!(de::<Vec<u8>>(&["7F", "0xA9", "10"]).is_err());
1017    }
1018
1019    #[test]
1020    fn seq_in_struct() {
1021        #[derive(Deserialize, Debug, PartialEq)]
1022        struct Foo {
1023            xs: Vec<i32>,
1024        }
1025        let got: Foo = de(&["1", "5", "10"]).unwrap();
1026        assert_eq!(got, Foo { xs: vec![1, 5, 10] });
1027    }
1028
1029    #[test]
1030    fn seq_in_struct_tail() {
1031        #[derive(Deserialize, Debug, PartialEq)]
1032        struct Foo {
1033            label: String,
1034            xs: Vec<i32>,
1035        }
1036        let got: Foo = de(&["foo", "1", "5", "10"]).unwrap();
1037        assert_eq!(got, Foo { label: "foo".into(), xs: vec![1, 5, 10] });
1038    }
1039
1040    #[test]
1041    fn map_headers() {
1042        let got: HashMap<String, i32> =
1043            de_headers(&["a", "b", "c"], &["1", "5", "10"]).unwrap();
1044        assert_eq!(got.len(), 3);
1045        assert_eq!(got["a"], 1);
1046        assert_eq!(got["b"], 5);
1047        assert_eq!(got["c"], 10);
1048    }
1049
1050    #[test]
1051    fn map_no_headers() {
1052        let got = de::<HashMap<String, i32>>(&["1", "5", "10"]);
1053        assert!(got.is_err());
1054    }
1055
1056    #[test]
1057    fn bytes() {
1058        let got: Vec<u8> = de::<BString>(&["foobar"]).unwrap().into();
1059        assert_eq!(got, b"foobar".to_vec());
1060    }
1061
1062    #[test]
1063    fn adjacent_fixed_arrays() {
1064        let got: ([u32; 2], [u32; 2]) = de(&["1", "5", "10", "15"]).unwrap();
1065        assert_eq!(got, ([1, 5], [10, 15]));
1066    }
1067
1068    #[test]
1069    fn enum_label_simple_tagged() {
1070        #[derive(Deserialize, Debug, PartialEq)]
1071        struct Row {
1072            label: Label,
1073            x: f64,
1074        }
1075
1076        #[derive(Deserialize, Debug, PartialEq)]
1077        #[serde(rename_all = "snake_case")]
1078        enum Label {
1079            Foo,
1080            Bar,
1081            Baz,
1082        }
1083
1084        let got: Row = de_headers(&["label", "x"], &["bar", "5"]).unwrap();
1085        assert_eq!(got, Row { label: Label::Bar, x: 5.0 });
1086    }
1087
1088    #[test]
1089    fn enum_untagged() {
1090        #[derive(Deserialize, Debug, PartialEq)]
1091        struct Row {
1092            x: Boolish,
1093            y: Boolish,
1094            z: Boolish,
1095        }
1096
1097        #[derive(Deserialize, Debug, PartialEq)]
1098        #[serde(rename_all = "snake_case")]
1099        #[serde(untagged)]
1100        enum Boolish {
1101            Bool(bool),
1102            Number(i64),
1103            String(String),
1104        }
1105
1106        let got: Row =
1107            de_headers(&["x", "y", "z"], &["true", "null", "1"]).unwrap();
1108        assert_eq!(
1109            got,
1110            Row {
1111                x: Boolish::Bool(true),
1112                y: Boolish::String("null".into()),
1113                z: Boolish::Number(1),
1114            }
1115        );
1116    }
1117
1118    #[test]
1119    fn option_empty_field() {
1120        #[derive(Deserialize, Debug, PartialEq)]
1121        struct Foo {
1122            a: Option<i32>,
1123            b: String,
1124            c: Option<i32>,
1125        }
1126
1127        let got: Foo =
1128            de_headers(&["a", "b", "c"], &["", "foo", "5"]).unwrap();
1129        assert_eq!(got, Foo { a: None, b: "foo".into(), c: Some(5) });
1130    }
1131
1132    #[test]
1133    fn option_invalid_field() {
1134        #[derive(Deserialize, Debug, PartialEq)]
1135        struct Foo {
1136            #[serde(deserialize_with = "crate::invalid_option")]
1137            a: Option<i32>,
1138            #[serde(deserialize_with = "crate::invalid_option")]
1139            b: Option<i32>,
1140            #[serde(deserialize_with = "crate::invalid_option")]
1141            c: Option<i32>,
1142        }
1143
1144        let got: Foo =
1145            de_headers(&["a", "b", "c"], &["xyz", "", "5"]).unwrap();
1146        assert_eq!(got, Foo { a: None, b: None, c: Some(5) });
1147    }
1148
1149    #[test]
1150    fn borrowed() {
1151        #[derive(Deserialize, Debug, PartialEq)]
1152        struct Foo<'a, 'c> {
1153            a: &'a str,
1154            b: i32,
1155            c: &'c str,
1156        }
1157
1158        let headers = StringRecord::from(vec!["a", "b", "c"]);
1159        let record = StringRecord::from(vec!["foo", "5", "bar"]);
1160        let got: Foo =
1161            deserialize_string_record(&record, Some(&headers)).unwrap();
1162        assert_eq!(got, Foo { a: "foo", b: 5, c: "bar" });
1163    }
1164
1165    #[test]
1166    fn borrowed_map() {
1167        use std::collections::HashMap;
1168
1169        let headers = StringRecord::from(vec!["a", "b", "c"]);
1170        let record = StringRecord::from(vec!["aardvark", "bee", "cat"]);
1171        let got: HashMap<&str, &str> =
1172            deserialize_string_record(&record, Some(&headers)).unwrap();
1173
1174        let expected: HashMap<&str, &str> =
1175            headers.iter().zip(&record).collect();
1176        assert_eq!(got, expected);
1177    }
1178
1179    #[test]
1180    fn borrowed_map_bytes() {
1181        use std::collections::HashMap;
1182
1183        let headers = ByteRecord::from(vec![b"a", b"\xFF", b"c"]);
1184        let record = ByteRecord::from(vec!["aardvark", "bee", "cat"]);
1185        let got: HashMap<&[u8], &[u8]> =
1186            deserialize_byte_record(&record, Some(&headers)).unwrap();
1187
1188        let expected: HashMap<&[u8], &[u8]> =
1189            headers.iter().zip(&record).collect();
1190        assert_eq!(got, expected);
1191    }
1192
1193    #[test]
1194    fn flatten() {
1195        #[derive(Deserialize, Debug, PartialEq)]
1196        struct Input {
1197            x: f64,
1198            y: f64,
1199        }
1200
1201        #[derive(Deserialize, Debug, PartialEq)]
1202        struct Properties {
1203            prop1: f64,
1204            prop2: f64,
1205        }
1206
1207        #[derive(Deserialize, Debug, PartialEq)]
1208        struct Row {
1209            #[serde(flatten)]
1210            input: Input,
1211            #[serde(flatten)]
1212            properties: Properties,
1213        }
1214
1215        let header = StringRecord::from(vec!["x", "y", "prop1", "prop2"]);
1216        let record = StringRecord::from(vec!["1", "2", "3", "4"]);
1217        let got: Row = record.deserialize(Some(&header)).unwrap();
1218        assert_eq!(
1219            got,
1220            Row {
1221                input: Input { x: 1.0, y: 2.0 },
1222                properties: Properties { prop1: 3.0, prop2: 4.0 },
1223            }
1224        );
1225    }
1226
1227    #[test]
1228    fn partially_invalid_utf8() {
1229        #[derive(Debug, Deserialize, PartialEq)]
1230        struct Row {
1231            h1: String,
1232            h2: BString,
1233            h3: String,
1234        }
1235
1236        let headers = ByteRecord::from(vec![b"h1", b"h2", b"h3"]);
1237        let record =
1238            ByteRecord::from(vec![b(b"baz"), b(b"foo\xFFbar"), b(b"quux")]);
1239        let got: Row =
1240            deserialize_byte_record(&record, Some(&headers)).unwrap();
1241        assert_eq!(
1242            got,
1243            Row {
1244                h1: "baz".to_string(),
1245                h2: BString::from(b"foo\xFFbar".to_vec()),
1246                h3: "quux".to_string(),
1247            }
1248        );
1249    }
1250}