ssh_format/
de.rs

1use std::{borrow::Cow, convert::TryInto, iter, str};
2
3use serde::de::{self, DeserializeSeed, IntoDeserializer, SeqAccess, VariantAccess, Visitor};
4use serde::Deserialize;
5
6use crate::{Error, Result};
7
8#[derive(Copy, Clone, Debug)]
9pub struct Deserializer<'de, It> {
10    slice: &'de [u8],
11    iter: It,
12}
13
14impl<'de, It> Deserializer<'de, It> {
15    pub const fn new(iter: It) -> Self {
16        Self { iter, slice: &[] }
17    }
18
19    pub fn into_inner(self) -> (&'de [u8], It) {
20        (self.slice, self.iter)
21    }
22}
23
24impl<'de> Deserializer<'de, iter::Empty<&'de [u8]>> {
25    pub const fn from_bytes(slice: &'de [u8]) -> Self {
26        Self {
27            slice,
28            iter: iter::empty(),
29        }
30    }
31}
32
33/// Return a deserialized value and trailing bytes.
34///
35/// # Example
36///
37/// Simple Usage:
38///
39/// ```ignore
40/// let serialized = to_bytes(value).unwrap();
41/// // Ignore the size
42/// let (new_value, _trailing_bytes) = from_bytes::<T>(&serialized[4..]).unwrap();
43///
44/// assert_eq!(value, new_value);
45/// ```
46///
47/// Replace `T` with type of `value`.
48///
49/// More complicated one (sending over socket):
50///
51/// ```ignore
52/// let buffer = [0, 0, 0, 4];
53/// let (size: u32, _trailing_bytes) = from_bytes(&buffer).unwrap();
54///
55/// let buffer = [0, 0, 4, 0];
56/// let (val: <T>, _trailing_bytes) = from_bytes(&buffer).unwrap();
57/// ```
58///
59/// Replace `T` with your own type.
60pub fn from_bytes<'a, T>(s: &'a [u8]) -> Result<(T, &'a [u8])>
61where
62    T: Deserialize<'a>,
63{
64    let mut deserializer = Deserializer::from_bytes(s);
65    let t = T::deserialize(&mut deserializer)?;
66    Ok((t, deserializer.slice))
67}
68
69impl<'de, It> Deserializer<'de, It>
70where
71    It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
72{
73    /// Extract the loop as a separate function so that `Self::update_slice`
74    /// can be trivally inlined.
75    fn update_slice_inner(&mut self) {
76        self.slice = self.iter.find(|slice| !slice.is_empty()).unwrap_or(&[]);
77    }
78
79    #[inline]
80    fn update_slice(&mut self) {
81        if self.slice.is_empty() {
82            self.update_slice_inner();
83        }
84    }
85
86    fn next_byte(&mut self) -> Result<u8> {
87        self.update_slice();
88
89        let byte = self.slice.first().copied().ok_or(Error::Eof)?;
90        self.slice = &self.slice[1..];
91
92        Ok(byte)
93    }
94
95    fn fill_buffer(&mut self, mut buffer: &mut [u8]) -> Result<()> {
96        loop {
97            if buffer.is_empty() {
98                break Ok(());
99            }
100
101            self.update_slice();
102
103            if self.slice.is_empty() {
104                break Err(Error::Eof);
105            }
106
107            let n = self.slice.len().min(buffer.len());
108
109            buffer[..n].copy_from_slice(&self.slice[..n]);
110
111            self.slice = &self.slice[n..];
112            buffer = &mut buffer[n..];
113        }
114    }
115
116    /// * `SIZE` - must not be 0!
117    fn next_bytes_const<const SIZE: usize>(&mut self) -> Result<[u8; SIZE]> {
118        assert_ne!(SIZE, 0);
119
120        let mut bytes = [0_u8; SIZE];
121        self.fill_buffer(&mut bytes)?;
122
123        Ok(bytes)
124    }
125
126    fn next_u32(&mut self) -> Result<u32> {
127        Ok(u32::from_be_bytes(self.next_bytes_const()?))
128    }
129
130    fn next_bytes(&mut self, size: usize) -> Result<Cow<'de, [u8]>> {
131        self.update_slice();
132
133        if self.slice.len() >= size {
134            let slice = &self.slice[..size];
135            self.slice = &self.slice[size..];
136
137            Ok(Cow::Borrowed(slice))
138        } else {
139            let mut bytes = vec![0_u8; size];
140            self.fill_buffer(&mut bytes)?;
141            Ok(Cow::Owned(bytes))
142        }
143    }
144
145    /// Parse &str and &[u8]
146    fn parse_bytes(&mut self) -> Result<Cow<'de, [u8]>> {
147        let len: usize = self.next_u32()?.try_into().map_err(|_| Error::TooLong)?;
148        self.next_bytes(len)
149    }
150
151    /// Is there any remaining data.
152    pub fn has_remaining_data(&mut self) -> bool {
153        self.update_slice();
154        !self.slice.is_empty()
155    }
156}
157
158macro_rules! impl_for_deserialize_primitive {
159    ( $name:ident, $visitor_fname:ident, $type:ty ) => {
160        fn $name<V>(self, visitor: V) -> Result<V::Value>
161        where
162            V: Visitor<'de>,
163        {
164            visitor.$visitor_fname(<$type>::from_be_bytes(self.next_bytes_const()?))
165        }
166    };
167}
168
169impl<'de, 'a, It> de::Deserializer<'de> for &'a mut Deserializer<'de, It>
170where
171    It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
172{
173    type Error = Error;
174
175    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
176    where
177        V: Visitor<'de>,
178    {
179        match self.next_u32()? {
180            1 => visitor.visit_bool(true),
181            0 => visitor.visit_bool(false),
182            _ => Err(Error::InvalidBoolEncoding),
183        }
184    }
185
186    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
187    where
188        V: Visitor<'de>,
189    {
190        visitor.visit_u8(self.next_byte()?)
191    }
192
193    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
194    where
195        V: Visitor<'de>,
196    {
197        visitor.visit_i8(self.next_byte()? as i8)
198    }
199
200    impl_for_deserialize_primitive!(deserialize_i16, visit_i16, i16);
201    impl_for_deserialize_primitive!(deserialize_i32, visit_i32, i32);
202    impl_for_deserialize_primitive!(deserialize_i64, visit_i64, i64);
203
204    impl_for_deserialize_primitive!(deserialize_u16, visit_u16, u16);
205    impl_for_deserialize_primitive!(deserialize_u32, visit_u32, u32);
206    impl_for_deserialize_primitive!(deserialize_u64, visit_u64, u64);
207
208    impl_for_deserialize_primitive!(deserialize_f32, visit_f32, f32);
209    impl_for_deserialize_primitive!(deserialize_f64, visit_f64, f64);
210
211    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
212    where
213        V: Visitor<'de>,
214    {
215        match char::from_u32(self.next_u32()?) {
216            Some(ch) => visitor.visit_char(ch),
217            None => Err(Error::InvalidChar),
218        }
219    }
220
221    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
222    where
223        V: Visitor<'de>,
224    {
225        match self.parse_bytes()? {
226            Cow::Owned(owned_bytes) => visitor.visit_string(String::from_utf8(owned_bytes)?),
227            Cow::Borrowed(bytes) => visitor.visit_borrowed_str(str::from_utf8(bytes)?),
228        }
229    }
230
231    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
232    where
233        V: Visitor<'de>,
234    {
235        self.deserialize_str(visitor)
236    }
237
238    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
239    where
240        V: Visitor<'de>,
241    {
242        match self.parse_bytes()? {
243            Cow::Owned(owned_bytes) => visitor.visit_byte_buf(owned_bytes),
244            Cow::Borrowed(bytes) => visitor.visit_borrowed_bytes(bytes),
245        }
246    }
247
248    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
249    where
250        V: Visitor<'de>,
251    {
252        self.deserialize_bytes(visitor)
253    }
254
255    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
256    where
257        V: Visitor<'de>,
258    {
259        visitor.visit_unit()
260    }
261
262    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
263    where
264        V: Visitor<'de>,
265    {
266        self.deserialize_unit(visitor)
267    }
268
269    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
270    where
271        V: Visitor<'de>,
272    {
273        visitor.visit_newtype_struct(self)
274    }
275
276    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
277    where
278        V: Visitor<'de>,
279    {
280        visitor.visit_seq(Access {
281            deserializer: self,
282            len,
283        })
284    }
285
286    fn deserialize_tuple_struct<V>(
287        self,
288        _name: &'static str,
289        len: usize,
290        visitor: V,
291    ) -> Result<V::Value>
292    where
293        V: Visitor<'de>,
294    {
295        self.deserialize_tuple(len, visitor)
296    }
297
298    fn deserialize_struct<V>(
299        self,
300        _name: &'static str,
301        fields: &'static [&'static str],
302        visitor: V,
303    ) -> Result<V::Value>
304    where
305        V: Visitor<'de>,
306    {
307        self.deserialize_tuple(fields.len(), visitor)
308    }
309
310    fn deserialize_enum<V>(
311        self,
312        _name: &'static str,
313        _variants: &'static [&'static str],
314        visitor: V,
315    ) -> Result<V::Value>
316    where
317        V: Visitor<'de>,
318    {
319        impl<'a, 'de, It> serde::de::EnumAccess<'de> for &'a mut Deserializer<'de, It>
320        where
321            It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
322        {
323            type Error = Error;
324            type Variant = Self;
325
326            fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
327            where
328                V: de::DeserializeSeed<'de>,
329            {
330                let idx: u32 = self.next_u32()?;
331                let val: Result<_> = seed.deserialize(idx.into_deserializer());
332                Ok((val?, self))
333            }
334        }
335
336        visitor.visit_enum(self)
337    }
338
339    #[cfg(feature = "is_human_readable")]
340    /// Always return `false`
341    fn is_human_readable(&self) -> bool {
342        false
343    }
344
345    /// Unsupported
346    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
347    where
348        V: Visitor<'de>,
349    {
350        let len = self.next_u32()? as usize;
351        visitor.visit_seq(Access {
352            deserializer: self,
353            len,
354        })
355    }
356
357    /// Unsupported
358    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
359    where
360        V: Visitor<'de>,
361    {
362        Err(Error::Unsupported(&"deserialize_any"))
363    }
364
365    /// Unsupported
366    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value>
367    where
368        V: Visitor<'de>,
369    {
370        Err(Error::Unsupported(&"deserialize_option"))
371    }
372
373    /// Unsupported
374    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
375    where
376        V: Visitor<'de>,
377    {
378        Err(Error::Unsupported(&"deserialize_map"))
379    }
380
381    /// Unsupported
382    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
383    where
384        V: Visitor<'de>,
385    {
386        Err(Error::Unsupported(&"deserialize_identifier"))
387    }
388
389    /// Unsupported
390    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
391    where
392        V: Visitor<'de>,
393    {
394        Err(Error::Unsupported(&"deserialize_ignored_any"))
395    }
396}
397
398impl<'a, 'de, It> VariantAccess<'de> for &'a mut Deserializer<'de, It>
399where
400    It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
401{
402    type Error = Error;
403
404    fn unit_variant(self) -> Result<()> {
405        Ok(())
406    }
407
408    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
409    where
410        T: DeserializeSeed<'de>,
411    {
412        DeserializeSeed::deserialize(seed, self)
413    }
414
415    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
416    where
417        V: Visitor<'de>,
418    {
419        de::Deserializer::deserialize_tuple(self, len, visitor)
420    }
421
422    fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
423    where
424        V: Visitor<'de>,
425    {
426        de::Deserializer::deserialize_tuple(self, fields.len(), visitor)
427    }
428}
429
430struct Access<'a, 'de, It> {
431    deserializer: &'a mut Deserializer<'de, It>,
432    len: usize,
433}
434
435impl<'a, 'de, It> SeqAccess<'de> for Access<'a, 'de, It>
436where
437    It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
438{
439    type Error = Error;
440
441    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
442    where
443        T: DeserializeSeed<'de>,
444    {
445        if self.len > 0 {
446            self.len -= 1;
447            let value = seed.deserialize(&mut *self.deserializer)?;
448            Ok(Some(value))
449        } else {
450            Ok(None)
451        }
452    }
453
454    fn size_hint(&self) -> Option<usize> {
455        Some(self.len)
456    }
457}
458
459/// Test deserialization
460#[cfg(test)]
461mod tests {
462    use std::fmt::Debug;
463
464    use assert_matches::assert_matches;
465    use generator::{done, Gn};
466    use itertools::Itertools;
467    use serde::{Deserialize, Serialize};
468
469    use super::*;
470    use crate::to_bytes;
471
472    /// Generate subslices, plus stuffing empty slices into the returned
473    /// iterator.
474    fn generate_subslices(mut bytes: &[u8], chunk_size: usize) -> impl Iterator<Item = &[u8]> {
475        assert_ne!(chunk_size, 0);
476
477        Gn::new_scoped(move |mut s| loop {
478            for _ in 0..8 {
479                // Stuffing empty slices
480                s.yield_(&bytes[..0]);
481            }
482
483            let n = bytes.len().min(chunk_size);
484            s.yield_(&bytes[..n]);
485            bytes = &bytes[n..];
486
487            if bytes.is_empty() {
488                done!();
489            }
490        })
491    }
492
493    /// First serialize value, then deserialize it.
494    fn test_roundtrip<'de, T>(value: &T)
495    where
496        T: Debug + Eq + Serialize + Deserialize<'de>,
497    {
498        let serialized = to_bytes(value).unwrap().leak();
499        // Ignore the size
500        let serialized = &serialized[4..];
501
502        // Test from_bytes
503        assert_eq!(from_bytes::<T>(serialized).unwrap().0, *value);
504
505        // Test cutting it into multiple small vectors
506        for chunk_size in 1..serialized.len() {
507            let mut deserializer =
508                Deserializer::new(generate_subslices(serialized, chunk_size).fuse());
509            let val = T::deserialize(&mut deserializer).unwrap();
510            assert_eq!(val, *value);
511
512            let (slice, mut iter) = deserializer.into_inner();
513
514            assert_eq!(slice, &[]);
515            assert_eq!(iter.next(), None);
516        }
517    }
518
519    #[test]
520    fn test_integer() {
521        test_roundtrip(&0x12_u8);
522        test_roundtrip(&0x1234_u16);
523        test_roundtrip(&0x12345678_u32);
524        test_roundtrip(&0x1234567887654321_u64);
525    }
526
527    #[test]
528    fn test_boolean() {
529        test_roundtrip(&true);
530        test_roundtrip(&false);
531    }
532
533    #[test]
534    fn test_str() {
535        let s = "Hello, world!";
536        let serialized = to_bytes(&s).unwrap();
537        // Ignore the size
538        let deserialized: &str = from_bytes(&serialized[4..]).unwrap().0;
539        assert_eq!(deserialized, s);
540    }
541
542    #[test]
543    fn test_seq() {
544        test_roundtrip(&vec![0x00_u8, 0x01_u8, 0x10_u8, 0x78_u8]);
545        test_roundtrip(&vec![0x0010_u16, 0x0100_u16, 0x1034_u16, 0x7812_u16]);
546    }
547
548    #[test]
549    fn test_tuple() {
550        test_roundtrip(&(0x00_u8, 0x0100_u16, 0x1034_u16, 0x7812_u16));
551    }
552
553    #[test]
554    fn test_struct() {
555        #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
556        struct S {
557            v1: u8,
558            v2: u16,
559            v3: u16,
560            v4: u16,
561        }
562        test_roundtrip(&S {
563            v1: 0x00,
564            v2: 0x0100,
565            v3: 0x1034,
566            v4: 0x7812,
567        });
568    }
569
570    #[test]
571    fn test_struct2() {
572        #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
573        struct S<'a> {
574            v1: u8,
575            v2: u16,
576            v3: u16,
577            v4: u16,
578            #[serde(borrow)]
579            v5: Cow<'a, str>,
580        }
581        test_roundtrip(&S {
582            v1: 0x00,
583            v2: 0x0100,
584            v3: 0x1034,
585            v4: 0x7812,
586            v5: Cow::Owned((0..100).join(", ")),
587        });
588    }
589
590    /// Test EOF error
591    #[test]
592    fn test_eof_error() {
593        assert_matches!(from_bytes::<u8>(&[]), Err(Error::Eof));
594
595        let s = "Hello, world!";
596        let serialized = to_bytes(&s).unwrap();
597        assert_matches!(
598            from_bytes::<String>(&serialized[0..serialized.len() - 1]),
599            Err(Error::Eof)
600        );
601    }
602}