1use std::{borrow::Cow, convert::TryInto, iter, str};
2
3use serde::de::{self, DeserializeSeed, IntoDeserializer, SeqAccess, VariantAccess, Visitor};
4use serde::Deserialize;
5
6use crate::{Error, Result};
7
8#[derive(Copy, Clone, Debug)]
9pub struct Deserializer<'de, It> {
10 slice: &'de [u8],
11 iter: It,
12}
13
14impl<'de, It> Deserializer<'de, It> {
15 pub const fn new(iter: It) -> Self {
16 Self { iter, slice: &[] }
17 }
18
19 pub fn into_inner(self) -> (&'de [u8], It) {
20 (self.slice, self.iter)
21 }
22}
23
24impl<'de> Deserializer<'de, iter::Empty<&'de [u8]>> {
25 pub const fn from_bytes(slice: &'de [u8]) -> Self {
26 Self {
27 slice,
28 iter: iter::empty(),
29 }
30 }
31}
32
33pub fn from_bytes<'a, T>(s: &'a [u8]) -> Result<(T, &'a [u8])>
61where
62 T: Deserialize<'a>,
63{
64 let mut deserializer = Deserializer::from_bytes(s);
65 let t = T::deserialize(&mut deserializer)?;
66 Ok((t, deserializer.slice))
67}
68
69impl<'de, It> Deserializer<'de, It>
70where
71 It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
72{
73 fn update_slice_inner(&mut self) {
76 self.slice = self.iter.find(|slice| !slice.is_empty()).unwrap_or(&[]);
77 }
78
79 #[inline]
80 fn update_slice(&mut self) {
81 if self.slice.is_empty() {
82 self.update_slice_inner();
83 }
84 }
85
86 fn next_byte(&mut self) -> Result<u8> {
87 self.update_slice();
88
89 let byte = self.slice.first().copied().ok_or(Error::Eof)?;
90 self.slice = &self.slice[1..];
91
92 Ok(byte)
93 }
94
95 fn fill_buffer(&mut self, mut buffer: &mut [u8]) -> Result<()> {
96 loop {
97 if buffer.is_empty() {
98 break Ok(());
99 }
100
101 self.update_slice();
102
103 if self.slice.is_empty() {
104 break Err(Error::Eof);
105 }
106
107 let n = self.slice.len().min(buffer.len());
108
109 buffer[..n].copy_from_slice(&self.slice[..n]);
110
111 self.slice = &self.slice[n..];
112 buffer = &mut buffer[n..];
113 }
114 }
115
116 fn next_bytes_const<const SIZE: usize>(&mut self) -> Result<[u8; SIZE]> {
118 assert_ne!(SIZE, 0);
119
120 let mut bytes = [0_u8; SIZE];
121 self.fill_buffer(&mut bytes)?;
122
123 Ok(bytes)
124 }
125
126 fn next_u32(&mut self) -> Result<u32> {
127 Ok(u32::from_be_bytes(self.next_bytes_const()?))
128 }
129
130 fn next_bytes(&mut self, size: usize) -> Result<Cow<'de, [u8]>> {
131 self.update_slice();
132
133 if self.slice.len() >= size {
134 let slice = &self.slice[..size];
135 self.slice = &self.slice[size..];
136
137 Ok(Cow::Borrowed(slice))
138 } else {
139 let mut bytes = vec![0_u8; size];
140 self.fill_buffer(&mut bytes)?;
141 Ok(Cow::Owned(bytes))
142 }
143 }
144
145 fn parse_bytes(&mut self) -> Result<Cow<'de, [u8]>> {
147 let len: usize = self.next_u32()?.try_into().map_err(|_| Error::TooLong)?;
148 self.next_bytes(len)
149 }
150
151 pub fn has_remaining_data(&mut self) -> bool {
153 self.update_slice();
154 !self.slice.is_empty()
155 }
156}
157
158macro_rules! impl_for_deserialize_primitive {
159 ( $name:ident, $visitor_fname:ident, $type:ty ) => {
160 fn $name<V>(self, visitor: V) -> Result<V::Value>
161 where
162 V: Visitor<'de>,
163 {
164 visitor.$visitor_fname(<$type>::from_be_bytes(self.next_bytes_const()?))
165 }
166 };
167}
168
169impl<'de, 'a, It> de::Deserializer<'de> for &'a mut Deserializer<'de, It>
170where
171 It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
172{
173 type Error = Error;
174
175 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
176 where
177 V: Visitor<'de>,
178 {
179 match self.next_u32()? {
180 1 => visitor.visit_bool(true),
181 0 => visitor.visit_bool(false),
182 _ => Err(Error::InvalidBoolEncoding),
183 }
184 }
185
186 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
187 where
188 V: Visitor<'de>,
189 {
190 visitor.visit_u8(self.next_byte()?)
191 }
192
193 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
194 where
195 V: Visitor<'de>,
196 {
197 visitor.visit_i8(self.next_byte()? as i8)
198 }
199
200 impl_for_deserialize_primitive!(deserialize_i16, visit_i16, i16);
201 impl_for_deserialize_primitive!(deserialize_i32, visit_i32, i32);
202 impl_for_deserialize_primitive!(deserialize_i64, visit_i64, i64);
203
204 impl_for_deserialize_primitive!(deserialize_u16, visit_u16, u16);
205 impl_for_deserialize_primitive!(deserialize_u32, visit_u32, u32);
206 impl_for_deserialize_primitive!(deserialize_u64, visit_u64, u64);
207
208 impl_for_deserialize_primitive!(deserialize_f32, visit_f32, f32);
209 impl_for_deserialize_primitive!(deserialize_f64, visit_f64, f64);
210
211 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
212 where
213 V: Visitor<'de>,
214 {
215 match char::from_u32(self.next_u32()?) {
216 Some(ch) => visitor.visit_char(ch),
217 None => Err(Error::InvalidChar),
218 }
219 }
220
221 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
222 where
223 V: Visitor<'de>,
224 {
225 match self.parse_bytes()? {
226 Cow::Owned(owned_bytes) => visitor.visit_string(String::from_utf8(owned_bytes)?),
227 Cow::Borrowed(bytes) => visitor.visit_borrowed_str(str::from_utf8(bytes)?),
228 }
229 }
230
231 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
232 where
233 V: Visitor<'de>,
234 {
235 self.deserialize_str(visitor)
236 }
237
238 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
239 where
240 V: Visitor<'de>,
241 {
242 match self.parse_bytes()? {
243 Cow::Owned(owned_bytes) => visitor.visit_byte_buf(owned_bytes),
244 Cow::Borrowed(bytes) => visitor.visit_borrowed_bytes(bytes),
245 }
246 }
247
248 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
249 where
250 V: Visitor<'de>,
251 {
252 self.deserialize_bytes(visitor)
253 }
254
255 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
256 where
257 V: Visitor<'de>,
258 {
259 visitor.visit_unit()
260 }
261
262 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
263 where
264 V: Visitor<'de>,
265 {
266 self.deserialize_unit(visitor)
267 }
268
269 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
270 where
271 V: Visitor<'de>,
272 {
273 visitor.visit_newtype_struct(self)
274 }
275
276 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
277 where
278 V: Visitor<'de>,
279 {
280 visitor.visit_seq(Access {
281 deserializer: self,
282 len,
283 })
284 }
285
286 fn deserialize_tuple_struct<V>(
287 self,
288 _name: &'static str,
289 len: usize,
290 visitor: V,
291 ) -> Result<V::Value>
292 where
293 V: Visitor<'de>,
294 {
295 self.deserialize_tuple(len, visitor)
296 }
297
298 fn deserialize_struct<V>(
299 self,
300 _name: &'static str,
301 fields: &'static [&'static str],
302 visitor: V,
303 ) -> Result<V::Value>
304 where
305 V: Visitor<'de>,
306 {
307 self.deserialize_tuple(fields.len(), visitor)
308 }
309
310 fn deserialize_enum<V>(
311 self,
312 _name: &'static str,
313 _variants: &'static [&'static str],
314 visitor: V,
315 ) -> Result<V::Value>
316 where
317 V: Visitor<'de>,
318 {
319 impl<'a, 'de, It> serde::de::EnumAccess<'de> for &'a mut Deserializer<'de, It>
320 where
321 It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
322 {
323 type Error = Error;
324 type Variant = Self;
325
326 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
327 where
328 V: de::DeserializeSeed<'de>,
329 {
330 let idx: u32 = self.next_u32()?;
331 let val: Result<_> = seed.deserialize(idx.into_deserializer());
332 Ok((val?, self))
333 }
334 }
335
336 visitor.visit_enum(self)
337 }
338
339 #[cfg(feature = "is_human_readable")]
340 fn is_human_readable(&self) -> bool {
342 false
343 }
344
345 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
347 where
348 V: Visitor<'de>,
349 {
350 let len = self.next_u32()? as usize;
351 visitor.visit_seq(Access {
352 deserializer: self,
353 len,
354 })
355 }
356
357 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
359 where
360 V: Visitor<'de>,
361 {
362 Err(Error::Unsupported(&"deserialize_any"))
363 }
364
365 fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value>
367 where
368 V: Visitor<'de>,
369 {
370 Err(Error::Unsupported(&"deserialize_option"))
371 }
372
373 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
375 where
376 V: Visitor<'de>,
377 {
378 Err(Error::Unsupported(&"deserialize_map"))
379 }
380
381 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
383 where
384 V: Visitor<'de>,
385 {
386 Err(Error::Unsupported(&"deserialize_identifier"))
387 }
388
389 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
391 where
392 V: Visitor<'de>,
393 {
394 Err(Error::Unsupported(&"deserialize_ignored_any"))
395 }
396}
397
398impl<'a, 'de, It> VariantAccess<'de> for &'a mut Deserializer<'de, It>
399where
400 It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
401{
402 type Error = Error;
403
404 fn unit_variant(self) -> Result<()> {
405 Ok(())
406 }
407
408 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
409 where
410 T: DeserializeSeed<'de>,
411 {
412 DeserializeSeed::deserialize(seed, self)
413 }
414
415 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
416 where
417 V: Visitor<'de>,
418 {
419 de::Deserializer::deserialize_tuple(self, len, visitor)
420 }
421
422 fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
423 where
424 V: Visitor<'de>,
425 {
426 de::Deserializer::deserialize_tuple(self, fields.len(), visitor)
427 }
428}
429
430struct Access<'a, 'de, It> {
431 deserializer: &'a mut Deserializer<'de, It>,
432 len: usize,
433}
434
435impl<'a, 'de, It> SeqAccess<'de> for Access<'a, 'de, It>
436where
437 It: iter::FusedIterator + Iterator<Item = &'de [u8]>,
438{
439 type Error = Error;
440
441 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
442 where
443 T: DeserializeSeed<'de>,
444 {
445 if self.len > 0 {
446 self.len -= 1;
447 let value = seed.deserialize(&mut *self.deserializer)?;
448 Ok(Some(value))
449 } else {
450 Ok(None)
451 }
452 }
453
454 fn size_hint(&self) -> Option<usize> {
455 Some(self.len)
456 }
457}
458
459#[cfg(test)]
461mod tests {
462 use std::fmt::Debug;
463
464 use assert_matches::assert_matches;
465 use generator::{done, Gn};
466 use itertools::Itertools;
467 use serde::{Deserialize, Serialize};
468
469 use super::*;
470 use crate::to_bytes;
471
472 fn generate_subslices(mut bytes: &[u8], chunk_size: usize) -> impl Iterator<Item = &[u8]> {
475 assert_ne!(chunk_size, 0);
476
477 Gn::new_scoped(move |mut s| loop {
478 for _ in 0..8 {
479 s.yield_(&bytes[..0]);
481 }
482
483 let n = bytes.len().min(chunk_size);
484 s.yield_(&bytes[..n]);
485 bytes = &bytes[n..];
486
487 if bytes.is_empty() {
488 done!();
489 }
490 })
491 }
492
493 fn test_roundtrip<'de, T>(value: &T)
495 where
496 T: Debug + Eq + Serialize + Deserialize<'de>,
497 {
498 let serialized = to_bytes(value).unwrap().leak();
499 let serialized = &serialized[4..];
501
502 assert_eq!(from_bytes::<T>(serialized).unwrap().0, *value);
504
505 for chunk_size in 1..serialized.len() {
507 let mut deserializer =
508 Deserializer::new(generate_subslices(serialized, chunk_size).fuse());
509 let val = T::deserialize(&mut deserializer).unwrap();
510 assert_eq!(val, *value);
511
512 let (slice, mut iter) = deserializer.into_inner();
513
514 assert_eq!(slice, &[]);
515 assert_eq!(iter.next(), None);
516 }
517 }
518
519 #[test]
520 fn test_integer() {
521 test_roundtrip(&0x12_u8);
522 test_roundtrip(&0x1234_u16);
523 test_roundtrip(&0x12345678_u32);
524 test_roundtrip(&0x1234567887654321_u64);
525 }
526
527 #[test]
528 fn test_boolean() {
529 test_roundtrip(&true);
530 test_roundtrip(&false);
531 }
532
533 #[test]
534 fn test_str() {
535 let s = "Hello, world!";
536 let serialized = to_bytes(&s).unwrap();
537 let deserialized: &str = from_bytes(&serialized[4..]).unwrap().0;
539 assert_eq!(deserialized, s);
540 }
541
542 #[test]
543 fn test_seq() {
544 test_roundtrip(&vec![0x00_u8, 0x01_u8, 0x10_u8, 0x78_u8]);
545 test_roundtrip(&vec![0x0010_u16, 0x0100_u16, 0x1034_u16, 0x7812_u16]);
546 }
547
548 #[test]
549 fn test_tuple() {
550 test_roundtrip(&(0x00_u8, 0x0100_u16, 0x1034_u16, 0x7812_u16));
551 }
552
553 #[test]
554 fn test_struct() {
555 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
556 struct S {
557 v1: u8,
558 v2: u16,
559 v3: u16,
560 v4: u16,
561 }
562 test_roundtrip(&S {
563 v1: 0x00,
564 v2: 0x0100,
565 v3: 0x1034,
566 v4: 0x7812,
567 });
568 }
569
570 #[test]
571 fn test_struct2() {
572 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
573 struct S<'a> {
574 v1: u8,
575 v2: u16,
576 v3: u16,
577 v4: u16,
578 #[serde(borrow)]
579 v5: Cow<'a, str>,
580 }
581 test_roundtrip(&S {
582 v1: 0x00,
583 v2: 0x0100,
584 v3: 0x1034,
585 v4: 0x7812,
586 v5: Cow::Owned((0..100).join(", ")),
587 });
588 }
589
590 #[test]
592 fn test_eof_error() {
593 assert_matches!(from_bytes::<u8>(&[]), Err(Error::Eof));
594
595 let s = "Hello, world!";
596 let serialized = to_bytes(&s).unwrap();
597 assert_matches!(
598 from_bytes::<String>(&serialized[0..serialized.len() - 1]),
599 Err(Error::Eof)
600 );
601 }
602}