Skip to main content

mz_storage_operators/oneshot_source/
aws_source.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! AWS S3 [`OneshotSource`].
11
12use std::path::Path;
13use std::str::FromStr;
14use std::sync::Arc;
15
16use aws_sdk_s3::error::DisplayErrorContext;
17use derivative::Derivative;
18use futures::StreamExt;
19use futures::stream::{BoxStream, TryStreamExt};
20use mz_ore::future::InTask;
21use mz_repr::CatalogItemId;
22use mz_storage_types::connections::ConnectionContext;
23use mz_storage_types::connections::aws::AwsConnection;
24use serde::{Deserialize, Serialize};
25
26use crate::oneshot_source::util::IntoRangeHeaderValue;
27use crate::oneshot_source::{
28    OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
29};
30
31#[derive(Clone, Derivative)]
32#[derivative(Debug)]
33pub struct AwsS3Source {
34    // Only used for initialization.
35    #[derivative(Debug = "ignore")]
36    connection: Arc<AwsConnection>,
37    connection_id: CatalogItemId,
38    #[derivative(Debug = "ignore")]
39    context: Arc<ConnectionContext>,
40
41    /// Name of the S3 bucket we'll list from.
42    bucket: String,
43    /// Optional prefix that can be specified via an S3 URI.
44    prefix: Option<String>,
45    /// S3 client that is lazily initialized.
46    #[derivative(Debug = "ignore")]
47    client: std::sync::OnceLock<mz_aws_util::s3::Client>,
48    use_checksum: bool,
49}
50
51impl AwsS3Source {
52    pub fn new(
53        connection: AwsConnection,
54        connection_id: CatalogItemId,
55        context: ConnectionContext,
56        uri: String,
57        use_checksum: bool,
58    ) -> Self {
59        let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
60
61        let bucket = uri
62            .host()
63            .expect("validated host in sequencing")
64            .to_string();
65        let prefix = if uri.path().is_empty() || uri.path() == "/" {
66            None
67        } else {
68            // The S3 client expects a trailing `/` but no leading `/`.
69            let mut prefix = uri.path().to_string();
70
71            if let Some(suffix) = prefix.strip_prefix('/') {
72                prefix = suffix.to_string();
73            }
74            if !prefix.ends_with('/') {
75                prefix = format!("{prefix}/");
76            }
77
78            Some(prefix)
79        };
80
81        AwsS3Source {
82            connection: Arc::new(connection),
83            context: Arc::new(context),
84            connection_id,
85            bucket,
86            prefix,
87            client: std::sync::OnceLock::new(),
88            use_checksum,
89        }
90    }
91
92    pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
93        let sdk_config = self
94            .connection
95            .load_sdk_config(&self.context, self.connection_id, InTask::Yes)
96            .await?;
97
98        let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&sdk_config)
99            .force_path_style(sdk_config.endpoint_url().is_some());
100
101        if !self.use_checksum {
102            s3_config_builder = s3_config_builder.response_checksum_validation(
103                aws_smithy_types::checksum_config::ResponseChecksumValidation::WhenRequired,
104            );
105        }
106        let s3_config = s3_config_builder.build();
107        let s3_client = aws_sdk_s3::Client::from_conf(s3_config);
108
109        Ok(s3_client)
110    }
111
112    pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
113        if self.client.get().is_none() {
114            let client = self.initialize().await?;
115            let _ = self.client.set(client);
116        }
117
118        Ok(self.client.get().expect("just initialized"))
119    }
120}
121
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub struct S3Object {
124    /// Key from S3 list operation.
125    key: String,
126    /// Name of the object, generally the last component of the key.
127    name: String,
128    /// Size of the object in bytes.
129    size: usize,
130}
131
132impl OneshotObject for S3Object {
133    fn name(&self) -> &str {
134        &self.name
135    }
136
137    fn path(&self) -> &str {
138        &self.key
139    }
140
141    fn size(&self) -> usize {
142        self.size
143    }
144
145    fn encodings(&self) -> &[super::Encoding] {
146        &[]
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
151pub struct S3Checksum {
152    e_tag: Option<String>,
153}
154
155impl OneshotSource for AwsS3Source {
156    type Object = S3Object;
157    type Checksum = S3Checksum;
158
159    async fn list<'a>(
160        &'a self,
161    ) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
162        let client = self.client().await.map_err(StorageErrorXKind::generic)?;
163        let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
164
165        // Users can optionally specify a prefix via the S3 uri they originally specify.
166        if let Some(prefix) = &self.prefix {
167            objects_request = objects_request.prefix(prefix);
168        }
169
170        let objects = objects_request
171            .into_paginator()
172            .send()
173            .try_collect()
174            .await
175            .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))
176            .context("list_objects_v2")?;
177
178        let objects: Vec<_> = objects
179            .iter()
180            .flat_map(aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output::contents)
181            .map(|o| {
182                let key = o
183                    .key()
184                    .ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
185                    .to_owned();
186                let name = Path::new(&key)
187                    .file_name()
188                    .and_then(|os_name| os_name.to_str())
189                    .ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
190                    .to_string();
191                let size = o
192                    .size()
193                    .ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
194                let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
195
196                let object = S3Object { key, name, size };
197                let checksum = S3Checksum {
198                    e_tag: o.e_tag().map(|x| x.to_owned()),
199                };
200
201                Ok::<_, StorageErrorXKind>((object, checksum))
202            })
203            .collect::<Result<_, _>>()
204            .context("list")?;
205
206        Ok(objects)
207    }
208
209    fn get<'s>(
210        &'s self,
211        object: Self::Object,
212        _checksum: Self::Checksum,
213        range: Option<std::ops::RangeInclusive<usize>>,
214    ) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
215        let initial_response = async move {
216            tracing::info!(name = %object.name(), ?range, "fetching object");
217
218            let client = self.client().await.map_err(StorageErrorXKind::generic)?;
219
220            let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
221            if let Some(range) = range {
222                let value = range.into_range_header_value();
223                request = request.range(value);
224            }
225            let object = request
226                .send()
227                .await
228                .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))?;
229            // AWS's ByteStream doesn't implement the Stream trait.
230            let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
231                .err_into()
232                .boxed();
233
234            Ok::<_, StorageErrorX>(stream)
235        };
236
237        futures::stream::once(initial_response)
238            .try_flatten()
239            .boxed()
240    }
241}