aws_sdk_s3/
http_response_checksum.rs

1// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
2/*
3 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7#![allow(dead_code)]
8
9//! Interceptor for handling Smithy `@httpChecksum` response checksumming
10
11use aws_smithy_checksums::ChecksumAlgorithm;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14    BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextRef, Input,
15};
16use aws_smithy_runtime_api::client::interceptors::Intercept;
17use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
18use aws_smithy_runtime_api::http::Headers;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
21use std::{fmt, mem};
22
23#[derive(Debug)]
24struct ResponseChecksumInterceptorState {
25    validation_enabled: bool,
26}
27impl Storable for ResponseChecksumInterceptorState {
28    type Storer = StoreReplace<Self>;
29}
30
31pub(crate) struct ResponseChecksumInterceptor<VE> {
32    response_algorithms: &'static [&'static str],
33    validation_enabled: VE,
34}
35
36impl<VE> fmt::Debug for ResponseChecksumInterceptor<VE> {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        f.debug_struct("ResponseChecksumInterceptor")
39            .field("response_algorithms", &self.response_algorithms)
40            .finish()
41    }
42}
43
44impl<VE> ResponseChecksumInterceptor<VE> {
45    pub(crate) fn new(response_algorithms: &'static [&'static str], validation_enabled: VE) -> Self {
46        Self {
47            response_algorithms,
48            validation_enabled,
49        }
50    }
51}
52
53impl<VE> Intercept for ResponseChecksumInterceptor<VE>
54where
55    VE: Fn(&Input) -> bool + Send + Sync,
56{
57    fn name(&self) -> &'static str {
58        "ResponseChecksumInterceptor"
59    }
60
61    fn read_before_serialization(
62        &self,
63        context: &BeforeSerializationInterceptorContextRef<'_>,
64        _runtime_components: &RuntimeComponents,
65        cfg: &mut ConfigBag,
66    ) -> Result<(), BoxError> {
67        let validation_enabled = (self.validation_enabled)(context.input());
68
69        let mut layer = Layer::new("ResponseChecksumInterceptor");
70        layer.store_put(ResponseChecksumInterceptorState { validation_enabled });
71        cfg.push_layer(layer);
72
73        Ok(())
74    }
75
76    fn modify_before_deserialization(
77        &self,
78        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
79        _runtime_components: &RuntimeComponents,
80        cfg: &mut ConfigBag,
81    ) -> Result<(), BoxError> {
82        let state = cfg
83            .load::<ResponseChecksumInterceptorState>()
84            .expect("set in `read_before_serialization`");
85
86        if state.validation_enabled {
87            let response = context.response_mut();
88            let maybe_checksum_headers = check_headers_for_precalculated_checksum(response.headers(), self.response_algorithms);
89            if let Some((checksum_algorithm, precalculated_checksum)) = maybe_checksum_headers {
90                let mut body = SdkBody::taken();
91                mem::swap(&mut body, response.body_mut());
92
93                let mut body = wrap_body_with_checksum_validator(body, checksum_algorithm, precalculated_checksum);
94                mem::swap(&mut body, response.body_mut());
95            }
96        }
97
98        Ok(())
99    }
100}
101
102/// Given an `SdkBody`, a `aws_smithy_checksums::ChecksumAlgorithm`, and a pre-calculated checksum,
103/// return an `SdkBody` where the body will processed with the checksum algorithm and checked
104/// against the pre-calculated checksum.
105pub(crate) fn wrap_body_with_checksum_validator(
106    body: SdkBody,
107    checksum_algorithm: ChecksumAlgorithm,
108    precalculated_checksum: bytes::Bytes,
109) -> SdkBody {
110    use aws_smithy_checksums::body::validate;
111
112    body.map(move |body| {
113        SdkBody::from_body_0_4(validate::ChecksumBody::new(
114            body,
115            checksum_algorithm.into_impl(),
116            precalculated_checksum.clone(),
117        ))
118    })
119}
120
121/// Given a `HeaderMap`, extract any checksum included in the headers as `Some(Bytes)`.
122/// If no checksum header is set, return `None`. If multiple checksum headers are set, the one that
123/// is fastest to compute will be chosen.
124pub(crate) fn check_headers_for_precalculated_checksum(headers: &Headers, response_algorithms: &[&str]) -> Option<(ChecksumAlgorithm, bytes::Bytes)> {
125    let checksum_algorithms_to_check = aws_smithy_checksums::http::CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER
126        .into_iter()
127        // Process list of algorithms, from fastest to slowest, that may have been used to checksum
128        // the response body, ignoring any that aren't marked as supported algorithms by the model.
129        .flat_map(|algo| {
130            // For loop is necessary b/c the compiler doesn't infer the correct lifetimes for iter().find()
131            for res_algo in response_algorithms {
132                if algo.eq_ignore_ascii_case(res_algo) {
133                    return Some(algo);
134                }
135            }
136
137            None
138        });
139
140    for checksum_algorithm in checksum_algorithms_to_check {
141        let checksum_algorithm: ChecksumAlgorithm = checksum_algorithm
142            .parse()
143            .expect("CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER only contains valid checksum algorithm names");
144        if let Some(base64_encoded_precalculated_checksum) = headers.get(checksum_algorithm.into_impl().header_name()) {
145            // S3 needs special handling for checksums of objects uploaded with `MultiPartUpload`.
146            if is_part_level_checksum(base64_encoded_precalculated_checksum) {
147                tracing::warn!(
148                      more_info = "See https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html#large-object-checksums for more information.",
149                      "This checksum is a part-level checksum which can't be validated by the Rust SDK. Disable checksum validation for this request to fix this warning.",
150                  );
151
152                return None;
153            }
154
155            let precalculated_checksum = match aws_smithy_types::base64::decode(base64_encoded_precalculated_checksum) {
156                Ok(decoded_checksum) => decoded_checksum.into(),
157                Err(_) => {
158                    tracing::error!("Checksum received from server could not be base64 decoded. No checksum validation will be performed.");
159                    return None;
160                }
161            };
162
163            return Some((checksum_algorithm, precalculated_checksum));
164        }
165    }
166
167    None
168}
169
170fn is_part_level_checksum(checksum: &str) -> bool {
171    let mut found_number = false;
172    let mut found_dash = false;
173
174    for ch in checksum.chars().rev() {
175        // this could be bad
176        if ch.is_ascii_digit() {
177            found_number = true;
178            continue;
179        }
180
181        // Yup, it's a part-level checksum
182        if ch == '-' {
183            if found_dash {
184                // Found a second dash?? This isn't a part-level checksum.
185                return false;
186            }
187
188            found_dash = true;
189            continue;
190        }
191
192        break;
193    }
194
195    found_number && found_dash
196}
197
198#[cfg(test)]
199mod tests {
200    use super::{is_part_level_checksum, wrap_body_with_checksum_validator};
201    use aws_smithy_types::body::SdkBody;
202    use aws_smithy_types::byte_stream::ByteStream;
203    use aws_smithy_types::error::display::DisplayErrorContext;
204    use bytes::Bytes;
205
206    #[tokio::test]
207    async fn test_build_checksum_validated_body_works() {
208        let checksum_algorithm = "crc32".parse().unwrap();
209        let input_text = "Hello world";
210        let precalculated_checksum = Bytes::from_static(&[0x8b, 0xd6, 0x9e, 0x52]);
211        let body = ByteStream::new(SdkBody::from(input_text));
212
213        let body = body.map(move |sdk_body| wrap_body_with_checksum_validator(sdk_body, checksum_algorithm, precalculated_checksum.clone()));
214
215        let mut validated_body = Vec::new();
216        if let Err(e) = tokio::io::copy(&mut body.into_async_read(), &mut validated_body).await {
217            tracing::error!("{}", DisplayErrorContext(&e));
218            panic!("checksum validation has failed");
219        };
220        let body = std::str::from_utf8(&validated_body).unwrap();
221
222        assert_eq!(input_text, body);
223    }
224
225    #[test]
226    fn test_is_multipart_object_checksum() {
227        // These ARE NOT part-level checksums
228        assert!(!is_part_level_checksum("abcd"));
229        assert!(!is_part_level_checksum("abcd="));
230        assert!(!is_part_level_checksum("abcd=="));
231        assert!(!is_part_level_checksum("1234"));
232        assert!(!is_part_level_checksum("1234="));
233        assert!(!is_part_level_checksum("1234=="));
234        // These ARE part-level checksums
235        assert!(is_part_level_checksum("abcd-1"));
236        assert!(is_part_level_checksum("abcd=-12"));
237        assert!(is_part_level_checksum("abcd12-134"));
238        assert!(is_part_level_checksum("abcd==-10000"));
239        // These are gibberish and shouldn't be regarded as a part-level checksum
240        assert!(!is_part_level_checksum(""));
241        assert!(!is_part_level_checksum("Spaces? In my header values?"));
242        assert!(!is_part_level_checksum("abcd==-134!#{!#"));
243        assert!(!is_part_level_checksum("abcd==-"));
244        assert!(!is_part_level_checksum("abcd==--11"));
245        assert!(!is_part_level_checksum("abcd==-AA"));
246    }
247}