Skip to main content

mz_aws_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::future::Future;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use aws_config::{BehaviorVersion, ConfigLoader};
16use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
17use aws_smithy_runtime_api::client::http::{HttpClient, SharedHttpClient};
18use hyper_0_14::client::HttpConnector;
19use hyper_0_14::client::connect::dns::Name;
20use hyper_tls::HttpsConnector;
21use tower_service::Service;
22
23#[cfg(feature = "s3")]
24pub mod s3;
25#[cfg(feature = "s3")]
26pub mod s3_uploader;
27
28/// Creates an AWS SDK configuration loader with the defaults for the latest
29/// behavior version plus some Materialize-specific overrides.
30pub fn defaults() -> ConfigLoader {
31    // Use the SDK's latest behavior version. We already pin the crate versions,
32    // and CI puts version upgrades through rigorous testing, so we're happy to
33    // take the latest behavior version. We can adjust this in the future as
34    // necessary, if the AWS SDK ships a new behavior version that causes
35    // trouble.
36    let behavior_version = BehaviorVersion::latest();
37
38    // This is the only method allowed to call `aws_config::defaults`.
39    #[allow(clippy::disallowed_methods)]
40    let loader = aws_config::defaults(behavior_version);
41
42    // Install our custom HTTP client.
43    let loader = loader.http_client(http_client());
44
45    loader
46}
47
48/// Returns an HTTP client for use with the AWS SDK that is appropriately
49/// configured for Materialize.
50pub fn http_client() -> impl HttpClient {
51    // The default AWS HTTP client uses rustls, while our company policy is to
52    // use native TLS.
53    HyperClientBuilder::new().build(HttpsConnector::new())
54}
55
56/// Returns an AWS SDK HTTP client whose DNS resolver delegates to
57/// [`mz_ore::netio::resolve_address`].
58///
59/// Only the IP resolution step is overridden — the SDK still uses the original
60/// hostname for SNI and TLS certificate validation, so HTTPS endpoints work
61/// unchanged.
62pub fn http_client_with_resolver(enforce_external_addresses: bool) -> SharedHttpClient {
63    let resolver = MzAwsResolver {
64        enforce_external_addresses,
65    };
66    let mut http = HttpConnector::new_with_resolver(resolver);
67    // The SDK speaks HTTPS to the public AWS API; the wrapper we build below
68    // handles TLS, but the underlying HTTP connector must allow non-`http://`
69    // schemes.
70    http.enforce_http(false);
71    let https = HttpsConnector::new_with_connector(http);
72    HyperClientBuilder::new().build(https)
73}
74
75/// A `tower_service::Service<Name>` resolver that delegates to
76/// [`mz_ore::netio::resolve_address`], used by [`http_client_with_resolver`].
77#[derive(Clone)]
78struct MzAwsResolver {
79    enforce_external_addresses: bool,
80}
81
82impl Service<Name> for MzAwsResolver {
83    type Response = std::vec::IntoIter<SocketAddr>;
84    type Error = mz_ore::netio::DnsResolutionError;
85    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
86
87    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88        Poll::Ready(Ok(()))
89    }
90
91    fn call(&mut self, name: Name) -> Self::Future {
92        let enforce = self.enforce_external_addresses;
93        let host = name.as_str().to_string();
94        Box::pin(async move {
95            let ips = mz_ore::netio::resolve_address(&host, enforce).await?;
96            // Hyper substitutes the URL's port (or the default for the scheme)
97            // when the SocketAddr's port is 0.
98            Ok(ips
99                .into_iter()
100                .map(|ip| SocketAddr::new(ip, 0))
101                .collect::<Vec<_>>()
102                .into_iter())
103        })
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use std::net::IpAddr;
110    use std::str::FromStr;
111
112    use mz_ore::netio::DnsResolutionError;
113
114    use super::*;
115
116    #[mz_ore::test(tokio::test)]
117    #[cfg_attr(miri, ignore)]
118    async fn resolver_rejects_loopback_when_enforced() {
119        let mut resolver = MzAwsResolver {
120            enforce_external_addresses: true,
121        };
122        let name = Name::from_str("127.0.0.1").unwrap();
123        let err = resolver.call(name).await.expect_err("must reject loopback");
124        assert!(
125            matches!(err, DnsResolutionError::PrivateAddress),
126            "got {err:?}"
127        );
128    }
129
130    #[mz_ore::test(tokio::test)]
131    #[cfg_attr(miri, ignore)]
132    async fn resolver_allows_loopback_when_not_enforced() {
133        let mut resolver = MzAwsResolver {
134            enforce_external_addresses: false,
135        };
136        let name = Name::from_str("127.0.0.1").unwrap();
137        let addrs: Vec<SocketAddr> = resolver
138            .call(name)
139            .await
140            .expect("loopback should resolve when enforcement is off")
141            .collect();
142        assert!(addrs.iter().any(|a| a.ip() == IpAddr::from([127, 0, 0, 1])));
143    }
144
145    #[mz_ore::test(tokio::test)]
146    #[cfg_attr(miri, ignore)]
147    async fn resolver_allows_public_when_enforced() {
148        let mut resolver = MzAwsResolver {
149            enforce_external_addresses: true,
150        };
151        let name = Name::from_str("8.8.8.8").unwrap();
152        let addrs: Vec<SocketAddr> = resolver
153            .call(name)
154            .await
155            .expect("public IP should resolve")
156            .collect();
157        assert!(addrs.iter().any(|a| a.ip() == IpAddr::from([8, 8, 8, 8])));
158    }
159}