1use bytes::{Bytes, BytesMut};
7use http_02x::{HeaderMap, HeaderValue};
8use http_body_04x::{Body, SizeHint};
9use pin_project_lite::pin_project;
10
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14const CRLF: &str = "\r\n";
15const CHUNK_TERMINATOR: &str = "0\r\n";
16const TRAILER_SEPARATOR: &[u8] = b":";
17
18pub mod header_value {
20 pub const AWS_CHUNKED: &str = "aws-chunked";
22}
23
24#[derive(Debug, Default)]
26#[non_exhaustive]
27pub struct AwsChunkedBodyOptions {
28 stream_length: u64,
32 trailer_lengths: Vec<u64>,
35}
36
37impl AwsChunkedBodyOptions {
38 pub fn new(stream_length: u64, trailer_lengths: Vec<u64>) -> Self {
40 Self {
41 stream_length,
42 trailer_lengths,
43 }
44 }
45
46 fn total_trailer_length(&self) -> u64 {
47 self.trailer_lengths.iter().sum::<u64>()
48 + (self.trailer_lengths.len() * CRLF.len()) as u64
50 }
51
52 pub fn with_trailer_len(mut self, trailer_len: u64) -> Self {
54 self.trailer_lengths.push(trailer_len);
55 self
56 }
57}
58
59#[derive(Debug, PartialEq, Eq)]
60enum AwsChunkedBodyState {
61 WritingChunkSize,
64 WritingChunk,
68 WritingTrailers,
71 Closed,
73}
74
75pin_project! {
76 #[derive(Debug)]
97 pub struct AwsChunkedBody<InnerBody> {
98 #[pin]
99 inner: InnerBody,
100 #[pin]
101 state: AwsChunkedBodyState,
102 options: AwsChunkedBodyOptions,
103 inner_body_bytes_read_so_far: usize,
104 }
105}
106
107impl<Inner> AwsChunkedBody<Inner> {
108 pub fn new(body: Inner, options: AwsChunkedBodyOptions) -> Self {
110 Self {
111 inner: body,
112 state: AwsChunkedBodyState::WritingChunkSize,
113 options,
114 inner_body_bytes_read_so_far: 0,
115 }
116 }
117
118 fn encoded_length(&self) -> u64 {
119 let mut length = 0;
120 if self.options.stream_length != 0 {
121 length += get_unsigned_chunk_bytes_length(self.options.stream_length);
122 }
123
124 length += CHUNK_TERMINATOR.len() as u64;
126
127 for len in self.options.trailer_lengths.iter() {
129 length += len + CRLF.len() as u64;
130 }
131
132 length += CRLF.len() as u64;
134
135 length
136 }
137}
138
139fn get_unsigned_chunk_bytes_length(payload_length: u64) -> u64 {
140 let hex_repr_len = int_log16(payload_length);
141 hex_repr_len + CRLF.len() as u64 + payload_length + CRLF.len() as u64
142}
143
144fn trailers_as_aws_chunked_bytes(
151 trailer_map: Option<HeaderMap>,
152 estimated_length: u64,
153) -> BytesMut {
154 if let Some(trailer_map) = trailer_map {
155 let mut current_header_name = None;
156 let mut trailers = BytesMut::with_capacity(estimated_length.try_into().unwrap_or_default());
157
158 for (header_name, header_value) in trailer_map.into_iter() {
159 current_header_name = header_name.or(current_header_name);
163
164 if let Some(header_name) = current_header_name.as_ref() {
166 trailers.extend_from_slice(header_name.as_ref());
167 trailers.extend_from_slice(TRAILER_SEPARATOR);
168 trailers.extend_from_slice(header_value.as_bytes());
169 trailers.extend_from_slice(CRLF.as_bytes());
170 }
171 }
172
173 trailers
174 } else {
175 BytesMut::new()
176 }
177}
178
179fn total_rendered_length_of_trailers(trailer_map: Option<&HeaderMap>) -> u64 {
186 match trailer_map {
187 Some(trailer_map) => trailer_map
188 .iter()
189 .map(|(trailer_name, trailer_value)| {
190 trailer_name.as_str().len()
191 + TRAILER_SEPARATOR.len()
192 + trailer_value.len()
193 + CRLF.len()
194 })
195 .sum::<usize>() as u64,
196 None => 0,
197 }
198}
199
200impl<Inner> Body for AwsChunkedBody<Inner>
201where
202 Inner: Body<Data = Bytes, Error = aws_smithy_types::body::Error>,
203{
204 type Data = Bytes;
205 type Error = aws_smithy_types::body::Error;
206
207 fn poll_data(
208 self: Pin<&mut Self>,
209 cx: &mut Context<'_>,
210 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
211 tracing::trace!(state = ?self.state, "polling AwsChunkedBody");
212 let mut this = self.project();
213
214 match *this.state {
215 AwsChunkedBodyState::WritingChunkSize => {
216 if this.options.stream_length == 0 {
217 *this.state = AwsChunkedBodyState::WritingTrailers;
219 tracing::trace!("stream is empty, writing chunk terminator");
220 Poll::Ready(Some(Ok(Bytes::from([CHUNK_TERMINATOR].concat()))))
221 } else {
222 *this.state = AwsChunkedBodyState::WritingChunk;
223 let chunk_size = format!("{:X?}{CRLF}", this.options.stream_length);
225 tracing::trace!(%chunk_size, "writing chunk size");
226 let chunk_size = Bytes::from(chunk_size);
227 Poll::Ready(Some(Ok(chunk_size)))
228 }
229 }
230 AwsChunkedBodyState::WritingChunk => match this.inner.poll_data(cx) {
231 Poll::Ready(Some(Ok(data))) => {
232 tracing::trace!(len = data.len(), "writing chunk data");
233 *this.inner_body_bytes_read_so_far += data.len();
234 Poll::Ready(Some(Ok(data)))
235 }
236 Poll::Ready(None) => {
237 let actual_stream_length = *this.inner_body_bytes_read_so_far as u64;
238 let expected_stream_length = this.options.stream_length;
239 if actual_stream_length != expected_stream_length {
240 let err = Box::new(AwsChunkedBodyError::StreamLengthMismatch {
241 actual: actual_stream_length,
242 expected: expected_stream_length,
243 });
244 return Poll::Ready(Some(Err(err)));
245 };
246
247 tracing::trace!("no more chunk data, writing CRLF and chunk terminator");
248 *this.state = AwsChunkedBodyState::WritingTrailers;
249 Poll::Ready(Some(Ok(Bytes::from([CRLF, CHUNK_TERMINATOR].concat()))))
252 }
253 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
254 Poll::Pending => Poll::Pending,
255 },
256 AwsChunkedBodyState::WritingTrailers => {
257 return match this.inner.poll_trailers(cx) {
258 Poll::Ready(Ok(trailers)) => {
259 *this.state = AwsChunkedBodyState::Closed;
260 let expected_length = total_rendered_length_of_trailers(trailers.as_ref());
261 let actual_length = this.options.total_trailer_length();
262
263 if expected_length != actual_length {
264 let err =
265 Box::new(AwsChunkedBodyError::ReportedTrailerLengthMismatch {
266 actual: actual_length,
267 expected: expected_length,
268 });
269 return Poll::Ready(Some(Err(err)));
270 }
271
272 let mut trailers =
273 trailers_as_aws_chunked_bytes(trailers, actual_length + 1);
274 trailers.extend_from_slice(CRLF.as_bytes());
276
277 Poll::Ready(Some(Ok(trailers.into())))
278 }
279 Poll::Pending => Poll::Pending,
280 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
281 };
282 }
283 AwsChunkedBodyState::Closed => Poll::Ready(None),
284 }
285 }
286
287 fn poll_trailers(
288 self: Pin<&mut Self>,
289 _cx: &mut Context<'_>,
290 ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
291 Poll::Ready(Ok(None))
293 }
294
295 fn is_end_stream(&self) -> bool {
296 self.state == AwsChunkedBodyState::Closed
297 }
298
299 fn size_hint(&self) -> SizeHint {
300 SizeHint::with_exact(self.encoded_length())
301 }
302}
303
304#[derive(Debug)]
306enum AwsChunkedBodyError {
307 ReportedTrailerLengthMismatch { actual: u64, expected: u64 },
312 StreamLengthMismatch { actual: u64, expected: u64 },
316}
317
318impl std::fmt::Display for AwsChunkedBodyError {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 match self {
321 Self::ReportedTrailerLengthMismatch { actual, expected } => {
322 write!(f, "When creating this AwsChunkedBody, length of trailers was reported as {expected}. However, when double checking during trailer encoding, length was found to be {actual} instead.")
323 }
324 Self::StreamLengthMismatch { actual, expected } => {
325 write!(f, "When creating this AwsChunkedBody, stream length was reported as {expected}. However, when double checking during body encoding, length was found to be {actual} instead.")
326 }
327 }
328 }
329}
330
331impl std::error::Error for AwsChunkedBodyError {}
332
333fn int_log16<T>(mut i: T) -> u64
335where
336 T: std::ops::DivAssign + PartialOrd + From<u8> + Copy,
337{
338 let mut len = 0;
339 let zero = T::from(0);
340 let sixteen = T::from(16);
341
342 while i > zero {
343 i /= sixteen;
344 len += 1;
345 }
346
347 len
348}
349
350#[cfg(test)]
351mod tests {
352 use super::{
353 total_rendered_length_of_trailers, trailers_as_aws_chunked_bytes, AwsChunkedBody,
354 AwsChunkedBodyOptions, CHUNK_TERMINATOR, CRLF,
355 };
356
357 use aws_smithy_types::body::SdkBody;
358 use bytes::{Buf, Bytes};
359 use bytes_utils::SegmentedBuf;
360 use http_02x::{HeaderMap, HeaderValue};
361 use http_body_04x::{Body, SizeHint};
362 use pin_project_lite::pin_project;
363
364 use std::io::Read;
365 use std::pin::Pin;
366 use std::task::{Context, Poll};
367 use std::time::Duration;
368
369 pin_project! {
370 struct SputteringBody {
371 parts: Vec<Option<Bytes>>,
372 cursor: usize,
373 delay_in_millis: u64,
374 }
375 }
376
377 impl SputteringBody {
378 fn len(&self) -> usize {
379 self.parts.iter().flatten().map(|b| b.len()).sum()
380 }
381 }
382
383 impl Body for SputteringBody {
384 type Data = Bytes;
385 type Error = aws_smithy_types::body::Error;
386
387 fn poll_data(
388 self: Pin<&mut Self>,
389 cx: &mut Context<'_>,
390 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
391 if self.cursor == self.parts.len() {
392 return Poll::Ready(None);
393 }
394
395 let this = self.project();
396 let delay_in_millis = *this.delay_in_millis;
397 let next_part = this.parts.get_mut(*this.cursor).unwrap().take();
398
399 match next_part {
400 None => {
401 *this.cursor += 1;
402 let waker = cx.waker().clone();
403 tokio::spawn(async move {
404 tokio::time::sleep(Duration::from_millis(delay_in_millis)).await;
405 waker.wake();
406 });
407 Poll::Pending
408 }
409 Some(data) => {
410 *this.cursor += 1;
411 Poll::Ready(Some(Ok(data)))
412 }
413 }
414 }
415
416 fn poll_trailers(
417 self: Pin<&mut Self>,
418 _cx: &mut Context<'_>,
419 ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
420 Poll::Ready(Ok(None))
421 }
422
423 fn is_end_stream(&self) -> bool {
424 false
425 }
426
427 fn size_hint(&self) -> SizeHint {
428 SizeHint::new()
429 }
430 }
431
432 #[tokio::test]
433 async fn test_aws_chunked_encoding() {
434 let test_fut = async {
435 let input_str = "Hello world";
436 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
437 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
438
439 let mut output = SegmentedBuf::new();
440 while let Some(buf) = body.data().await {
441 output.push(buf.unwrap());
442 }
443
444 let mut actual_output = String::new();
445 output
446 .reader()
447 .read_to_string(&mut actual_output)
448 .expect("Doesn't cause IO errors");
449
450 let expected_output = "B\r\nHello world\r\n0\r\n\r\n";
451
452 assert_eq!(expected_output, actual_output);
453 assert!(
454 body.trailers()
455 .await
456 .expect("no errors occurred during trailer polling")
457 .is_none(),
458 "aws-chunked encoded bodies don't have normal HTTP trailers"
459 );
460
461 };
463
464 let timeout_duration = Duration::from_secs(3);
465 if tokio::time::timeout(timeout_duration, test_fut)
466 .await
467 .is_err()
468 {
469 panic!("test_aws_chunked_encoding timed out after {timeout_duration:?}");
470 }
471 }
472
473 #[tokio::test]
474 async fn test_aws_chunked_encoding_sputtering_body() {
475 let test_fut = async {
476 let input = SputteringBody {
477 parts: vec![
478 Some(Bytes::from_static(b"chunk 1, ")),
479 None,
480 Some(Bytes::from_static(b"chunk 2, ")),
481 Some(Bytes::from_static(b"chunk 3, ")),
482 None,
483 None,
484 Some(Bytes::from_static(b"chunk 4, ")),
485 Some(Bytes::from_static(b"chunk 5, ")),
486 Some(Bytes::from_static(b"chunk 6")),
487 ],
488 cursor: 0,
489 delay_in_millis: 500,
490 };
491 let opts = AwsChunkedBodyOptions::new(input.len() as u64, Vec::new());
492 let mut body = AwsChunkedBody::new(input, opts);
493
494 let mut output = SegmentedBuf::new();
495 while let Some(buf) = body.data().await {
496 output.push(buf.unwrap());
497 }
498
499 let mut actual_output = String::new();
500 output
501 .reader()
502 .read_to_string(&mut actual_output)
503 .expect("Doesn't cause IO errors");
504
505 let expected_output =
506 "34\r\nchunk 1, chunk 2, chunk 3, chunk 4, chunk 5, chunk 6\r\n0\r\n\r\n";
507
508 assert_eq!(expected_output, actual_output);
509 assert!(
510 body.trailers()
511 .await
512 .expect("no errors occurred during trailer polling")
513 .is_none(),
514 "aws-chunked encoded bodies don't have normal HTTP trailers"
515 );
516 };
517
518 let timeout_duration = Duration::from_secs(3);
519 if tokio::time::timeout(timeout_duration, test_fut)
520 .await
521 .is_err()
522 {
523 panic!(
524 "test_aws_chunked_encoding_sputtering_body timed out after {timeout_duration:?}"
525 );
526 }
527 }
528
529 #[tokio::test]
530 #[should_panic = "called `Result::unwrap()` on an `Err` value: ReportedTrailerLengthMismatch { actual: 44, expected: 0 }"]
531 async fn test_aws_chunked_encoding_incorrect_trailer_length_panic() {
532 let input_str = "Hello world";
533 let wrong_trailer_len = 42;
537 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, vec![wrong_trailer_len]);
538 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
539
540 while let Some(buf) = body.data().await {
542 drop(buf.unwrap());
543 }
544
545 assert!(
546 body.trailers()
547 .await
548 .expect("no errors occurred during trailer polling")
549 .is_none(),
550 "aws-chunked encoded bodies don't have normal HTTP trailers"
551 );
552 }
553
554 #[tokio::test]
555 async fn test_aws_chunked_encoding_empty_body() {
556 let input_str = "";
557 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
558 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
559
560 let mut output = SegmentedBuf::new();
561 while let Some(buf) = body.data().await {
562 output.push(buf.unwrap());
563 }
564
565 let mut actual_output = String::new();
566 output
567 .reader()
568 .read_to_string(&mut actual_output)
569 .expect("Doesn't cause IO errors");
570
571 let expected_output = [CHUNK_TERMINATOR, CRLF].concat();
572
573 assert_eq!(expected_output, actual_output);
574 assert!(
575 body.trailers()
576 .await
577 .expect("no errors occurred during trailer polling")
578 .is_none(),
579 "aws-chunked encoded bodies don't have normal HTTP trailers"
580 );
581 }
582
583 #[tokio::test]
584 async fn test_total_rendered_length_of_trailers() {
585 let mut trailers = HeaderMap::new();
586
587 trailers.insert("empty_value", HeaderValue::from_static(""));
588
589 trailers.insert("single_value", HeaderValue::from_static("value 1"));
590
591 trailers.insert("two_values", HeaderValue::from_static("value 1"));
592 trailers.append("two_values", HeaderValue::from_static("value 2"));
593
594 trailers.insert("three_values", HeaderValue::from_static("value 1"));
595 trailers.append("three_values", HeaderValue::from_static("value 2"));
596 trailers.append("three_values", HeaderValue::from_static("value 3"));
597
598 let trailers = Some(trailers);
599 let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
600 let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
601
602 assert_eq!(expected_length, actual_length);
603 }
604
605 #[tokio::test]
606 async fn test_total_rendered_length_of_empty_trailers() {
607 let trailers = Some(HeaderMap::new());
608 let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
609 let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
610
611 assert_eq!(expected_length, actual_length);
612 }
613}