aws_sdk_s3/
http_response_checksum.rs
1#![allow(dead_code)]
8
9use aws_smithy_checksums::ChecksumAlgorithm;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14 BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextRef, Input,
15};
16use aws_smithy_runtime_api::client::interceptors::Intercept;
17use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
18use aws_smithy_runtime_api::http::Headers;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
21use std::{fmt, mem};
22
23#[derive(Debug)]
24struct ResponseChecksumInterceptorState {
25 validation_enabled: bool,
26}
27impl Storable for ResponseChecksumInterceptorState {
28 type Storer = StoreReplace<Self>;
29}
30
31pub(crate) struct ResponseChecksumInterceptor<VE> {
32 response_algorithms: &'static [&'static str],
33 validation_enabled: VE,
34}
35
36impl<VE> fmt::Debug for ResponseChecksumInterceptor<VE> {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 f.debug_struct("ResponseChecksumInterceptor")
39 .field("response_algorithms", &self.response_algorithms)
40 .finish()
41 }
42}
43
44impl<VE> ResponseChecksumInterceptor<VE> {
45 pub(crate) fn new(response_algorithms: &'static [&'static str], validation_enabled: VE) -> Self {
46 Self {
47 response_algorithms,
48 validation_enabled,
49 }
50 }
51}
52
53impl<VE> Intercept for ResponseChecksumInterceptor<VE>
54where
55 VE: Fn(&Input) -> bool + Send + Sync,
56{
57 fn name(&self) -> &'static str {
58 "ResponseChecksumInterceptor"
59 }
60
61 fn read_before_serialization(
62 &self,
63 context: &BeforeSerializationInterceptorContextRef<'_>,
64 _runtime_components: &RuntimeComponents,
65 cfg: &mut ConfigBag,
66 ) -> Result<(), BoxError> {
67 let validation_enabled = (self.validation_enabled)(context.input());
68
69 let mut layer = Layer::new("ResponseChecksumInterceptor");
70 layer.store_put(ResponseChecksumInterceptorState { validation_enabled });
71 cfg.push_layer(layer);
72
73 Ok(())
74 }
75
76 fn modify_before_deserialization(
77 &self,
78 context: &mut BeforeDeserializationInterceptorContextMut<'_>,
79 _runtime_components: &RuntimeComponents,
80 cfg: &mut ConfigBag,
81 ) -> Result<(), BoxError> {
82 let state = cfg
83 .load::<ResponseChecksumInterceptorState>()
84 .expect("set in `read_before_serialization`");
85
86 if state.validation_enabled {
87 let response = context.response_mut();
88 let maybe_checksum_headers = check_headers_for_precalculated_checksum(response.headers(), self.response_algorithms);
89 if let Some((checksum_algorithm, precalculated_checksum)) = maybe_checksum_headers {
90 let mut body = SdkBody::taken();
91 mem::swap(&mut body, response.body_mut());
92
93 let mut body = wrap_body_with_checksum_validator(body, checksum_algorithm, precalculated_checksum);
94 mem::swap(&mut body, response.body_mut());
95 }
96 }
97
98 Ok(())
99 }
100}
101
102pub(crate) fn wrap_body_with_checksum_validator(
106 body: SdkBody,
107 checksum_algorithm: ChecksumAlgorithm,
108 precalculated_checksum: bytes::Bytes,
109) -> SdkBody {
110 use aws_smithy_checksums::body::validate;
111
112 body.map(move |body| {
113 SdkBody::from_body_0_4(validate::ChecksumBody::new(
114 body,
115 checksum_algorithm.into_impl(),
116 precalculated_checksum.clone(),
117 ))
118 })
119}
120
121pub(crate) fn check_headers_for_precalculated_checksum(headers: &Headers, response_algorithms: &[&str]) -> Option<(ChecksumAlgorithm, bytes::Bytes)> {
125 let checksum_algorithms_to_check = aws_smithy_checksums::http::CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER
126 .into_iter()
127 .flat_map(|algo| {
130 for res_algo in response_algorithms {
132 if algo.eq_ignore_ascii_case(res_algo) {
133 return Some(algo);
134 }
135 }
136
137 None
138 });
139
140 for checksum_algorithm in checksum_algorithms_to_check {
141 let checksum_algorithm: ChecksumAlgorithm = checksum_algorithm
142 .parse()
143 .expect("CHECKSUM_ALGORITHMS_IN_PRIORITY_ORDER only contains valid checksum algorithm names");
144 if let Some(base64_encoded_precalculated_checksum) = headers.get(checksum_algorithm.into_impl().header_name()) {
145 if is_part_level_checksum(base64_encoded_precalculated_checksum) {
147 tracing::warn!(
148 more_info = "See https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html#large-object-checksums for more information.",
149 "This checksum is a part-level checksum which can't be validated by the Rust SDK. Disable checksum validation for this request to fix this warning.",
150 );
151
152 return None;
153 }
154
155 let precalculated_checksum = match aws_smithy_types::base64::decode(base64_encoded_precalculated_checksum) {
156 Ok(decoded_checksum) => decoded_checksum.into(),
157 Err(_) => {
158 tracing::error!("Checksum received from server could not be base64 decoded. No checksum validation will be performed.");
159 return None;
160 }
161 };
162
163 return Some((checksum_algorithm, precalculated_checksum));
164 }
165 }
166
167 None
168}
169
170fn is_part_level_checksum(checksum: &str) -> bool {
171 let mut found_number = false;
172 let mut found_dash = false;
173
174 for ch in checksum.chars().rev() {
175 if ch.is_ascii_digit() {
177 found_number = true;
178 continue;
179 }
180
181 if ch == '-' {
183 if found_dash {
184 return false;
186 }
187
188 found_dash = true;
189 continue;
190 }
191
192 break;
193 }
194
195 found_number && found_dash
196}
197
198#[cfg(test)]
199mod tests {
200 use super::{is_part_level_checksum, wrap_body_with_checksum_validator};
201 use aws_smithy_types::body::SdkBody;
202 use aws_smithy_types::byte_stream::ByteStream;
203 use aws_smithy_types::error::display::DisplayErrorContext;
204 use bytes::Bytes;
205
206 #[tokio::test]
207 async fn test_build_checksum_validated_body_works() {
208 let checksum_algorithm = "crc32".parse().unwrap();
209 let input_text = "Hello world";
210 let precalculated_checksum = Bytes::from_static(&[0x8b, 0xd6, 0x9e, 0x52]);
211 let body = ByteStream::new(SdkBody::from(input_text));
212
213 let body = body.map(move |sdk_body| wrap_body_with_checksum_validator(sdk_body, checksum_algorithm, precalculated_checksum.clone()));
214
215 let mut validated_body = Vec::new();
216 if let Err(e) = tokio::io::copy(&mut body.into_async_read(), &mut validated_body).await {
217 tracing::error!("{}", DisplayErrorContext(&e));
218 panic!("checksum validation has failed");
219 };
220 let body = std::str::from_utf8(&validated_body).unwrap();
221
222 assert_eq!(input_text, body);
223 }
224
225 #[test]
226 fn test_is_multipart_object_checksum() {
227 assert!(!is_part_level_checksum("abcd"));
229 assert!(!is_part_level_checksum("abcd="));
230 assert!(!is_part_level_checksum("abcd=="));
231 assert!(!is_part_level_checksum("1234"));
232 assert!(!is_part_level_checksum("1234="));
233 assert!(!is_part_level_checksum("1234=="));
234 assert!(is_part_level_checksum("abcd-1"));
236 assert!(is_part_level_checksum("abcd=-12"));
237 assert!(is_part_level_checksum("abcd12-134"));
238 assert!(is_part_level_checksum("abcd==-10000"));
239 assert!(!is_part_level_checksum(""));
241 assert!(!is_part_level_checksum("Spaces? In my header values?"));
242 assert!(!is_part_level_checksum("abcd==-134!#{!#"));
243 assert!(!is_part_level_checksum("abcd==-"));
244 assert!(!is_part_level_checksum("abcd==--11"));
245 assert!(!is_part_level_checksum("abcd==-AA"));
246 }
247}