1use std::future::Future;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use aws_config::{BehaviorVersion, ConfigLoader};
16use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
17use aws_smithy_runtime_api::client::http::{HttpClient, SharedHttpClient};
18use hyper_0_14::client::HttpConnector;
19use hyper_0_14::client::connect::dns::Name;
20use hyper_tls::HttpsConnector;
21use tower_service::Service;
22
23#[cfg(feature = "s3")]
24pub mod s3;
25#[cfg(feature = "s3")]
26pub mod s3_uploader;
27
28pub fn defaults() -> ConfigLoader {
31 let behavior_version = BehaviorVersion::latest();
37
38 #[allow(clippy::disallowed_methods)]
40 let loader = aws_config::defaults(behavior_version);
41
42 let loader = loader.http_client(http_client());
44
45 loader
46}
47
48pub fn http_client() -> impl HttpClient {
51 HyperClientBuilder::new().build(HttpsConnector::new())
54}
55
56pub fn http_client_with_resolver(enforce_external_addresses: bool) -> SharedHttpClient {
63 let resolver = MzAwsResolver {
64 enforce_external_addresses,
65 };
66 let mut http = HttpConnector::new_with_resolver(resolver);
67 http.enforce_http(false);
71 let https = HttpsConnector::new_with_connector(http);
72 HyperClientBuilder::new().build(https)
73}
74
75#[derive(Clone)]
78struct MzAwsResolver {
79 enforce_external_addresses: bool,
80}
81
82impl Service<Name> for MzAwsResolver {
83 type Response = std::vec::IntoIter<SocketAddr>;
84 type Error = mz_ore::netio::DnsResolutionError;
85 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
86
87 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88 Poll::Ready(Ok(()))
89 }
90
91 fn call(&mut self, name: Name) -> Self::Future {
92 let enforce = self.enforce_external_addresses;
93 let host = name.as_str().to_string();
94 Box::pin(async move {
95 let ips = mz_ore::netio::resolve_address(&host, enforce).await?;
96 Ok(ips
99 .into_iter()
100 .map(|ip| SocketAddr::new(ip, 0))
101 .collect::<Vec<_>>()
102 .into_iter())
103 })
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use std::net::IpAddr;
110 use std::str::FromStr;
111
112 use mz_ore::netio::DnsResolutionError;
113
114 use super::*;
115
116 #[mz_ore::test(tokio::test)]
117 #[cfg_attr(miri, ignore)]
118 async fn resolver_rejects_loopback_when_enforced() {
119 let mut resolver = MzAwsResolver {
120 enforce_external_addresses: true,
121 };
122 let name = Name::from_str("127.0.0.1").unwrap();
123 let err = resolver.call(name).await.expect_err("must reject loopback");
124 assert!(
125 matches!(err, DnsResolutionError::PrivateAddress),
126 "got {err:?}"
127 );
128 }
129
130 #[mz_ore::test(tokio::test)]
131 #[cfg_attr(miri, ignore)]
132 async fn resolver_allows_loopback_when_not_enforced() {
133 let mut resolver = MzAwsResolver {
134 enforce_external_addresses: false,
135 };
136 let name = Name::from_str("127.0.0.1").unwrap();
137 let addrs: Vec<SocketAddr> = resolver
138 .call(name)
139 .await
140 .expect("loopback should resolve when enforcement is off")
141 .collect();
142 assert!(addrs.iter().any(|a| a.ip() == IpAddr::from([127, 0, 0, 1])));
143 }
144
145 #[mz_ore::test(tokio::test)]
146 #[cfg_attr(miri, ignore)]
147 async fn resolver_allows_public_when_enforced() {
148 let mut resolver = MzAwsResolver {
149 enforce_external_addresses: true,
150 };
151 let name = Name::from_str("8.8.8.8").unwrap();
152 let addrs: Vec<SocketAddr> = resolver
153 .call(name)
154 .await
155 .expect("public IP should resolve")
156 .collect();
157 assert!(addrs.iter().any(|a| a.ip() == IpAddr::from([8, 8, 8, 8])));
158 }
159}