1use crate::buf::count::CountBuf;
9use crate::buf::crc::{CrcBuf, CrcBufMut};
10use crate::error::{Error, ErrorKind};
11use aws_smithy_types::config_bag::{Storable, StoreReplace};
12use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
13use aws_smithy_types::str_bytes::StrBytes;
14use aws_smithy_types::DateTime;
15use bytes::{Buf, BufMut};
16use std::convert::{TryFrom, TryInto};
17use std::error::Error as StdError;
18use std::fmt;
19use std::mem::size_of;
20use std::sync::{mpsc, Mutex};
21
22const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
23const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
24const MESSAGE_CRC_LENGTH_BYTES: u32 = size_of::<u32>() as u32;
25const MAX_HEADER_NAME_LEN: usize = 255;
26const MIN_HEADER_LEN: usize = 2;
27
28pub(crate) const TYPE_TRUE: u8 = 0;
29pub(crate) const TYPE_FALSE: u8 = 1;
30pub(crate) const TYPE_BYTE: u8 = 2;
31pub(crate) const TYPE_INT16: u8 = 3;
32pub(crate) const TYPE_INT32: u8 = 4;
33pub(crate) const TYPE_INT64: u8 = 5;
34pub(crate) const TYPE_BYTE_ARRAY: u8 = 6;
35pub(crate) const TYPE_STRING: u8 = 7;
36pub(crate) const TYPE_TIMESTAMP: u8 = 8;
37pub(crate) const TYPE_UUID: u8 = 9;
38
39pub type SignMessageError = Box<dyn StdError + Send + Sync + 'static>;
40
41pub trait SignMessage: fmt::Debug {
43 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError>;
44
45 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>>;
50}
51
52#[derive(Debug)]
54#[non_exhaustive]
55pub struct DeferredSignerSender(Mutex<mpsc::Sender<Box<dyn SignMessage + Send + Sync>>>);
56
57impl DeferredSignerSender {
58 fn new(tx: mpsc::Sender<Box<dyn SignMessage + Send + Sync>>) -> Self {
60 Self(Mutex::new(tx))
61 }
62
63 pub fn send(
65 &self,
66 signer: Box<dyn SignMessage + Send + Sync>,
67 ) -> Result<(), mpsc::SendError<Box<dyn SignMessage + Send + Sync>>> {
68 self.0.lock().unwrap().send(signer)
69 }
70}
71
72impl Storable for DeferredSignerSender {
73 type Storer = StoreReplace<Self>;
74}
75
76#[derive(Debug)]
92pub struct DeferredSigner {
93 rx: Option<Mutex<mpsc::Receiver<Box<dyn SignMessage + Send + Sync>>>>,
94 signer: Option<Box<dyn SignMessage + Send + Sync>>,
95}
96
97impl DeferredSigner {
98 pub fn new() -> (Self, DeferredSignerSender) {
99 let (tx, rx) = mpsc::channel();
100 (
101 Self {
102 rx: Some(Mutex::new(rx)),
103 signer: None,
104 },
105 DeferredSignerSender::new(tx),
106 )
107 }
108
109 fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
110 if self.signer.is_some() {
112 return self.signer.as_mut().unwrap().as_mut();
113 } else {
114 self.signer = Some(
115 self.rx
116 .take()
117 .expect("only taken once")
118 .lock()
119 .unwrap()
120 .try_recv()
121 .ok()
122 .unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
129 );
130 self.acquire()
131 }
132 }
133}
134
135impl SignMessage for DeferredSigner {
136 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
137 self.acquire().sign(message)
138 }
139
140 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
141 self.acquire().sign_empty()
142 }
143}
144
145#[derive(Debug)]
146pub struct NoOpSigner {}
147impl SignMessage for NoOpSigner {
148 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
149 Ok(message)
150 }
151
152 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
153 None
154 }
155}
156
157pub trait MarshallMessage: fmt::Debug {
159 type Input;
161
162 fn marshall(&self, input: Self::Input) -> Result<Message, Error>;
163}
164
165#[derive(Debug)]
167pub enum UnmarshalledMessage<T, E> {
168 Event(T),
169 Error(E),
170}
171
172pub trait UnmarshallMessage: fmt::Debug {
174 type Output;
176 type Error;
178
179 fn unmarshall(
180 &self,
181 message: &Message,
182 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, Error>;
183}
184
185macro_rules! read_value {
186 ($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => {
187 if $buf.remaining() >= size_of::<$size_typ>() {
188 Ok(HeaderValue::$typ($buf.$read_fn()))
189 } else {
190 Err(ErrorKind::InvalidHeaderValue.into())
191 }
192 };
193}
194
195fn read_header_value_from<B: Buf>(mut buffer: B) -> Result<HeaderValue, Error> {
196 let value_type = buffer.get_u8();
197 match value_type {
198 TYPE_TRUE => Ok(HeaderValue::Bool(true)),
199 TYPE_FALSE => Ok(HeaderValue::Bool(false)),
200 TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8),
201 TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16),
202 TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32),
203 TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64),
204 TYPE_BYTE_ARRAY | TYPE_STRING => {
205 if buffer.remaining() > size_of::<u16>() {
206 let len = buffer.get_u16() as usize;
207 if buffer.remaining() < len {
208 return Err(ErrorKind::InvalidHeaderValue.into());
209 }
210 let bytes = buffer.copy_to_bytes(len);
211 if value_type == TYPE_STRING {
212 Ok(HeaderValue::String(
213 bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?,
214 ))
215 } else {
216 Ok(HeaderValue::ByteArray(bytes))
217 }
218 } else {
219 Err(ErrorKind::InvalidHeaderValue.into())
220 }
221 }
222 TYPE_TIMESTAMP => {
223 if buffer.remaining() >= size_of::<i64>() {
224 let epoch_millis = buffer.get_i64();
225 Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis)))
226 } else {
227 Err(ErrorKind::InvalidHeaderValue.into())
228 }
229 }
230 TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128),
231 _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()),
232 }
233}
234
235fn write_header_value_to<B: BufMut>(value: &HeaderValue, mut buffer: B) -> Result<(), Error> {
236 use HeaderValue::*;
237 match value {
238 Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }),
239 Byte(val) => {
240 buffer.put_u8(TYPE_BYTE);
241 buffer.put_i8(*val);
242 }
243 Int16(val) => {
244 buffer.put_u8(TYPE_INT16);
245 buffer.put_i16(*val);
246 }
247 Int32(val) => {
248 buffer.put_u8(TYPE_INT32);
249 buffer.put_i32(*val);
250 }
251 Int64(val) => {
252 buffer.put_u8(TYPE_INT64);
253 buffer.put_i64(*val);
254 }
255 ByteArray(val) => {
256 buffer.put_u8(TYPE_BYTE_ARRAY);
257 buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?);
258 buffer.put_slice(&val[..]);
259 }
260 String(val) => {
261 buffer.put_u8(TYPE_STRING);
262 buffer.put_u16(checked(
263 val.as_bytes().len(),
264 ErrorKind::HeaderValueTooLong.into(),
265 )?);
266 buffer.put_slice(&val.as_bytes()[..]);
267 }
268 Timestamp(time) => {
269 buffer.put_u8(TYPE_TIMESTAMP);
270 buffer.put_i64(
271 time.to_millis()
272 .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?,
273 );
274 }
275 Uuid(val) => {
276 buffer.put_u8(TYPE_UUID);
277 buffer.put_u128(*val);
278 }
279 _ => {
280 panic!("matched on unexpected variant in `aws_smithy_types::event_stream::HeaderValue`")
281 }
282 }
283 Ok(())
284}
285
286fn read_header_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> {
288 if buffer.remaining() < MIN_HEADER_LEN {
289 return Err(ErrorKind::InvalidHeadersLength.into());
290 }
291
292 let mut counting_buf = CountBuf::new(&mut buffer);
293 let name_len = counting_buf.get_u8();
294 if name_len as usize >= counting_buf.remaining() {
295 return Err(ErrorKind::InvalidHeaderNameLength.into());
296 }
297
298 let name: StrBytes = counting_buf
299 .copy_to_bytes(name_len as usize)
300 .try_into()
301 .map_err(|_| ErrorKind::InvalidUtf8String)?;
302 let value = read_header_value_from(&mut counting_buf)?;
303 Ok((Header::new(name, value), counting_buf.into_count()))
304}
305
306fn write_header_to<B: BufMut>(header: &Header, mut buffer: B) -> Result<(), Error> {
308 if header.name().as_bytes().len() > MAX_HEADER_NAME_LEN {
309 return Err(ErrorKind::InvalidHeaderNameLength.into());
310 }
311
312 buffer.put_u8(u8::try_from(header.name().as_bytes().len()).expect("bounds check above"));
313 buffer.put_slice(&header.name().as_bytes()[..]);
314 write_header_value_to(header.value(), buffer)
315}
316
317pub fn write_headers_to<B: BufMut>(headers: &[Header], mut buffer: B) -> Result<(), Error> {
319 for header in headers {
320 write_header_to(header, &mut buffer)?;
321 }
322 Ok(())
323}
324
325fn read_prelude_from<B: Buf>(mut buffer: B) -> Result<(u32, u32), Error> {
327 let mut crc_buffer = CrcBuf::new(&mut buffer);
328
329 let total_len = crc_buffer.get_u32();
331 if crc_buffer.remaining() + size_of::<u32>() < total_len as usize {
332 return Err(ErrorKind::InvalidMessageLength.into());
333 }
334
335 let header_len = crc_buffer.get_u32();
337 let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32());
338 if expected_crc != prelude_crc {
339 return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into());
340 }
341 if header_len == 1 || header_len > max_header_len(total_len)? {
343 return Err(ErrorKind::InvalidHeadersLength.into());
344 }
345 Ok((total_len, header_len))
346}
347
348pub fn read_message_from<B: Buf>(mut buffer: B) -> Result<Message, Error> {
351 if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE {
352 return Err(ErrorKind::InvalidMessageLength.into());
353 }
354
355 let mut crc_buffer = CrcBuf::new(&mut buffer);
357 let (total_len, header_len) = read_prelude_from(&mut crc_buffer)?;
358
359 let remaining_len = total_len
361 .checked_sub(PRELUDE_LENGTH_BYTES)
362 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
363 if crc_buffer.remaining() < remaining_len as usize {
364 return Err(ErrorKind::InvalidMessageLength.into());
365 }
366
367 let mut header_bytes_read = 0;
369 let mut headers = Vec::new();
370 while header_bytes_read < header_len as usize {
371 let (header, bytes_read) = read_header_from(&mut crc_buffer)?;
372 header_bytes_read += bytes_read;
373 if header_bytes_read > header_len as usize {
374 return Err(ErrorKind::InvalidHeaderValue.into());
375 }
376 headers.push(header);
377 }
378
379 let payload_len = payload_len(total_len, header_len)?;
381 let payload = crc_buffer.copy_to_bytes(payload_len as usize);
382
383 let expected_crc = crc_buffer.into_crc();
384 let message_crc = buffer.get_u32();
385 if expected_crc != message_crc {
386 return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into());
387 }
388
389 Ok(Message::new_from_parts(headers, payload))
390}
391
392pub fn write_message_to(message: &Message, buffer: &mut dyn BufMut) -> Result<(), Error> {
394 let mut headers = Vec::new();
395 for header in message.headers() {
396 write_header_to(header, &mut headers)?;
397 }
398
399 let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?;
400 let payload_len = checked(message.payload().len(), ErrorKind::PayloadTooLong.into())?;
401 let message_len = [
402 PRELUDE_LENGTH_BYTES,
403 headers_len,
404 payload_len,
405 MESSAGE_CRC_LENGTH_BYTES,
406 ]
407 .iter()
408 .try_fold(0u32, |acc, v| {
409 acc.checked_add(*v)
410 .ok_or_else(|| Error::from(ErrorKind::MessageTooLong))
411 })?;
412
413 let mut crc_buffer = CrcBufMut::new(buffer);
414 crc_buffer.put_u32(message_len);
415 crc_buffer.put_u32(headers_len);
416 crc_buffer.put_crc();
417 crc_buffer.put(&headers[..]);
418 crc_buffer.put(&message.payload()[..]);
419 crc_buffer.put_crc();
420 Ok(())
421}
422
423fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> {
424 T::try_from(from).map_err(|_| err)
425}
426
427fn max_header_len(total_len: u32) -> Result<u32, Error> {
428 total_len
429 .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
430 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
431}
432
433fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
434 total_len
435 .checked_sub(
436 header_len
437 .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
438 .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?,
439 )
440 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
441}
442
443#[cfg(test)]
444mod message_tests {
445 use super::read_message_from;
446 use crate::error::ErrorKind;
447 use crate::frame::{write_message_to, Header, HeaderValue, Message};
448 use aws_smithy_types::DateTime;
449 use bytes::Bytes;
450
451 macro_rules! read_message_expect_err {
452 ($bytes:expr, $err:pat) => {
453 let result = read_message_from(&mut Bytes::from_static($bytes));
454 let result = result.as_ref();
455 assert!(result.is_err(), "Expected error, got {:?}", result);
456 assert!(
457 matches!(result.err().unwrap().kind(), $err),
458 "Expected {}, got {:?}",
459 stringify!($err),
460 result
461 );
462 };
463 }
464
465 #[test]
466 fn invalid_messages() {
467 read_message_expect_err!(
468 include_bytes!("../test_data/invalid_header_string_value_length"),
469 ErrorKind::InvalidHeaderValue
470 );
471 read_message_expect_err!(
472 include_bytes!("../test_data/invalid_header_string_length_cut_off"),
473 ErrorKind::InvalidHeaderValue
474 );
475 read_message_expect_err!(
476 include_bytes!("../test_data/invalid_header_value_type"),
477 ErrorKind::InvalidHeaderValueType(0x60)
478 );
479 read_message_expect_err!(
480 include_bytes!("../test_data/invalid_header_name_length"),
481 ErrorKind::InvalidHeaderNameLength
482 );
483 read_message_expect_err!(
484 include_bytes!("../test_data/invalid_headers_length"),
485 ErrorKind::InvalidHeadersLength
486 );
487 read_message_expect_err!(
488 include_bytes!("../test_data/invalid_prelude_checksum"),
489 ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
490 );
491 read_message_expect_err!(
492 include_bytes!("../test_data/invalid_message_checksum"),
493 ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
494 );
495 read_message_expect_err!(
496 include_bytes!("../test_data/invalid_header_name_length_too_long"),
497 ErrorKind::InvalidUtf8String
498 );
499 }
500
501 #[test]
502 fn read_message_no_headers() {
503 let data: &'static [u8] = &[
506 0x00, 0x00, 0x00, 0x1D, 0x00, 0x00, 0x00, 0x00, 0xfd, 0x52, 0x8c, 0x5a, 0x7b, 0x27,
507 0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27, 0x7d, 0xc3, 0x65, 0x39,
508 0x36,
509 ];
510
511 let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
512 assert_eq!(result.headers(), Vec::new());
513
514 let expected_payload = b"{'foo':'bar'}";
515 assert_eq!(expected_payload, result.payload().as_ref());
516 }
517
518 #[test]
519 fn read_message_one_header() {
520 let data: &'static [u8] = &[
523 0x00, 0x00, 0x00, 0x3D, 0x00, 0x00, 0x00, 0x20, 0x07, 0xFD, 0x83, 0x96, 0x0C, b'c',
524 b'o', b'n', b't', b'e', b'n', b't', b'-', b't', b'y', b'p', b'e', 0x07, 0x00, 0x10,
525 b'a', b'p', b'p', b'l', b'i', b'c', b'a', b't', b'i', b'o', b'n', b'/', b'j', b's',
526 b'o', b'n', 0x7b, 0x27, 0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27,
527 0x7d, 0x8D, 0x9C, 0x08, 0xB1,
528 ];
529
530 let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
531 assert_eq!(
532 result.headers(),
533 vec![Header::new(
534 "content-type",
535 HeaderValue::String("application/json".into())
536 )]
537 );
538
539 let expected_payload = b"{'foo':'bar'}";
540 assert_eq!(expected_payload, result.payload().as_ref());
541 }
542
543 #[test]
544 fn read_all_headers_and_payload() {
545 let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
546 let result = read_message_from(&mut Bytes::from_static(message)).unwrap();
547 assert_eq!(
548 result.headers(),
549 vec![
550 Header::new("true", HeaderValue::Bool(true)),
551 Header::new("false", HeaderValue::Bool(false)),
552 Header::new("byte", HeaderValue::Byte(50)),
553 Header::new("short", HeaderValue::Int16(20_000)),
554 Header::new("int", HeaderValue::Int32(500_000)),
555 Header::new("long", HeaderValue::Int64(50_000_000_000)),
556 Header::new(
557 "bytes",
558 HeaderValue::ByteArray(Bytes::from(&b"some bytes"[..]))
559 ),
560 Header::new("str", HeaderValue::String("some str".into())),
561 Header::new(
562 "time",
563 HeaderValue::Timestamp(DateTime::from_secs(5_000_000))
564 ),
565 Header::new(
566 "uuid",
567 HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b)
568 ),
569 ]
570 );
571
572 assert_eq!(b"some payload", result.payload().as_ref());
573 }
574
575 #[test]
576 fn round_trip_all_headers_payload() {
577 let message = Message::new(&b"some payload"[..])
578 .add_header(Header::new("true", HeaderValue::Bool(true)))
579 .add_header(Header::new("false", HeaderValue::Bool(false)))
580 .add_header(Header::new("byte", HeaderValue::Byte(50)))
581 .add_header(Header::new("short", HeaderValue::Int16(20_000)))
582 .add_header(Header::new("int", HeaderValue::Int32(500_000)))
583 .add_header(Header::new("long", HeaderValue::Int64(50_000_000_000)))
584 .add_header(Header::new(
585 "bytes",
586 HeaderValue::ByteArray((&b"some bytes"[..]).into()),
587 ))
588 .add_header(Header::new("str", HeaderValue::String("some str".into())))
589 .add_header(Header::new(
590 "time",
591 HeaderValue::Timestamp(DateTime::from_secs(5_000_000)),
592 ))
593 .add_header(Header::new(
594 "uuid",
595 HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b),
596 ));
597
598 let mut actual = Vec::new();
599 write_message_to(&message, &mut actual).unwrap();
600
601 let expected = include_bytes!("../test_data/valid_with_all_headers_and_payload").to_vec();
602 assert_eq!(expected, actual);
603
604 let result = read_message_from(&mut Bytes::from(actual)).unwrap();
605 assert_eq!(message.headers(), result.headers());
606 assert_eq!(message.payload().as_ref(), result.payload().as_ref());
607 }
608}
609
610#[derive(Debug)]
612pub enum DecodedFrame {
613 Incomplete,
615 Complete(Message),
617}
618
619#[non_exhaustive]
621#[derive(Default, Debug)]
622pub struct MessageFrameDecoder {
623 prelude: [u8; PRELUDE_LENGTH_BYTES_USIZE],
624 prelude_read: bool,
625}
626
627impl MessageFrameDecoder {
628 pub fn new() -> Self {
630 Default::default()
631 }
632
633 fn remaining_bytes_if_frame_available<B: Buf>(
638 &self,
639 buffer: &B,
640 ) -> Result<Option<usize>, Error> {
641 if self.prelude_read {
642 let remaining_len = (&self.prelude[..])
643 .get_u32()
644 .checked_sub(PRELUDE_LENGTH_BYTES)
645 .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
646 if buffer.remaining() >= remaining_len as usize {
647 return Ok(Some(remaining_len as usize));
648 }
649 }
650 Ok(None)
651 }
652
653 fn reset(&mut self) {
655 self.prelude_read = false;
656 self.prelude = [0u8; PRELUDE_LENGTH_BYTES_USIZE];
657 }
658
659 pub fn decode_frame<B: Buf>(&mut self, mut buffer: B) -> Result<DecodedFrame, Error> {
668 if !self.prelude_read && buffer.remaining() >= PRELUDE_LENGTH_BYTES_USIZE {
669 buffer.copy_to_slice(&mut self.prelude);
670 self.prelude_read = true;
671 }
672
673 if let Some(remaining_len) = self.remaining_bytes_if_frame_available(&buffer)? {
674 let mut message_buf = (&self.prelude[..]).chain(buffer.take(remaining_len));
675 let result = read_message_from(&mut message_buf).map(DecodedFrame::Complete);
676 self.reset();
677 return result;
678 }
679
680 Ok(DecodedFrame::Incomplete)
681 }
682}
683
684#[cfg(test)]
685mod message_frame_decoder_tests {
686 use super::{DecodedFrame, MessageFrameDecoder};
687 use crate::frame::read_message_from;
688 use bytes::Bytes;
689 use bytes_utils::SegmentedBuf;
690
691 #[test]
692 fn single_streaming_message() {
693 let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
694
695 let mut decoder = MessageFrameDecoder::new();
696 let mut segmented = SegmentedBuf::new();
697 for i in 0..(message.len() - 1) {
698 segmented.push(&message[i..(i + 1)]);
699 if let DecodedFrame::Complete(_) = decoder.decode_frame(&mut segmented).unwrap() {
700 panic!("incomplete frame shouldn't result in message");
701 }
702 }
703
704 segmented.push(&message[(message.len() - 1)..]);
705 match decoder.decode_frame(&mut segmented).unwrap() {
706 DecodedFrame::Incomplete => panic!("frame should be complete now"),
707 DecodedFrame::Complete(actual) => {
708 let expected = read_message_from(&mut Bytes::from_static(message)).unwrap();
709 assert_eq!(expected, actual);
710 }
711 }
712 }
713
714 fn multiple_streaming_messages_chunk_size(chunk_size: usize) {
715 let message1 = include_bytes!("../test_data/valid_with_all_headers_and_payload");
716 let message2 = include_bytes!("../test_data/valid_empty_payload");
717 let message3 = include_bytes!("../test_data/valid_no_headers");
718 let mut repeated = message1.to_vec();
719 repeated.extend_from_slice(message2);
720 repeated.extend_from_slice(message3);
721
722 let mut decoder = MessageFrameDecoder::new();
723 let mut segmented = SegmentedBuf::new();
724 let mut decoded = Vec::new();
725 for window in repeated.chunks(chunk_size) {
726 segmented.push(window);
727 match dbg!(decoder.decode_frame(&mut segmented)).unwrap() {
728 DecodedFrame::Incomplete => {}
729 DecodedFrame::Complete(message) => {
730 decoded.push(message);
731 }
732 }
733 }
734
735 let expected1 = read_message_from(&mut Bytes::from_static(message1)).unwrap();
736 let expected2 = read_message_from(&mut Bytes::from_static(message2)).unwrap();
737 let expected3 = read_message_from(&mut Bytes::from_static(message3)).unwrap();
738 assert_eq!(3, decoded.len());
739 assert_eq!(expected1, decoded[0]);
740 assert_eq!(expected2, decoded[1]);
741 assert_eq!(expected3, decoded[2]);
742 }
743
744 #[test]
745 fn multiple_streaming_messages() {
746 for chunk_size in 1..=11 {
747 println!("chunk size: {}", chunk_size);
748 multiple_streaming_messages_chunk_size(chunk_size);
749 }
750 }
751}
752
753#[cfg(test)]
754mod deferred_signer_tests {
755 use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
756 use bytes::Bytes;
757
758 fn check_send_sync<T: Send + Sync>(value: T) -> T {
759 value
760 }
761
762 #[test]
763 fn deferred_signer() {
764 #[derive(Default, Debug)]
765 struct TestSigner {
766 call_num: i32,
767 }
768 impl SignMessage for TestSigner {
769 fn sign(
770 &mut self,
771 message: Message,
772 ) -> Result<Message, crate::frame::SignMessageError> {
773 self.call_num += 1;
774 Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
775 }
776
777 fn sign_empty(&mut self) -> Option<Result<Message, crate::frame::SignMessageError>> {
778 None
779 }
780 }
781
782 let (mut signer, sender) = check_send_sync(DeferredSigner::new());
783
784 sender.send(Box::<TestSigner>::default()).expect("success");
785
786 let message = signer.sign(Message::new(Bytes::new())).expect("success");
787 assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
788
789 let message = signer.sign(Message::new(Bytes::new())).expect("success");
790 assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
791
792 assert!(signer.sign_empty().is_none());
793 }
794
795 #[test]
796 fn deferred_signer_defaults_to_noop_signer() {
797 let (mut signer, _sender) = DeferredSigner::new();
798 assert_eq!(
799 Message::new(Bytes::new()),
800 signer.sign(Message::new(Bytes::new())).unwrap()
801 );
802 assert!(signer.sign_empty().is_none());
803 }
804}