1pub use flate2::Compression;
12
13use byteorder::{ByteOrder, LittleEndian};
14use bytes::{Buf, BufMut, BytesMut};
15use flate2::read::{ZlibDecoder, ZlibEncoder};
16
17use std::{
18 cmp::{max, min},
19 io::Read,
20 mem,
21 num::NonZeroUsize,
22 ptr::slice_from_raw_parts_mut,
23};
24
25use self::error::PacketCodecError;
26use crate::constants::{DEFAULT_MAX_ALLOWED_PACKET, MAX_PAYLOAD_LEN, MIN_COMPRESS_LENGTH};
27
28pub mod error;
29
30pub fn packet_to_chunks<T: Buf>(mut seq_id: u8, packet: &mut T, dst: &mut BytesMut) -> u8 {
36 let extra_packet = packet.remaining() % MAX_PAYLOAD_LEN == 0;
37 dst.reserve(packet.remaining() + (packet.remaining() / MAX_PAYLOAD_LEN) * 4 + 4);
38
39 while packet.has_remaining() {
40 let mut chunk_len = min(packet.remaining(), MAX_PAYLOAD_LEN);
41 dst.put_u32_le(chunk_len as u32 | (u32::from(seq_id) << 24));
42 while chunk_len > 0 {
43 let chunk = packet.chunk();
44 let count = min(chunk.len(), chunk_len);
45 dst.put(&chunk[..count]);
46 chunk_len -= count;
47 packet.advance(count);
48 }
49 seq_id = seq_id.wrapping_add(1);
50 }
51
52 if extra_packet {
53 dst.put_u32_le(u32::from(seq_id) << 24);
54 seq_id = seq_id.wrapping_add(1);
55 }
56
57 seq_id
58}
59
60pub fn compress(
64 mut seq_id: u8,
65 compression: Compression,
66 max_allowed_packet: usize,
67 src: &mut BytesMut,
68 dst: &mut BytesMut,
69) -> Result<u8, PacketCodecError> {
70 if src.is_empty() {
71 return Ok(0);
72 }
73
74 for chunk in src.chunks(min(MAX_PAYLOAD_LEN, max_allowed_packet)) {
75 dst.reserve(7 + chunk.len());
76
77 if compression != Compression::none() && chunk.len() >= MIN_COMPRESS_LENGTH {
78 unsafe {
79 let mut encoder = ZlibEncoder::new(chunk, compression);
80 let mut read = 0;
81 loop {
82 dst.reserve(max(chunk.len().saturating_sub(read), 1));
83 let dst_buf = &mut dst.chunk_mut()[7 + read..];
84 match encoder.read(&mut *slice_from_raw_parts_mut(
85 dst_buf.as_mut_ptr(),
86 dst_buf.len(),
87 ))? {
88 0 => break,
89 count => read += count,
90 }
91 }
92
93 dst.put_uint_le(read as u64, 3);
94 dst.put_u8(seq_id);
95 dst.put_uint_le(chunk.len() as u64, 3);
96 dst.advance_mut(read);
97 }
98 } else {
99 dst.put_uint_le(chunk.len() as u64, 3);
100 dst.put_u8(seq_id);
101 dst.put_uint_le(0, 3);
102 dst.put_slice(chunk);
103 }
104
105 seq_id = seq_id.wrapping_add(1);
106 }
107
108 src.clear();
109
110 Ok(seq_id)
111}
112
113#[derive(Debug, Copy, Clone, Eq, PartialEq)]
115pub enum ChunkInfo {
116 Middle(u8),
120 Last(u8),
124}
125
126impl ChunkInfo {
127 fn seq_id(self) -> u8 {
128 match self {
129 ChunkInfo::Middle(x) | ChunkInfo::Last(x) => x,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
136pub enum ChunkDecoder {
137 #[default]
141 Idle,
142 Chunk {
144 seq_id: u8,
146 needed: NonZeroUsize,
148 },
149}
150
151impl ChunkDecoder {
152 pub fn decode<T>(
160 &mut self,
161 src: &mut BytesMut,
162 dst: &mut T,
163 max_allowed_packet: usize,
164 ) -> Result<Option<ChunkInfo>, PacketCodecError>
165 where
166 T: AsRef<[u8]>,
167 T: BufMut,
168 {
169 match *self {
170 ChunkDecoder::Idle => {
171 if src.len() < 4 {
172 Ok(None)
174 } else {
175 let raw_chunk_len = LittleEndian::read_u24(&*src) as usize;
176 let seq_id = src[3];
177
178 match NonZeroUsize::new(raw_chunk_len) {
179 Some(chunk_len) => {
180 if dst.as_ref().len() + chunk_len.get() > max_allowed_packet {
181 return Err(PacketCodecError::PacketTooLarge);
182 }
183
184 *self = ChunkDecoder::Chunk {
185 seq_id,
186 needed: chunk_len,
187 };
188
189 if src.len() > 4 {
190 self.decode(src, dst, max_allowed_packet)
191 } else {
192 Ok(None)
193 }
194 }
195 None => {
196 src.advance(4);
197 Ok(Some(ChunkInfo::Last(seq_id)))
198 }
199 }
200 }
201 }
202 ChunkDecoder::Chunk { seq_id, needed } => {
203 if src.len() >= 4 + needed.get() {
204 src.advance(4);
205
206 dst.put_slice(&src[..needed.get()]);
207 src.advance(needed.get());
208
209 *self = ChunkDecoder::Idle;
210
211 if dst.as_ref().len() % MAX_PAYLOAD_LEN == 0 {
212 Ok(Some(ChunkInfo::Middle(seq_id)))
213 } else {
214 Ok(Some(ChunkInfo::Last(seq_id)))
215 }
216 } else {
217 Ok(None)
218 }
219 }
220 }
221 }
222}
223
224#[derive(Debug, Clone, Copy, Eq, PartialEq)]
226pub enum CompData {
227 Compressed(NonZeroUsize, NonZeroUsize),
229 Uncompressed(NonZeroUsize),
231}
232
233impl CompData {
234 fn new(
236 compressed_len: usize,
237 uncompressed_len: usize,
238 max_allowed_packet: usize,
239 ) -> Result<Option<Self>, PacketCodecError> {
240 if max(compressed_len, uncompressed_len) > max_allowed_packet {
242 return Err(PacketCodecError::PacketTooLarge);
243 }
244
245 let compressed_len = NonZeroUsize::new(compressed_len);
246 let uncompressed_len = NonZeroUsize::new(uncompressed_len);
247
248 match (compressed_len, uncompressed_len) {
249 (Some(needed), Some(plain_len)) => Ok(Some(CompData::Compressed(needed, plain_len))),
250 (Some(needed), None) => Ok(Some(CompData::Uncompressed(needed))),
251 (None, Some(_)) => {
252 Err(PacketCodecError::BadCompressedPacketHeader)
255 }
256 (None, None) => Ok(None),
257 }
258 }
259
260 fn needed(&self) -> usize {
262 match *self {
263 CompData::Compressed(needed, _) | CompData::Uncompressed(needed) => needed.get(),
264 }
265 }
266}
267
268#[derive(Debug, Clone, Copy, Eq, PartialEq)]
270pub enum CompDecoder {
271 Idle,
273 Packet {
275 seq_id: u8,
277 needed: CompData,
279 },
280}
281
282impl CompDecoder {
283 pub fn decode(
287 &mut self,
288 src: &mut BytesMut,
289 dst: &mut BytesMut,
290 max_allowed_packet: usize,
291 ) -> Result<Option<ChunkInfo>, PacketCodecError> {
292 match *self {
293 CompDecoder::Idle => {
294 if src.len() < 7 {
295 Ok(None)
297 } else {
298 let compressed_len = LittleEndian::read_u24(&*src) as usize;
299 let seq_id = src[3];
300 let uncompressed_len = LittleEndian::read_u24(&src[4..]) as usize;
301
302 match CompData::new(compressed_len, uncompressed_len, max_allowed_packet)? {
303 Some(needed) => {
304 *self = CompDecoder::Packet { seq_id, needed };
305 self.decode(src, dst, max_allowed_packet)
306 }
307 None => {
308 src.advance(7);
309 Ok(Some(ChunkInfo::Last(seq_id)))
310 }
311 }
312 }
313 }
314 CompDecoder::Packet { seq_id, needed } => {
315 if src.len() >= 7 + needed.needed() {
316 src.advance(7);
317 match needed {
318 CompData::Uncompressed(needed) => {
319 dst.extend_from_slice(&src[..needed.get()]);
320 }
321 CompData::Compressed(needed, plain_len) => {
322 dst.reserve(plain_len.get());
323 unsafe {
324 let mut decoder = ZlibDecoder::new(&src[..needed.get()]);
325 let dst_buf = &mut dst.chunk_mut()[..plain_len.get()];
326 decoder.read_exact(&mut *slice_from_raw_parts_mut(
327 dst_buf.as_mut_ptr(),
328 dst_buf.len(),
329 ))?;
330 dst.advance_mut(plain_len.get());
331 }
332 }
333 }
334 src.advance(needed.needed());
335 *self = CompDecoder::Idle;
336 Ok(Some(ChunkInfo::Last(seq_id)))
337 } else {
338 Ok(None)
339 }
340 }
341 }
342 }
343}
344
345#[derive(Debug)]
349pub struct PacketCodec {
350 pub max_allowed_packet: usize,
352 inner: PacketCodecInner,
354}
355
356impl PacketCodec {
357 pub fn reset_seq_id(&mut self) {
359 self.inner.reset_seq_id();
360 }
361
362 pub fn sync_seq_id(&mut self) {
364 self.inner.sync_seq_id();
365 }
366
367 pub fn compress(&mut self, level: Compression) {
369 self.inner.compress(level);
370 }
371
372 pub fn decode<T>(&mut self, src: &mut BytesMut, dst: &mut T) -> Result<bool, PacketCodecError>
379 where
380 T: AsRef<[u8]>,
381 T: BufMut,
382 {
383 self.inner.decode(src, dst, self.max_allowed_packet)
384 }
385
386 pub fn encode<T: Buf>(
388 &mut self,
389 src: &mut T,
390 dst: &mut BytesMut,
391 ) -> Result<(), PacketCodecError> {
392 self.inner.encode(src, dst, self.max_allowed_packet)
393 }
394}
395
396impl Default for PacketCodec {
397 fn default() -> Self {
398 Self {
399 max_allowed_packet: DEFAULT_MAX_ALLOWED_PACKET,
400 inner: Default::default(),
401 }
402 }
403}
404
405#[derive(Debug)]
407enum PacketCodecInner {
408 Plain(PlainPacketCodec),
410 Comp(CompPacketCodec),
412}
413
414impl PacketCodecInner {
415 fn reset_seq_id(&mut self) {
417 match self {
418 PacketCodecInner::Plain(c) => c.reset_seq_id(),
419 PacketCodecInner::Comp(c) => c.reset_seq_id(),
420 }
421 }
422
423 fn sync_seq_id(&mut self) {
425 match self {
426 PacketCodecInner::Plain(_) => (),
427 PacketCodecInner::Comp(c) => c.sync_seq_id(),
428 }
429 }
430
431 fn compress(&mut self, level: Compression) {
433 match self {
434 PacketCodecInner::Plain(c) => {
435 *self = PacketCodecInner::Comp(CompPacketCodec {
436 level,
437 comp_seq_id: 0,
438 in_buf: BytesMut::with_capacity(DEFAULT_MAX_ALLOWED_PACKET),
439 out_buf: BytesMut::with_capacity(DEFAULT_MAX_ALLOWED_PACKET),
440 comp_decoder: CompDecoder::Idle,
441 plain_codec: mem::take(c),
442 })
443 }
444 PacketCodecInner::Comp(c) => c.level = level,
445 }
446 }
447
448 fn decode<T>(
452 &mut self,
453 src: &mut BytesMut,
454 dst: &mut T,
455 max_allowed_packet: usize,
456 ) -> Result<bool, PacketCodecError>
457 where
458 T: AsRef<[u8]>,
459 T: BufMut,
460 {
461 match self {
462 PacketCodecInner::Plain(codec) => codec.decode(src, dst, max_allowed_packet, None),
463 PacketCodecInner::Comp(codec) => codec.decode(src, dst, max_allowed_packet),
464 }
465 }
466
467 fn encode<T: Buf>(
469 &mut self,
470 packet: &mut T,
471 dst: &mut BytesMut,
472 max_allowed_packet: usize,
473 ) -> Result<(), PacketCodecError> {
474 match self {
475 PacketCodecInner::Plain(codec) => codec.encode(packet, dst, max_allowed_packet),
476 PacketCodecInner::Comp(codec) => codec.encode(packet, dst, max_allowed_packet),
477 }
478 }
479}
480
481impl Default for PacketCodecInner {
482 fn default() -> Self {
483 PacketCodecInner::Plain(Default::default())
484 }
485}
486
487#[derive(Debug, Clone, Eq, PartialEq, Default)]
489struct PlainPacketCodec {
490 pub seq_id: u8,
492 chunk_decoder: ChunkDecoder,
494}
495
496impl PlainPacketCodec {
497 fn reset_seq_id(&mut self) {
499 self.seq_id = 0;
500 }
501
502 fn decode<T>(
508 &mut self,
509 src: &mut BytesMut,
510 dst: &mut T,
511 max_allowed_packet: usize,
512 comp_seq_id: Option<u8>,
513 ) -> Result<bool, PacketCodecError>
514 where
515 T: AsRef<[u8]>,
516 T: BufMut,
517 {
518 match self.chunk_decoder.decode(src, dst, max_allowed_packet)? {
519 Some(chunk_info) => {
520 if self.seq_id != chunk_info.seq_id() {
521 match comp_seq_id {
522 Some(seq_id) if seq_id == chunk_info.seq_id() => {
523 self.seq_id = seq_id;
525 }
526 _ => {
527 return Err(PacketCodecError::PacketsOutOfSync);
528 }
529 }
530 }
531
532 self.seq_id = self.seq_id.wrapping_add(1);
533
534 match chunk_info {
535 ChunkInfo::Middle(_) => {
536 if !src.is_empty() {
537 self.decode(src, dst, max_allowed_packet, comp_seq_id)
538 } else {
539 Ok(false)
540 }
541 }
542 ChunkInfo::Last(_) => Ok(true),
543 }
544 }
545 None => Ok(false),
546 }
547 }
548
549 fn encode<T: Buf>(
551 &mut self,
552 packet: &mut T,
553 dst: &mut BytesMut,
554 max_allowed_packet: usize,
555 ) -> Result<(), PacketCodecError> {
556 if packet.remaining() > max_allowed_packet {
557 return Err(PacketCodecError::PacketTooLarge);
558 }
559
560 self.seq_id = packet_to_chunks(self.seq_id, packet, dst);
561
562 Ok(())
563 }
564}
565
566#[derive(Debug)]
568struct CompPacketCodec {
569 level: Compression,
571 comp_seq_id: u8,
573 in_buf: BytesMut,
575 out_buf: BytesMut,
577 comp_decoder: CompDecoder,
579 plain_codec: PlainPacketCodec,
581}
582
583impl CompPacketCodec {
584 fn reset_seq_id(&mut self) {
586 self.comp_seq_id = 0;
587 self.plain_codec.reset_seq_id();
588 }
589
590 fn sync_seq_id(&mut self) {
593 if self.in_buf.is_empty() {
594 self.plain_codec.seq_id = self.comp_seq_id;
595 }
596 }
597
598 fn decode<T>(
602 &mut self,
603 src: &mut BytesMut,
604 dst: &mut T,
605 max_allowed_packet: usize,
606 ) -> Result<bool, PacketCodecError>
607 where
608 T: AsRef<[u8]>,
609 T: BufMut,
610 {
611 if !self.in_buf.is_empty()
612 && self.plain_codec.decode(
613 &mut self.in_buf,
614 dst,
615 max_allowed_packet,
616 Some(self.comp_seq_id.wrapping_sub(1)),
619 )?
620 {
621 return Ok(true);
622 }
623
624 match self
625 .comp_decoder
626 .decode(src, &mut self.in_buf, max_allowed_packet)?
627 {
628 Some(chunk_info) => {
629 if self.comp_seq_id != chunk_info.seq_id() {
630 return Err(PacketCodecError::PacketsOutOfSync);
631 }
632
633 self.comp_seq_id = self.comp_seq_id.wrapping_add(1);
634
635 self.decode(src, dst, max_allowed_packet)
636 }
637 None => Ok(false),
638 }
639 }
640
641 fn encode<T: Buf>(
643 &mut self,
644 packet: &mut T,
645 dst: &mut BytesMut,
646 max_allowed_packet: usize,
647 ) -> Result<(), PacketCodecError> {
648 self.plain_codec
649 .encode(packet, &mut self.out_buf, max_allowed_packet)?;
650
651 self.comp_seq_id = compress(
652 self.comp_seq_id,
653 self.level,
654 max_allowed_packet,
655 &mut self.out_buf,
656 dst,
657 )?;
658
659 self.plain_codec.seq_id = self.comp_seq_id;
661
662 Ok(())
663 }
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 const COMPRESSED: &[u8] = &[
671 0x22, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x78, 0x9c, 0xd3, 0x63, 0x60, 0x60, 0x60, 0x2e,
672 0x4e, 0xcd, 0x49, 0x4d, 0x2e, 0x51, 0x50, 0x32, 0x30, 0x34, 0x32, 0x36, 0x31, 0x35, 0x33,
673 0xb7, 0xb0, 0xc4, 0xcd, 0x52, 0x02, 0x00, 0x0c, 0xd1, 0x0a, 0x6c,
674 ];
675
676 const PLAIN: [u8; 46] = [
677 0x03, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x22, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35,
678 0x36, 0x37, 0x38, 0x39, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30,
679 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35,
680 0x22,
681 ];
682
683 #[test]
684 fn zero_len_packet() -> Result<(), error::PacketCodecError> {
685 let mut encoder = PacketCodec::default();
686 let mut empty: &[u8] = &[];
687 let mut src = BytesMut::new();
688 encoder.encode(&mut empty, &mut src)?;
689
690 let mut dst = vec![];
691 let mut decoder = PacketCodec::default();
692 let result = decoder.decode(&mut src, &mut dst)?;
693 assert!(result);
694 assert_eq!(dst, vec![0_u8; 0]);
695
696 Ok(())
697 }
698
699 #[test]
700 fn regular_packet() -> Result<(), error::PacketCodecError> {
701 let mut encoder = PacketCodec::default();
702 let mut src = BytesMut::new();
703 encoder.encode(&mut &[0x31_u8, 0x32, 0x33][..], &mut src)?;
704
705 let mut dst = vec![];
706 let mut decoder = PacketCodec::default();
707 let result = decoder.decode(&mut src, &mut dst)?;
708 assert!(result);
709 assert_eq!(dst, vec![0x31, 0x32, 0x33]);
710
711 Ok(())
712 }
713
714 #[test]
715 fn packet_sequence() -> Result<(), error::PacketCodecError> {
716 let mut encoder = PacketCodec::default();
717 let mut decoder = PacketCodec::default();
718 let mut src = BytesMut::new();
719
720 for i in 0..1024_usize {
721 encoder.encode(&mut &*vec![0; i], &mut src)?;
722 let mut dst = vec![];
723 let result = decoder.decode(&mut src, &mut dst)?;
724 assert!(result);
725 assert_eq!(dst, vec![0; i]);
726 }
727
728 Ok(())
729 }
730
731 #[test]
732 fn large_packets() -> Result<(), error::PacketCodecError> {
733 let lengths = vec![MAX_PAYLOAD_LEN, MAX_PAYLOAD_LEN + 1, MAX_PAYLOAD_LEN * 2];
734 let mut encoder = PacketCodec::default();
735 let mut decoder = PacketCodec::default();
736 let mut src = BytesMut::new();
737
738 decoder.max_allowed_packet = *lengths.iter().max().unwrap();
739 encoder.max_allowed_packet = *lengths.iter().max().unwrap();
740
741 for &len in &lengths {
742 encoder.encode(&mut &*vec![0x42_u8; len], &mut src)?;
743 }
744
745 for &len in &lengths {
746 let mut dst = vec![];
747 let result = decoder.decode(&mut src, &mut dst)?;
748 assert!(result);
749 assert_eq!(dst, vec![0x42; len]);
750 }
751
752 Ok(())
753 }
754
755 #[test]
756 fn compressed_roundtrip() {
757 let mut encoder = PacketCodec::default();
758 let mut decoder = PacketCodec::default();
759 let mut src = BytesMut::from(COMPRESSED);
760
761 encoder.compress(Compression::best());
762 decoder.compress(Compression::best());
763
764 let mut dst = vec![];
765 let result = decoder.decode(&mut src, &mut dst).unwrap();
766 assert!(result);
767 assert_eq!(&*dst, PLAIN);
768 encoder.encode(&mut &*dst, &mut src).unwrap();
769
770 let mut dst = vec![];
771 decoder.reset_seq_id();
772 let result = decoder.decode(&mut src, &mut dst).unwrap();
773 assert!(result);
774 assert_eq!(&*dst, PLAIN);
775 }
776
777 #[test]
778 fn compression_none() {
779 let mut encoder = PacketCodec::default();
780 let mut decoder = PacketCodec::default();
781 let mut src = BytesMut::new();
782
783 encoder.compress(Compression::none());
784 decoder.compress(Compression::none());
785
786 encoder.encode(&mut (&PLAIN[..]), &mut src).unwrap();
787 let mut dst = vec![];
788 let result = decoder.decode(&mut src, &mut dst).unwrap();
789 assert!(result);
790 assert_eq!(&*dst, PLAIN);
791 }
792
793 #[test]
794 #[should_panic(expected = "PacketsOutOfSync")]
795 fn out_of_sync() {
796 let mut src = BytesMut::from(&b"\x00\x00\x00\x01"[..]);
797 let mut codec = PacketCodec::default();
798 let mut dst = vec![];
799 codec.decode(&mut src, &mut dst).unwrap();
800 }
801
802 #[test]
803 #[should_panic(expected = "PacketTooLarge")]
804 fn packet_too_large() {
805 let mut encoder = PacketCodec::default();
806 let mut decoder = PacketCodec::default();
807 let mut src = BytesMut::new();
808
809 encoder
810 .encode(&mut &*vec![0; encoder.max_allowed_packet + 1], &mut src)
811 .unwrap();
812 let mut dst = vec![];
813 decoder.decode(&mut src, &mut dst).unwrap();
814 }
815}