mz_storage_operators/oneshot_source/
aws_source.rs1use std::path::Path;
13use std::str::FromStr;
14use std::sync::Arc;
15
16use aws_sdk_s3::error::DisplayErrorContext;
17use derivative::Derivative;
18use futures::StreamExt;
19use futures::stream::{BoxStream, TryStreamExt};
20use mz_ore::future::InTask;
21use mz_repr::CatalogItemId;
22use mz_storage_types::connections::ConnectionContext;
23use mz_storage_types::connections::aws::AwsConnection;
24use serde::{Deserialize, Serialize};
25
26use crate::oneshot_source::util::IntoRangeHeaderValue;
27use crate::oneshot_source::{
28 OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
29};
30
31#[derive(Clone, Derivative)]
32#[derivative(Debug)]
33pub struct AwsS3Source {
34 #[derivative(Debug = "ignore")]
36 connection: Arc<AwsConnection>,
37 connection_id: CatalogItemId,
38 #[derivative(Debug = "ignore")]
39 context: Arc<ConnectionContext>,
40
41 bucket: String,
43 prefix: Option<String>,
45 #[derivative(Debug = "ignore")]
47 client: std::sync::OnceLock<mz_aws_util::s3::Client>,
48 use_checksum: bool,
49}
50
51impl AwsS3Source {
52 pub fn new(
53 connection: AwsConnection,
54 connection_id: CatalogItemId,
55 context: ConnectionContext,
56 uri: String,
57 use_checksum: bool,
58 ) -> Self {
59 let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
60
61 let bucket = uri
62 .host()
63 .expect("validated host in sequencing")
64 .to_string();
65 let prefix = if uri.path().is_empty() || uri.path() == "/" {
66 None
67 } else {
68 let mut prefix = uri.path().to_string();
70
71 if let Some(suffix) = prefix.strip_prefix('/') {
72 prefix = suffix.to_string();
73 }
74 if !prefix.ends_with('/') {
75 prefix = format!("{prefix}/");
76 }
77
78 Some(prefix)
79 };
80
81 AwsS3Source {
82 connection: Arc::new(connection),
83 context: Arc::new(context),
84 connection_id,
85 bucket,
86 prefix,
87 client: std::sync::OnceLock::new(),
88 use_checksum,
89 }
90 }
91
92 pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
93 let sdk_config = self
94 .connection
95 .load_sdk_config(&self.context, self.connection_id, InTask::Yes)
96 .await?;
97
98 let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&sdk_config)
99 .force_path_style(sdk_config.endpoint_url().is_some());
100
101 if !self.use_checksum {
102 s3_config_builder = s3_config_builder.response_checksum_validation(
103 aws_smithy_types::checksum_config::ResponseChecksumValidation::WhenRequired,
104 );
105 }
106 let s3_config = s3_config_builder.build();
107 let s3_client = aws_sdk_s3::Client::from_conf(s3_config);
108
109 Ok(s3_client)
110 }
111
112 pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
113 if self.client.get().is_none() {
114 let client = self.initialize().await?;
115 let _ = self.client.set(client);
116 }
117
118 Ok(self.client.get().expect("just initialized"))
119 }
120}
121
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub struct S3Object {
124 key: String,
126 name: String,
128 size: usize,
130}
131
132impl OneshotObject for S3Object {
133 fn name(&self) -> &str {
134 &self.name
135 }
136
137 fn path(&self) -> &str {
138 &self.key
139 }
140
141 fn size(&self) -> usize {
142 self.size
143 }
144
145 fn encodings(&self) -> &[super::Encoding] {
146 &[]
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
151pub struct S3Checksum {
152 e_tag: Option<String>,
153}
154
155impl OneshotSource for AwsS3Source {
156 type Object = S3Object;
157 type Checksum = S3Checksum;
158
159 async fn list<'a>(
160 &'a self,
161 ) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
162 let client = self.client().await.map_err(StorageErrorXKind::generic)?;
163 let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
164
165 if let Some(prefix) = &self.prefix {
167 objects_request = objects_request.prefix(prefix);
168 }
169
170 let objects = objects_request
171 .into_paginator()
172 .send()
173 .try_collect()
174 .await
175 .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))
176 .context("list_objects_v2")?;
177
178 let objects: Vec<_> = objects
179 .iter()
180 .flat_map(aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output::contents)
181 .map(|o| {
182 let key = o
183 .key()
184 .ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
185 .to_owned();
186 let name = Path::new(&key)
187 .file_name()
188 .and_then(|os_name| os_name.to_str())
189 .ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
190 .to_string();
191 let size = o
192 .size()
193 .ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
194 let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
195
196 let object = S3Object { key, name, size };
197 let checksum = S3Checksum {
198 e_tag: o.e_tag().map(|x| x.to_owned()),
199 };
200
201 Ok::<_, StorageErrorXKind>((object, checksum))
202 })
203 .collect::<Result<_, _>>()
204 .context("list")?;
205
206 Ok(objects)
207 }
208
209 fn get<'s>(
210 &'s self,
211 object: Self::Object,
212 _checksum: Self::Checksum,
213 range: Option<std::ops::RangeInclusive<usize>>,
214 ) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
215 let initial_response = async move {
216 tracing::info!(name = %object.name(), ?range, "fetching object");
217
218 let client = self.client().await.map_err(StorageErrorXKind::generic)?;
219
220 let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
221 if let Some(range) = range {
222 let value = range.into_range_header_value();
223 request = request.range(value);
224 }
225 let object = request
226 .send()
227 .await
228 .map_err(|err| StorageErrorXKind::generic(DisplayErrorContext(err)))?;
229 let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
231 .err_into()
232 .boxed();
233
234 Ok::<_, StorageErrorX>(stream)
235 };
236
237 futures::stream::once(initial_response)
238 .try_flatten()
239 .boxed()
240 }
241}