mz_storage_operators/oneshot_source/
aws_source.rs1use std::path::Path;
13use std::str::FromStr;
14use std::sync::Arc;
15
16use aws_sdk_s3::error::DisplayErrorContext;
17use derivative::Derivative;
18use futures::StreamExt;
19use futures::stream::{BoxStream, TryStreamExt};
20use mz_ore::future::InTask;
21use mz_repr::CatalogItemId;
22use mz_storage_types::connections::ConnectionContext;
23use mz_storage_types::connections::aws::AwsConnection;
24use serde::{Deserialize, Serialize};
25
26use crate::oneshot_source::util::IntoRangeHeaderValue;
27use crate::oneshot_source::{
28 OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
29};
30
31#[derive(Clone, Derivative)]
32#[derivative(Debug)]
33pub struct AwsS3Source {
34 #[derivative(Debug = "ignore")]
36 connection: Arc<AwsConnection>,
37 connection_id: CatalogItemId,
38 #[derivative(Debug = "ignore")]
39 context: Arc<ConnectionContext>,
40
41 bucket: String,
43 prefix: Option<String>,
45 #[derivative(Debug = "ignore")]
47 client: std::sync::OnceLock<mz_aws_util::s3::Client>,
48 use_checksum: bool,
49 enforce_external_addresses: bool,
50}
51
52impl AwsS3Source {
53 pub fn new(
54 connection: AwsConnection,
55 connection_id: CatalogItemId,
56 context: ConnectionContext,
57 uri: String,
58 use_checksum: bool,
59 enforce_external_addresses: bool,
60 ) -> Self {
61 let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
62
63 let bucket = uri
64 .host()
65 .expect("validated host in sequencing")
66 .to_string();
67 let prefix = if uri.path().is_empty() || uri.path() == "/" {
68 None
69 } else {
70 let mut prefix = uri.path().to_string();
72
73 if let Some(suffix) = prefix.strip_prefix('/') {
74 prefix = suffix.to_string();
75 }
76 if !prefix.ends_with('/') {
77 prefix = format!("{prefix}/");
78 }
79
80 Some(prefix)
81 };
82
83 AwsS3Source {
84 connection: Arc::new(connection),
85 context: Arc::new(context),
86 connection_id,
87 bucket,
88 prefix,
89 client: std::sync::OnceLock::new(),
90 use_checksum,
91 enforce_external_addresses,
92 }
93 }
94
95 pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
96 let sdk_config = self
97 .connection
98 .load_sdk_config(
99 &self.context,
100 self.connection_id,
101 InTask::Yes,
102 self.enforce_external_addresses,
103 )
104 .await?;
105
106 let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&sdk_config)
107 .force_path_style(sdk_config.endpoint_url().is_some());
108
109 if !self.use_checksum {
110 s3_config_builder = s3_config_builder.response_checksum_validation(
111 aws_smithy_types::checksum_config::ResponseChecksumValidation::WhenRequired,
112 );
113 }
114 let s3_config = s3_config_builder.build();
115 let s3_client = aws_sdk_s3::Client::from_conf(s3_config);
116
117 Ok(s3_client)
118 }
119
120 pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
121 if self.client.get().is_none() {
122 let client = self.initialize().await?;
123 let _ = self.client.set(client);
124 }
125
126 Ok(self.client.get().expect("just initialized"))
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct S3Object {
132 key: String,
134 name: String,
136 size: usize,
138}
139
140impl OneshotObject for S3Object {
141 fn name(&self) -> &str {
142 &self.name
143 }
144
145 fn path(&self) -> &str {
146 &self.key
147 }
148
149 fn size(&self) -> usize {
150 self.size
151 }
152
153 fn encodings(&self) -> &[super::Encoding] {
154 &[]
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct S3Checksum {
160 e_tag: Option<String>,
161}
162
163impl OneshotSource for AwsS3Source {
164 type Object = S3Object;
165 type Checksum = S3Checksum;
166
167 async fn list<'a>(
168 &'a self,
169 ) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
170 let client = self.client().await.map_err(StorageErrorXKind::generic)?;
171 let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
172
173 if let Some(prefix) = &self.prefix {
175 objects_request = objects_request.prefix(prefix);
176 }
177
178 let objects = objects_request
179 .into_paginator()
180 .send()
181 .try_collect()
182 .await
183 .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))
184 .context("list_objects_v2")?;
185
186 let objects: Vec<_> = objects
187 .iter()
188 .flat_map(aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output::contents)
189 .map(|o| {
190 let key = o
191 .key()
192 .ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
193 .to_owned();
194 let name = Path::new(&key)
195 .file_name()
196 .and_then(|os_name| os_name.to_str())
197 .ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
198 .to_string();
199 let size = o
200 .size()
201 .ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
202 let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
203
204 let object = S3Object { key, name, size };
205 let checksum = S3Checksum {
206 e_tag: o.e_tag().map(|x| x.to_owned()),
207 };
208
209 Ok::<_, StorageErrorXKind>((object, checksum))
210 })
211 .collect::<Result<_, _>>()
212 .context("list")?;
213
214 Ok(objects)
215 }
216
217 fn get<'s>(
218 &'s self,
219 object: Self::Object,
220 _checksum: Self::Checksum,
221 range: Option<std::ops::RangeInclusive<usize>>,
222 ) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
223 let initial_response = async move {
224 tracing::info!(name = %object.name(), ?range, "fetching object");
225
226 let client = self.client().await.map_err(StorageErrorXKind::generic)?;
227
228 let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
229 if let Some(range) = range {
230 let value = range.into_range_header_value();
231 request = request.range(value);
232 }
233 let object = request
234 .send()
235 .await
236 .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))?;
237 let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
239 .err_into()
240 .boxed();
241
242 Ok::<_, StorageErrorX>(stream)
243 };
244
245 futures::stream::once(initial_response)
246 .try_flatten()
247 .boxed()
248 }
249}