mz_storage_operators/oneshot_source/
http_source.rs1use std::net::SocketAddr;
13use std::sync::Arc;
14
15use bytes::Bytes;
16use derivative::Derivative;
17use futures::TryStreamExt;
18use futures::stream::{BoxStream, StreamExt};
19use reqwest::Client;
20use reqwest::dns::{Addrs, Name, Resolve, Resolving};
21use serde::{Deserialize, Serialize};
22use url::Url;
23
24use crate::oneshot_source::util::IntoRangeHeaderValue;
25use crate::oneshot_source::{
26 Encoding, OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
27};
28
29#[derive(Debug)]
35struct MzHttpResolver {
36 enforce_external_addresses: bool,
37}
38
39impl Resolve for MzHttpResolver {
40 fn resolve(&self, name: Name) -> Resolving {
41 let enforce = self.enforce_external_addresses;
42 Box::pin(async move {
43 let ips = mz_ore::netio::resolve_address(name.as_str(), enforce).await?;
44 let addrs: Addrs = Box::new(
47 ips.into_iter()
48 .map(|ip| SocketAddr::new(ip, 0))
49 .collect::<Vec<_>>()
50 .into_iter(),
51 );
52 Ok(addrs)
53 })
54 }
55}
56
57pub fn build_http_client(enforce_external_addresses: bool) -> Result<Client, reqwest::Error> {
65 Client::builder()
66 .dns_resolver(Arc::new(MzHttpResolver {
67 enforce_external_addresses,
68 }))
69 .redirect(reqwest::redirect::Policy::none())
70 .build()
71}
72
73fn check_not_redirect(response: &reqwest::Response) -> Result<(), StorageErrorX> {
78 if response.status().is_redirection() {
79 return Err(StorageErrorXKind::Redirect(response.status().as_u16()).into());
80 }
81 Ok(())
82}
83
84#[derive(Clone, Derivative)]
86#[derivative(Debug)]
87pub struct HttpOneshotSource {
88 #[derivative(Debug = "ignore")]
89 client: Client,
90 origin: Url,
91}
92
93impl HttpOneshotSource {
94 pub fn new(client: Client, origin: Url) -> Self {
95 HttpOneshotSource { client, origin }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct HttpObject {
102 url: Url,
104 filename: String,
106 size: usize,
108 content_encoding: Vec<Encoding>,
110}
111
112impl OneshotObject for HttpObject {
113 fn name(&self) -> &str {
114 &self.filename
115 }
116
117 fn path(&self) -> &str {
118 &self.filename
119 }
120
121 fn size(&self) -> usize {
122 self.size
123 }
124
125 fn encodings(&self) -> &[Encoding] {
126 &self.content_encoding
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub enum HttpChecksum {
132 None,
134 ETag(String),
136 LastModified(String),
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[mz_ore::test(tokio::test)]
149 #[cfg_attr(miri, ignore)]
150 async fn build_http_client_rejects_localhost_when_enforced() {
151 let client = build_http_client(true).expect("build client");
152 let err = client
153 .get("http://localhost:1/")
154 .send()
155 .await
156 .expect_err("request must fail at DNS resolution");
157 let mut found = false;
161 let mut current: &dyn std::error::Error = &err;
162 loop {
163 if current.to_string().to_lowercase().contains("private") {
164 found = true;
165 break;
166 }
167 match current.source() {
168 Some(src) => current = src,
169 None => break,
170 }
171 }
172 assert!(found, "expected private-address rejection, got: {err:?}");
173 }
174
175 #[mz_ore::test(tokio::test)]
179 #[cfg_attr(miri, ignore)]
180 async fn build_http_client_allows_localhost_when_not_enforced() {
181 let client = build_http_client(false).expect("build client");
182 let err = client
183 .get("http://localhost:1/")
184 .send()
185 .await
186 .expect_err("port 1 should not be listening");
187 let mut current: &dyn std::error::Error = &err;
188 loop {
189 assert!(
190 !current.to_string().to_lowercase().contains("private"),
191 "expected a connect error, not a DNS rejection: {err:?}"
192 );
193 match current.source() {
194 Some(src) => current = src,
195 None => break,
196 }
197 }
198 }
199}
200
201impl OneshotSource for HttpOneshotSource {
202 type Object = HttpObject;
203 type Checksum = HttpChecksum;
204
205 async fn list<'a>(&'a self) -> Result<Vec<(Self::Object, Self::Checksum)>, StorageErrorX> {
206 let response = self
211 .client
212 .head(self.origin.clone())
213 .send()
214 .await
215 .context("HEAD request")?;
216
217 check_not_redirect(&response)?;
218
219 let headers = match response.error_for_status() {
222 Ok(response) => response.headers().clone(),
223 Err(err) => {
224 tracing::warn!(status = ?err.status(), "HEAD request failed");
225
226 let response = self
227 .client
228 .get(self.origin.clone())
229 .send()
230 .await
231 .context("GET request")?;
232
233 check_not_redirect(&response)?;
234
235 let headers = response.headers().clone();
236
237 drop(response);
239
240 headers
241 }
242 };
243
244 let get_header = |name: &reqwest::header::HeaderName| {
245 let header = headers.get(name)?;
246 match header.to_str() {
247 Err(e) => {
248 tracing::warn!("failed to deserialize header '{name}', err: {e}");
249 None
250 }
251 Ok(value) => Some(value),
252 }
253 };
254
255 let checksum = if let Some(etag) = get_header(&reqwest::header::ETAG) {
257 HttpChecksum::ETag(etag.to_string())
258 } else if let Some(last_modified) = get_header(&reqwest::header::LAST_MODIFIED) {
259 let last_modified = last_modified.to_string();
260 HttpChecksum::LastModified(last_modified.to_string())
261 } else {
262 HttpChecksum::None
263 };
264
265 let size = get_header(&reqwest::header::CONTENT_LENGTH)
267 .ok_or(StorageErrorXKind::MissingSize)
268 .and_then(|s| s.parse::<usize>().map_err(StorageErrorXKind::generic))
269 .context("content-length header")?;
270
271 let filename = self
274 .origin
275 .path_segments()
276 .and_then(|segments| segments.rev().next())
277 .map(|s| s.to_string())
278 .unwrap_or_default();
279 let object = HttpObject {
280 url: self.origin.clone(),
281 filename,
282 size,
283 content_encoding: vec![],
284 };
285 tracing::info!(?object, "found objects");
286
287 Ok(vec![(object, checksum)])
288 }
289
290 fn get<'s>(
291 &'s self,
292 object: Self::Object,
293 _checksum: Self::Checksum,
294 range: Option<std::ops::RangeInclusive<usize>>,
295 ) -> BoxStream<'s, Result<Bytes, StorageErrorX>> {
296 let initial_response = async move {
299 let mut request = self.client.get(object.url);
300
301 if let Some(range) = &range {
302 let value = range.into_range_header_value();
303 request = request.header(&reqwest::header::RANGE, value);
304 }
305
306 let response = request.send().await.context("get")?;
311 check_not_redirect(&response)?;
312 let bytes_stream = response.bytes_stream().err_into();
313
314 Ok::<_, StorageErrorX>(bytes_stream)
315 };
316
317 futures::stream::once(initial_response)
318 .try_flatten()
319 .boxed()
320 }
321}