1use anyhow::anyhow;
11use aws_sdk_s3::Client;
12use aws_sdk_s3::error::SdkError;
13use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadError;
14use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadError;
15use aws_sdk_s3::operation::upload_part::UploadPartError;
16use aws_sdk_s3::primitives::ByteStream;
17use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
18use aws_types::sdk_config::SdkConfig;
19use bytes::{Bytes, BytesMut};
20use bytesize::ByteSize;
21use mz_ore::cast::CastFrom;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{JoinHandle, JoinHandleExt, spawn};
24
25#[derive(Debug)]
33pub struct S3MultiPartUploader {
34 client: Client,
35 config: S3MultiPartUploaderConfig,
37 bucket: String,
39 key: String,
41 upload_id: String,
43 part_count: i32,
45 total_bytes_uploaded: u64,
47 buffer: BytesMut,
50 upload_handles: Vec<JoinHandle<Result<(Option<String>, i32), S3MultiPartUploadError>>>,
52}
53
54pub const AWS_S3_MAX_PART_COUNT: i32 = 10_000;
58const AWS_S3_MIN_PART_SIZE: ByteSize = ByteSize::mib(5);
64const AWS_S3_MAX_PART_SIZE: ByteSize = ByteSize::gib(5);
68const AWS_S3_MAX_OBJECT_SIZE: ByteSize = ByteSize::tib(5);
72
73#[derive(Debug)]
75pub struct CompletedUpload {
76 pub part_count: u32,
78 pub total_bytes_uploaded: u64,
80 pub bucket: String,
81 pub key: String,
82}
83
84#[derive(Debug)]
86pub struct S3MultiPartUploaderConfig {
87 pub part_size_limit: u64,
89 pub file_size_limit: u64,
91}
92
93impl S3MultiPartUploaderConfig {
94 const DEFAULT_MAX_FILE_SIZE: ByteSize = ByteSize::gib(5);
98 const DEFAULT_PART_SIZE_LIMIT: ByteSize = ByteSize::mib(10);
101
102 fn validate(&self) -> Result<(), anyhow::Error> {
105 let S3MultiPartUploaderConfig {
106 part_size_limit,
107 file_size_limit,
108 } = self;
109 if part_size_limit < &AWS_S3_MIN_PART_SIZE.as_u64()
110 || part_size_limit > &AWS_S3_MAX_PART_SIZE.as_u64()
111 {
112 return Err(anyhow!(format!(
113 "invalid part size: {}, should be between {} and {} bytes",
114 part_size_limit,
115 AWS_S3_MIN_PART_SIZE.as_u64(),
116 AWS_S3_MAX_PART_SIZE.as_u64()
117 )));
118 }
119 if file_size_limit > &AWS_S3_MAX_OBJECT_SIZE.as_u64() {
120 return Err(anyhow!(format!(
121 "invalid file size: {}, cannot exceed {} bytes",
122 file_size_limit,
123 AWS_S3_MAX_OBJECT_SIZE.as_u64()
124 )));
125 }
126 let max_parts_count: u64 = AWS_S3_MAX_PART_COUNT.try_into().expect("i32 to u64");
127 let estimated_parts_count = file_size_limit.div_ceil(*part_size_limit);
130 if estimated_parts_count > max_parts_count {
131 return Err(anyhow!(format!(
132 "total number of possible parts (file_size_limit / part_size_limit): {}, cannot exceed {}",
133 estimated_parts_count, AWS_S3_MAX_PART_COUNT
134 )));
135 }
136 Ok(())
137 }
138}
139
140impl Default for S3MultiPartUploaderConfig {
141 fn default() -> Self {
142 Self {
143 part_size_limit: Self::DEFAULT_PART_SIZE_LIMIT.as_u64(),
144 file_size_limit: Self::DEFAULT_MAX_FILE_SIZE.as_u64(),
145 }
146 }
147}
148
149impl S3MultiPartUploader {
150 pub async fn try_new(
154 sdk_config: &SdkConfig,
155 bucket: String,
156 key: String,
157 config: S3MultiPartUploaderConfig,
158 ) -> Result<S3MultiPartUploader, S3MultiPartUploadError> {
159 config.validate()?;
161
162 let client = crate::s3::new_client(sdk_config);
163 let res = client
164 .create_multipart_upload()
165 .bucket(&bucket)
166 .key(&key)
167 .send()
168 .await?;
169 let upload_id = res
170 .upload_id()
171 .ok_or_else(|| anyhow!("create_multipart_upload response missing upload id"))?
172 .to_string();
173 Ok(S3MultiPartUploader {
174 client,
175 bucket,
176 key,
177 upload_id,
178 part_count: 0,
179 total_bytes_uploaded: 0,
180 buffer: Default::default(),
181 config,
182 upload_handles: Default::default(),
183 })
184 }
185
186 pub fn buffer_chunk(&mut self, data: &[u8]) -> Result<(), S3MultiPartUploadError> {
192 let data_len = u64::cast_from(data.len());
193
194 let aws_max_part_count: u64 = AWS_S3_MAX_PART_COUNT.try_into().expect("i32 to u64");
195 let absolute_max_file_limit = std::cmp::min(
196 self.config.part_size_limit * aws_max_part_count,
197 AWS_S3_MAX_OBJECT_SIZE.as_u64(),
198 );
199
200 let can_force_first_upload = self.added_bytes() == 0 && data_len <= absolute_max_file_limit;
202
203 if data_len <= self.remaining_bytes_limit() || can_force_first_upload {
204 self.buffer.extend_from_slice(data);
205 self.flush_chunks()?;
206 Ok(())
207 } else {
208 Err(S3MultiPartUploadError::UploadExceedsMaxFileLimit(
209 self.config.file_size_limit,
210 ))
211 }
212 }
213
214 pub async fn finish(mut self) -> Result<CompletedUpload, S3MultiPartUploadError> {
218 let remaining = self.buffer.split();
219 self.upload_part_internal(remaining.freeze())?;
220
221 let mut parts: Vec<CompletedPart> = Vec::with_capacity(self.upload_handles.len());
222 for handle in self.upload_handles {
223 let (etag, part_num) = handle.wait_and_assert_finished().await?;
224 match etag {
225 Some(etag) => {
226 parts.push(
227 CompletedPart::builder()
228 .e_tag(etag)
229 .part_number(part_num)
230 .build(),
231 );
232 }
233 None => Err(anyhow!("etag for part {part_num} is None"))?,
234 }
235 }
236
237 self.client
238 .complete_multipart_upload()
239 .bucket(&self.bucket)
240 .key(&self.key)
241 .upload_id(self.upload_id.clone())
242 .multipart_upload(
243 CompletedMultipartUpload::builder()
244 .set_parts(Some(parts))
245 .build(),
246 )
247 .send()
248 .await?;
249 Ok(CompletedUpload {
250 part_count: self.part_count.try_into().expect("i32 to u32"),
251 total_bytes_uploaded: self.total_bytes_uploaded,
252 bucket: self.bucket,
253 key: self.key,
254 })
255 }
256
257 fn buffer_size(&self) -> u64 {
258 u64::cast_from(self.buffer.len())
259 }
260
261 fn remaining_bytes_limit(&self) -> u64 {
264 self.config
265 .file_size_limit
266 .saturating_sub(self.added_bytes())
267 }
268
269 pub fn added_bytes(&self) -> u64 {
271 self.total_bytes_uploaded + self.buffer_size()
272 }
273
274 fn flush_chunks(&mut self) -> Result<(), S3MultiPartUploadError> {
277 let part_size_limit = self.config.part_size_limit;
278 while self.buffer_size() > part_size_limit {
280 let data = self.buffer.split_to(usize::cast_from(part_size_limit));
281 self.upload_part_internal(data.freeze())?;
282 }
283 Ok(())
284 }
285
286 fn upload_part_internal(&mut self, data: Bytes) -> Result<(), S3MultiPartUploadError> {
288 let num_of_bytes: u64 = u64::cast_from(data.len());
289
290 let next_part_number = self.part_count + 1;
291 if next_part_number > AWS_S3_MAX_PART_COUNT {
292 return Err(S3MultiPartUploadError::ExceedsMaxPartNumber);
293 }
294 let client = self.client.clone();
295 let bucket = self.bucket.clone();
296 let key = self.key.clone();
297 let upload_id = self.upload_id.clone();
298
299 let handle = spawn(|| "s3::upload_part", async move {
300 let res = client
301 .upload_part()
302 .bucket(&bucket)
303 .key(&key)
304 .upload_id(upload_id)
305 .part_number(next_part_number)
306 .body(ByteStream::from(data))
307 .send()
308 .await?;
309 Ok((res.e_tag, next_part_number))
310 });
311 self.upload_handles.push(handle);
312
313 self.part_count = next_part_number;
314 self.total_bytes_uploaded += num_of_bytes;
315 Ok(())
316 }
317}
318
319#[derive(thiserror::Error, Debug)]
320pub enum S3MultiPartUploadError {
321 #[error(
322 "multi-part upload cannot have more than {} parts",
323 AWS_S3_MAX_PART_COUNT
324 )]
325 ExceedsMaxPartNumber,
326 #[error("multi-part upload will exceed configured file_size_limit: {} bytes", .0)]
327 UploadExceedsMaxFileLimit(u64),
328 #[error("{}", .0.display_with_causes())]
329 CreateMultipartUploadError(#[from] SdkError<CreateMultipartUploadError>),
330 #[error("{}", .0.display_with_causes())]
331 UploadPartError(#[from] SdkError<UploadPartError>),
332 #[error("{}", .0.display_with_causes())]
333 CompleteMultipartUploadError(#[from] SdkError<CompleteMultipartUploadError>),
334 #[error("{}", .0.display_with_causes())]
335 Other(#[from] anyhow::Error),
336}
337
338#[cfg(test)]
353mod tests {
354 use bytes::Bytes;
355 use uuid::Uuid;
356
357 use super::*;
358 use crate::{defaults, s3};
359
360 fn s3_bucket_key_for_test() -> Option<(String, String)> {
361 let bucket = match std::env::var("MZ_S3_UPLOADER_TEST_S3_BUCKET") {
362 Ok(bucket) => bucket,
363 Err(_) => {
364 if mz_ore::env::is_var_truthy("CI") {
365 panic!("CI is supposed to run this test but something has gone wrong!");
366 }
367 return None;
368 }
369 };
370
371 let prefix = Uuid::new_v4().to_string();
372 let key = format!("cargo_test/{}/file", prefix);
373 Some((bucket, key))
374 }
375
376 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
377 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_success() -> Result<(), S3MultiPartUploadError> {
381 let sdk_config = defaults().load().await;
382 let (bucket, key) = match s3_bucket_key_for_test() {
383 Some(tuple) => tuple,
384 None => return Ok(()),
385 };
386
387 let config = S3MultiPartUploaderConfig::default();
388 let mut uploader =
389 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
390
391 let expected_data = "onetwothree";
392 uploader.buffer_chunk(b"one")?;
393 uploader.buffer_chunk(b"two")?;
394 uploader.buffer_chunk(b"three")?;
395
396 let CompletedUpload {
398 part_count,
399 total_bytes_uploaded,
400 bucket: _,
401 key: _,
402 } = uploader.finish().await?;
403
404 let s3_client = s3::new_client(&sdk_config);
406 let uploaded_object = s3_client
407 .get_object()
408 .bucket(bucket)
409 .key(key)
410 .part_number(1) .send()
412 .await
413 .unwrap();
414
415 let uploaded_parts_count: u32 = uploaded_object.parts_count().unwrap().try_into().unwrap();
416 assert_eq!(uploaded_parts_count, part_count);
417 assert_eq!(part_count, 1);
418
419 let body = uploaded_object.body.collect().await.unwrap().into_bytes();
420 assert_eq!(body, expected_data);
421
422 let expected_bytes: u64 = Bytes::from(expected_data).len().try_into().unwrap();
423 assert_eq!(total_bytes_uploaded, expected_bytes);
424
425 Ok(())
426 }
427
428 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
429 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_buffer() -> Result<(), S3MultiPartUploadError> {
433 let sdk_config = defaults().load().await;
434 let (bucket, key) = match s3_bucket_key_for_test() {
435 Some(tuple) => tuple,
436 None => return Ok(()),
437 };
438
439 let config = S3MultiPartUploaderConfig {
440 part_size_limit: ByteSize::mib(5).as_u64(),
441 file_size_limit: ByteSize::mib(10).as_u64(),
442 };
443 let mut uploader =
444 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
445
446 let expected_data = vec![97; 6291456]; let expected_bytes: u64 = u64::cast_from(expected_data.len());
449 uploader.buffer_chunk(&expected_data)?;
450
451 assert_eq!(uploader.remaining_bytes_limit(), ByteSize::mib(4).as_u64());
452
453 let error = uploader.buffer_chunk(&expected_data).unwrap_err();
455 assert!(matches!(
456 error,
457 S3MultiPartUploadError::UploadExceedsMaxFileLimit(_)
458 ));
459
460 let CompletedUpload {
461 part_count,
462 total_bytes_uploaded,
463 bucket: _,
464 key: _,
465 } = uploader.finish().await?;
466
467 let s3_client = s3::new_client(&sdk_config);
469 let uploaded_object = s3_client
470 .get_object()
471 .bucket(bucket)
472 .key(key)
473 .send()
474 .await
475 .unwrap();
476
477 assert_eq!(part_count, 2); let body = uploaded_object.body.collect().await.unwrap().into_bytes();
480 assert_eq!(body, *expected_data);
481
482 assert_eq!(total_bytes_uploaded, expected_bytes);
483
484 Ok(())
485 }
486
487 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
488 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_no_data() -> Result<(), S3MultiPartUploadError> {
492 let sdk_config = defaults().load().await;
493 let (bucket, key) = match s3_bucket_key_for_test() {
494 Some(tuple) => tuple,
495 None => return Ok(()),
496 };
497
498 let config = Default::default();
499 let uploader =
500 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
501
502 uploader.finish().await.unwrap();
504
505 let s3_client = s3::new_client(&sdk_config);
507 let uploaded_object = s3_client
508 .get_object()
509 .bucket(bucket)
510 .key(key)
511 .send()
512 .await
513 .unwrap();
514
515 assert_eq!(uploaded_object.content_length(), Some(0));
516
517 Ok(())
518 }
519
520 #[mz_ore::test]
521 fn test_invalid_configs() {
522 let config = S3MultiPartUploaderConfig {
523 part_size_limit: ByteSize::mib(5).as_u64() - 1,
524 file_size_limit: ByteSize::gib(5).as_u64(),
525 };
526 let error = config.validate().unwrap_err();
527
528 assert_eq!(
529 error.to_string(),
530 "invalid part size: 5242879, should be between 5242880 and 5368709120 bytes"
531 );
532
533 let config = S3MultiPartUploaderConfig {
534 part_size_limit: ByteSize::mib(5).as_u64(),
535 file_size_limit: (ByteSize::mib(5).as_u64() * 10001) - 1,
538 };
539 let error = config.validate().unwrap_err();
540 assert_eq!(
541 error.to_string(),
542 "total number of possible parts (file_size_limit / part_size_limit): 10001, cannot exceed 10000",
543 );
544 }
545}