aws_smithy_http/event_stream/
receiver.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{
7    DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21/// Wrapper around SegmentedBuf that tracks the state of the stream.
22#[derive(Debug)]
23enum RecvBuf {
24    /// Nothing has been buffered yet.
25    Empty,
26    /// Some data has been buffered.
27    /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
28    Partial(SegmentedBuf<Bytes>),
29    /// The end of the stream has been reached, but there may still be some buffered data.
30    EosPartial(SegmentedBuf<Bytes>),
31    /// An exception terminated this stream.
32    Terminated,
33}
34
35impl RecvBuf {
36    /// Returns true if there's more buffered data.
37    fn has_data(&self) -> bool {
38        match self {
39            RecvBuf::Empty | RecvBuf::Terminated => false,
40            RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41        }
42    }
43
44    /// Returns true if the stream has ended.
45    fn is_eos(&self) -> bool {
46        matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47    }
48
49    /// Returns a mutable reference to the underlying buffered data.
50    fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51        match self {
52            RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53            RecvBuf::Partial(segmented) => segmented,
54            RecvBuf::EosPartial(segmented) => segmented,
55            RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56        }
57    }
58
59    /// Returns a new `RecvBuf` with additional data buffered. This will only allocate
60    /// if the `RecvBuf` was previously empty.
61    fn with_partial(self, partial: Bytes) -> Self {
62        match self {
63            RecvBuf::Empty => {
64                let mut segmented = SegmentedBuf::new();
65                segmented.push(partial);
66                RecvBuf::Partial(segmented)
67            }
68            RecvBuf::Partial(mut segmented) => {
69                segmented.push(partial);
70                RecvBuf::Partial(segmented)
71            }
72            RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73                panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74            }
75        }
76    }
77
78    /// Returns a `RecvBuf` that has reached end of stream.
79    fn ended(self) -> Self {
80        match self {
81            RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82            RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83            RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84            RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85        }
86    }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91    /// The stream ended before a complete message frame was received.
92    UnexpectedEndOfStream,
93}
94
95/// An error that occurs within an event stream receiver.
96#[derive(Debug)]
97pub struct ReceiverError {
98    kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self.kind {
104            ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105        }
106    }
107}
108
109impl StdError for ReceiverError {}
110
111/// Receives Smithy-modeled messages out of an Event Stream.
112#[derive(Debug)]
113pub struct Receiver<T, E> {
114    unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115    decoder: MessageFrameDecoder,
116    buffer: RecvBuf,
117    body: SdkBody,
118    /// Event Stream has optional initial response frames an with `:message-type` of
119    /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
120    /// initial response, then the message will be stored in `buffered_message` so that it can
121    /// be returned with the next call of `recv()`.
122    buffered_message: Option<Message>,
123    _phantom: PhantomData<E>,
124}
125
126impl<T, E> Receiver<T, E> {
127    /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
128    pub fn new(
129        unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
130        body: SdkBody,
131    ) -> Self {
132        Receiver {
133            unmarshaller: Box::new(unmarshaller),
134            decoder: MessageFrameDecoder::new(),
135            buffer: RecvBuf::Empty,
136            body,
137            buffered_message: None,
138            _phantom: Default::default(),
139        }
140    }
141
142    fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
143        match self.unmarshaller.unmarshall(&message) {
144            Ok(unmarshalled) => match unmarshalled {
145                UnmarshalledMessage::Event(event) => Ok(Some(event)),
146                UnmarshalledMessage::Error(err) => {
147                    Err(SdkError::service_error(err, RawMessage::Decoded(message)))
148                }
149            },
150            Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
151        }
152    }
153
154    async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
155        use http_body_04x::Body;
156
157        if !self.buffer.is_eos() {
158            let next_chunk = self
159                .body
160                .data()
161                .await
162                .transpose()
163                .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
164            let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
165            if let Some(chunk) = next_chunk {
166                self.buffer = buffer.with_partial(chunk);
167            } else {
168                self.buffer = buffer.ended();
169            }
170        }
171        Ok(())
172    }
173
174    async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
175        while !self.buffer.is_eos() {
176            if self.buffer.has_data() {
177                if let DecodedFrame::Complete(message) = self
178                    .decoder
179                    .decode_frame(self.buffer.buffered())
180                    .map_err(|err| {
181                        SdkError::response_error(
182                            err,
183                            // the buffer has been consumed
184                            RawMessage::Invalid(None),
185                        )
186                    })?
187                {
188                    trace!(message = ?message, "received complete event stream message");
189                    return Ok(Some(message));
190                }
191            }
192
193            self.buffer_next_chunk().await?;
194        }
195        if self.buffer.has_data() {
196            trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
197            let buf = self.buffer.buffered();
198            return Err(SdkError::response_error(
199                ReceiverError {
200                    kind: ReceiverErrorKind::UnexpectedEndOfStream,
201                },
202                RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
203            ));
204        }
205        Ok(None)
206    }
207
208    /// Tries to receive the initial response message that has `:event-type` of `initial-response`.
209    /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
210    #[doc(hidden)]
211    pub async fn try_recv_initial(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
212        if let Some(message) = self.next_message().await? {
213            if let Some(event_type) = message
214                .headers()
215                .iter()
216                .find(|h| h.name().as_str() == ":event-type")
217            {
218                if event_type
219                    .value()
220                    .as_string()
221                    .map(|s| s.as_str() == "initial-response")
222                    .unwrap_or(false)
223                {
224                    return Ok(Some(message));
225                }
226            } else {
227                // Buffer the message so that it can be returned by the next call to `recv()`
228                self.buffered_message = Some(message);
229            }
230        }
231        Ok(None)
232    }
233
234    /// Asynchronously tries to receive a message from the stream. If the stream has ended,
235    /// it returns an `Ok(None)`. If there is a transport layer error, it will return
236    /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
237    /// messages.
238    pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
239        if let Some(buffered) = self.buffered_message.take() {
240            return match self.unmarshall(buffered) {
241                Ok(message) => Ok(message),
242                Err(error) => {
243                    self.buffer = RecvBuf::Terminated;
244                    Err(error)
245                }
246            };
247        }
248        if let Some(message) = self.next_message().await? {
249            match self.unmarshall(message) {
250                Ok(message) => Ok(message),
251                Err(error) => {
252                    self.buffer = RecvBuf::Terminated;
253                    Err(error)
254                }
255            }
256        } else {
257            Ok(None)
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::{Receiver, UnmarshallMessage};
265    use aws_smithy_eventstream::error::Error as EventStreamError;
266    use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
267    use aws_smithy_runtime_api::client::result::SdkError;
268    use aws_smithy_types::body::SdkBody;
269    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
270    use bytes::Bytes;
271    use hyper::body::Body;
272    use std::error::Error as StdError;
273    use std::io::{Error as IOError, ErrorKind};
274
275    fn encode_initial_response() -> Bytes {
276        let mut buffer = Vec::new();
277        let message = Message::new(Bytes::new())
278            .add_header(Header::new(
279                ":message-type",
280                HeaderValue::String("event".into()),
281            ))
282            .add_header(Header::new(
283                ":event-type",
284                HeaderValue::String("initial-response".into()),
285            ));
286        write_message_to(&message, &mut buffer).unwrap();
287        buffer.into()
288    }
289
290    fn encode_message(message: &str) -> Bytes {
291        let mut buffer = Vec::new();
292        let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
293        write_message_to(&message, &mut buffer).unwrap();
294        buffer.into()
295    }
296
297    #[derive(Debug)]
298    struct FakeError;
299    impl std::fmt::Display for FakeError {
300        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301            write!(f, "FakeError")
302        }
303    }
304    impl StdError for FakeError {}
305
306    #[derive(Debug, Eq, PartialEq)]
307    struct TestMessage(String);
308
309    #[derive(Debug)]
310    struct Unmarshaller;
311    impl UnmarshallMessage for Unmarshaller {
312        type Output = TestMessage;
313        type Error = EventStreamError;
314
315        fn unmarshall(
316            &self,
317            message: &Message,
318        ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
319            Ok(UnmarshalledMessage::Event(TestMessage(
320                std::str::from_utf8(&message.payload()[..]).unwrap().into(),
321            )))
322        }
323    }
324
325    #[tokio::test]
326    async fn receive_success() {
327        let chunks: Vec<Result<_, IOError>> =
328            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
329        let chunk_stream = futures_util::stream::iter(chunks);
330        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
331        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
332        assert_eq!(
333            TestMessage("one".into()),
334            receiver.recv().await.unwrap().unwrap()
335        );
336        assert_eq!(
337            TestMessage("two".into()),
338            receiver.recv().await.unwrap().unwrap()
339        );
340        assert_eq!(None, receiver.recv().await.unwrap());
341    }
342
343    #[tokio::test]
344    async fn receive_last_chunk_empty() {
345        let chunks: Vec<Result<_, IOError>> = vec![
346            Ok(encode_message("one")),
347            Ok(encode_message("two")),
348            Ok(Bytes::from_static(&[])),
349        ];
350        let chunk_stream = futures_util::stream::iter(chunks);
351        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
352        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
353        assert_eq!(
354            TestMessage("one".into()),
355            receiver.recv().await.unwrap().unwrap()
356        );
357        assert_eq!(
358            TestMessage("two".into()),
359            receiver.recv().await.unwrap().unwrap()
360        );
361        assert_eq!(None, receiver.recv().await.unwrap());
362    }
363
364    #[tokio::test]
365    async fn receive_last_chunk_not_full_message() {
366        let chunks: Vec<Result<_, IOError>> = vec![
367            Ok(encode_message("one")),
368            Ok(encode_message("two")),
369            Ok(encode_message("three").split_to(10)),
370        ];
371        let chunk_stream = futures_util::stream::iter(chunks);
372        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
373        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
374        assert_eq!(
375            TestMessage("one".into()),
376            receiver.recv().await.unwrap().unwrap()
377        );
378        assert_eq!(
379            TestMessage("two".into()),
380            receiver.recv().await.unwrap().unwrap()
381        );
382        assert!(matches!(
383            receiver.recv().await,
384            Err(SdkError::ResponseError { .. }),
385        ));
386    }
387
388    #[tokio::test]
389    async fn receive_last_chunk_has_multiple_messages() {
390        let chunks: Vec<Result<_, IOError>> = vec![
391            Ok(encode_message("one")),
392            Ok(encode_message("two")),
393            Ok(Bytes::from(
394                [encode_message("three"), encode_message("four")].concat(),
395            )),
396        ];
397        let chunk_stream = futures_util::stream::iter(chunks);
398        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
399        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
400        assert_eq!(
401            TestMessage("one".into()),
402            receiver.recv().await.unwrap().unwrap()
403        );
404        assert_eq!(
405            TestMessage("two".into()),
406            receiver.recv().await.unwrap().unwrap()
407        );
408        assert_eq!(
409            TestMessage("three".into()),
410            receiver.recv().await.unwrap().unwrap()
411        );
412        assert_eq!(
413            TestMessage("four".into()),
414            receiver.recv().await.unwrap().unwrap()
415        );
416        assert_eq!(None, receiver.recv().await.unwrap());
417    }
418
419    proptest::proptest! {
420        #[test]
421        fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
422            let combined = Bytes::from([
423                encode_message("one"),
424                encode_message("two"),
425                encode_message("three"),
426                encode_message("four"),
427                encode_message("five"),
428                encode_message("six"),
429                encode_message("seven"),
430                encode_message("eight"),
431            ].concat());
432
433            let midpoint = combined.len() / 2;
434            let (start, boundary1, boundary2, end) = (
435                0,
436                b1 % midpoint,
437                midpoint + b2 % midpoint,
438                combined.len()
439            );
440            println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
441
442            let rt = tokio::runtime::Runtime::new().unwrap();
443            rt.block_on(async move {
444                let chunks: Vec<Result<_, IOError>> = vec![
445                    Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
446                    Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
447                    Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
448                ];
449
450                let chunk_stream = futures_util::stream::iter(chunks);
451                let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
452                let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
453                for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
454                    assert_eq!(
455                        TestMessage((*payload).into()),
456                        receiver.recv().await.unwrap().unwrap()
457                    );
458                }
459                assert_eq!(None, receiver.recv().await.unwrap());
460            });
461        }
462    }
463
464    #[tokio::test]
465    async fn receive_network_failure() {
466        let chunks: Vec<Result<_, IOError>> = vec![
467            Ok(encode_message("one")),
468            Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
469        ];
470        let chunk_stream = futures_util::stream::iter(chunks);
471        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
472        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
473        assert_eq!(
474            TestMessage("one".into()),
475            receiver.recv().await.unwrap().unwrap()
476        );
477        assert!(matches!(
478            receiver.recv().await,
479            Err(SdkError::DispatchFailure(_))
480        ));
481    }
482
483    #[tokio::test]
484    async fn receive_message_parse_failure() {
485        let chunks: Vec<Result<_, IOError>> = vec![
486            Ok(encode_message("one")),
487            // A zero length message will be invalid. We need to provide a minimum of 12 bytes
488            // for the MessageFrameDecoder to actually start parsing it.
489            Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
490        ];
491        let chunk_stream = futures_util::stream::iter(chunks);
492        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
493        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
494        assert_eq!(
495            TestMessage("one".into()),
496            receiver.recv().await.unwrap().unwrap()
497        );
498        assert!(matches!(
499            receiver.recv().await,
500            Err(SdkError::ResponseError { .. })
501        ));
502    }
503
504    #[tokio::test]
505    async fn receive_initial_response() {
506        let chunks: Vec<Result<_, IOError>> =
507            vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
508        let chunk_stream = futures_util::stream::iter(chunks);
509        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
510        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
511        assert!(receiver.try_recv_initial().await.unwrap().is_some());
512        assert_eq!(
513            TestMessage("one".into()),
514            receiver.recv().await.unwrap().unwrap()
515        );
516    }
517
518    #[tokio::test]
519    async fn receive_no_initial_response() {
520        let chunks: Vec<Result<_, IOError>> =
521            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
522        let chunk_stream = futures_util::stream::iter(chunks);
523        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
524        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
525        assert!(receiver.try_recv_initial().await.unwrap().is_none());
526        assert_eq!(
527            TestMessage("one".into()),
528            receiver.recv().await.unwrap().unwrap()
529        );
530        assert_eq!(
531            TestMessage("two".into()),
532            receiver.recv().await.unwrap().unwrap()
533        );
534    }
535
536    fn assert_send_and_sync<T: Send + Sync>() {}
537
538    #[tokio::test]
539    async fn receiver_is_send_and_sync() {
540        assert_send_and_sync::<Receiver<(), ()>>();
541    }
542}