ssh_format/
ser.rs

1use serde::{ser, Serialize};
2use std::convert::TryInto;
3
4use crate::{Error, Result, SerOutput};
5
6fn usize_to_u32(v: usize) -> Result<u32> {
7    v.try_into().map_err(|_| Error::TooLong)
8}
9
10#[derive(Clone, Debug)]
11pub struct Serializer<T: SerOutput = Vec<u8>> {
12    pub output: T,
13    len: usize,
14}
15
16impl<T: SerOutput + Default> Default for Serializer<T> {
17    fn default() -> Self {
18        Self::new(Default::default())
19    }
20}
21
22impl<T: SerOutput> Serializer<T> {
23    pub fn new(output: T) -> Self {
24        Self { output, len: 0 }
25    }
26
27    pub fn reserve(&mut self, additional: usize) {
28        self.output.reserve(additional);
29    }
30
31    /// * `len` - length of additional data included in the packet.
32    pub fn create_header(&self, len: u32) -> Result<[u8; 4]> {
33        let len: u32 = usize_to_u32(self.len + len as usize)?;
34
35        Ok(len.to_be_bytes())
36    }
37
38    /// Reset the internal counter.
39    /// This would cause [`Self::create_header`] to return `Ok([0, 0, 0, 0])`
40    /// until you call [`Serialize::serialize`] again.
41    pub fn reset_counter(&mut self) {
42        self.len = 0;
43    }
44
45    fn extend_from_slice(&mut self, other: &[u8]) {
46        self.output.extend_from_slice(other);
47        self.len += other.len();
48    }
49
50    fn push(&mut self, byte: u8) {
51        self.output.push(byte);
52        self.len += 1;
53    }
54
55    fn serialize_usize(&mut self, v: usize) -> Result<()> {
56        ser::Serializer::serialize_u32(self, usize_to_u32(v)?)
57    }
58}
59
60/// Return a byte array with the first 4 bytes representing the size
61/// of the rest of the serialized message.
62///
63/// See doc of `from_bytes` for examples.
64pub fn to_bytes<T>(value: &T) -> Result<Vec<u8>>
65where
66    T: Serialize,
67{
68    let mut buffer = vec![0, 0, 0, 0];
69
70    let mut serializer = Serializer::new(&mut buffer);
71    value.serialize(&mut serializer)?;
72    let header = serializer.create_header(0)?;
73
74    buffer[..4].copy_from_slice(&header);
75
76    Ok(buffer)
77}
78
79macro_rules! impl_for_serialize_primitive {
80    ( $name:ident, $type:ty ) => {
81        fn $name(self, v: $type) -> Result<()> {
82            self.extend_from_slice(&v.to_be_bytes());
83            Ok(())
84        }
85    };
86}
87
88impl<'a, Container: SerOutput> ser::Serializer for &'a mut Serializer<Container> {
89    type Ok = ();
90    type Error = Error;
91
92    type SerializeSeq = Self;
93    type SerializeTuple = Self;
94    type SerializeTupleStruct = Self;
95    type SerializeTupleVariant = Self;
96    type SerializeMap = Self;
97    type SerializeStruct = Self;
98    type SerializeStructVariant = Self;
99
100    fn serialize_bool(self, v: bool) -> Result<()> {
101        self.serialize_u32(v as u32)
102    }
103
104    fn serialize_u8(self, v: u8) -> Result<()> {
105        self.push(v);
106        Ok(())
107    }
108
109    fn serialize_i8(self, v: i8) -> Result<()> {
110        self.push(v as u8);
111        Ok(())
112    }
113
114    impl_for_serialize_primitive!(serialize_i16, i16);
115    impl_for_serialize_primitive!(serialize_i32, i32);
116    impl_for_serialize_primitive!(serialize_i64, i64);
117
118    impl_for_serialize_primitive!(serialize_u16, u16);
119    impl_for_serialize_primitive!(serialize_u32, u32);
120    impl_for_serialize_primitive!(serialize_u64, u64);
121
122    impl_for_serialize_primitive!(serialize_f32, f32);
123    impl_for_serialize_primitive!(serialize_f64, f64);
124
125    fn serialize_char(self, v: char) -> Result<()> {
126        self.serialize_u32(v as u32)
127    }
128
129    fn serialize_str(self, v: &str) -> Result<()> {
130        fn is_null_byte(byte: &u8) -> bool {
131            *byte == b'\0'
132        }
133
134        let bytes = v.as_bytes();
135
136        let null_byte_counts = bytes.iter().copied().filter(is_null_byte).count();
137
138        let len = bytes.len() - null_byte_counts;
139
140        // Reserve bytes
141        self.reserve(4 + len);
142
143        self.serialize_usize(len)?;
144
145        if null_byte_counts == 0 {
146            self.extend_from_slice(v.as_bytes());
147        } else {
148            bytes
149                .split(is_null_byte)
150                .filter(|slice| !slice.is_empty())
151                .for_each(|slice| {
152                    self.extend_from_slice(slice);
153                });
154        }
155
156        Ok(())
157    }
158
159    fn serialize_bytes(self, v: &[u8]) -> Result<()> {
160        self.reserve(4 + v.len());
161
162        self.serialize_usize(v.len())?;
163
164        self.extend_from_slice(v);
165
166        Ok(())
167    }
168
169    fn serialize_none(self) -> Result<()> {
170        Ok(())
171    }
172
173    fn serialize_some<T>(self, value: &T) -> Result<()>
174    where
175        T: ?Sized + Serialize,
176    {
177        value.serialize(self)
178    }
179
180    fn serialize_unit(self) -> Result<()> {
181        Ok(())
182    }
183
184    fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
185        self.serialize_unit()
186    }
187
188    fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
189    where
190        T: ?Sized + Serialize,
191    {
192        value.serialize(self)
193    }
194
195    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
196        if let Some(len) = len {
197            self.reserve(4 + len as usize);
198
199            self.serialize_usize(len)?;
200        }
201        Ok(self)
202    }
203
204    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
205        Ok(self)
206    }
207
208    fn serialize_tuple_struct(
209        self,
210        _name: &'static str,
211        len: usize,
212    ) -> Result<Self::SerializeTupleStruct> {
213        self.serialize_tuple(len)
214    }
215
216    fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
217        self.serialize_tuple(len)
218    }
219
220    fn serialize_unit_variant(
221        self,
222        _name: &'static str,
223        variant_index: u32,
224        _variant: &'static str,
225    ) -> Result<()> {
226        self.serialize_u32(variant_index)
227    }
228
229    fn serialize_newtype_variant<T>(
230        self,
231        name: &'static str,
232        variant_index: u32,
233        variant: &'static str,
234        value: &T,
235    ) -> Result<()>
236    where
237        T: ?Sized + Serialize,
238    {
239        self.serialize_unit_variant(name, variant_index, variant)?;
240        value.serialize(self)
241    }
242
243    fn serialize_tuple_variant(
244        self,
245        name: &'static str,
246        variant_index: u32,
247        variant: &'static str,
248        len: usize,
249    ) -> Result<Self::SerializeTupleVariant> {
250        self.serialize_unit_variant(name, variant_index, variant)?;
251        self.serialize_tuple(len)
252    }
253
254    fn serialize_struct_variant(
255        self,
256        name: &'static str,
257        variant_index: u32,
258        variant: &'static str,
259        len: usize,
260    ) -> Result<Self::SerializeStructVariant> {
261        self.serialize_unit_variant(name, variant_index, variant)?;
262        self.serialize_tuple(len)
263    }
264
265    #[cfg(feature = "is_human_readable")]
266    /// Always return false
267    fn is_human_readable(&self) -> bool {
268        false
269    }
270
271    /// Unsupported
272    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
273        Err(Error::Unsupported(&"serialize_map"))
274    }
275}
276
277macro_rules! impl_serialize_trait {
278    ( $name:ident, $function_name:ident ) => {
279        impl<'a, Container: SerOutput> ser::$name for &'a mut Serializer<Container> {
280            type Ok = ();
281            type Error = Error;
282
283            fn $function_name<T>(&mut self, value: &T) -> Result<()>
284            where
285                T: ?Sized + Serialize,
286            {
287                value.serialize(&mut **self)
288            }
289
290            fn end(self) -> Result<()> {
291                Ok(())
292            }
293        }
294    };
295}
296
297impl_serialize_trait!(SerializeSeq, serialize_element);
298impl_serialize_trait!(SerializeTuple, serialize_element);
299impl_serialize_trait!(SerializeTupleStruct, serialize_field);
300impl_serialize_trait!(SerializeTupleVariant, serialize_field);
301
302/// Unsupported
303impl<'a, Container: SerOutput> ser::SerializeMap for &'a mut Serializer<Container> {
304    type Ok = ();
305    type Error = Error;
306
307    /// Unsupported
308    fn serialize_key<T>(&mut self, _key: &T) -> Result<()>
309    where
310        T: ?Sized + Serialize,
311    {
312        Err(Error::Unsupported(&"serialize_map"))
313    }
314
315    /// Unsupported
316    fn serialize_value<T>(&mut self, _value: &T) -> Result<()>
317    where
318        T: ?Sized + Serialize,
319    {
320        Err(Error::Unsupported(&"serialize_map"))
321    }
322
323    /// Unsupported
324    fn end(self) -> Result<()> {
325        Err(Error::Unsupported(&"serialize_map"))
326    }
327}
328
329impl<'a, Container: SerOutput> ser::SerializeStruct for &'a mut Serializer<Container> {
330    type Ok = ();
331    type Error = Error;
332
333    fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<()>
334    where
335        T: ?Sized + Serialize,
336    {
337        value.serialize(&mut **self)
338    }
339
340    fn end(self) -> Result<()> {
341        Ok(())
342    }
343}
344impl<'a, Container: SerOutput> ser::SerializeStructVariant for &'a mut Serializer<Container> {
345    type Ok = ();
346    type Error = Error;
347
348    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
349    where
350        T: ?Sized + Serialize,
351    {
352        ser::SerializeStruct::serialize_field(self, key, value)
353    }
354
355    fn end(self) -> Result<()> {
356        ser::SerializeStruct::end(self)
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use crate::{to_bytes, Serializer};
363    use serde::{ser, Serialize};
364    use std::convert::TryInto;
365
366    #[test]
367    fn test_integer() {
368        assert_eq!(to_bytes(&0x12_u8).unwrap(), [0, 0, 0, 1, 0x12]);
369        assert_eq!(to_bytes(&0x1234_u16).unwrap(), [0, 0, 0, 2, 0x12, 0x34]);
370        assert_eq!(
371            to_bytes(&0x12345678_u32).unwrap(),
372            [0, 0, 0, 4, 0x12, 0x34, 0x56, 0x78]
373        );
374        assert_eq!(
375            to_bytes(&0x1234567887654321_u64).unwrap(),
376            [0, 0, 0, 8, 0x12, 0x34, 0x56, 0x78, 0x87, 0x65, 0x43, 0x21]
377        );
378    }
379
380    #[test]
381    fn test_boolean() {
382        assert_eq!(to_bytes(&true).unwrap(), [0, 0, 0, 4, 0, 0, 0, 1]);
383        assert_eq!(to_bytes(&false).unwrap(), [0, 0, 0, 4, 0, 0, 0, 0]);
384    }
385
386    #[test]
387    fn test_str() {
388        let s = "Hello, world!";
389        let serialized = to_bytes(&s).unwrap();
390        let len: u32 = (serialized.len() - 4).try_into().unwrap();
391        assert_eq!(&serialized[..4], len.to_be_bytes());
392        assert_eq!(&serialized[4..8], (s.len() as u32).to_be_bytes());
393        assert_eq!(&serialized[8..], s.as_bytes());
394    }
395
396    #[test]
397    fn test_str_with_null() {
398        let s = "\0Hello, world!";
399        let serialized = to_bytes(&s).unwrap();
400        let len: u32 = (serialized.len() - 4).try_into().unwrap();
401        assert_eq!(&serialized[..4], len.to_be_bytes());
402        assert_eq!(&serialized[4..8], ((s.len() - 1) as u32).to_be_bytes());
403
404        assert_eq!(&serialized[8..], &s.as_bytes()[1..]);
405    }
406
407    #[test]
408    fn test_array() {
409        let array = [0x00_u8, 0x01_u8, 0x10_u8, 0x78_u8];
410        let slice: &[_] = &array;
411
412        let serialized = to_bytes(&slice).unwrap();
413        assert_eq!(serialized[..4], [0, 0, 0, 8]);
414        assert_eq!(serialized[4..8], [0, 0, 0, 4]);
415        assert_eq!(serialized[8..], array);
416
417        let slice: &[_] = &[0x0010_u16, 0x0100_u16, 0x1034_u16, 0x7812_u16];
418
419        assert_eq!(
420            to_bytes(&slice).unwrap(),
421            &[0, 0, 0, 12, 0, 0, 0, 4, 0x00, 0x10, 0x01, 0x00, 0x10, 0x34, 0x78, 0x12_u8]
422        );
423    }
424
425    #[test]
426    fn test_tuple() {
427        assert_eq!(
428            to_bytes(&(0x00_u8, 0x0100_u16, 0x1034_u16, 0x7812_u16)).unwrap(),
429            &[0, 0, 0, 7, 0x00_u8, 0x01_u8, 0x00_u8, 0x10_u8, 0x34_u8, 0x78_u8, 0x12_u8]
430        );
431    }
432
433    #[test]
434    fn test_struct() {
435        #[derive(Serialize)]
436        struct S {
437            v1: u8,
438            v2: u16,
439            v3: u16,
440            v4: u16,
441        }
442        let v = S {
443            v1: 0x00,
444            v2: 0x0100,
445            v3: 0x1034,
446            v4: 0x7812,
447        };
448        assert_eq!(
449            to_bytes(&v).unwrap(),
450            &[0, 0, 0, 7, 0x00_u8, 0x01_u8, 0x00_u8, 0x10_u8, 0x34_u8, 0x78_u8, 0x12_u8]
451        );
452    }
453
454    #[test]
455    fn test_enum() {
456        use ser::Serializer as SerdeSerializerTrait;
457
458        let mut serializer: Serializer<Vec<u8>> = Serializer::default();
459
460        serializer.serialize_unit_variant("", 1, "").unwrap();
461        assert_eq!(serializer.create_header(0).unwrap(), [0, 0, 0, 4]);
462        assert_eq!(serializer.output, [0, 0, 0, 1]);
463
464        // Reset serializer
465        serializer.reset_counter();
466        serializer.output.clear();
467
468        serializer.serialize_newtype_variant("", 0, "", &3).unwrap();
469        assert_eq!(serializer.create_header(0).unwrap(), [0, 0, 0, 8]);
470        assert_eq!(serializer.output, [0, 0, 0, 0, 0, 0, 0, 3]);
471    }
472}