1use aws_smithy_eventstream::frame::{
7 DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21#[derive(Debug)]
23enum RecvBuf {
24 Empty,
26 Partial(SegmentedBuf<Bytes>),
29 EosPartial(SegmentedBuf<Bytes>),
31 Terminated,
33}
34
35impl RecvBuf {
36 fn has_data(&self) -> bool {
38 match self {
39 RecvBuf::Empty | RecvBuf::Terminated => false,
40 RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41 }
42 }
43
44 fn is_eos(&self) -> bool {
46 matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47 }
48
49 fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51 match self {
52 RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53 RecvBuf::Partial(segmented) => segmented,
54 RecvBuf::EosPartial(segmented) => segmented,
55 RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56 }
57 }
58
59 fn with_partial(self, partial: Bytes) -> Self {
62 match self {
63 RecvBuf::Empty => {
64 let mut segmented = SegmentedBuf::new();
65 segmented.push(partial);
66 RecvBuf::Partial(segmented)
67 }
68 RecvBuf::Partial(mut segmented) => {
69 segmented.push(partial);
70 RecvBuf::Partial(segmented)
71 }
72 RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73 panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74 }
75 }
76 }
77
78 fn ended(self) -> Self {
80 match self {
81 RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82 RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83 RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84 RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85 }
86 }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91 UnexpectedEndOfStream,
93}
94
95#[derive(Debug)]
97pub struct ReceiverError {
98 kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self.kind {
104 ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105 }
106 }
107}
108
109impl StdError for ReceiverError {}
110
111#[derive(Debug)]
113pub struct Receiver<T, E> {
114 unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115 decoder: MessageFrameDecoder,
116 buffer: RecvBuf,
117 body: SdkBody,
118 buffered_message: Option<Message>,
123 _phantom: PhantomData<E>,
124}
125
126impl<T, E> Receiver<T, E> {
127 pub fn new(
129 unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
130 body: SdkBody,
131 ) -> Self {
132 Receiver {
133 unmarshaller: Box::new(unmarshaller),
134 decoder: MessageFrameDecoder::new(),
135 buffer: RecvBuf::Empty,
136 body,
137 buffered_message: None,
138 _phantom: Default::default(),
139 }
140 }
141
142 fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
143 match self.unmarshaller.unmarshall(&message) {
144 Ok(unmarshalled) => match unmarshalled {
145 UnmarshalledMessage::Event(event) => Ok(Some(event)),
146 UnmarshalledMessage::Error(err) => {
147 Err(SdkError::service_error(err, RawMessage::Decoded(message)))
148 }
149 },
150 Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
151 }
152 }
153
154 async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
155 use http_body_04x::Body;
156
157 if !self.buffer.is_eos() {
158 let next_chunk = self
159 .body
160 .data()
161 .await
162 .transpose()
163 .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
164 let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
165 if let Some(chunk) = next_chunk {
166 self.buffer = buffer.with_partial(chunk);
167 } else {
168 self.buffer = buffer.ended();
169 }
170 }
171 Ok(())
172 }
173
174 async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
175 while !self.buffer.is_eos() {
176 if self.buffer.has_data() {
177 if let DecodedFrame::Complete(message) = self
178 .decoder
179 .decode_frame(self.buffer.buffered())
180 .map_err(|err| {
181 SdkError::response_error(
182 err,
183 RawMessage::Invalid(None),
185 )
186 })?
187 {
188 trace!(message = ?message, "received complete event stream message");
189 return Ok(Some(message));
190 }
191 }
192
193 self.buffer_next_chunk().await?;
194 }
195 if self.buffer.has_data() {
196 trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
197 let buf = self.buffer.buffered();
198 return Err(SdkError::response_error(
199 ReceiverError {
200 kind: ReceiverErrorKind::UnexpectedEndOfStream,
201 },
202 RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
203 ));
204 }
205 Ok(None)
206 }
207
208 #[doc(hidden)]
211 pub async fn try_recv_initial(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
212 if let Some(message) = self.next_message().await? {
213 if let Some(event_type) = message
214 .headers()
215 .iter()
216 .find(|h| h.name().as_str() == ":event-type")
217 {
218 if event_type
219 .value()
220 .as_string()
221 .map(|s| s.as_str() == "initial-response")
222 .unwrap_or(false)
223 {
224 return Ok(Some(message));
225 }
226 } else {
227 self.buffered_message = Some(message);
229 }
230 }
231 Ok(None)
232 }
233
234 pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
239 if let Some(buffered) = self.buffered_message.take() {
240 return match self.unmarshall(buffered) {
241 Ok(message) => Ok(message),
242 Err(error) => {
243 self.buffer = RecvBuf::Terminated;
244 Err(error)
245 }
246 };
247 }
248 if let Some(message) = self.next_message().await? {
249 match self.unmarshall(message) {
250 Ok(message) => Ok(message),
251 Err(error) => {
252 self.buffer = RecvBuf::Terminated;
253 Err(error)
254 }
255 }
256 } else {
257 Ok(None)
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::{Receiver, UnmarshallMessage};
265 use aws_smithy_eventstream::error::Error as EventStreamError;
266 use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
267 use aws_smithy_runtime_api::client::result::SdkError;
268 use aws_smithy_types::body::SdkBody;
269 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
270 use bytes::Bytes;
271 use hyper::body::Body;
272 use std::error::Error as StdError;
273 use std::io::{Error as IOError, ErrorKind};
274
275 fn encode_initial_response() -> Bytes {
276 let mut buffer = Vec::new();
277 let message = Message::new(Bytes::new())
278 .add_header(Header::new(
279 ":message-type",
280 HeaderValue::String("event".into()),
281 ))
282 .add_header(Header::new(
283 ":event-type",
284 HeaderValue::String("initial-response".into()),
285 ));
286 write_message_to(&message, &mut buffer).unwrap();
287 buffer.into()
288 }
289
290 fn encode_message(message: &str) -> Bytes {
291 let mut buffer = Vec::new();
292 let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
293 write_message_to(&message, &mut buffer).unwrap();
294 buffer.into()
295 }
296
297 #[derive(Debug)]
298 struct FakeError;
299 impl std::fmt::Display for FakeError {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 write!(f, "FakeError")
302 }
303 }
304 impl StdError for FakeError {}
305
306 #[derive(Debug, Eq, PartialEq)]
307 struct TestMessage(String);
308
309 #[derive(Debug)]
310 struct Unmarshaller;
311 impl UnmarshallMessage for Unmarshaller {
312 type Output = TestMessage;
313 type Error = EventStreamError;
314
315 fn unmarshall(
316 &self,
317 message: &Message,
318 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
319 Ok(UnmarshalledMessage::Event(TestMessage(
320 std::str::from_utf8(&message.payload()[..]).unwrap().into(),
321 )))
322 }
323 }
324
325 #[tokio::test]
326 async fn receive_success() {
327 let chunks: Vec<Result<_, IOError>> =
328 vec![Ok(encode_message("one")), Ok(encode_message("two"))];
329 let chunk_stream = futures_util::stream::iter(chunks);
330 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
331 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
332 assert_eq!(
333 TestMessage("one".into()),
334 receiver.recv().await.unwrap().unwrap()
335 );
336 assert_eq!(
337 TestMessage("two".into()),
338 receiver.recv().await.unwrap().unwrap()
339 );
340 assert_eq!(None, receiver.recv().await.unwrap());
341 }
342
343 #[tokio::test]
344 async fn receive_last_chunk_empty() {
345 let chunks: Vec<Result<_, IOError>> = vec![
346 Ok(encode_message("one")),
347 Ok(encode_message("two")),
348 Ok(Bytes::from_static(&[])),
349 ];
350 let chunk_stream = futures_util::stream::iter(chunks);
351 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
352 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
353 assert_eq!(
354 TestMessage("one".into()),
355 receiver.recv().await.unwrap().unwrap()
356 );
357 assert_eq!(
358 TestMessage("two".into()),
359 receiver.recv().await.unwrap().unwrap()
360 );
361 assert_eq!(None, receiver.recv().await.unwrap());
362 }
363
364 #[tokio::test]
365 async fn receive_last_chunk_not_full_message() {
366 let chunks: Vec<Result<_, IOError>> = vec![
367 Ok(encode_message("one")),
368 Ok(encode_message("two")),
369 Ok(encode_message("three").split_to(10)),
370 ];
371 let chunk_stream = futures_util::stream::iter(chunks);
372 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
373 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
374 assert_eq!(
375 TestMessage("one".into()),
376 receiver.recv().await.unwrap().unwrap()
377 );
378 assert_eq!(
379 TestMessage("two".into()),
380 receiver.recv().await.unwrap().unwrap()
381 );
382 assert!(matches!(
383 receiver.recv().await,
384 Err(SdkError::ResponseError { .. }),
385 ));
386 }
387
388 #[tokio::test]
389 async fn receive_last_chunk_has_multiple_messages() {
390 let chunks: Vec<Result<_, IOError>> = vec![
391 Ok(encode_message("one")),
392 Ok(encode_message("two")),
393 Ok(Bytes::from(
394 [encode_message("three"), encode_message("four")].concat(),
395 )),
396 ];
397 let chunk_stream = futures_util::stream::iter(chunks);
398 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
399 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
400 assert_eq!(
401 TestMessage("one".into()),
402 receiver.recv().await.unwrap().unwrap()
403 );
404 assert_eq!(
405 TestMessage("two".into()),
406 receiver.recv().await.unwrap().unwrap()
407 );
408 assert_eq!(
409 TestMessage("three".into()),
410 receiver.recv().await.unwrap().unwrap()
411 );
412 assert_eq!(
413 TestMessage("four".into()),
414 receiver.recv().await.unwrap().unwrap()
415 );
416 assert_eq!(None, receiver.recv().await.unwrap());
417 }
418
419 proptest::proptest! {
420 #[test]
421 fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
422 let combined = Bytes::from([
423 encode_message("one"),
424 encode_message("two"),
425 encode_message("three"),
426 encode_message("four"),
427 encode_message("five"),
428 encode_message("six"),
429 encode_message("seven"),
430 encode_message("eight"),
431 ].concat());
432
433 let midpoint = combined.len() / 2;
434 let (start, boundary1, boundary2, end) = (
435 0,
436 b1 % midpoint,
437 midpoint + b2 % midpoint,
438 combined.len()
439 );
440 println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
441
442 let rt = tokio::runtime::Runtime::new().unwrap();
443 rt.block_on(async move {
444 let chunks: Vec<Result<_, IOError>> = vec![
445 Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
446 Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
447 Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
448 ];
449
450 let chunk_stream = futures_util::stream::iter(chunks);
451 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
452 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
453 for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
454 assert_eq!(
455 TestMessage((*payload).into()),
456 receiver.recv().await.unwrap().unwrap()
457 );
458 }
459 assert_eq!(None, receiver.recv().await.unwrap());
460 });
461 }
462 }
463
464 #[tokio::test]
465 async fn receive_network_failure() {
466 let chunks: Vec<Result<_, IOError>> = vec![
467 Ok(encode_message("one")),
468 Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
469 ];
470 let chunk_stream = futures_util::stream::iter(chunks);
471 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
472 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
473 assert_eq!(
474 TestMessage("one".into()),
475 receiver.recv().await.unwrap().unwrap()
476 );
477 assert!(matches!(
478 receiver.recv().await,
479 Err(SdkError::DispatchFailure(_))
480 ));
481 }
482
483 #[tokio::test]
484 async fn receive_message_parse_failure() {
485 let chunks: Vec<Result<_, IOError>> = vec![
486 Ok(encode_message("one")),
487 Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
490 ];
491 let chunk_stream = futures_util::stream::iter(chunks);
492 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
493 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
494 assert_eq!(
495 TestMessage("one".into()),
496 receiver.recv().await.unwrap().unwrap()
497 );
498 assert!(matches!(
499 receiver.recv().await,
500 Err(SdkError::ResponseError { .. })
501 ));
502 }
503
504 #[tokio::test]
505 async fn receive_initial_response() {
506 let chunks: Vec<Result<_, IOError>> =
507 vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
508 let chunk_stream = futures_util::stream::iter(chunks);
509 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
510 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
511 assert!(receiver.try_recv_initial().await.unwrap().is_some());
512 assert_eq!(
513 TestMessage("one".into()),
514 receiver.recv().await.unwrap().unwrap()
515 );
516 }
517
518 #[tokio::test]
519 async fn receive_no_initial_response() {
520 let chunks: Vec<Result<_, IOError>> =
521 vec![Ok(encode_message("one")), Ok(encode_message("two"))];
522 let chunk_stream = futures_util::stream::iter(chunks);
523 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
524 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
525 assert!(receiver.try_recv_initial().await.unwrap().is_none());
526 assert_eq!(
527 TestMessage("one".into()),
528 receiver.recv().await.unwrap().unwrap()
529 );
530 assert_eq!(
531 TestMessage("two".into()),
532 receiver.recv().await.unwrap().unwrap()
533 );
534 }
535
536 fn assert_send_and_sync<T: Send + Sync>() {}
537
538 #[tokio::test]
539 async fn receiver_is_send_and_sync() {
540 assert_send_and_sync::<Receiver<(), ()>>();
541 }
542}