1use anyhow::anyhow;
11use aws_sdk_s3::Client;
12use aws_sdk_s3::error::SdkError;
13use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadError;
14use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadError;
15use aws_sdk_s3::operation::upload_part::UploadPartError;
16use aws_sdk_s3::primitives::ByteStream;
17use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
18use aws_types::sdk_config::SdkConfig;
19use bytes::{Bytes, BytesMut};
20use bytesize::ByteSize;
21use mz_ore::cast::CastFrom;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{JoinHandle, JoinHandleExt, spawn};
24
25#[derive(Debug)]
33pub struct S3MultiPartUploader {
34 client: Client,
35 config: S3MultiPartUploaderConfig,
37 bucket: String,
39 key: String,
41 upload_id: String,
43 part_count: i32,
45 total_bytes_uploaded: u64,
47 buffer: BytesMut,
50 upload_handles: Vec<JoinHandle<Result<(Option<String>, i32), S3MultiPartUploadError>>>,
52}
53
54pub const AWS_S3_MAX_PART_COUNT: i32 = 10_000;
58const AWS_S3_MIN_PART_SIZE: ByteSize = ByteSize::mib(5);
64const AWS_S3_MAX_PART_SIZE: ByteSize = ByteSize::gib(5);
68const AWS_S3_MAX_OBJECT_SIZE: ByteSize = ByteSize::tib(5);
72
73#[derive(Debug)]
75pub struct CompletedUpload {
76 pub part_count: u32,
78 pub total_bytes_uploaded: u64,
80 pub bucket: String,
81 pub key: String,
82}
83
84#[derive(Debug)]
86pub struct S3MultiPartUploaderConfig {
87 pub part_size_limit: u64,
89 pub file_size_limit: u64,
91}
92
93impl S3MultiPartUploaderConfig {
94 const DEFAULT_MAX_FILE_SIZE: ByteSize = ByteSize::gib(5);
98 const DEFAULT_PART_SIZE_LIMIT: ByteSize = ByteSize::mib(10);
101
102 fn validate(&self) -> Result<(), anyhow::Error> {
105 let S3MultiPartUploaderConfig {
106 part_size_limit,
107 file_size_limit,
108 } = self;
109 if part_size_limit < &AWS_S3_MIN_PART_SIZE.as_u64()
110 || part_size_limit > &AWS_S3_MAX_PART_SIZE.as_u64()
111 {
112 return Err(anyhow!(format!(
113 "invalid part size: {}, should be between {} and {} bytes",
114 part_size_limit,
115 AWS_S3_MIN_PART_SIZE.as_u64(),
116 AWS_S3_MAX_PART_SIZE.as_u64()
117 )));
118 }
119 if file_size_limit > &AWS_S3_MAX_OBJECT_SIZE.as_u64() {
120 return Err(anyhow!(format!(
121 "invalid file size: {}, cannot exceed {} bytes",
122 file_size_limit,
123 AWS_S3_MAX_OBJECT_SIZE.as_u64()
124 )));
125 }
126 let max_parts_count: u64 = AWS_S3_MAX_PART_COUNT.try_into().expect("i32 to u64");
127 let estimated_parts_count = file_size_limit.div_ceil(*part_size_limit);
130 if estimated_parts_count > max_parts_count {
131 return Err(anyhow!(format!(
132 "total number of possible parts (file_size_limit / part_size_limit): {}, cannot exceed {}",
133 estimated_parts_count, AWS_S3_MAX_PART_COUNT
134 )));
135 }
136 Ok(())
137 }
138}
139
140impl Default for S3MultiPartUploaderConfig {
141 fn default() -> Self {
142 Self {
143 part_size_limit: Self::DEFAULT_PART_SIZE_LIMIT.as_u64(),
144 file_size_limit: Self::DEFAULT_MAX_FILE_SIZE.as_u64(),
145 }
146 }
147}
148
149impl S3MultiPartUploader {
150 pub async fn try_new(
154 sdk_config: &SdkConfig,
155 bucket: String,
156 key: String,
157 config: S3MultiPartUploaderConfig,
158 ) -> Result<S3MultiPartUploader, S3MultiPartUploadError> {
159 config.validate()?;
161
162 let client = crate::s3::new_client(sdk_config);
163 let res = client
164 .create_multipart_upload()
165 .bucket(&bucket)
166 .key(&key)
167 .customize()
168 .mutate_request(|req| {
169 req.headers_mut().insert("Content-Length", "0");
172 })
173 .send()
174 .await?;
175 let upload_id = res
176 .upload_id()
177 .ok_or_else(|| anyhow!("create_multipart_upload response missing upload id"))?
178 .to_string();
179 Ok(S3MultiPartUploader {
180 client,
181 bucket,
182 key,
183 upload_id,
184 part_count: 0,
185 total_bytes_uploaded: 0,
186 buffer: Default::default(),
187 config,
188 upload_handles: Default::default(),
189 })
190 }
191
192 pub fn buffer_chunk(&mut self, data: &[u8]) -> Result<(), S3MultiPartUploadError> {
198 let data_len = u64::cast_from(data.len());
199
200 let aws_max_part_count: u64 = AWS_S3_MAX_PART_COUNT.try_into().expect("i32 to u64");
201 let absolute_max_file_limit = std::cmp::min(
202 self.config.part_size_limit * aws_max_part_count,
203 AWS_S3_MAX_OBJECT_SIZE.as_u64(),
204 );
205
206 let can_force_first_upload = self.added_bytes() == 0 && data_len <= absolute_max_file_limit;
208
209 if data_len <= self.remaining_bytes_limit() || can_force_first_upload {
210 self.buffer.extend_from_slice(data);
211 self.flush_chunks()?;
212 Ok(())
213 } else {
214 Err(S3MultiPartUploadError::UploadExceedsMaxFileLimit(
215 self.config.file_size_limit,
216 ))
217 }
218 }
219
220 pub async fn finish(mut self) -> Result<CompletedUpload, S3MultiPartUploadError> {
224 let remaining = self.buffer.split();
225 self.upload_part_internal(remaining.freeze())?;
226
227 let mut parts: Vec<CompletedPart> = Vec::with_capacity(self.upload_handles.len());
228 for handle in self.upload_handles {
229 let (etag, part_num) = handle.wait_and_assert_finished().await?;
230 match etag {
231 Some(etag) => {
232 parts.push(
233 CompletedPart::builder()
234 .e_tag(etag)
235 .part_number(part_num)
236 .build(),
237 );
238 }
239 None => Err(anyhow!("etag for part {part_num} is None"))?,
240 }
241 }
242
243 self.client
244 .complete_multipart_upload()
245 .bucket(&self.bucket)
246 .key(&self.key)
247 .upload_id(self.upload_id.clone())
248 .multipart_upload(
249 CompletedMultipartUpload::builder()
250 .set_parts(Some(parts))
251 .build(),
252 )
253 .send()
254 .await?;
255 Ok(CompletedUpload {
256 part_count: self.part_count.try_into().expect("i32 to u32"),
257 total_bytes_uploaded: self.total_bytes_uploaded,
258 bucket: self.bucket,
259 key: self.key,
260 })
261 }
262
263 fn buffer_size(&self) -> u64 {
264 u64::cast_from(self.buffer.len())
265 }
266
267 fn remaining_bytes_limit(&self) -> u64 {
270 self.config
271 .file_size_limit
272 .saturating_sub(self.added_bytes())
273 }
274
275 pub fn added_bytes(&self) -> u64 {
277 self.total_bytes_uploaded + self.buffer_size()
278 }
279
280 fn flush_chunks(&mut self) -> Result<(), S3MultiPartUploadError> {
283 let part_size_limit = self.config.part_size_limit;
284 while self.buffer_size() > part_size_limit {
286 let data = self.buffer.split_to(usize::cast_from(part_size_limit));
287 self.upload_part_internal(data.freeze())?;
288 }
289 Ok(())
290 }
291
292 fn upload_part_internal(&mut self, data: Bytes) -> Result<(), S3MultiPartUploadError> {
294 let num_of_bytes: u64 = u64::cast_from(data.len());
295
296 let next_part_number = self.part_count + 1;
297 if next_part_number > AWS_S3_MAX_PART_COUNT {
298 return Err(S3MultiPartUploadError::ExceedsMaxPartNumber);
299 }
300 let client = self.client.clone();
301 let bucket = self.bucket.clone();
302 let key = self.key.clone();
303 let upload_id = self.upload_id.clone();
304
305 let handle = spawn(|| "s3::upload_part", async move {
306 let res = client
307 .upload_part()
308 .bucket(&bucket)
309 .key(&key)
310 .upload_id(upload_id)
311 .part_number(next_part_number)
312 .body(ByteStream::from(data))
313 .send()
314 .await?;
315 Ok((res.e_tag, next_part_number))
316 });
317 self.upload_handles.push(handle);
318
319 self.part_count = next_part_number;
320 self.total_bytes_uploaded += num_of_bytes;
321 Ok(())
322 }
323}
324
325#[derive(thiserror::Error, Debug)]
326pub enum S3MultiPartUploadError {
327 #[error(
328 "multi-part upload cannot have more than {} parts",
329 AWS_S3_MAX_PART_COUNT
330 )]
331 ExceedsMaxPartNumber,
332 #[error("multi-part upload will exceed configured file_size_limit: {} bytes", .0)]
333 UploadExceedsMaxFileLimit(u64),
334 #[error("{}", .0.display_with_causes())]
335 CreateMultipartUploadError(#[from] SdkError<CreateMultipartUploadError>),
336 #[error("{}", .0.display_with_causes())]
337 UploadPartError(#[from] SdkError<UploadPartError>),
338 #[error("{}", .0.display_with_causes())]
339 CompleteMultipartUploadError(#[from] SdkError<CompleteMultipartUploadError>),
340 #[error("{}", .0.display_with_causes())]
341 Other(#[from] anyhow::Error),
342}
343
344#[cfg(test)]
359mod tests {
360 use bytes::Bytes;
361 use uuid::Uuid;
362
363 use super::*;
364 use crate::{defaults, s3};
365
366 fn s3_bucket_key_for_test() -> Option<(String, String)> {
367 let bucket = match std::env::var("MZ_S3_UPLOADER_TEST_S3_BUCKET") {
368 Ok(bucket) => bucket,
369 Err(_) => {
370 if mz_ore::env::is_var_truthy("CI") {
371 panic!("CI is supposed to run this test but something has gone wrong!");
372 }
373 return None;
374 }
375 };
376
377 let prefix = Uuid::new_v4().to_string();
378 let key = format!("cargo_test/{}/file", prefix);
379 Some((bucket, key))
380 }
381
382 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
383 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_success() -> Result<(), S3MultiPartUploadError> {
387 let sdk_config = defaults().load().await;
388 let (bucket, key) = match s3_bucket_key_for_test() {
389 Some(tuple) => tuple,
390 None => return Ok(()),
391 };
392
393 let config = S3MultiPartUploaderConfig::default();
394 let mut uploader =
395 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
396
397 let expected_data = "onetwothree";
398 uploader.buffer_chunk(b"one")?;
399 uploader.buffer_chunk(b"two")?;
400 uploader.buffer_chunk(b"three")?;
401
402 let CompletedUpload {
404 part_count,
405 total_bytes_uploaded,
406 bucket: _,
407 key: _,
408 } = uploader.finish().await?;
409
410 let s3_client = s3::new_client(&sdk_config);
412 let uploaded_object = s3_client
413 .get_object()
414 .bucket(bucket)
415 .key(key)
416 .part_number(1) .send()
418 .await
419 .unwrap();
420
421 let uploaded_parts_count: u32 = uploaded_object.parts_count().unwrap().try_into().unwrap();
422 assert_eq!(uploaded_parts_count, part_count);
423 assert_eq!(part_count, 1);
424
425 let body = uploaded_object.body.collect().await.unwrap().into_bytes();
426 assert_eq!(body, expected_data);
427
428 let expected_bytes: u64 = Bytes::from(expected_data).len().try_into().unwrap();
429 assert_eq!(total_bytes_uploaded, expected_bytes);
430
431 Ok(())
432 }
433
434 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
435 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_buffer() -> Result<(), S3MultiPartUploadError> {
439 let sdk_config = defaults().load().await;
440 let (bucket, key) = match s3_bucket_key_for_test() {
441 Some(tuple) => tuple,
442 None => return Ok(()),
443 };
444
445 let config = S3MultiPartUploaderConfig {
446 part_size_limit: ByteSize::mib(5).as_u64(),
447 file_size_limit: ByteSize::mib(10).as_u64(),
448 };
449 let mut uploader =
450 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
451
452 let expected_data = vec![97; 6291456]; let expected_bytes: u64 = u64::cast_from(expected_data.len());
455 uploader.buffer_chunk(&expected_data)?;
456
457 assert_eq!(uploader.remaining_bytes_limit(), ByteSize::mib(4).as_u64());
458
459 let error = uploader.buffer_chunk(&expected_data).unwrap_err();
461 assert!(matches!(
462 error,
463 S3MultiPartUploadError::UploadExceedsMaxFileLimit(_)
464 ));
465
466 let CompletedUpload {
467 part_count,
468 total_bytes_uploaded,
469 bucket: _,
470 key: _,
471 } = uploader.finish().await?;
472
473 let s3_client = s3::new_client(&sdk_config);
475 let uploaded_object = s3_client
476 .get_object()
477 .bucket(bucket)
478 .key(key)
479 .send()
480 .await
481 .unwrap();
482
483 assert_eq!(part_count, 2); let body = uploaded_object.body.collect().await.unwrap().into_bytes();
486 assert_eq!(body, *expected_data);
487
488 assert_eq!(total_bytes_uploaded, expected_bytes);
489
490 Ok(())
491 }
492
493 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
494 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_no_data() -> Result<(), S3MultiPartUploadError> {
498 let sdk_config = defaults().load().await;
499 let (bucket, key) = match s3_bucket_key_for_test() {
500 Some(tuple) => tuple,
501 None => return Ok(()),
502 };
503
504 let config = Default::default();
505 let uploader =
506 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
507
508 uploader.finish().await.unwrap();
510
511 let s3_client = s3::new_client(&sdk_config);
513 let uploaded_object = s3_client
514 .get_object()
515 .bucket(bucket)
516 .key(key)
517 .send()
518 .await
519 .unwrap();
520
521 assert_eq!(uploaded_object.content_length(), Some(0));
522
523 Ok(())
524 }
525
526 #[mz_ore::test]
527 fn test_invalid_configs() {
528 let config = S3MultiPartUploaderConfig {
529 part_size_limit: ByteSize::mib(5).as_u64() - 1,
530 file_size_limit: ByteSize::gib(5).as_u64(),
531 };
532 let error = config.validate().unwrap_err();
533
534 assert_eq!(
535 error.to_string(),
536 "invalid part size: 5242879, should be between 5242880 and 5368709120 bytes"
537 );
538
539 let config = S3MultiPartUploaderConfig {
540 part_size_limit: ByteSize::mib(5).as_u64(),
541 file_size_limit: (ByteSize::mib(5).as_u64() * 10001) - 1,
544 };
545 let error = config.validate().unwrap_err();
546 assert_eq!(
547 error.to_string(),
548 "total number of possible parts (file_size_limit / part_size_limit): 10001, cannot exceed 10000",
549 );
550 }
551}