aws_smithy_eventstream/
frame.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Event Stream message frame types and serialization/deserialization logic.
7
8use crate::buf::count::CountBuf;
9use crate::buf::crc::{CrcBuf, CrcBufMut};
10use crate::error::{Error, ErrorKind};
11use aws_smithy_types::config_bag::{Storable, StoreReplace};
12use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
13use aws_smithy_types::str_bytes::StrBytes;
14use aws_smithy_types::DateTime;
15use bytes::{Buf, BufMut};
16use std::convert::{TryFrom, TryInto};
17use std::error::Error as StdError;
18use std::fmt;
19use std::mem::size_of;
20use std::sync::{mpsc, Mutex};
21
22const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
23const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
24const MESSAGE_CRC_LENGTH_BYTES: u32 = size_of::<u32>() as u32;
25const MAX_HEADER_NAME_LEN: usize = 255;
26const MIN_HEADER_LEN: usize = 2;
27
28pub(crate) const TYPE_TRUE: u8 = 0;
29pub(crate) const TYPE_FALSE: u8 = 1;
30pub(crate) const TYPE_BYTE: u8 = 2;
31pub(crate) const TYPE_INT16: u8 = 3;
32pub(crate) const TYPE_INT32: u8 = 4;
33pub(crate) const TYPE_INT64: u8 = 5;
34pub(crate) const TYPE_BYTE_ARRAY: u8 = 6;
35pub(crate) const TYPE_STRING: u8 = 7;
36pub(crate) const TYPE_TIMESTAMP: u8 = 8;
37pub(crate) const TYPE_UUID: u8 = 9;
38
39pub type SignMessageError = Box<dyn StdError + Send + Sync + 'static>;
40
41/// Signs an Event Stream message.
42pub trait SignMessage: fmt::Debug {
43    fn sign(&mut self, message: Message) -> Result<Message, SignMessageError>;
44
45    /// SigV4 requires an empty last signed message to be sent.
46    /// Other protocols do not require one.
47    /// Return `Some(_)` to send a signed last empty message, before completing the stream.
48    /// Return `None` to not send one and terminate the stream immediately.
49    fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>>;
50}
51
52/// A sender that gets placed in the request config to wire up an event stream signer after signing.
53#[derive(Debug)]
54#[non_exhaustive]
55pub struct DeferredSignerSender(Mutex<mpsc::Sender<Box<dyn SignMessage + Send + Sync>>>);
56
57impl DeferredSignerSender {
58    /// Creates a new `DeferredSignerSender`
59    fn new(tx: mpsc::Sender<Box<dyn SignMessage + Send + Sync>>) -> Self {
60        Self(Mutex::new(tx))
61    }
62
63    /// Sends a signer on the channel
64    pub fn send(
65        &self,
66        signer: Box<dyn SignMessage + Send + Sync>,
67    ) -> Result<(), mpsc::SendError<Box<dyn SignMessage + Send + Sync>>> {
68        self.0.lock().unwrap().send(signer)
69    }
70}
71
72impl Storable for DeferredSignerSender {
73    type Storer = StoreReplace<Self>;
74}
75
76/// Deferred event stream signer to allow a signer to be wired up later.
77///
78/// HTTP request signing takes place after serialization, and the event stream
79/// message stream body is established during serialization. Since event stream
80/// signing may need context from the initial HTTP signing operation, this
81/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle.
82///
83/// This signer basically just establishes a MPSC channel so that the sender can
84/// be placed in the request's config. Then the HTTP signer implementation can
85/// retrieve the sender from that config and send an actual signing implementation
86/// with all the context needed.
87///
88/// When an event stream implementation needs to sign a message, the first call to
89/// sign will acquire a signing implementation off of the channel and cache it
90/// for the remainder of the operation.
91#[derive(Debug)]
92pub struct DeferredSigner {
93    rx: Option<Mutex<mpsc::Receiver<Box<dyn SignMessage + Send + Sync>>>>,
94    signer: Option<Box<dyn SignMessage + Send + Sync>>,
95}
96
97impl DeferredSigner {
98    pub fn new() -> (Self, DeferredSignerSender) {
99        let (tx, rx) = mpsc::channel();
100        (
101            Self {
102                rx: Some(Mutex::new(rx)),
103                signer: None,
104            },
105            DeferredSignerSender::new(tx),
106        )
107    }
108
109    fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
110        // Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough
111        if self.signer.is_some() {
112            return self.signer.as_mut().unwrap().as_mut();
113        } else {
114            self.signer = Some(
115                self.rx
116                    .take()
117                    .expect("only taken once")
118                    .lock()
119                    .unwrap()
120                    .try_recv()
121                    .ok()
122                    // TODO(enableNewSmithyRuntimeCleanup): When the middleware implementation is removed,
123                    // this should panic rather than default to the `NoOpSigner`. The reason it defaults
124                    // is because middleware-based generic clients don't have any default middleware,
125                    // so there is no way to send a `NoOpSigner` by default when there is no other
126                    // auth scheme. The orchestrator auth setup is a lot more robust and will make
127                    // this problem trivial.
128                    .unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
129            );
130            self.acquire()
131        }
132    }
133}
134
135impl SignMessage for DeferredSigner {
136    fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
137        self.acquire().sign(message)
138    }
139
140    fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
141        self.acquire().sign_empty()
142    }
143}
144
145#[derive(Debug)]
146pub struct NoOpSigner {}
147impl SignMessage for NoOpSigner {
148    fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
149        Ok(message)
150    }
151
152    fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
153        None
154    }
155}
156
157/// Converts a Smithy modeled Event Stream type into a [`Message`].
158pub trait MarshallMessage: fmt::Debug {
159    /// Smithy modeled input type to convert from.
160    type Input;
161
162    fn marshall(&self, input: Self::Input) -> Result<Message, Error>;
163}
164
165/// A successfully unmarshalled message that is either an `Event` or an `Error`.
166#[derive(Debug)]
167pub enum UnmarshalledMessage<T, E> {
168    Event(T),
169    Error(E),
170}
171
172/// Converts an Event Stream [`Message`] into a Smithy modeled type.
173pub trait UnmarshallMessage: fmt::Debug {
174    /// Smithy modeled type to convert into.
175    type Output;
176    /// Smithy modeled error to convert into.
177    type Error;
178
179    fn unmarshall(
180        &self,
181        message: &Message,
182    ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, Error>;
183}
184
185macro_rules! read_value {
186    ($buf:ident, $typ:ident, $size_typ:ident, $read_fn:ident) => {
187        if $buf.remaining() >= size_of::<$size_typ>() {
188            Ok(HeaderValue::$typ($buf.$read_fn()))
189        } else {
190            Err(ErrorKind::InvalidHeaderValue.into())
191        }
192    };
193}
194
195fn read_header_value_from<B: Buf>(mut buffer: B) -> Result<HeaderValue, Error> {
196    let value_type = buffer.get_u8();
197    match value_type {
198        TYPE_TRUE => Ok(HeaderValue::Bool(true)),
199        TYPE_FALSE => Ok(HeaderValue::Bool(false)),
200        TYPE_BYTE => read_value!(buffer, Byte, i8, get_i8),
201        TYPE_INT16 => read_value!(buffer, Int16, i16, get_i16),
202        TYPE_INT32 => read_value!(buffer, Int32, i32, get_i32),
203        TYPE_INT64 => read_value!(buffer, Int64, i64, get_i64),
204        TYPE_BYTE_ARRAY | TYPE_STRING => {
205            if buffer.remaining() > size_of::<u16>() {
206                let len = buffer.get_u16() as usize;
207                if buffer.remaining() < len {
208                    return Err(ErrorKind::InvalidHeaderValue.into());
209                }
210                let bytes = buffer.copy_to_bytes(len);
211                if value_type == TYPE_STRING {
212                    Ok(HeaderValue::String(
213                        bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?,
214                    ))
215                } else {
216                    Ok(HeaderValue::ByteArray(bytes))
217                }
218            } else {
219                Err(ErrorKind::InvalidHeaderValue.into())
220            }
221        }
222        TYPE_TIMESTAMP => {
223            if buffer.remaining() >= size_of::<i64>() {
224                let epoch_millis = buffer.get_i64();
225                Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis)))
226            } else {
227                Err(ErrorKind::InvalidHeaderValue.into())
228            }
229        }
230        TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128),
231        _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()),
232    }
233}
234
235fn write_header_value_to<B: BufMut>(value: &HeaderValue, mut buffer: B) -> Result<(), Error> {
236    use HeaderValue::*;
237    match value {
238        Bool(val) => buffer.put_u8(if *val { TYPE_TRUE } else { TYPE_FALSE }),
239        Byte(val) => {
240            buffer.put_u8(TYPE_BYTE);
241            buffer.put_i8(*val);
242        }
243        Int16(val) => {
244            buffer.put_u8(TYPE_INT16);
245            buffer.put_i16(*val);
246        }
247        Int32(val) => {
248            buffer.put_u8(TYPE_INT32);
249            buffer.put_i32(*val);
250        }
251        Int64(val) => {
252            buffer.put_u8(TYPE_INT64);
253            buffer.put_i64(*val);
254        }
255        ByteArray(val) => {
256            buffer.put_u8(TYPE_BYTE_ARRAY);
257            buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?);
258            buffer.put_slice(&val[..]);
259        }
260        String(val) => {
261            buffer.put_u8(TYPE_STRING);
262            buffer.put_u16(checked(
263                val.as_bytes().len(),
264                ErrorKind::HeaderValueTooLong.into(),
265            )?);
266            buffer.put_slice(&val.as_bytes()[..]);
267        }
268        Timestamp(time) => {
269            buffer.put_u8(TYPE_TIMESTAMP);
270            buffer.put_i64(
271                time.to_millis()
272                    .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?,
273            );
274        }
275        Uuid(val) => {
276            buffer.put_u8(TYPE_UUID);
277            buffer.put_u128(*val);
278        }
279        _ => {
280            panic!("matched on unexpected variant in `aws_smithy_types::event_stream::HeaderValue`")
281        }
282    }
283    Ok(())
284}
285
286/// Reads a header from the given `buffer`.
287fn read_header_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> {
288    if buffer.remaining() < MIN_HEADER_LEN {
289        return Err(ErrorKind::InvalidHeadersLength.into());
290    }
291
292    let mut counting_buf = CountBuf::new(&mut buffer);
293    let name_len = counting_buf.get_u8();
294    if name_len as usize >= counting_buf.remaining() {
295        return Err(ErrorKind::InvalidHeaderNameLength.into());
296    }
297
298    let name: StrBytes = counting_buf
299        .copy_to_bytes(name_len as usize)
300        .try_into()
301        .map_err(|_| ErrorKind::InvalidUtf8String)?;
302    let value = read_header_value_from(&mut counting_buf)?;
303    Ok((Header::new(name, value), counting_buf.into_count()))
304}
305
306/// Writes the header to the given `buffer`.
307fn write_header_to<B: BufMut>(header: &Header, mut buffer: B) -> Result<(), Error> {
308    if header.name().as_bytes().len() > MAX_HEADER_NAME_LEN {
309        return Err(ErrorKind::InvalidHeaderNameLength.into());
310    }
311
312    buffer.put_u8(u8::try_from(header.name().as_bytes().len()).expect("bounds check above"));
313    buffer.put_slice(&header.name().as_bytes()[..]);
314    write_header_value_to(header.value(), buffer)
315}
316
317/// Writes the given `headers` to a `buffer`.
318pub fn write_headers_to<B: BufMut>(headers: &[Header], mut buffer: B) -> Result<(), Error> {
319    for header in headers {
320        write_header_to(header, &mut buffer)?;
321    }
322    Ok(())
323}
324
325// Returns (total_len, header_len)
326fn read_prelude_from<B: Buf>(mut buffer: B) -> Result<(u32, u32), Error> {
327    let mut crc_buffer = CrcBuf::new(&mut buffer);
328
329    // If the buffer doesn't have the entire, then error
330    let total_len = crc_buffer.get_u32();
331    if crc_buffer.remaining() + size_of::<u32>() < total_len as usize {
332        return Err(ErrorKind::InvalidMessageLength.into());
333    }
334
335    // Validate the prelude
336    let header_len = crc_buffer.get_u32();
337    let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32());
338    if expected_crc != prelude_crc {
339        return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into());
340    }
341    // The header length can be 0 or >= 2, but must fit within the frame size
342    if header_len == 1 || header_len > max_header_len(total_len)? {
343        return Err(ErrorKind::InvalidHeadersLength.into());
344    }
345    Ok((total_len, header_len))
346}
347
348/// Reads a message from the given `buffer`. For streaming use cases, use
349/// the [`MessageFrameDecoder`] instead of this.
350pub fn read_message_from<B: Buf>(mut buffer: B) -> Result<Message, Error> {
351    if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE {
352        return Err(ErrorKind::InvalidMessageLength.into());
353    }
354
355    // Calculate a CRC as we go and read the prelude
356    let mut crc_buffer = CrcBuf::new(&mut buffer);
357    let (total_len, header_len) = read_prelude_from(&mut crc_buffer)?;
358
359    // Verify we have the full frame before continuing
360    let remaining_len = total_len
361        .checked_sub(PRELUDE_LENGTH_BYTES)
362        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
363    if crc_buffer.remaining() < remaining_len as usize {
364        return Err(ErrorKind::InvalidMessageLength.into());
365    }
366
367    // Read headers
368    let mut header_bytes_read = 0;
369    let mut headers = Vec::new();
370    while header_bytes_read < header_len as usize {
371        let (header, bytes_read) = read_header_from(&mut crc_buffer)?;
372        header_bytes_read += bytes_read;
373        if header_bytes_read > header_len as usize {
374            return Err(ErrorKind::InvalidHeaderValue.into());
375        }
376        headers.push(header);
377    }
378
379    // Read payload
380    let payload_len = payload_len(total_len, header_len)?;
381    let payload = crc_buffer.copy_to_bytes(payload_len as usize);
382
383    let expected_crc = crc_buffer.into_crc();
384    let message_crc = buffer.get_u32();
385    if expected_crc != message_crc {
386        return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into());
387    }
388
389    Ok(Message::new_from_parts(headers, payload))
390}
391
392/// Writes the `message` to the given `buffer`.
393pub fn write_message_to(message: &Message, buffer: &mut dyn BufMut) -> Result<(), Error> {
394    let mut headers = Vec::new();
395    for header in message.headers() {
396        write_header_to(header, &mut headers)?;
397    }
398
399    let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?;
400    let payload_len = checked(message.payload().len(), ErrorKind::PayloadTooLong.into())?;
401    let message_len = [
402        PRELUDE_LENGTH_BYTES,
403        headers_len,
404        payload_len,
405        MESSAGE_CRC_LENGTH_BYTES,
406    ]
407    .iter()
408    .try_fold(0u32, |acc, v| {
409        acc.checked_add(*v)
410            .ok_or_else(|| Error::from(ErrorKind::MessageTooLong))
411    })?;
412
413    let mut crc_buffer = CrcBufMut::new(buffer);
414    crc_buffer.put_u32(message_len);
415    crc_buffer.put_u32(headers_len);
416    crc_buffer.put_crc();
417    crc_buffer.put(&headers[..]);
418    crc_buffer.put(&message.payload()[..]);
419    crc_buffer.put_crc();
420    Ok(())
421}
422
423fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> {
424    T::try_from(from).map_err(|_| err)
425}
426
427fn max_header_len(total_len: u32) -> Result<u32, Error> {
428    total_len
429        .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
430        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
431}
432
433fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
434    total_len
435        .checked_sub(
436            header_len
437                .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
438                .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?,
439        )
440        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
441}
442
443#[cfg(test)]
444mod message_tests {
445    use super::read_message_from;
446    use crate::error::ErrorKind;
447    use crate::frame::{write_message_to, Header, HeaderValue, Message};
448    use aws_smithy_types::DateTime;
449    use bytes::Bytes;
450
451    macro_rules! read_message_expect_err {
452        ($bytes:expr, $err:pat) => {
453            let result = read_message_from(&mut Bytes::from_static($bytes));
454            let result = result.as_ref();
455            assert!(result.is_err(), "Expected error, got {:?}", result);
456            assert!(
457                matches!(result.err().unwrap().kind(), $err),
458                "Expected {}, got {:?}",
459                stringify!($err),
460                result
461            );
462        };
463    }
464
465    #[test]
466    fn invalid_messages() {
467        read_message_expect_err!(
468            include_bytes!("../test_data/invalid_header_string_value_length"),
469            ErrorKind::InvalidHeaderValue
470        );
471        read_message_expect_err!(
472            include_bytes!("../test_data/invalid_header_string_length_cut_off"),
473            ErrorKind::InvalidHeaderValue
474        );
475        read_message_expect_err!(
476            include_bytes!("../test_data/invalid_header_value_type"),
477            ErrorKind::InvalidHeaderValueType(0x60)
478        );
479        read_message_expect_err!(
480            include_bytes!("../test_data/invalid_header_name_length"),
481            ErrorKind::InvalidHeaderNameLength
482        );
483        read_message_expect_err!(
484            include_bytes!("../test_data/invalid_headers_length"),
485            ErrorKind::InvalidHeadersLength
486        );
487        read_message_expect_err!(
488            include_bytes!("../test_data/invalid_prelude_checksum"),
489            ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
490        );
491        read_message_expect_err!(
492            include_bytes!("../test_data/invalid_message_checksum"),
493            ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
494        );
495        read_message_expect_err!(
496            include_bytes!("../test_data/invalid_header_name_length_too_long"),
497            ErrorKind::InvalidUtf8String
498        );
499    }
500
501    #[test]
502    fn read_message_no_headers() {
503        // Test message taken from the CRT:
504        // https://github.com/awslabs/aws-c-event-stream/blob/main/tests/message_deserializer_test.c
505        let data: &'static [u8] = &[
506            0x00, 0x00, 0x00, 0x1D, 0x00, 0x00, 0x00, 0x00, 0xfd, 0x52, 0x8c, 0x5a, 0x7b, 0x27,
507            0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27, 0x7d, 0xc3, 0x65, 0x39,
508            0x36,
509        ];
510
511        let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
512        assert_eq!(result.headers(), Vec::new());
513
514        let expected_payload = b"{'foo':'bar'}";
515        assert_eq!(expected_payload, result.payload().as_ref());
516    }
517
518    #[test]
519    fn read_message_one_header() {
520        // Test message taken from the CRT:
521        // https://github.com/awslabs/aws-c-event-stream/blob/main/tests/message_deserializer_test.c
522        let data: &'static [u8] = &[
523            0x00, 0x00, 0x00, 0x3D, 0x00, 0x00, 0x00, 0x20, 0x07, 0xFD, 0x83, 0x96, 0x0C, b'c',
524            b'o', b'n', b't', b'e', b'n', b't', b'-', b't', b'y', b'p', b'e', 0x07, 0x00, 0x10,
525            b'a', b'p', b'p', b'l', b'i', b'c', b'a', b't', b'i', b'o', b'n', b'/', b'j', b's',
526            b'o', b'n', 0x7b, 0x27, 0x66, 0x6f, 0x6f, 0x27, 0x3a, 0x27, 0x62, 0x61, 0x72, 0x27,
527            0x7d, 0x8D, 0x9C, 0x08, 0xB1,
528        ];
529
530        let result = read_message_from(&mut Bytes::from_static(data)).unwrap();
531        assert_eq!(
532            result.headers(),
533            vec![Header::new(
534                "content-type",
535                HeaderValue::String("application/json".into())
536            )]
537        );
538
539        let expected_payload = b"{'foo':'bar'}";
540        assert_eq!(expected_payload, result.payload().as_ref());
541    }
542
543    #[test]
544    fn read_all_headers_and_payload() {
545        let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
546        let result = read_message_from(&mut Bytes::from_static(message)).unwrap();
547        assert_eq!(
548            result.headers(),
549            vec![
550                Header::new("true", HeaderValue::Bool(true)),
551                Header::new("false", HeaderValue::Bool(false)),
552                Header::new("byte", HeaderValue::Byte(50)),
553                Header::new("short", HeaderValue::Int16(20_000)),
554                Header::new("int", HeaderValue::Int32(500_000)),
555                Header::new("long", HeaderValue::Int64(50_000_000_000)),
556                Header::new(
557                    "bytes",
558                    HeaderValue::ByteArray(Bytes::from(&b"some bytes"[..]))
559                ),
560                Header::new("str", HeaderValue::String("some str".into())),
561                Header::new(
562                    "time",
563                    HeaderValue::Timestamp(DateTime::from_secs(5_000_000))
564                ),
565                Header::new(
566                    "uuid",
567                    HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b)
568                ),
569            ]
570        );
571
572        assert_eq!(b"some payload", result.payload().as_ref());
573    }
574
575    #[test]
576    fn round_trip_all_headers_payload() {
577        let message = Message::new(&b"some payload"[..])
578            .add_header(Header::new("true", HeaderValue::Bool(true)))
579            .add_header(Header::new("false", HeaderValue::Bool(false)))
580            .add_header(Header::new("byte", HeaderValue::Byte(50)))
581            .add_header(Header::new("short", HeaderValue::Int16(20_000)))
582            .add_header(Header::new("int", HeaderValue::Int32(500_000)))
583            .add_header(Header::new("long", HeaderValue::Int64(50_000_000_000)))
584            .add_header(Header::new(
585                "bytes",
586                HeaderValue::ByteArray((&b"some bytes"[..]).into()),
587            ))
588            .add_header(Header::new("str", HeaderValue::String("some str".into())))
589            .add_header(Header::new(
590                "time",
591                HeaderValue::Timestamp(DateTime::from_secs(5_000_000)),
592            ))
593            .add_header(Header::new(
594                "uuid",
595                HeaderValue::Uuid(0xb79bc914_de21_4e13_b8b2_bc47e85b7f0b),
596            ));
597
598        let mut actual = Vec::new();
599        write_message_to(&message, &mut actual).unwrap();
600
601        let expected = include_bytes!("../test_data/valid_with_all_headers_and_payload").to_vec();
602        assert_eq!(expected, actual);
603
604        let result = read_message_from(&mut Bytes::from(actual)).unwrap();
605        assert_eq!(message.headers(), result.headers());
606        assert_eq!(message.payload().as_ref(), result.payload().as_ref());
607    }
608}
609
610/// Return value from [`MessageFrameDecoder`].
611#[derive(Debug)]
612pub enum DecodedFrame {
613    /// There wasn't enough data in the buffer to decode a full message.
614    Incomplete,
615    /// There was enough data in the buffer to decode.
616    Complete(Message),
617}
618
619/// Streaming decoder for decoding a [`Message`] from a stream.
620#[non_exhaustive]
621#[derive(Default, Debug)]
622pub struct MessageFrameDecoder {
623    prelude: [u8; PRELUDE_LENGTH_BYTES_USIZE],
624    prelude_read: bool,
625}
626
627impl MessageFrameDecoder {
628    /// Returns a new `MessageFrameDecoder`.
629    pub fn new() -> Self {
630        Default::default()
631    }
632
633    /// Determines if the `buffer` has enough data in it to read a full frame.
634    /// Returns `Ok(None)` if there's not enough data, or `Some(remaining)` where
635    /// `remaining` is the number of bytes after the prelude that belong to the
636    /// message that's in the buffer.
637    fn remaining_bytes_if_frame_available<B: Buf>(
638        &self,
639        buffer: &B,
640    ) -> Result<Option<usize>, Error> {
641        if self.prelude_read {
642            let remaining_len = (&self.prelude[..])
643                .get_u32()
644                .checked_sub(PRELUDE_LENGTH_BYTES)
645                .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
646            if buffer.remaining() >= remaining_len as usize {
647                return Ok(Some(remaining_len as usize));
648            }
649        }
650        Ok(None)
651    }
652
653    /// Resets the decoder.
654    fn reset(&mut self) {
655        self.prelude_read = false;
656        self.prelude = [0u8; PRELUDE_LENGTH_BYTES_USIZE];
657    }
658
659    /// Attempts to decode a [`Message`] from the given `buffer`. This function expects
660    /// to be called over and over again with more data in the buffer each time its called.
661    /// When there's not enough data to decode a message, it returns `Ok(None)`.
662    ///
663    /// Once there is enough data to read a message prelude, then it will mutate the `Buf`
664    /// position. The state from the reading of the prelude is stored in the decoder so that
665    /// the next call will be able to decode the entire message, even though the prelude
666    /// is no longer available in the `Buf`.
667    pub fn decode_frame<B: Buf>(&mut self, mut buffer: B) -> Result<DecodedFrame, Error> {
668        if !self.prelude_read && buffer.remaining() >= PRELUDE_LENGTH_BYTES_USIZE {
669            buffer.copy_to_slice(&mut self.prelude);
670            self.prelude_read = true;
671        }
672
673        if let Some(remaining_len) = self.remaining_bytes_if_frame_available(&buffer)? {
674            let mut message_buf = (&self.prelude[..]).chain(buffer.take(remaining_len));
675            let result = read_message_from(&mut message_buf).map(DecodedFrame::Complete);
676            self.reset();
677            return result;
678        }
679
680        Ok(DecodedFrame::Incomplete)
681    }
682}
683
684#[cfg(test)]
685mod message_frame_decoder_tests {
686    use super::{DecodedFrame, MessageFrameDecoder};
687    use crate::frame::read_message_from;
688    use bytes::Bytes;
689    use bytes_utils::SegmentedBuf;
690
691    #[test]
692    fn single_streaming_message() {
693        let message = include_bytes!("../test_data/valid_with_all_headers_and_payload");
694
695        let mut decoder = MessageFrameDecoder::new();
696        let mut segmented = SegmentedBuf::new();
697        for i in 0..(message.len() - 1) {
698            segmented.push(&message[i..(i + 1)]);
699            if let DecodedFrame::Complete(_) = decoder.decode_frame(&mut segmented).unwrap() {
700                panic!("incomplete frame shouldn't result in message");
701            }
702        }
703
704        segmented.push(&message[(message.len() - 1)..]);
705        match decoder.decode_frame(&mut segmented).unwrap() {
706            DecodedFrame::Incomplete => panic!("frame should be complete now"),
707            DecodedFrame::Complete(actual) => {
708                let expected = read_message_from(&mut Bytes::from_static(message)).unwrap();
709                assert_eq!(expected, actual);
710            }
711        }
712    }
713
714    fn multiple_streaming_messages_chunk_size(chunk_size: usize) {
715        let message1 = include_bytes!("../test_data/valid_with_all_headers_and_payload");
716        let message2 = include_bytes!("../test_data/valid_empty_payload");
717        let message3 = include_bytes!("../test_data/valid_no_headers");
718        let mut repeated = message1.to_vec();
719        repeated.extend_from_slice(message2);
720        repeated.extend_from_slice(message3);
721
722        let mut decoder = MessageFrameDecoder::new();
723        let mut segmented = SegmentedBuf::new();
724        let mut decoded = Vec::new();
725        for window in repeated.chunks(chunk_size) {
726            segmented.push(window);
727            match dbg!(decoder.decode_frame(&mut segmented)).unwrap() {
728                DecodedFrame::Incomplete => {}
729                DecodedFrame::Complete(message) => {
730                    decoded.push(message);
731                }
732            }
733        }
734
735        let expected1 = read_message_from(&mut Bytes::from_static(message1)).unwrap();
736        let expected2 = read_message_from(&mut Bytes::from_static(message2)).unwrap();
737        let expected3 = read_message_from(&mut Bytes::from_static(message3)).unwrap();
738        assert_eq!(3, decoded.len());
739        assert_eq!(expected1, decoded[0]);
740        assert_eq!(expected2, decoded[1]);
741        assert_eq!(expected3, decoded[2]);
742    }
743
744    #[test]
745    fn multiple_streaming_messages() {
746        for chunk_size in 1..=11 {
747            println!("chunk size: {}", chunk_size);
748            multiple_streaming_messages_chunk_size(chunk_size);
749        }
750    }
751}
752
753#[cfg(test)]
754mod deferred_signer_tests {
755    use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
756    use bytes::Bytes;
757
758    fn check_send_sync<T: Send + Sync>(value: T) -> T {
759        value
760    }
761
762    #[test]
763    fn deferred_signer() {
764        #[derive(Default, Debug)]
765        struct TestSigner {
766            call_num: i32,
767        }
768        impl SignMessage for TestSigner {
769            fn sign(
770                &mut self,
771                message: Message,
772            ) -> Result<Message, crate::frame::SignMessageError> {
773                self.call_num += 1;
774                Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
775            }
776
777            fn sign_empty(&mut self) -> Option<Result<Message, crate::frame::SignMessageError>> {
778                None
779            }
780        }
781
782        let (mut signer, sender) = check_send_sync(DeferredSigner::new());
783
784        sender.send(Box::<TestSigner>::default()).expect("success");
785
786        let message = signer.sign(Message::new(Bytes::new())).expect("success");
787        assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
788
789        let message = signer.sign(Message::new(Bytes::new())).expect("success");
790        assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
791
792        assert!(signer.sign_empty().is_none());
793    }
794
795    #[test]
796    fn deferred_signer_defaults_to_noop_signer() {
797        let (mut signer, _sender) = DeferredSigner::new();
798        assert_eq!(
799            Message::new(Bytes::new()),
800            signer.sign(Message::new(Bytes::new())).unwrap()
801        );
802        assert!(signer.sign_empty().is_none());
803    }
804}