1use serde::{ser, Serialize};
2use std::convert::TryInto;
3
4use crate::{Error, Result, SerOutput};
5
6fn usize_to_u32(v: usize) -> Result<u32> {
7 v.try_into().map_err(|_| Error::TooLong)
8}
9
10#[derive(Clone, Debug)]
11pub struct Serializer<T: SerOutput = Vec<u8>> {
12 pub output: T,
13 len: usize,
14}
15
16impl<T: SerOutput + Default> Default for Serializer<T> {
17 fn default() -> Self {
18 Self::new(Default::default())
19 }
20}
21
22impl<T: SerOutput> Serializer<T> {
23 pub fn new(output: T) -> Self {
24 Self { output, len: 0 }
25 }
26
27 pub fn reserve(&mut self, additional: usize) {
28 self.output.reserve(additional);
29 }
30
31 pub fn create_header(&self, len: u32) -> Result<[u8; 4]> {
33 let len: u32 = usize_to_u32(self.len + len as usize)?;
34
35 Ok(len.to_be_bytes())
36 }
37
38 pub fn reset_counter(&mut self) {
42 self.len = 0;
43 }
44
45 fn extend_from_slice(&mut self, other: &[u8]) {
46 self.output.extend_from_slice(other);
47 self.len += other.len();
48 }
49
50 fn push(&mut self, byte: u8) {
51 self.output.push(byte);
52 self.len += 1;
53 }
54
55 fn serialize_usize(&mut self, v: usize) -> Result<()> {
56 ser::Serializer::serialize_u32(self, usize_to_u32(v)?)
57 }
58}
59
60pub fn to_bytes<T>(value: &T) -> Result<Vec<u8>>
65where
66 T: Serialize,
67{
68 let mut buffer = vec![0, 0, 0, 0];
69
70 let mut serializer = Serializer::new(&mut buffer);
71 value.serialize(&mut serializer)?;
72 let header = serializer.create_header(0)?;
73
74 buffer[..4].copy_from_slice(&header);
75
76 Ok(buffer)
77}
78
79macro_rules! impl_for_serialize_primitive {
80 ( $name:ident, $type:ty ) => {
81 fn $name(self, v: $type) -> Result<()> {
82 self.extend_from_slice(&v.to_be_bytes());
83 Ok(())
84 }
85 };
86}
87
88impl<'a, Container: SerOutput> ser::Serializer for &'a mut Serializer<Container> {
89 type Ok = ();
90 type Error = Error;
91
92 type SerializeSeq = Self;
93 type SerializeTuple = Self;
94 type SerializeTupleStruct = Self;
95 type SerializeTupleVariant = Self;
96 type SerializeMap = Self;
97 type SerializeStruct = Self;
98 type SerializeStructVariant = Self;
99
100 fn serialize_bool(self, v: bool) -> Result<()> {
101 self.serialize_u32(v as u32)
102 }
103
104 fn serialize_u8(self, v: u8) -> Result<()> {
105 self.push(v);
106 Ok(())
107 }
108
109 fn serialize_i8(self, v: i8) -> Result<()> {
110 self.push(v as u8);
111 Ok(())
112 }
113
114 impl_for_serialize_primitive!(serialize_i16, i16);
115 impl_for_serialize_primitive!(serialize_i32, i32);
116 impl_for_serialize_primitive!(serialize_i64, i64);
117
118 impl_for_serialize_primitive!(serialize_u16, u16);
119 impl_for_serialize_primitive!(serialize_u32, u32);
120 impl_for_serialize_primitive!(serialize_u64, u64);
121
122 impl_for_serialize_primitive!(serialize_f32, f32);
123 impl_for_serialize_primitive!(serialize_f64, f64);
124
125 fn serialize_char(self, v: char) -> Result<()> {
126 self.serialize_u32(v as u32)
127 }
128
129 fn serialize_str(self, v: &str) -> Result<()> {
130 fn is_null_byte(byte: &u8) -> bool {
131 *byte == b'\0'
132 }
133
134 let bytes = v.as_bytes();
135
136 let null_byte_counts = bytes.iter().copied().filter(is_null_byte).count();
137
138 let len = bytes.len() - null_byte_counts;
139
140 self.reserve(4 + len);
142
143 self.serialize_usize(len)?;
144
145 if null_byte_counts == 0 {
146 self.extend_from_slice(v.as_bytes());
147 } else {
148 bytes
149 .split(is_null_byte)
150 .filter(|slice| !slice.is_empty())
151 .for_each(|slice| {
152 self.extend_from_slice(slice);
153 });
154 }
155
156 Ok(())
157 }
158
159 fn serialize_bytes(self, v: &[u8]) -> Result<()> {
160 self.reserve(4 + v.len());
161
162 self.serialize_usize(v.len())?;
163
164 self.extend_from_slice(v);
165
166 Ok(())
167 }
168
169 fn serialize_none(self) -> Result<()> {
170 Ok(())
171 }
172
173 fn serialize_some<T>(self, value: &T) -> Result<()>
174 where
175 T: ?Sized + Serialize,
176 {
177 value.serialize(self)
178 }
179
180 fn serialize_unit(self) -> Result<()> {
181 Ok(())
182 }
183
184 fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
185 self.serialize_unit()
186 }
187
188 fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
189 where
190 T: ?Sized + Serialize,
191 {
192 value.serialize(self)
193 }
194
195 fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
196 if let Some(len) = len {
197 self.reserve(4 + len as usize);
198
199 self.serialize_usize(len)?;
200 }
201 Ok(self)
202 }
203
204 fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
205 Ok(self)
206 }
207
208 fn serialize_tuple_struct(
209 self,
210 _name: &'static str,
211 len: usize,
212 ) -> Result<Self::SerializeTupleStruct> {
213 self.serialize_tuple(len)
214 }
215
216 fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
217 self.serialize_tuple(len)
218 }
219
220 fn serialize_unit_variant(
221 self,
222 _name: &'static str,
223 variant_index: u32,
224 _variant: &'static str,
225 ) -> Result<()> {
226 self.serialize_u32(variant_index)
227 }
228
229 fn serialize_newtype_variant<T>(
230 self,
231 name: &'static str,
232 variant_index: u32,
233 variant: &'static str,
234 value: &T,
235 ) -> Result<()>
236 where
237 T: ?Sized + Serialize,
238 {
239 self.serialize_unit_variant(name, variant_index, variant)?;
240 value.serialize(self)
241 }
242
243 fn serialize_tuple_variant(
244 self,
245 name: &'static str,
246 variant_index: u32,
247 variant: &'static str,
248 len: usize,
249 ) -> Result<Self::SerializeTupleVariant> {
250 self.serialize_unit_variant(name, variant_index, variant)?;
251 self.serialize_tuple(len)
252 }
253
254 fn serialize_struct_variant(
255 self,
256 name: &'static str,
257 variant_index: u32,
258 variant: &'static str,
259 len: usize,
260 ) -> Result<Self::SerializeStructVariant> {
261 self.serialize_unit_variant(name, variant_index, variant)?;
262 self.serialize_tuple(len)
263 }
264
265 #[cfg(feature = "is_human_readable")]
266 fn is_human_readable(&self) -> bool {
268 false
269 }
270
271 fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
273 Err(Error::Unsupported(&"serialize_map"))
274 }
275}
276
277macro_rules! impl_serialize_trait {
278 ( $name:ident, $function_name:ident ) => {
279 impl<'a, Container: SerOutput> ser::$name for &'a mut Serializer<Container> {
280 type Ok = ();
281 type Error = Error;
282
283 fn $function_name<T>(&mut self, value: &T) -> Result<()>
284 where
285 T: ?Sized + Serialize,
286 {
287 value.serialize(&mut **self)
288 }
289
290 fn end(self) -> Result<()> {
291 Ok(())
292 }
293 }
294 };
295}
296
297impl_serialize_trait!(SerializeSeq, serialize_element);
298impl_serialize_trait!(SerializeTuple, serialize_element);
299impl_serialize_trait!(SerializeTupleStruct, serialize_field);
300impl_serialize_trait!(SerializeTupleVariant, serialize_field);
301
302impl<'a, Container: SerOutput> ser::SerializeMap for &'a mut Serializer<Container> {
304 type Ok = ();
305 type Error = Error;
306
307 fn serialize_key<T>(&mut self, _key: &T) -> Result<()>
309 where
310 T: ?Sized + Serialize,
311 {
312 Err(Error::Unsupported(&"serialize_map"))
313 }
314
315 fn serialize_value<T>(&mut self, _value: &T) -> Result<()>
317 where
318 T: ?Sized + Serialize,
319 {
320 Err(Error::Unsupported(&"serialize_map"))
321 }
322
323 fn end(self) -> Result<()> {
325 Err(Error::Unsupported(&"serialize_map"))
326 }
327}
328
329impl<'a, Container: SerOutput> ser::SerializeStruct for &'a mut Serializer<Container> {
330 type Ok = ();
331 type Error = Error;
332
333 fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<()>
334 where
335 T: ?Sized + Serialize,
336 {
337 value.serialize(&mut **self)
338 }
339
340 fn end(self) -> Result<()> {
341 Ok(())
342 }
343}
344impl<'a, Container: SerOutput> ser::SerializeStructVariant for &'a mut Serializer<Container> {
345 type Ok = ();
346 type Error = Error;
347
348 fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
349 where
350 T: ?Sized + Serialize,
351 {
352 ser::SerializeStruct::serialize_field(self, key, value)
353 }
354
355 fn end(self) -> Result<()> {
356 ser::SerializeStruct::end(self)
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use crate::{to_bytes, Serializer};
363 use serde::{ser, Serialize};
364 use std::convert::TryInto;
365
366 #[test]
367 fn test_integer() {
368 assert_eq!(to_bytes(&0x12_u8).unwrap(), [0, 0, 0, 1, 0x12]);
369 assert_eq!(to_bytes(&0x1234_u16).unwrap(), [0, 0, 0, 2, 0x12, 0x34]);
370 assert_eq!(
371 to_bytes(&0x12345678_u32).unwrap(),
372 [0, 0, 0, 4, 0x12, 0x34, 0x56, 0x78]
373 );
374 assert_eq!(
375 to_bytes(&0x1234567887654321_u64).unwrap(),
376 [0, 0, 0, 8, 0x12, 0x34, 0x56, 0x78, 0x87, 0x65, 0x43, 0x21]
377 );
378 }
379
380 #[test]
381 fn test_boolean() {
382 assert_eq!(to_bytes(&true).unwrap(), [0, 0, 0, 4, 0, 0, 0, 1]);
383 assert_eq!(to_bytes(&false).unwrap(), [0, 0, 0, 4, 0, 0, 0, 0]);
384 }
385
386 #[test]
387 fn test_str() {
388 let s = "Hello, world!";
389 let serialized = to_bytes(&s).unwrap();
390 let len: u32 = (serialized.len() - 4).try_into().unwrap();
391 assert_eq!(&serialized[..4], len.to_be_bytes());
392 assert_eq!(&serialized[4..8], (s.len() as u32).to_be_bytes());
393 assert_eq!(&serialized[8..], s.as_bytes());
394 }
395
396 #[test]
397 fn test_str_with_null() {
398 let s = "\0Hello, world!";
399 let serialized = to_bytes(&s).unwrap();
400 let len: u32 = (serialized.len() - 4).try_into().unwrap();
401 assert_eq!(&serialized[..4], len.to_be_bytes());
402 assert_eq!(&serialized[4..8], ((s.len() - 1) as u32).to_be_bytes());
403
404 assert_eq!(&serialized[8..], &s.as_bytes()[1..]);
405 }
406
407 #[test]
408 fn test_array() {
409 let array = [0x00_u8, 0x01_u8, 0x10_u8, 0x78_u8];
410 let slice: &[_] = &array;
411
412 let serialized = to_bytes(&slice).unwrap();
413 assert_eq!(serialized[..4], [0, 0, 0, 8]);
414 assert_eq!(serialized[4..8], [0, 0, 0, 4]);
415 assert_eq!(serialized[8..], array);
416
417 let slice: &[_] = &[0x0010_u16, 0x0100_u16, 0x1034_u16, 0x7812_u16];
418
419 assert_eq!(
420 to_bytes(&slice).unwrap(),
421 &[0, 0, 0, 12, 0, 0, 0, 4, 0x00, 0x10, 0x01, 0x00, 0x10, 0x34, 0x78, 0x12_u8]
422 );
423 }
424
425 #[test]
426 fn test_tuple() {
427 assert_eq!(
428 to_bytes(&(0x00_u8, 0x0100_u16, 0x1034_u16, 0x7812_u16)).unwrap(),
429 &[0, 0, 0, 7, 0x00_u8, 0x01_u8, 0x00_u8, 0x10_u8, 0x34_u8, 0x78_u8, 0x12_u8]
430 );
431 }
432
433 #[test]
434 fn test_struct() {
435 #[derive(Serialize)]
436 struct S {
437 v1: u8,
438 v2: u16,
439 v3: u16,
440 v4: u16,
441 }
442 let v = S {
443 v1: 0x00,
444 v2: 0x0100,
445 v3: 0x1034,
446 v4: 0x7812,
447 };
448 assert_eq!(
449 to_bytes(&v).unwrap(),
450 &[0, 0, 0, 7, 0x00_u8, 0x01_u8, 0x00_u8, 0x10_u8, 0x34_u8, 0x78_u8, 0x12_u8]
451 );
452 }
453
454 #[test]
455 fn test_enum() {
456 use ser::Serializer as SerdeSerializerTrait;
457
458 let mut serializer: Serializer<Vec<u8>> = Serializer::default();
459
460 serializer.serialize_unit_variant("", 1, "").unwrap();
461 assert_eq!(serializer.create_header(0).unwrap(), [0, 0, 0, 4]);
462 assert_eq!(serializer.output, [0, 0, 0, 1]);
463
464 serializer.reset_counter();
466 serializer.output.clear();
467
468 serializer.serialize_newtype_variant("", 0, "", &3).unwrap();
469 assert_eq!(serializer.create_header(0).unwrap(), [0, 0, 0, 8]);
470 assert_eq!(serializer.output, [0, 0, 0, 0, 0, 0, 0, 3]);
471 }
472}