Skip to main content

mz_storage_operators/oneshot_source/
http_source.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Generic HTTP oneshot source that will fetch a file from the public internet.
11
12use std::net::SocketAddr;
13use std::sync::Arc;
14
15use bytes::Bytes;
16use derivative::Derivative;
17use futures::TryStreamExt;
18use futures::stream::{BoxStream, StreamExt};
19use reqwest::Client;
20use reqwest::dns::{Addrs, Name, Resolve, Resolving};
21use serde::{Deserialize, Serialize};
22use url::Url;
23
24use crate::oneshot_source::util::IntoRangeHeaderValue;
25use crate::oneshot_source::{
26    Encoding, OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
27};
28
29/// reqwest DNS resolver that delegates to [`mz_ore::netio::resolve_address`].
30///
31/// Only the IP resolution step is overridden — reqwest still uses the URL's
32/// original hostname for SNI and TLS certificate validation, so HTTPS works
33/// normally.
34#[derive(Debug)]
35struct MzHttpResolver {
36    enforce_external_addresses: bool,
37}
38
39impl Resolve for MzHttpResolver {
40    fn resolve(&self, name: Name) -> Resolving {
41        let enforce = self.enforce_external_addresses;
42        Box::pin(async move {
43            let ips = mz_ore::netio::resolve_address(name.as_str(), enforce).await?;
44            // reqwest substitutes the conventional port (80/443) when the
45            // SocketAddr's port is 0 and no explicit port was given in the URL.
46            let addrs: Addrs = Box::new(
47                ips.into_iter()
48                    .map(|ip| SocketAddr::new(ip, 0))
49                    .collect::<Vec<_>>()
50                    .into_iter(),
51            );
52            Ok(addrs)
53        })
54    }
55}
56
57/// Build a reqwest [`Client`] for fetching `COPY FROM` URLs. This uses
58/// [`mz_ore::netio::resolve_address`] for DNS resolution.
59///
60/// Redirects are disabled: the custom DNS resolver re-validates hostnames on
61/// every hop, but reqwest skips DNS for IP-literal targets, so a redirect to
62/// `http://127.0.0.1/` would bypass the SSRF check. Refusing to follow
63/// redirects closes that hole.
64pub fn build_http_client(enforce_external_addresses: bool) -> Result<Client, reqwest::Error> {
65    Client::builder()
66        .dns_resolver(Arc::new(MzHttpResolver {
67            enforce_external_addresses,
68        }))
69        .redirect(reqwest::redirect::Policy::none())
70        .build()
71}
72
73/// Returns an error if `response` is a 3xx redirect. Materialize disables
74/// redirect following on the HTTP client (see `build_http_client`) to close
75/// an SSRF hole, so callers must surface a meaningful error rather than
76/// letting the response fall through to header parsing.
77fn check_not_redirect(response: &reqwest::Response) -> Result<(), StorageErrorX> {
78    if response.status().is_redirection() {
79        return Err(StorageErrorXKind::Redirect(response.status().as_u16()).into());
80    }
81    Ok(())
82}
83
84/// Generic oneshot source that fetches a file from a URL on the public internet.
85#[derive(Clone, Derivative)]
86#[derivative(Debug)]
87pub struct HttpOneshotSource {
88    #[derivative(Debug = "ignore")]
89    client: Client,
90    origin: Url,
91}
92
93impl HttpOneshotSource {
94    pub fn new(client: Client, origin: Url) -> Self {
95        HttpOneshotSource { client, origin }
96    }
97}
98
99/// Object returned from an [`HttpOneshotSource`].
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct HttpObject {
102    /// [`Url`] to access the file.
103    url: Url,
104    /// Name of the file.
105    filename: String,
106    /// Size of this file reported by the [`Content-Length`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Length) header
107    size: usize,
108    /// Any values reporting from the [`Content-Encoding`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding) header.
109    content_encoding: Vec<Encoding>,
110}
111
112impl OneshotObject for HttpObject {
113    fn name(&self) -> &str {
114        &self.filename
115    }
116
117    fn path(&self) -> &str {
118        &self.filename
119    }
120
121    fn size(&self) -> usize {
122        self.size
123    }
124
125    fn encodings(&self) -> &[Encoding] {
126        &self.content_encoding
127    }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub enum HttpChecksum {
132    /// No checksumming is requested.
133    None,
134    /// The HTTP [`ETag`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag) header.
135    ETag(String),
136    /// The HTTP [`Last-Modified`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified) header.
137    LastModified(String),
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    /// `reqwest::dns::Name` has no public constructor, so we exercise
145    /// [`MzHttpResolver`] through a fully-built [`Client`]. `localhost`
146    /// resolves via /etc/hosts on supported platforms and stays inside the
147    /// resolver path (an IP literal would short-circuit DNS entirely).
148    #[mz_ore::test(tokio::test)]
149    #[cfg_attr(miri, ignore)]
150    async fn build_http_client_rejects_localhost_when_enforced() {
151        let client = build_http_client(true).expect("build client");
152        let err = client
153            .get("http://localhost:1/")
154            .send()
155            .await
156            .expect_err("request must fail at DNS resolution");
157        // Walk the error chain for the resolver's PrivateAddress message —
158        // reqwest wraps it inside its connect error, so a `to_string()` on
159        // the top-level error is not enough.
160        let mut found = false;
161        let mut current: &dyn std::error::Error = &err;
162        loop {
163            if current.to_string().to_lowercase().contains("private") {
164                found = true;
165                break;
166            }
167            match current.source() {
168                Some(src) => current = src,
169                None => break,
170            }
171        }
172        assert!(found, "expected private-address rejection, got: {err:?}");
173    }
174
175    /// With enforcement off the resolver returns the loopback IP, so the
176    /// request reaches the connect stage and fails for a different reason
177    /// (port 1 is not listening). The point is that DNS does *not* fail.
178    #[mz_ore::test(tokio::test)]
179    #[cfg_attr(miri, ignore)]
180    async fn build_http_client_allows_localhost_when_not_enforced() {
181        let client = build_http_client(false).expect("build client");
182        let err = client
183            .get("http://localhost:1/")
184            .send()
185            .await
186            .expect_err("port 1 should not be listening");
187        let mut current: &dyn std::error::Error = &err;
188        loop {
189            assert!(
190                !current.to_string().to_lowercase().contains("private"),
191                "expected a connect error, not a DNS rejection: {err:?}"
192            );
193            match current.source() {
194                Some(src) => current = src,
195                None => break,
196            }
197        }
198    }
199}
200
201impl OneshotSource for HttpOneshotSource {
202    type Object = HttpObject;
203    type Checksum = HttpChecksum;
204
205    async fn list<'a>(&'a self) -> Result<Vec<(Self::Object, Self::Checksum)>, StorageErrorX> {
206        // TODO(cf3): Support listing files from a directory index.
207
208        // To get metadata about a file we'll first try issuing a `HEAD` request, which
209        // canonically is the right thing do.
210        let response = self
211            .client
212            .head(self.origin.clone())
213            .send()
214            .await
215            .context("HEAD request")?;
216
217        check_not_redirect(&response)?;
218
219        // Not all servers accept `HEAD` requests though, so we'll fallback to a `GET`
220        // request and skip fetching the body.
221        let headers = match response.error_for_status() {
222            Ok(response) => response.headers().clone(),
223            Err(err) => {
224                tracing::warn!(status = ?err.status(), "HEAD request failed");
225
226                let response = self
227                    .client
228                    .get(self.origin.clone())
229                    .send()
230                    .await
231                    .context("GET request")?;
232
233                check_not_redirect(&response)?;
234
235                let headers = response.headers().clone();
236
237                // Immediately drop the response so we don't attempt to fetch the body.
238                drop(response);
239
240                headers
241            }
242        };
243
244        let get_header = |name: &reqwest::header::HeaderName| {
245            let header = headers.get(name)?;
246            match header.to_str() {
247                Err(e) => {
248                    tracing::warn!("failed to deserialize header '{name}', err: {e}");
249                    None
250                }
251                Ok(value) => Some(value),
252            }
253        };
254
255        // Get a checksum from the content.
256        let checksum = if let Some(etag) = get_header(&reqwest::header::ETAG) {
257            HttpChecksum::ETag(etag.to_string())
258        } else if let Some(last_modified) = get_header(&reqwest::header::LAST_MODIFIED) {
259            let last_modified = last_modified.to_string();
260            HttpChecksum::LastModified(last_modified.to_string())
261        } else {
262            HttpChecksum::None
263        };
264
265        // Get the size of the object from the Conent-Length header.
266        let size = get_header(&reqwest::header::CONTENT_LENGTH)
267            .ok_or(StorageErrorXKind::MissingSize)
268            .and_then(|s| s.parse::<usize>().map_err(StorageErrorXKind::generic))
269            .context("content-length header")?;
270
271        // TODO(cf1): We should probably check the content-type as well. At least for advisory purposes.
272
273        let filename = self
274            .origin
275            .path_segments()
276            .and_then(|segments| segments.rev().next())
277            .map(|s| s.to_string())
278            .unwrap_or_default();
279        let object = HttpObject {
280            url: self.origin.clone(),
281            filename,
282            size,
283            content_encoding: vec![],
284        };
285        tracing::info!(?object, "found objects");
286
287        Ok(vec![(object, checksum)])
288    }
289
290    fn get<'s>(
291        &'s self,
292        object: Self::Object,
293        _checksum: Self::Checksum,
294        range: Option<std::ops::RangeInclusive<usize>>,
295    ) -> BoxStream<'s, Result<Bytes, StorageErrorX>> {
296        // TODO(cf1): Validate our checksum.
297
298        let initial_response = async move {
299            let mut request = self.client.get(object.url);
300
301            if let Some(range) = &range {
302                let value = range.into_range_header_value();
303                request = request.header(&reqwest::header::RANGE, value);
304            }
305
306            // TODO(parkmycar): We should probably assert that the response contains
307            // an appropriate Content-Range header in the response, and maybe that we
308            // got back an HTTP 206?
309
310            let response = request.send().await.context("get")?;
311            check_not_redirect(&response)?;
312            let bytes_stream = response.bytes_stream().err_into();
313
314            Ok::<_, StorageErrorX>(bytes_stream)
315        };
316
317        futures::stream::once(initial_response)
318            .try_flatten()
319            .boxed()
320    }
321}