aws_runtime/
request_info.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::service_clock_skew::ServiceClockSkew;
7use aws_smithy_async::time::TimeSource;
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
10use aws_smithy_runtime_api::client::interceptors::Intercept;
11use aws_smithy_runtime_api::client::retries::RequestAttempts;
12use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
13use aws_smithy_types::config_bag::ConfigBag;
14use aws_smithy_types::date_time::Format;
15use aws_smithy_types::retry::RetryConfig;
16use aws_smithy_types::timeout::TimeoutConfig;
17use aws_smithy_types::DateTime;
18use http_02x::{HeaderName, HeaderValue};
19use std::borrow::Cow;
20use std::time::Duration;
21
22#[allow(clippy::declare_interior_mutable_const)] // we will never mutate this
23const AMZ_SDK_REQUEST: HeaderName = HeaderName::from_static("amz-sdk-request");
24
25/// Generates and attaches a request header that communicates request-related metadata.
26/// Examples include:
27///
28/// - When the client will time out this request.
29/// - How many times the request has been retried.
30/// - The maximum number of retries that the client will attempt.
31#[non_exhaustive]
32#[derive(Debug, Default)]
33pub struct RequestInfoInterceptor {}
34
35impl RequestInfoInterceptor {
36    /// Creates a new `RequestInfoInterceptor`
37    pub fn new() -> Self {
38        RequestInfoInterceptor {}
39    }
40}
41
42impl RequestInfoInterceptor {
43    fn build_attempts_pair(
44        &self,
45        cfg: &ConfigBag,
46    ) -> Option<(Cow<'static, str>, Cow<'static, str>)> {
47        let request_attempts = cfg
48            .load::<RequestAttempts>()
49            .map(|r_a| r_a.attempts())
50            .unwrap_or(0);
51        let request_attempts = request_attempts.to_string();
52        Some((Cow::Borrowed("attempt"), Cow::Owned(request_attempts)))
53    }
54
55    fn build_max_attempts_pair(
56        &self,
57        cfg: &ConfigBag,
58    ) -> Option<(Cow<'static, str>, Cow<'static, str>)> {
59        if let Some(retry_config) = cfg.load::<RetryConfig>() {
60            let max_attempts = retry_config.max_attempts().to_string();
61            Some((Cow::Borrowed("max"), Cow::Owned(max_attempts)))
62        } else {
63            None
64        }
65    }
66
67    fn build_ttl_pair(
68        &self,
69        cfg: &ConfigBag,
70        timesource: impl TimeSource,
71    ) -> Option<(Cow<'static, str>, Cow<'static, str>)> {
72        let timeout_config = cfg.load::<TimeoutConfig>()?;
73        let socket_read = timeout_config.read_timeout()?;
74        let estimated_skew: Duration = cfg.load::<ServiceClockSkew>().cloned()?.into();
75        let current_time = timesource.now();
76        let ttl = current_time.checked_add(socket_read + estimated_skew)?;
77        let mut timestamp = DateTime::from(ttl);
78        // Set subsec_nanos to 0 so that the formatted `DateTime` won't have fractional seconds.
79        timestamp.set_subsec_nanos(0);
80        let mut formatted_timestamp = timestamp
81            .fmt(Format::DateTime)
82            .expect("the resulting DateTime will always be valid");
83
84        // Remove dashes and colons
85        formatted_timestamp = formatted_timestamp
86            .chars()
87            .filter(|&c| c != '-' && c != ':')
88            .collect();
89
90        Some((Cow::Borrowed("ttl"), Cow::Owned(formatted_timestamp)))
91    }
92}
93
94impl Intercept for RequestInfoInterceptor {
95    fn name(&self) -> &'static str {
96        "RequestInfoInterceptor"
97    }
98
99    fn modify_before_transmit(
100        &self,
101        context: &mut BeforeTransmitInterceptorContextMut<'_>,
102        runtime_components: &RuntimeComponents,
103        cfg: &mut ConfigBag,
104    ) -> Result<(), BoxError> {
105        let mut pairs = RequestPairs::new();
106        if let Some(pair) = self.build_ttl_pair(
107            cfg,
108            runtime_components
109                .time_source()
110                .ok_or("A timesource must be provided")?,
111        ) {
112            pairs = pairs.with_pair(pair);
113        }
114        if let Some(pair) = self.build_attempts_pair(cfg) {
115            pairs = pairs.with_pair(pair);
116        }
117        if let Some(pair) = self.build_max_attempts_pair(cfg) {
118            pairs = pairs.with_pair(pair);
119        }
120
121        let headers = context.request_mut().headers_mut();
122        headers.insert(AMZ_SDK_REQUEST, pairs.try_into_header_value()?);
123
124        Ok(())
125    }
126}
127
128/// A builder for creating a `RequestPairs` header value. `RequestPairs` is used to generate a
129/// retry information header that is sent with every request. The information conveyed by this
130/// header allows services to anticipate whether a client will time out or retry a request.
131#[derive(Default, Debug)]
132struct RequestPairs {
133    inner: Vec<(Cow<'static, str>, Cow<'static, str>)>,
134}
135
136impl RequestPairs {
137    /// Creates a new `RequestPairs` builder.
138    fn new() -> Self {
139        Default::default()
140    }
141
142    /// Adds a pair to the `RequestPairs` builder.
143    /// Only strings that can be converted to header values are considered valid.
144    fn with_pair(
145        mut self,
146        pair: (impl Into<Cow<'static, str>>, impl Into<Cow<'static, str>>),
147    ) -> Self {
148        let pair = (pair.0.into(), pair.1.into());
149        self.inner.push(pair);
150        self
151    }
152
153    /// Converts the `RequestPairs` builder into a `HeaderValue`.
154    fn try_into_header_value(self) -> Result<HeaderValue, BoxError> {
155        self.try_into()
156    }
157}
158
159impl TryFrom<RequestPairs> for HeaderValue {
160    type Error = BoxError;
161
162    fn try_from(value: RequestPairs) -> Result<Self, BoxError> {
163        let mut pairs = String::new();
164        for (key, value) in value.inner {
165            if !pairs.is_empty() {
166                pairs.push_str("; ");
167            }
168
169            pairs.push_str(&key);
170            pairs.push('=');
171            pairs.push_str(&value);
172            continue;
173        }
174        HeaderValue::from_str(&pairs).map_err(Into::into)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::RequestInfoInterceptor;
181    use crate::request_info::RequestPairs;
182    use aws_smithy_runtime_api::client::interceptors::context::Input;
183    use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
184    use aws_smithy_runtime_api::client::interceptors::Intercept;
185    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
186    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
187    use aws_smithy_types::config_bag::{ConfigBag, Layer};
188    use aws_smithy_types::retry::RetryConfig;
189    use aws_smithy_types::timeout::TimeoutConfig;
190
191    use http_02x::HeaderValue;
192    use std::time::Duration;
193
194    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str {
195        context
196            .request()
197            .expect("request is set")
198            .headers()
199            .get(header_name)
200            .unwrap()
201    }
202
203    #[test]
204    fn test_request_pairs_for_initial_attempt() {
205        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
206        let mut context = InterceptorContext::new(Input::doesnt_matter());
207        context.enter_serialization_phase();
208        context.set_request(HttpRequest::empty());
209
210        let mut layer = Layer::new("test");
211        layer.store_put(RetryConfig::standard());
212        layer.store_put(
213            TimeoutConfig::builder()
214                .read_timeout(Duration::from_secs(30))
215                .build(),
216        );
217        let mut config = ConfigBag::of_layers(vec![layer]);
218
219        let _ = context.take_input();
220        context.enter_before_transmit_phase();
221        let interceptor = RequestInfoInterceptor::new();
222        let mut ctx = (&mut context).into();
223        interceptor
224            .modify_before_transmit(&mut ctx, &rc, &mut config)
225            .unwrap();
226
227        assert_eq!(
228            expect_header(&context, "amz-sdk-request"),
229            "attempt=0; max=3"
230        );
231    }
232
233    #[test]
234    fn test_header_value_from_request_pairs_supports_all_valid_characters() {
235        // The list of valid characters is defined by an internal-only spec.
236        let rp = RequestPairs::new()
237            .with_pair(("allowed-symbols", "!#$&'*+-.^_`|~"))
238            .with_pair(("allowed-digits", "01234567890"))
239            .with_pair((
240                "allowed-characters",
241                "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
242            ))
243            .with_pair(("allowed-whitespace", " \t"));
244        let _header_value: HeaderValue = rp
245            .try_into()
246            .expect("request pairs can be converted into valid header value.");
247    }
248}