aws_sdk_s3/
http_request_checksum.rs

1// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
2/*
3 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7#![allow(dead_code)]
8
9//! Interceptor for handling Smithy `@httpChecksum` request checksumming with AWS SigV4
10
11use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
12use aws_runtime::{auth::SigV4OperationSigningConfig, content_encoding::header_value::AWS_CHUNKED};
13use aws_sigv4::http_request::SignableBody;
14use aws_smithy_checksums::ChecksumAlgorithm;
15use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
16use aws_smithy_runtime_api::box_error::BoxError;
17use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, Input};
18use aws_smithy_runtime_api::client::interceptors::Intercept;
19use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
20use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
21use aws_smithy_types::body::SdkBody;
22use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23use aws_smithy_types::error::operation::BuildError;
24use http::HeaderValue;
25use http_body::Body;
26use std::{fmt, mem};
27
28/// Errors related to constructing checksum-validated HTTP requests
29#[derive(Debug)]
30pub(crate) enum Error {
31    /// Only request bodies with a known size can be checksum validated
32    UnsizedRequestBody,
33    ChecksumHeadersAreUnsupportedForStreamingBody,
34}
35
36impl fmt::Display for Error {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
40            Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
41                f,
42                "Checksum header insertion is only supported for non-streaming HTTP bodies. \
43                   To checksum validate a streaming body, the checksums must be sent as trailers."
44            ),
45        }
46    }
47}
48
49impl std::error::Error for Error {}
50
51#[derive(Debug)]
52struct RequestChecksumInterceptorState {
53    checksum_algorithm: Option<ChecksumAlgorithm>,
54}
55impl Storable for RequestChecksumInterceptorState {
56    type Storer = StoreReplace<Self>;
57}
58
59type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
60
61pub(crate) struct DefaultRequestChecksumOverride {
62    custom_default: CustomDefaultFn,
63}
64impl fmt::Debug for DefaultRequestChecksumOverride {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("DefaultRequestChecksumOverride").finish()
67    }
68}
69impl Storable for DefaultRequestChecksumOverride {
70    type Storer = StoreReplace<Self>;
71}
72impl DefaultRequestChecksumOverride {
73    pub(crate) fn new<F>(custom_default: F) -> Self
74    where
75        F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
76    {
77        Self {
78            custom_default: Box::new(custom_default),
79        }
80    }
81    pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
82        (self.custom_default)(original, config_bag)
83    }
84}
85
86pub(crate) struct RequestChecksumInterceptor<AP> {
87    algorithm_provider: AP,
88}
89
90impl<AP> fmt::Debug for RequestChecksumInterceptor<AP> {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("RequestChecksumInterceptor").finish()
93    }
94}
95
96impl<AP> RequestChecksumInterceptor<AP> {
97    pub(crate) fn new(algorithm_provider: AP) -> Self {
98        Self { algorithm_provider }
99    }
100}
101
102impl<AP> Intercept for RequestChecksumInterceptor<AP>
103where
104    AP: Fn(&Input) -> Result<Option<ChecksumAlgorithm>, BoxError> + Send + Sync,
105{
106    fn name(&self) -> &'static str {
107        "RequestChecksumInterceptor"
108    }
109
110    fn read_before_serialization(
111        &self,
112        context: &BeforeSerializationInterceptorContextRef<'_>,
113        _runtime_components: &RuntimeComponents,
114        cfg: &mut ConfigBag,
115    ) -> Result<(), BoxError> {
116        let checksum_algorithm = (self.algorithm_provider)(context.input())?;
117
118        let mut layer = Layer::new("RequestChecksumInterceptor");
119        layer.store_put(RequestChecksumInterceptorState { checksum_algorithm });
120        cfg.push_layer(layer);
121
122        Ok(())
123    }
124
125    /// Calculate a checksum and modify the request to include the checksum as a header
126    /// (for in-memory request bodies) or a trailer (for streaming request bodies).
127    /// Streaming bodies must be sized or this will return an error.
128    fn modify_before_signing(
129        &self,
130        context: &mut BeforeTransmitInterceptorContextMut<'_>,
131        _runtime_components: &RuntimeComponents,
132        cfg: &mut ConfigBag,
133    ) -> Result<(), BoxError> {
134        let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
135
136        let checksum_algorithm = incorporate_custom_default(state.checksum_algorithm, cfg);
137        if let Some(checksum_algorithm) = checksum_algorithm {
138            let request = context.request_mut();
139            add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
140        }
141
142        Ok(())
143    }
144}
145
146fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
147    match cfg.load::<DefaultRequestChecksumOverride>() {
148        Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
149        None => checksum,
150    }
151}
152
153fn add_checksum_for_request_body(request: &mut HttpRequest, checksum_algorithm: ChecksumAlgorithm, cfg: &mut ConfigBag) -> Result<(), BoxError> {
154    match request.body().bytes() {
155        // Body is in-memory: read it and insert the checksum as a header.
156        Some(data) => {
157            tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
158            let mut checksum = checksum_algorithm.into_impl();
159            checksum.update(data);
160
161            request.headers_mut().insert(checksum.header_name(), checksum.header_value());
162        }
163        // Body is streaming: wrap the body so it will emit a checksum as a trailer.
164        None => {
165            tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
166            if let Some(mut signing_config) = cfg.load::<SigV4OperationSigningConfig>().cloned() {
167                signing_config.signing_options.payload_override = Some(SignableBody::StreamingUnsignedPayloadTrailer);
168                cfg.interceptor_state().store_put(signing_config);
169            }
170            wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
171        }
172    }
173    Ok(())
174}
175
176fn wrap_streaming_request_body_in_checksum_calculating_body(
177    request: &mut HttpRequest,
178    checksum_algorithm: ChecksumAlgorithm,
179) -> Result<(), BuildError> {
180    let original_body_size = request
181        .body()
182        .size_hint()
183        .exact()
184        .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
185
186    let mut body = {
187        let body = mem::replace(request.body_mut(), SdkBody::taken());
188
189        body.map(move |body| {
190            let checksum = checksum_algorithm.into_impl();
191            let trailer_len = HttpChecksum::size(checksum.as_ref());
192            let body = calculate::ChecksumBody::new(body, checksum);
193            let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
194
195            let body = AwsChunkedBody::new(body, aws_chunked_body_options);
196
197            SdkBody::from_body_0_4(body)
198        })
199    };
200
201    let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
202
203    let headers = request.headers_mut();
204
205    headers.insert(
206        http::header::HeaderName::from_static("x-amz-trailer"),
207        checksum_algorithm.into_impl().header_name(),
208    );
209
210    headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
211    headers.insert(
212        http::header::HeaderName::from_static("x-amz-decoded-content-length"),
213        HeaderValue::from(original_body_size),
214    );
215    headers.insert(
216        http::header::CONTENT_ENCODING,
217        HeaderValue::from_str(AWS_CHUNKED)
218            .map_err(BuildError::other)
219            .expect("\"aws-chunked\" will always be a valid HeaderValue"),
220    );
221
222    mem::swap(request.body_mut(), &mut body);
223
224    Ok(())
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
230    use aws_smithy_checksums::ChecksumAlgorithm;
231    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
232    use aws_smithy_types::base64;
233    use aws_smithy_types::body::SdkBody;
234    use aws_smithy_types::byte_stream::ByteStream;
235    use bytes::BytesMut;
236    use http_body::Body;
237    use tempfile::NamedTempFile;
238
239    #[tokio::test]
240    async fn test_checksum_body_is_retryable() {
241        let input_text = "Hello world";
242        let chunk_len_hex = format!("{:X}", input_text.len());
243        let mut request: HttpRequest = http::Request::builder()
244            .body(SdkBody::retryable(move || SdkBody::from(input_text)))
245            .unwrap()
246            .try_into()
247            .unwrap();
248
249        // ensure original SdkBody is retryable
250        assert!(request.body().try_clone().is_some());
251
252        let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
253        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
254
255        // ensure wrapped SdkBody is retryable
256        let mut body = request.body().try_clone().expect("body is retryable");
257
258        let mut body_data = BytesMut::new();
259        while let Some(data) = body.data().await {
260            body_data.extend_from_slice(&data.unwrap())
261        }
262        let body = std::str::from_utf8(&body_data).unwrap();
263        assert_eq!(
264            format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
265            body
266        );
267    }
268
269    #[tokio::test]
270    async fn test_checksum_body_from_file_is_retryable() {
271        use std::io::Write;
272        let mut file = NamedTempFile::new().unwrap();
273        let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
274
275        let mut crc32c_checksum = checksum_algorithm.into_impl();
276        for i in 0..10000 {
277            let line = format!("This is a large file created for testing purposes {}", i);
278            file.as_file_mut().write_all(line.as_bytes()).unwrap();
279            crc32c_checksum.update(line.as_bytes());
280        }
281        let crc32c_checksum = crc32c_checksum.finalize();
282
283        let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
284
285        // ensure original SdkBody is retryable
286        assert!(request.body().try_clone().is_some());
287
288        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
289
290        // ensure wrapped SdkBody is retryable
291        let mut body = request.body().try_clone().expect("body is retryable");
292
293        let mut body_data = BytesMut::new();
294        while let Some(data) = body.data().await {
295            body_data.extend_from_slice(&data.unwrap())
296        }
297        let body = std::str::from_utf8(&body_data).unwrap();
298        let expected_checksum = base64::encode(&crc32c_checksum);
299        let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
300        assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
301    }
302}