aws_sdk_s3/
http_request_checksum.rs
1#![allow(dead_code)]
8
9use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
12use aws_runtime::{auth::SigV4OperationSigningConfig, content_encoding::header_value::AWS_CHUNKED};
13use aws_sigv4::http_request::SignableBody;
14use aws_smithy_checksums::ChecksumAlgorithm;
15use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
16use aws_smithy_runtime_api::box_error::BoxError;
17use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, Input};
18use aws_smithy_runtime_api::client::interceptors::Intercept;
19use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
20use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
21use aws_smithy_types::body::SdkBody;
22use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23use aws_smithy_types::error::operation::BuildError;
24use http::HeaderValue;
25use http_body::Body;
26use std::{fmt, mem};
27
28#[derive(Debug)]
30pub(crate) enum Error {
31 UnsizedRequestBody,
33 ChecksumHeadersAreUnsupportedForStreamingBody,
34}
35
36impl fmt::Display for Error {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
40 Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
41 f,
42 "Checksum header insertion is only supported for non-streaming HTTP bodies. \
43 To checksum validate a streaming body, the checksums must be sent as trailers."
44 ),
45 }
46 }
47}
48
49impl std::error::Error for Error {}
50
51#[derive(Debug)]
52struct RequestChecksumInterceptorState {
53 checksum_algorithm: Option<ChecksumAlgorithm>,
54}
55impl Storable for RequestChecksumInterceptorState {
56 type Storer = StoreReplace<Self>;
57}
58
59type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
60
61pub(crate) struct DefaultRequestChecksumOverride {
62 custom_default: CustomDefaultFn,
63}
64impl fmt::Debug for DefaultRequestChecksumOverride {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("DefaultRequestChecksumOverride").finish()
67 }
68}
69impl Storable for DefaultRequestChecksumOverride {
70 type Storer = StoreReplace<Self>;
71}
72impl DefaultRequestChecksumOverride {
73 pub(crate) fn new<F>(custom_default: F) -> Self
74 where
75 F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
76 {
77 Self {
78 custom_default: Box::new(custom_default),
79 }
80 }
81 pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
82 (self.custom_default)(original, config_bag)
83 }
84}
85
86pub(crate) struct RequestChecksumInterceptor<AP> {
87 algorithm_provider: AP,
88}
89
90impl<AP> fmt::Debug for RequestChecksumInterceptor<AP> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("RequestChecksumInterceptor").finish()
93 }
94}
95
96impl<AP> RequestChecksumInterceptor<AP> {
97 pub(crate) fn new(algorithm_provider: AP) -> Self {
98 Self { algorithm_provider }
99 }
100}
101
102impl<AP> Intercept for RequestChecksumInterceptor<AP>
103where
104 AP: Fn(&Input) -> Result<Option<ChecksumAlgorithm>, BoxError> + Send + Sync,
105{
106 fn name(&self) -> &'static str {
107 "RequestChecksumInterceptor"
108 }
109
110 fn read_before_serialization(
111 &self,
112 context: &BeforeSerializationInterceptorContextRef<'_>,
113 _runtime_components: &RuntimeComponents,
114 cfg: &mut ConfigBag,
115 ) -> Result<(), BoxError> {
116 let checksum_algorithm = (self.algorithm_provider)(context.input())?;
117
118 let mut layer = Layer::new("RequestChecksumInterceptor");
119 layer.store_put(RequestChecksumInterceptorState { checksum_algorithm });
120 cfg.push_layer(layer);
121
122 Ok(())
123 }
124
125 fn modify_before_signing(
129 &self,
130 context: &mut BeforeTransmitInterceptorContextMut<'_>,
131 _runtime_components: &RuntimeComponents,
132 cfg: &mut ConfigBag,
133 ) -> Result<(), BoxError> {
134 let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
135
136 let checksum_algorithm = incorporate_custom_default(state.checksum_algorithm, cfg);
137 if let Some(checksum_algorithm) = checksum_algorithm {
138 let request = context.request_mut();
139 add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
140 }
141
142 Ok(())
143 }
144}
145
146fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
147 match cfg.load::<DefaultRequestChecksumOverride>() {
148 Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
149 None => checksum,
150 }
151}
152
153fn add_checksum_for_request_body(request: &mut HttpRequest, checksum_algorithm: ChecksumAlgorithm, cfg: &mut ConfigBag) -> Result<(), BoxError> {
154 match request.body().bytes() {
155 Some(data) => {
157 tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
158 let mut checksum = checksum_algorithm.into_impl();
159 checksum.update(data);
160
161 request.headers_mut().insert(checksum.header_name(), checksum.header_value());
162 }
163 None => {
165 tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
166 if let Some(mut signing_config) = cfg.load::<SigV4OperationSigningConfig>().cloned() {
167 signing_config.signing_options.payload_override = Some(SignableBody::StreamingUnsignedPayloadTrailer);
168 cfg.interceptor_state().store_put(signing_config);
169 }
170 wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
171 }
172 }
173 Ok(())
174}
175
176fn wrap_streaming_request_body_in_checksum_calculating_body(
177 request: &mut HttpRequest,
178 checksum_algorithm: ChecksumAlgorithm,
179) -> Result<(), BuildError> {
180 let original_body_size = request
181 .body()
182 .size_hint()
183 .exact()
184 .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
185
186 let mut body = {
187 let body = mem::replace(request.body_mut(), SdkBody::taken());
188
189 body.map(move |body| {
190 let checksum = checksum_algorithm.into_impl();
191 let trailer_len = HttpChecksum::size(checksum.as_ref());
192 let body = calculate::ChecksumBody::new(body, checksum);
193 let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
194
195 let body = AwsChunkedBody::new(body, aws_chunked_body_options);
196
197 SdkBody::from_body_0_4(body)
198 })
199 };
200
201 let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
202
203 let headers = request.headers_mut();
204
205 headers.insert(
206 http::header::HeaderName::from_static("x-amz-trailer"),
207 checksum_algorithm.into_impl().header_name(),
208 );
209
210 headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
211 headers.insert(
212 http::header::HeaderName::from_static("x-amz-decoded-content-length"),
213 HeaderValue::from(original_body_size),
214 );
215 headers.insert(
216 http::header::CONTENT_ENCODING,
217 HeaderValue::from_str(AWS_CHUNKED)
218 .map_err(BuildError::other)
219 .expect("\"aws-chunked\" will always be a valid HeaderValue"),
220 );
221
222 mem::swap(request.body_mut(), &mut body);
223
224 Ok(())
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
230 use aws_smithy_checksums::ChecksumAlgorithm;
231 use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
232 use aws_smithy_types::base64;
233 use aws_smithy_types::body::SdkBody;
234 use aws_smithy_types::byte_stream::ByteStream;
235 use bytes::BytesMut;
236 use http_body::Body;
237 use tempfile::NamedTempFile;
238
239 #[tokio::test]
240 async fn test_checksum_body_is_retryable() {
241 let input_text = "Hello world";
242 let chunk_len_hex = format!("{:X}", input_text.len());
243 let mut request: HttpRequest = http::Request::builder()
244 .body(SdkBody::retryable(move || SdkBody::from(input_text)))
245 .unwrap()
246 .try_into()
247 .unwrap();
248
249 assert!(request.body().try_clone().is_some());
251
252 let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
253 wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
254
255 let mut body = request.body().try_clone().expect("body is retryable");
257
258 let mut body_data = BytesMut::new();
259 while let Some(data) = body.data().await {
260 body_data.extend_from_slice(&data.unwrap())
261 }
262 let body = std::str::from_utf8(&body_data).unwrap();
263 assert_eq!(
264 format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
265 body
266 );
267 }
268
269 #[tokio::test]
270 async fn test_checksum_body_from_file_is_retryable() {
271 use std::io::Write;
272 let mut file = NamedTempFile::new().unwrap();
273 let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
274
275 let mut crc32c_checksum = checksum_algorithm.into_impl();
276 for i in 0..10000 {
277 let line = format!("This is a large file created for testing purposes {}", i);
278 file.as_file_mut().write_all(line.as_bytes()).unwrap();
279 crc32c_checksum.update(line.as_bytes());
280 }
281 let crc32c_checksum = crc32c_checksum.finalize();
282
283 let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
284
285 assert!(request.body().try_clone().is_some());
287
288 wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
289
290 let mut body = request.body().try_clone().expect("body is retryable");
292
293 let mut body_data = BytesMut::new();
294 while let Some(data) = body.data().await {
295 body_data.extend_from_slice(&data.unwrap())
296 }
297 let body = std::str::from_utf8(&body_data).unwrap();
298 let expected_checksum = base64::encode(&crc32c_checksum);
299 let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
300 assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
301 }
302}