Skip to main content

mz_storage_operators/oneshot_source/
aws_source.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! AWS S3 [`OneshotSource`].
11
12use std::path::Path;
13use std::str::FromStr;
14use std::sync::Arc;
15
16use aws_sdk_s3::error::DisplayErrorContext;
17use derivative::Derivative;
18use futures::StreamExt;
19use futures::stream::{BoxStream, TryStreamExt};
20use mz_ore::future::InTask;
21use mz_repr::CatalogItemId;
22use mz_storage_types::connections::ConnectionContext;
23use mz_storage_types::connections::aws::AwsConnection;
24use serde::{Deserialize, Serialize};
25
26use crate::oneshot_source::util::IntoRangeHeaderValue;
27use crate::oneshot_source::{
28    OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
29};
30
31#[derive(Clone, Derivative)]
32#[derivative(Debug)]
33pub struct AwsS3Source {
34    // Only used for initialization.
35    #[derivative(Debug = "ignore")]
36    connection: Arc<AwsConnection>,
37    connection_id: CatalogItemId,
38    #[derivative(Debug = "ignore")]
39    context: Arc<ConnectionContext>,
40
41    /// Name of the S3 bucket we'll list from.
42    bucket: String,
43    /// Optional prefix that can be specified via an S3 URI.
44    prefix: Option<String>,
45    /// S3 client that is lazily initialized.
46    #[derivative(Debug = "ignore")]
47    client: std::sync::OnceLock<mz_aws_util::s3::Client>,
48    use_checksum: bool,
49    enforce_external_addresses: bool,
50}
51
52impl AwsS3Source {
53    pub fn new(
54        connection: AwsConnection,
55        connection_id: CatalogItemId,
56        context: ConnectionContext,
57        uri: String,
58        use_checksum: bool,
59        enforce_external_addresses: bool,
60    ) -> Self {
61        let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
62
63        let bucket = uri
64            .host()
65            .expect("validated host in sequencing")
66            .to_string();
67        let prefix = if uri.path().is_empty() || uri.path() == "/" {
68            None
69        } else {
70            // The S3 client expects a trailing `/` but no leading `/`.
71            let mut prefix = uri.path().to_string();
72
73            if let Some(suffix) = prefix.strip_prefix('/') {
74                prefix = suffix.to_string();
75            }
76            if !prefix.ends_with('/') {
77                prefix = format!("{prefix}/");
78            }
79
80            Some(prefix)
81        };
82
83        AwsS3Source {
84            connection: Arc::new(connection),
85            context: Arc::new(context),
86            connection_id,
87            bucket,
88            prefix,
89            client: std::sync::OnceLock::new(),
90            use_checksum,
91            enforce_external_addresses,
92        }
93    }
94
95    pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
96        let sdk_config = self
97            .connection
98            .load_sdk_config(
99                &self.context,
100                self.connection_id,
101                InTask::Yes,
102                self.enforce_external_addresses,
103            )
104            .await?;
105
106        let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&sdk_config)
107            .force_path_style(sdk_config.endpoint_url().is_some());
108
109        if !self.use_checksum {
110            s3_config_builder = s3_config_builder.response_checksum_validation(
111                aws_smithy_types::checksum_config::ResponseChecksumValidation::WhenRequired,
112            );
113        }
114        let s3_config = s3_config_builder.build();
115        let s3_client = aws_sdk_s3::Client::from_conf(s3_config);
116
117        Ok(s3_client)
118    }
119
120    pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
121        if self.client.get().is_none() {
122            let client = self.initialize().await?;
123            let _ = self.client.set(client);
124        }
125
126        Ok(self.client.get().expect("just initialized"))
127    }
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct S3Object {
132    /// Key from S3 list operation.
133    key: String,
134    /// Name of the object, generally the last component of the key.
135    name: String,
136    /// Size of the object in bytes.
137    size: usize,
138}
139
140impl OneshotObject for S3Object {
141    fn name(&self) -> &str {
142        &self.name
143    }
144
145    fn path(&self) -> &str {
146        &self.key
147    }
148
149    fn size(&self) -> usize {
150        self.size
151    }
152
153    fn encodings(&self) -> &[super::Encoding] {
154        &[]
155    }
156}
157
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct S3Checksum {
160    e_tag: Option<String>,
161}
162
163impl OneshotSource for AwsS3Source {
164    type Object = S3Object;
165    type Checksum = S3Checksum;
166
167    async fn list<'a>(
168        &'a self,
169    ) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
170        let client = self.client().await.map_err(StorageErrorXKind::generic)?;
171        let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
172
173        // Users can optionally specify a prefix via the S3 uri they originally specify.
174        if let Some(prefix) = &self.prefix {
175            objects_request = objects_request.prefix(prefix);
176        }
177
178        let objects = objects_request
179            .into_paginator()
180            .send()
181            .try_collect()
182            .await
183            .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))
184            .context("list_objects_v2")?;
185
186        let objects: Vec<_> = objects
187            .iter()
188            .flat_map(aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output::contents)
189            .map(|o| {
190                let key = o
191                    .key()
192                    .ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
193                    .to_owned();
194                let name = Path::new(&key)
195                    .file_name()
196                    .and_then(|os_name| os_name.to_str())
197                    .ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
198                    .to_string();
199                let size = o
200                    .size()
201                    .ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
202                let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
203
204                let object = S3Object { key, name, size };
205                let checksum = S3Checksum {
206                    e_tag: o.e_tag().map(|x| x.to_owned()),
207                };
208
209                Ok::<_, StorageErrorXKind>((object, checksum))
210            })
211            .collect::<Result<_, _>>()
212            .context("list")?;
213
214        Ok(objects)
215    }
216
217    fn get<'s>(
218        &'s self,
219        object: Self::Object,
220        _checksum: Self::Checksum,
221        range: Option<std::ops::RangeInclusive<usize>>,
222    ) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
223        let initial_response = async move {
224            tracing::info!(name = %object.name(), ?range, "fetching object");
225
226            let client = self.client().await.map_err(StorageErrorXKind::generic)?;
227
228            let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
229            if let Some(range) = range {
230                let value = range.into_range_header_value();
231                request = request.range(value);
232            }
233            let object = request
234                .send()
235                .await
236                .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))?;
237            // AWS's ByteStream doesn't implement the Stream trait.
238            let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
239                .err_into()
240                .boxed();
241
242            Ok::<_, StorageErrorX>(stream)
243        };
244
245        futures::stream::once(initial_response)
246            .try_flatten()
247            .boxed()
248    }
249}