mz_storage_operators/oneshot_source/
aws_source.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! AWS S3 [`OneshotSource`].
11
12use std::path::Path;
13use std::str::FromStr;
14use std::sync::Arc;
15
16use derivative::Derivative;
17use futures::StreamExt;
18use futures::stream::{BoxStream, TryStreamExt};
19use mz_ore::future::InTask;
20use mz_repr::CatalogItemId;
21use mz_storage_types::connections::ConnectionContext;
22use mz_storage_types::connections::aws::AwsConnection;
23use serde::{Deserialize, Serialize};
24
25use crate::oneshot_source::util::IntoRangeHeaderValue;
26use crate::oneshot_source::{
27    OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
28};
29
30#[derive(Clone, Derivative)]
31#[derivative(Debug)]
32pub struct AwsS3Source {
33    // Only used for initialization.
34    #[derivative(Debug = "ignore")]
35    connection: Arc<AwsConnection>,
36    connection_id: CatalogItemId,
37    #[derivative(Debug = "ignore")]
38    context: Arc<ConnectionContext>,
39
40    /// Name of the S3 bucket we'll list from.
41    bucket: String,
42    /// Optional prefix that can be specified via an S3 URI.
43    prefix: Option<String>,
44    /// S3 client that is lazily initialized.
45    #[derivative(Debug = "ignore")]
46    client: std::sync::OnceLock<mz_aws_util::s3::Client>,
47}
48
49impl AwsS3Source {
50    pub fn new(
51        connection: AwsConnection,
52        connection_id: CatalogItemId,
53        context: ConnectionContext,
54        uri: String,
55    ) -> Self {
56        let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
57
58        let bucket = uri
59            .host()
60            .expect("validated host in sequencing")
61            .to_string();
62        let prefix = if uri.path().is_empty() || uri.path() == "/" {
63            None
64        } else {
65            // The S3 client expects a trailing `/` but no leading `/`.
66            let mut prefix = uri.path().to_string();
67
68            if let Some(suffix) = prefix.strip_prefix('/') {
69                prefix = suffix.to_string();
70            }
71            if !prefix.ends_with('/') {
72                prefix = format!("{prefix}/");
73            }
74
75            Some(prefix)
76        };
77
78        AwsS3Source {
79            connection: Arc::new(connection),
80            context: Arc::new(context),
81            connection_id,
82            bucket,
83            prefix,
84            client: std::sync::OnceLock::new(),
85        }
86    }
87
88    pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
89        let sdk_config = self
90            .connection
91            .load_sdk_config(&self.context, self.connection_id, InTask::Yes)
92            .await?;
93        let s3_client = mz_aws_util::s3::new_client(&sdk_config);
94
95        Ok(s3_client)
96    }
97
98    pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
99        if self.client.get().is_none() {
100            let client = self.initialize().await?;
101            let _ = self.client.set(client);
102        }
103
104        Ok(self.client.get().expect("just initialized"))
105    }
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct S3Object {
110    /// Key from S3 list operation.
111    key: String,
112    /// Name of the object, generally the last component of the key.
113    name: String,
114    /// Size of the object in bytes.
115    size: usize,
116}
117
118impl OneshotObject for S3Object {
119    fn name(&self) -> &str {
120        &self.name
121    }
122
123    fn path(&self) -> &str {
124        &self.key
125    }
126
127    fn size(&self) -> usize {
128        self.size
129    }
130
131    fn encodings(&self) -> &[super::Encoding] {
132        &[]
133    }
134}
135
136#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
137pub struct S3Checksum {
138    e_tag: Option<String>,
139}
140
141impl OneshotSource for AwsS3Source {
142    type Object = S3Object;
143    type Checksum = S3Checksum;
144
145    async fn list<'a>(
146        &'a self,
147    ) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
148        let client = self.client().await.map_err(StorageErrorXKind::generic)?;
149        let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
150
151        // Users can optionally specify a prefix via the S3 uri they originally specify.
152        if let Some(prefix) = &self.prefix {
153            objects_request = objects_request.prefix(prefix);
154        }
155
156        let objects = objects_request
157            .send()
158            .await
159            .map_err(StorageErrorXKind::generic)
160            .context("list_objects_v2")?;
161
162        // TODO(cf1): Pagination.
163
164        let objects: Vec<_> = objects
165            .contents()
166            .iter()
167            .map(|o| {
168                let key = o
169                    .key()
170                    .ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
171                    .to_owned();
172                let name = Path::new(&key)
173                    .file_name()
174                    .and_then(|os_name| os_name.to_str())
175                    .ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
176                    .to_string();
177                let size = o
178                    .size()
179                    .ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
180                let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
181
182                let object = S3Object { key, name, size };
183                let checksum = S3Checksum {
184                    e_tag: o.e_tag().map(|x| x.to_owned()),
185                };
186
187                Ok::<_, StorageErrorXKind>((object, checksum))
188            })
189            .collect::<Result<_, _>>()
190            .context("list")?;
191
192        Ok(objects)
193    }
194
195    fn get<'s>(
196        &'s self,
197        object: Self::Object,
198        _checksum: Self::Checksum,
199        range: Option<std::ops::RangeInclusive<usize>>,
200    ) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
201        let initial_response = async move {
202            tracing::info!(name = %object.name(), ?range, "fetching object");
203
204            // TODO(cf1): Validate our checksum.
205            let client = self.client().await.map_err(StorageErrorXKind::generic)?;
206
207            let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
208            if let Some(range) = range {
209                let value = range.into_range_header_value();
210                request = request.range(value);
211            }
212
213            let object = request
214                .send()
215                .await
216                .map_err(|err| StorageErrorXKind::AwsS3Request(err.to_string()))?;
217            // AWS's ByteStream doesn't implement the Stream trait.
218            let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
219                .err_into()
220                .boxed();
221
222            Ok::<_, StorageErrorX>(stream)
223        };
224
225        futures::stream::once(initial_response)
226            .try_flatten()
227            .boxed()
228    }
229}