mz_storage_operators/oneshot_source/
aws_source.rs1use std::path::Path;
13use std::str::FromStr;
14use std::sync::Arc;
15
16use derivative::Derivative;
17use futures::StreamExt;
18use futures::stream::{BoxStream, TryStreamExt};
19use mz_ore::future::InTask;
20use mz_repr::CatalogItemId;
21use mz_storage_types::connections::ConnectionContext;
22use mz_storage_types::connections::aws::AwsConnection;
23use serde::{Deserialize, Serialize};
24
25use crate::oneshot_source::util::IntoRangeHeaderValue;
26use crate::oneshot_source::{
27 OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
28};
29
30#[derive(Clone, Derivative)]
31#[derivative(Debug)]
32pub struct AwsS3Source {
33 #[derivative(Debug = "ignore")]
35 connection: Arc<AwsConnection>,
36 connection_id: CatalogItemId,
37 #[derivative(Debug = "ignore")]
38 context: Arc<ConnectionContext>,
39
40 bucket: String,
42 prefix: Option<String>,
44 #[derivative(Debug = "ignore")]
46 client: std::sync::OnceLock<mz_aws_util::s3::Client>,
47}
48
49impl AwsS3Source {
50 pub fn new(
51 connection: AwsConnection,
52 connection_id: CatalogItemId,
53 context: ConnectionContext,
54 uri: String,
55 ) -> Self {
56 let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
57
58 let bucket = uri
59 .host()
60 .expect("validated host in sequencing")
61 .to_string();
62 let prefix = if uri.path().is_empty() || uri.path() == "/" {
63 None
64 } else {
65 let mut prefix = uri.path().to_string();
67
68 if let Some(suffix) = prefix.strip_prefix('/') {
69 prefix = suffix.to_string();
70 }
71 if !prefix.ends_with('/') {
72 prefix = format!("{prefix}/");
73 }
74
75 Some(prefix)
76 };
77
78 AwsS3Source {
79 connection: Arc::new(connection),
80 context: Arc::new(context),
81 connection_id,
82 bucket,
83 prefix,
84 client: std::sync::OnceLock::new(),
85 }
86 }
87
88 pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
89 let sdk_config = self
90 .connection
91 .load_sdk_config(&self.context, self.connection_id, InTask::Yes)
92 .await?;
93 let s3_client = mz_aws_util::s3::new_client(&sdk_config);
94
95 Ok(s3_client)
96 }
97
98 pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
99 if self.client.get().is_none() {
100 let client = self.initialize().await?;
101 let _ = self.client.set(client);
102 }
103
104 Ok(self.client.get().expect("just initialized"))
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct S3Object {
110 key: String,
112 name: String,
114 size: usize,
116}
117
118impl OneshotObject for S3Object {
119 fn name(&self) -> &str {
120 &self.name
121 }
122
123 fn path(&self) -> &str {
124 &self.key
125 }
126
127 fn size(&self) -> usize {
128 self.size
129 }
130
131 fn encodings(&self) -> &[super::Encoding] {
132 &[]
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
137pub struct S3Checksum {
138 e_tag: Option<String>,
139}
140
141impl OneshotSource for AwsS3Source {
142 type Object = S3Object;
143 type Checksum = S3Checksum;
144
145 async fn list<'a>(
146 &'a self,
147 ) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
148 let client = self.client().await.map_err(StorageErrorXKind::generic)?;
149 let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
150
151 if let Some(prefix) = &self.prefix {
153 objects_request = objects_request.prefix(prefix);
154 }
155
156 let objects = objects_request
157 .send()
158 .await
159 .map_err(StorageErrorXKind::generic)
160 .context("list_objects_v2")?;
161
162 let objects: Vec<_> = objects
165 .contents()
166 .iter()
167 .map(|o| {
168 let key = o
169 .key()
170 .ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
171 .to_owned();
172 let name = Path::new(&key)
173 .file_name()
174 .and_then(|os_name| os_name.to_str())
175 .ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
176 .to_string();
177 let size = o
178 .size()
179 .ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
180 let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
181
182 let object = S3Object { key, name, size };
183 let checksum = S3Checksum {
184 e_tag: o.e_tag().map(|x| x.to_owned()),
185 };
186
187 Ok::<_, StorageErrorXKind>((object, checksum))
188 })
189 .collect::<Result<_, _>>()
190 .context("list")?;
191
192 Ok(objects)
193 }
194
195 fn get<'s>(
196 &'s self,
197 object: Self::Object,
198 _checksum: Self::Checksum,
199 range: Option<std::ops::RangeInclusive<usize>>,
200 ) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
201 let initial_response = async move {
202 tracing::info!(name = %object.name(), ?range, "fetching object");
203
204 let client = self.client().await.map_err(StorageErrorXKind::generic)?;
206
207 let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
208 if let Some(range) = range {
209 let value = range.into_range_header_value();
210 request = request.range(value);
211 }
212
213 let object = request
214 .send()
215 .await
216 .map_err(|err| StorageErrorXKind::AwsS3Request(err.to_string()))?;
217 let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
219 .err_into()
220 .boxed();
221
222 Ok::<_, StorageErrorX>(stream)
223 };
224
225 futures::stream::once(initial_response)
226 .try_flatten()
227 .boxed()
228 }
229}