aws_runtime/
content_encoding.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use bytes::{Bytes, BytesMut};
7use http_02x::{HeaderMap, HeaderValue};
8use http_body_04x::{Body, SizeHint};
9use pin_project_lite::pin_project;
10
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14const CRLF: &str = "\r\n";
15const CHUNK_TERMINATOR: &str = "0\r\n";
16const TRAILER_SEPARATOR: &[u8] = b":";
17
18/// Content encoding header value constants
19pub mod header_value {
20    /// Header value denoting "aws-chunked" encoding
21    pub const AWS_CHUNKED: &str = "aws-chunked";
22}
23
24/// Options used when constructing an [`AwsChunkedBody`].
25#[derive(Debug, Default)]
26#[non_exhaustive]
27pub struct AwsChunkedBodyOptions {
28    /// The total size of the stream. Because we only support unsigned encoding
29    /// this implies that there will only be a single chunk containing the
30    /// underlying payload.
31    stream_length: u64,
32    /// The length of each trailer sent within an `AwsChunkedBody`. Necessary in
33    /// order to correctly calculate the total size of the body accurately.
34    trailer_lengths: Vec<u64>,
35}
36
37impl AwsChunkedBodyOptions {
38    /// Create a new [`AwsChunkedBodyOptions`].
39    pub fn new(stream_length: u64, trailer_lengths: Vec<u64>) -> Self {
40        Self {
41            stream_length,
42            trailer_lengths,
43        }
44    }
45
46    fn total_trailer_length(&self) -> u64 {
47        self.trailer_lengths.iter().sum::<u64>()
48            // We need to account for a CRLF after each trailer name/value pair
49            + (self.trailer_lengths.len() * CRLF.len()) as u64
50    }
51
52    /// Set a trailer len
53    pub fn with_trailer_len(mut self, trailer_len: u64) -> Self {
54        self.trailer_lengths.push(trailer_len);
55        self
56    }
57}
58
59#[derive(Debug, PartialEq, Eq)]
60enum AwsChunkedBodyState {
61    /// Write out the size of the chunk that will follow. Then, transition into the
62    /// `WritingChunk` state.
63    WritingChunkSize,
64    /// Write out the next chunk of data. Multiple polls of the inner body may need to occur before
65    /// all data is written out. Once there is no more data to write, transition into the
66    /// `WritingTrailers` state.
67    WritingChunk,
68    /// Write out all trailers associated with this `AwsChunkedBody` and then transition into the
69    /// `Closed` state.
70    WritingTrailers,
71    /// This is the final state. Write out the body terminator and then remain in this state.
72    Closed,
73}
74
75pin_project! {
76    /// A request body compatible with `Content-Encoding: aws-chunked`. This implementation is only
77    /// capable of writing a single chunk and does not support signed chunks.
78    ///
79    /// Chunked-Body grammar is defined in [ABNF] as:
80    ///
81    /// ```txt
82    /// Chunked-Body    = *chunk
83    ///                   last-chunk
84    ///                   chunked-trailer
85    ///                   CRLF
86    ///
87    /// chunk           = chunk-size CRLF chunk-data CRLF
88    /// chunk-size      = 1*HEXDIG
89    /// last-chunk      = 1*("0") CRLF
90    /// chunked-trailer = *( entity-header CRLF )
91    /// entity-header   = field-name ":" OWS field-value OWS
92    /// ```
93    /// For more info on what the abbreviations mean, see https://datatracker.ietf.org/doc/html/rfc7230#section-1.2
94    ///
95    /// [ABNF]:https://en.wikipedia.org/wiki/Augmented_Backus%E2%80%93Naur_form
96    #[derive(Debug)]
97    pub struct AwsChunkedBody<InnerBody> {
98        #[pin]
99        inner: InnerBody,
100        #[pin]
101        state: AwsChunkedBodyState,
102        options: AwsChunkedBodyOptions,
103        inner_body_bytes_read_so_far: usize,
104    }
105}
106
107impl<Inner> AwsChunkedBody<Inner> {
108    /// Wrap the given body in an outer body compatible with `Content-Encoding: aws-chunked`
109    pub fn new(body: Inner, options: AwsChunkedBodyOptions) -> Self {
110        Self {
111            inner: body,
112            state: AwsChunkedBodyState::WritingChunkSize,
113            options,
114            inner_body_bytes_read_so_far: 0,
115        }
116    }
117
118    fn encoded_length(&self) -> u64 {
119        let mut length = 0;
120        if self.options.stream_length != 0 {
121            length += get_unsigned_chunk_bytes_length(self.options.stream_length);
122        }
123
124        // End chunk
125        length += CHUNK_TERMINATOR.len() as u64;
126
127        // Trailers
128        for len in self.options.trailer_lengths.iter() {
129            length += len + CRLF.len() as u64;
130        }
131
132        // Encoding terminator
133        length += CRLF.len() as u64;
134
135        length
136    }
137}
138
139fn get_unsigned_chunk_bytes_length(payload_length: u64) -> u64 {
140    let hex_repr_len = int_log16(payload_length);
141    hex_repr_len + CRLF.len() as u64 + payload_length + CRLF.len() as u64
142}
143
144/// Writes trailers out into a `string` and then converts that `String` to a `Bytes` before
145/// returning.
146///
147/// - Trailer names are separated by a single colon only, no space.
148/// - Trailer names with multiple values will be written out one line per value, with the name
149///   appearing on each line.
150fn trailers_as_aws_chunked_bytes(
151    trailer_map: Option<HeaderMap>,
152    estimated_length: u64,
153) -> BytesMut {
154    if let Some(trailer_map) = trailer_map {
155        let mut current_header_name = None;
156        let mut trailers = BytesMut::with_capacity(estimated_length.try_into().unwrap_or_default());
157
158        for (header_name, header_value) in trailer_map.into_iter() {
159            // When a header has multiple values, the name only comes up in iteration the first time
160            // we see it. Therefore, we need to keep track of the last name we saw and fall back to
161            // it when `header_name == None`.
162            current_header_name = header_name.or(current_header_name);
163
164            // In practice, this will always exist, but `if let` is nicer than unwrap
165            if let Some(header_name) = current_header_name.as_ref() {
166                trailers.extend_from_slice(header_name.as_ref());
167                trailers.extend_from_slice(TRAILER_SEPARATOR);
168                trailers.extend_from_slice(header_value.as_bytes());
169                trailers.extend_from_slice(CRLF.as_bytes());
170            }
171        }
172
173        trailers
174    } else {
175        BytesMut::new()
176    }
177}
178
179/// Given an optional `HeaderMap`, calculate the total number of bytes required to represent the
180/// `HeaderMap`. If no `HeaderMap` is given as input, return 0.
181///
182/// - Trailer names are separated by a single colon only, no space.
183/// - Trailer names with multiple values will be written out one line per value, with the name
184///   appearing on each line.
185fn total_rendered_length_of_trailers(trailer_map: Option<&HeaderMap>) -> u64 {
186    match trailer_map {
187        Some(trailer_map) => trailer_map
188            .iter()
189            .map(|(trailer_name, trailer_value)| {
190                trailer_name.as_str().len()
191                    + TRAILER_SEPARATOR.len()
192                    + trailer_value.len()
193                    + CRLF.len()
194            })
195            .sum::<usize>() as u64,
196        None => 0,
197    }
198}
199
200impl<Inner> Body for AwsChunkedBody<Inner>
201where
202    Inner: Body<Data = Bytes, Error = aws_smithy_types::body::Error>,
203{
204    type Data = Bytes;
205    type Error = aws_smithy_types::body::Error;
206
207    fn poll_data(
208        self: Pin<&mut Self>,
209        cx: &mut Context<'_>,
210    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
211        tracing::trace!(state = ?self.state, "polling AwsChunkedBody");
212        let mut this = self.project();
213
214        match *this.state {
215            AwsChunkedBodyState::WritingChunkSize => {
216                if this.options.stream_length == 0 {
217                    // If the stream is empty, we skip to writing trailers after writing the CHUNK_TERMINATOR.
218                    *this.state = AwsChunkedBodyState::WritingTrailers;
219                    tracing::trace!("stream is empty, writing chunk terminator");
220                    Poll::Ready(Some(Ok(Bytes::from([CHUNK_TERMINATOR].concat()))))
221                } else {
222                    *this.state = AwsChunkedBodyState::WritingChunk;
223                    // A chunk must be prefixed by chunk size in hexadecimal
224                    let chunk_size = format!("{:X?}{CRLF}", this.options.stream_length);
225                    tracing::trace!(%chunk_size, "writing chunk size");
226                    let chunk_size = Bytes::from(chunk_size);
227                    Poll::Ready(Some(Ok(chunk_size)))
228                }
229            }
230            AwsChunkedBodyState::WritingChunk => match this.inner.poll_data(cx) {
231                Poll::Ready(Some(Ok(data))) => {
232                    tracing::trace!(len = data.len(), "writing chunk data");
233                    *this.inner_body_bytes_read_so_far += data.len();
234                    Poll::Ready(Some(Ok(data)))
235                }
236                Poll::Ready(None) => {
237                    let actual_stream_length = *this.inner_body_bytes_read_so_far as u64;
238                    let expected_stream_length = this.options.stream_length;
239                    if actual_stream_length != expected_stream_length {
240                        let err = Box::new(AwsChunkedBodyError::StreamLengthMismatch {
241                            actual: actual_stream_length,
242                            expected: expected_stream_length,
243                        });
244                        return Poll::Ready(Some(Err(err)));
245                    };
246
247                    tracing::trace!("no more chunk data, writing CRLF and chunk terminator");
248                    *this.state = AwsChunkedBodyState::WritingTrailers;
249                    // Since we wrote chunk data, we end it with a CRLF and since we only write
250                    // a single chunk, we write the CHUNK_TERMINATOR immediately after
251                    Poll::Ready(Some(Ok(Bytes::from([CRLF, CHUNK_TERMINATOR].concat()))))
252                }
253                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
254                Poll::Pending => Poll::Pending,
255            },
256            AwsChunkedBodyState::WritingTrailers => {
257                return match this.inner.poll_trailers(cx) {
258                    Poll::Ready(Ok(trailers)) => {
259                        *this.state = AwsChunkedBodyState::Closed;
260                        let expected_length = total_rendered_length_of_trailers(trailers.as_ref());
261                        let actual_length = this.options.total_trailer_length();
262
263                        if expected_length != actual_length {
264                            let err =
265                                Box::new(AwsChunkedBodyError::ReportedTrailerLengthMismatch {
266                                    actual: actual_length,
267                                    expected: expected_length,
268                                });
269                            return Poll::Ready(Some(Err(err)));
270                        }
271
272                        let mut trailers =
273                            trailers_as_aws_chunked_bytes(trailers, actual_length + 1);
274                        // Insert the final CRLF to close the body
275                        trailers.extend_from_slice(CRLF.as_bytes());
276
277                        Poll::Ready(Some(Ok(trailers.into())))
278                    }
279                    Poll::Pending => Poll::Pending,
280                    Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
281                };
282            }
283            AwsChunkedBodyState::Closed => Poll::Ready(None),
284        }
285    }
286
287    fn poll_trailers(
288        self: Pin<&mut Self>,
289        _cx: &mut Context<'_>,
290    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
291        // Trailers were already appended to the body because of the content encoding scheme
292        Poll::Ready(Ok(None))
293    }
294
295    fn is_end_stream(&self) -> bool {
296        self.state == AwsChunkedBodyState::Closed
297    }
298
299    fn size_hint(&self) -> SizeHint {
300        SizeHint::with_exact(self.encoded_length())
301    }
302}
303
304/// Errors related to `AwsChunkedBody`
305#[derive(Debug)]
306enum AwsChunkedBodyError {
307    /// Error that occurs when the sum of `trailer_lengths` set when creating an `AwsChunkedBody` is
308    /// not equal to the actual length of the trailers returned by the inner `http_body::Body`
309    /// implementor. These trailer lengths are necessary in order to correctly calculate the total
310    /// size of the body for setting the content length header.
311    ReportedTrailerLengthMismatch { actual: u64, expected: u64 },
312    /// Error that occurs when the `stream_length` set when creating an `AwsChunkedBody` is not
313    /// equal to the actual length of the body returned by the inner `http_body::Body` implementor.
314    /// `stream_length` must be correct in order to set an accurate content length header.
315    StreamLengthMismatch { actual: u64, expected: u64 },
316}
317
318impl std::fmt::Display for AwsChunkedBodyError {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        match self {
321            Self::ReportedTrailerLengthMismatch { actual, expected } => {
322                write!(f, "When creating this AwsChunkedBody, length of trailers was reported as {expected}. However, when double checking during trailer encoding, length was found to be {actual} instead.")
323            }
324            Self::StreamLengthMismatch { actual, expected } => {
325                write!(f, "When creating this AwsChunkedBody, stream length was reported as {expected}. However, when double checking during body encoding, length was found to be {actual} instead.")
326            }
327        }
328    }
329}
330
331impl std::error::Error for AwsChunkedBodyError {}
332
333// Used for finding how many hexadecimal digits it takes to represent a base 10 integer
334fn int_log16<T>(mut i: T) -> u64
335where
336    T: std::ops::DivAssign + PartialOrd + From<u8> + Copy,
337{
338    let mut len = 0;
339    let zero = T::from(0);
340    let sixteen = T::from(16);
341
342    while i > zero {
343        i /= sixteen;
344        len += 1;
345    }
346
347    len
348}
349
350#[cfg(test)]
351mod tests {
352    use super::{
353        total_rendered_length_of_trailers, trailers_as_aws_chunked_bytes, AwsChunkedBody,
354        AwsChunkedBodyOptions, CHUNK_TERMINATOR, CRLF,
355    };
356
357    use aws_smithy_types::body::SdkBody;
358    use bytes::{Buf, Bytes};
359    use bytes_utils::SegmentedBuf;
360    use http_02x::{HeaderMap, HeaderValue};
361    use http_body_04x::{Body, SizeHint};
362    use pin_project_lite::pin_project;
363
364    use std::io::Read;
365    use std::pin::Pin;
366    use std::task::{Context, Poll};
367    use std::time::Duration;
368
369    pin_project! {
370        struct SputteringBody {
371            parts: Vec<Option<Bytes>>,
372            cursor: usize,
373            delay_in_millis: u64,
374        }
375    }
376
377    impl SputteringBody {
378        fn len(&self) -> usize {
379            self.parts.iter().flatten().map(|b| b.len()).sum()
380        }
381    }
382
383    impl Body for SputteringBody {
384        type Data = Bytes;
385        type Error = aws_smithy_types::body::Error;
386
387        fn poll_data(
388            self: Pin<&mut Self>,
389            cx: &mut Context<'_>,
390        ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
391            if self.cursor == self.parts.len() {
392                return Poll::Ready(None);
393            }
394
395            let this = self.project();
396            let delay_in_millis = *this.delay_in_millis;
397            let next_part = this.parts.get_mut(*this.cursor).unwrap().take();
398
399            match next_part {
400                None => {
401                    *this.cursor += 1;
402                    let waker = cx.waker().clone();
403                    tokio::spawn(async move {
404                        tokio::time::sleep(Duration::from_millis(delay_in_millis)).await;
405                        waker.wake();
406                    });
407                    Poll::Pending
408                }
409                Some(data) => {
410                    *this.cursor += 1;
411                    Poll::Ready(Some(Ok(data)))
412                }
413            }
414        }
415
416        fn poll_trailers(
417            self: Pin<&mut Self>,
418            _cx: &mut Context<'_>,
419        ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
420            Poll::Ready(Ok(None))
421        }
422
423        fn is_end_stream(&self) -> bool {
424            false
425        }
426
427        fn size_hint(&self) -> SizeHint {
428            SizeHint::new()
429        }
430    }
431
432    #[tokio::test]
433    async fn test_aws_chunked_encoding() {
434        let test_fut = async {
435            let input_str = "Hello world";
436            let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
437            let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
438
439            let mut output = SegmentedBuf::new();
440            while let Some(buf) = body.data().await {
441                output.push(buf.unwrap());
442            }
443
444            let mut actual_output = String::new();
445            output
446                .reader()
447                .read_to_string(&mut actual_output)
448                .expect("Doesn't cause IO errors");
449
450            let expected_output = "B\r\nHello world\r\n0\r\n\r\n";
451
452            assert_eq!(expected_output, actual_output);
453            assert!(
454                body.trailers()
455                    .await
456                    .expect("no errors occurred during trailer polling")
457                    .is_none(),
458                "aws-chunked encoded bodies don't have normal HTTP trailers"
459            );
460
461            // You can insert a `tokio::time::sleep` here to verify the timeout works as intended
462        };
463
464        let timeout_duration = Duration::from_secs(3);
465        if tokio::time::timeout(timeout_duration, test_fut)
466            .await
467            .is_err()
468        {
469            panic!("test_aws_chunked_encoding timed out after {timeout_duration:?}");
470        }
471    }
472
473    #[tokio::test]
474    async fn test_aws_chunked_encoding_sputtering_body() {
475        let test_fut = async {
476            let input = SputteringBody {
477                parts: vec![
478                    Some(Bytes::from_static(b"chunk 1, ")),
479                    None,
480                    Some(Bytes::from_static(b"chunk 2, ")),
481                    Some(Bytes::from_static(b"chunk 3, ")),
482                    None,
483                    None,
484                    Some(Bytes::from_static(b"chunk 4, ")),
485                    Some(Bytes::from_static(b"chunk 5, ")),
486                    Some(Bytes::from_static(b"chunk 6")),
487                ],
488                cursor: 0,
489                delay_in_millis: 500,
490            };
491            let opts = AwsChunkedBodyOptions::new(input.len() as u64, Vec::new());
492            let mut body = AwsChunkedBody::new(input, opts);
493
494            let mut output = SegmentedBuf::new();
495            while let Some(buf) = body.data().await {
496                output.push(buf.unwrap());
497            }
498
499            let mut actual_output = String::new();
500            output
501                .reader()
502                .read_to_string(&mut actual_output)
503                .expect("Doesn't cause IO errors");
504
505            let expected_output =
506                "34\r\nchunk 1, chunk 2, chunk 3, chunk 4, chunk 5, chunk 6\r\n0\r\n\r\n";
507
508            assert_eq!(expected_output, actual_output);
509            assert!(
510                body.trailers()
511                    .await
512                    .expect("no errors occurred during trailer polling")
513                    .is_none(),
514                "aws-chunked encoded bodies don't have normal HTTP trailers"
515            );
516        };
517
518        let timeout_duration = Duration::from_secs(3);
519        if tokio::time::timeout(timeout_duration, test_fut)
520            .await
521            .is_err()
522        {
523            panic!(
524                "test_aws_chunked_encoding_sputtering_body timed out after {timeout_duration:?}"
525            );
526        }
527    }
528
529    #[tokio::test]
530    #[should_panic = "called `Result::unwrap()` on an `Err` value: ReportedTrailerLengthMismatch { actual: 44, expected: 0 }"]
531    async fn test_aws_chunked_encoding_incorrect_trailer_length_panic() {
532        let input_str = "Hello world";
533        // Test body has no trailers, so this length is incorrect and will trigger an assert panic
534        // When the panic occurs, it will actually expect a length of 44. This is because, when using
535        // aws-chunked encoding, each trailer will end with a CRLF which is 2 bytes long.
536        let wrong_trailer_len = 42;
537        let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, vec![wrong_trailer_len]);
538        let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
539
540        // We don't care about the body contents but we have to read it all before checking for trailers
541        while let Some(buf) = body.data().await {
542            drop(buf.unwrap());
543        }
544
545        assert!(
546            body.trailers()
547                .await
548                .expect("no errors occurred during trailer polling")
549                .is_none(),
550            "aws-chunked encoded bodies don't have normal HTTP trailers"
551        );
552    }
553
554    #[tokio::test]
555    async fn test_aws_chunked_encoding_empty_body() {
556        let input_str = "";
557        let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
558        let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
559
560        let mut output = SegmentedBuf::new();
561        while let Some(buf) = body.data().await {
562            output.push(buf.unwrap());
563        }
564
565        let mut actual_output = String::new();
566        output
567            .reader()
568            .read_to_string(&mut actual_output)
569            .expect("Doesn't cause IO errors");
570
571        let expected_output = [CHUNK_TERMINATOR, CRLF].concat();
572
573        assert_eq!(expected_output, actual_output);
574        assert!(
575            body.trailers()
576                .await
577                .expect("no errors occurred during trailer polling")
578                .is_none(),
579            "aws-chunked encoded bodies don't have normal HTTP trailers"
580        );
581    }
582
583    #[tokio::test]
584    async fn test_total_rendered_length_of_trailers() {
585        let mut trailers = HeaderMap::new();
586
587        trailers.insert("empty_value", HeaderValue::from_static(""));
588
589        trailers.insert("single_value", HeaderValue::from_static("value 1"));
590
591        trailers.insert("two_values", HeaderValue::from_static("value 1"));
592        trailers.append("two_values", HeaderValue::from_static("value 2"));
593
594        trailers.insert("three_values", HeaderValue::from_static("value 1"));
595        trailers.append("three_values", HeaderValue::from_static("value 2"));
596        trailers.append("three_values", HeaderValue::from_static("value 3"));
597
598        let trailers = Some(trailers);
599        let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
600        let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
601
602        assert_eq!(expected_length, actual_length);
603    }
604
605    #[tokio::test]
606    async fn test_total_rendered_length_of_empty_trailers() {
607        let trailers = Some(HeaderMap::new());
608        let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
609        let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
610
611        assert_eq!(expected_length, actual_length);
612    }
613}