1use anyhow::anyhow;
11use aws_sdk_s3::Client;
12use aws_sdk_s3::error::SdkError;
13use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadError;
14use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadError;
15use aws_sdk_s3::operation::upload_part::UploadPartError;
16use aws_sdk_s3::primitives::ByteStream;
17use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
18use aws_types::sdk_config::SdkConfig;
19use bytes::{Bytes, BytesMut};
20use bytesize::ByteSize;
21use mz_ore::cast::CastFrom;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{JoinHandle, JoinHandleExt, spawn};
24
25#[derive(Debug)]
33pub struct S3MultiPartUploader {
34 client: Client,
35 config: S3MultiPartUploaderConfig,
37 bucket: String,
39 key: String,
41 upload_id: String,
43 part_count: i32,
45 total_bytes_uploaded: u64,
47 buffer: BytesMut,
50 upload_handles: Vec<JoinHandle<Result<(Option<String>, i32), S3MultiPartUploadError>>>,
52}
53
54pub const AWS_S3_MAX_PART_COUNT: i32 = 10_000;
58const AWS_S3_MIN_PART_SIZE: ByteSize = ByteSize::mib(5);
64const AWS_S3_MAX_PART_SIZE: ByteSize = ByteSize::gib(5);
68const AWS_S3_MAX_OBJECT_SIZE: ByteSize = ByteSize::tib(5);
72
73#[derive(Debug)]
75pub struct CompletedUpload {
76 pub part_count: u32,
78 pub total_bytes_uploaded: u64,
80 pub bucket: String,
81 pub key: String,
82}
83
84#[derive(Debug)]
86pub struct S3MultiPartUploaderConfig {
87 pub part_size_limit: u64,
89 pub file_size_limit: u64,
91}
92
93impl S3MultiPartUploaderConfig {
94 const DEFAULT_MAX_FILE_SIZE: ByteSize = ByteSize::gib(5);
98 const DEFAULT_PART_SIZE_LIMIT: ByteSize = ByteSize::mib(10);
101
102 fn validate(&self) -> Result<(), anyhow::Error> {
105 let S3MultiPartUploaderConfig {
106 part_size_limit,
107 file_size_limit,
108 } = self;
109 if part_size_limit < &AWS_S3_MIN_PART_SIZE.as_u64()
110 || part_size_limit > &AWS_S3_MAX_PART_SIZE.as_u64()
111 {
112 return Err(anyhow!(format!(
113 "invalid part size: {}, should be between {} and {} bytes",
114 part_size_limit,
115 AWS_S3_MIN_PART_SIZE.as_u64(),
116 AWS_S3_MAX_PART_SIZE.as_u64()
117 )));
118 }
119 if file_size_limit > &AWS_S3_MAX_OBJECT_SIZE.as_u64() {
120 return Err(anyhow!(format!(
121 "invalid file size: {}, cannot exceed {} bytes",
122 file_size_limit,
123 AWS_S3_MAX_OBJECT_SIZE.as_u64()
124 )));
125 }
126 let max_parts_count: u64 = AWS_S3_MAX_PART_COUNT.try_into().expect("i32 to u64");
127 let estimated_parts_count = file_size_limit.div_ceil(*part_size_limit);
130 if estimated_parts_count > max_parts_count {
131 return Err(anyhow!(format!(
132 "total number of possible parts (file_size_limit / part_size_limit): {}, cannot exceed {}",
133 estimated_parts_count, AWS_S3_MAX_PART_COUNT
134 )));
135 }
136 Ok(())
137 }
138}
139
140impl Default for S3MultiPartUploaderConfig {
141 fn default() -> Self {
142 Self {
143 part_size_limit: Self::DEFAULT_PART_SIZE_LIMIT.as_u64(),
144 file_size_limit: Self::DEFAULT_MAX_FILE_SIZE.as_u64(),
145 }
146 }
147}
148
149impl S3MultiPartUploader {
150 pub async fn try_new(
154 sdk_config: &SdkConfig,
155 bucket: String,
156 key: String,
157 config: S3MultiPartUploaderConfig,
158 ) -> Result<S3MultiPartUploader, S3MultiPartUploadError> {
159 config.validate()?;
161
162 let client = crate::s3::new_client(sdk_config);
163 let res = client
164 .create_multipart_upload()
165 .bucket(&bucket)
166 .key(&key)
167 .send()
168 .await?;
169 let upload_id = res
170 .upload_id()
171 .ok_or(anyhow!(
172 "create_multipart_upload response missing upload id"
173 ))?
174 .to_string();
175 Ok(S3MultiPartUploader {
176 client,
177 bucket,
178 key,
179 upload_id,
180 part_count: 0,
181 total_bytes_uploaded: 0,
182 buffer: Default::default(),
183 config,
184 upload_handles: Default::default(),
185 })
186 }
187
188 pub fn buffer_chunk(&mut self, data: &[u8]) -> Result<(), S3MultiPartUploadError> {
194 let data_len = u64::cast_from(data.len());
195
196 let aws_max_part_count: u64 = AWS_S3_MAX_PART_COUNT.try_into().expect("i32 to u64");
197 let absolute_max_file_limit = std::cmp::min(
198 self.config.part_size_limit * aws_max_part_count,
199 AWS_S3_MAX_OBJECT_SIZE.as_u64(),
200 );
201
202 let can_force_first_upload = self.added_bytes() == 0 && data_len <= absolute_max_file_limit;
204
205 if data_len <= self.remaining_bytes_limit() || can_force_first_upload {
206 self.buffer.extend_from_slice(data);
207 self.flush_chunks()?;
208 Ok(())
209 } else {
210 Err(S3MultiPartUploadError::UploadExceedsMaxFileLimit(
211 self.config.file_size_limit,
212 ))
213 }
214 }
215
216 pub async fn finish(mut self) -> Result<CompletedUpload, S3MultiPartUploadError> {
220 let remaining = self.buffer.split();
221 self.upload_part_internal(remaining.freeze())?;
222
223 let mut parts: Vec<CompletedPart> = Vec::with_capacity(self.upload_handles.len());
224 for handle in self.upload_handles {
225 let (etag, part_num) = handle.wait_and_assert_finished().await?;
226 match etag {
227 Some(etag) => {
228 parts.push(
229 CompletedPart::builder()
230 .e_tag(etag)
231 .part_number(part_num)
232 .build(),
233 );
234 }
235 None => Err(anyhow!("etag for part {part_num} is None"))?,
236 }
237 }
238
239 self.client
240 .complete_multipart_upload()
241 .bucket(&self.bucket)
242 .key(&self.key)
243 .upload_id(self.upload_id.clone())
244 .multipart_upload(
245 CompletedMultipartUpload::builder()
246 .set_parts(Some(parts))
247 .build(),
248 )
249 .send()
250 .await?;
251 Ok(CompletedUpload {
252 part_count: self.part_count.try_into().expect("i32 to u32"),
253 total_bytes_uploaded: self.total_bytes_uploaded,
254 bucket: self.bucket,
255 key: self.key,
256 })
257 }
258
259 fn buffer_size(&self) -> u64 {
260 u64::cast_from(self.buffer.len())
261 }
262
263 fn remaining_bytes_limit(&self) -> u64 {
266 self.config
267 .file_size_limit
268 .saturating_sub(self.added_bytes())
269 }
270
271 pub fn added_bytes(&self) -> u64 {
273 self.total_bytes_uploaded + self.buffer_size()
274 }
275
276 fn flush_chunks(&mut self) -> Result<(), S3MultiPartUploadError> {
279 let part_size_limit = self.config.part_size_limit;
280 while self.buffer_size() > part_size_limit {
282 let data = self.buffer.split_to(usize::cast_from(part_size_limit));
283 self.upload_part_internal(data.freeze())?;
284 }
285 Ok(())
286 }
287
288 fn upload_part_internal(&mut self, data: Bytes) -> Result<(), S3MultiPartUploadError> {
290 let num_of_bytes: u64 = u64::cast_from(data.len());
291
292 let next_part_number = self.part_count + 1;
293 if next_part_number > AWS_S3_MAX_PART_COUNT {
294 return Err(S3MultiPartUploadError::ExceedsMaxPartNumber);
295 }
296 let client = self.client.clone();
297 let bucket = self.bucket.clone();
298 let key = self.key.clone();
299 let upload_id = self.upload_id.clone();
300
301 let handle = spawn(|| "s3::upload_part", async move {
302 let res = client
303 .upload_part()
304 .bucket(&bucket)
305 .key(&key)
306 .upload_id(upload_id)
307 .part_number(next_part_number)
308 .body(ByteStream::from(data))
309 .send()
310 .await?;
311 Ok((res.e_tag, next_part_number))
312 });
313 self.upload_handles.push(handle);
314
315 self.part_count = next_part_number;
316 self.total_bytes_uploaded += num_of_bytes;
317 Ok(())
318 }
319}
320
321#[derive(thiserror::Error, Debug)]
322pub enum S3MultiPartUploadError {
323 #[error(
324 "multi-part upload cannot have more than {} parts",
325 AWS_S3_MAX_PART_COUNT
326 )]
327 ExceedsMaxPartNumber,
328 #[error("multi-part upload will exceed configured file_size_limit: {} bytes", .0)]
329 UploadExceedsMaxFileLimit(u64),
330 #[error("{}", .0.display_with_causes())]
331 CreateMultipartUploadError(#[from] SdkError<CreateMultipartUploadError>),
332 #[error("{}", .0.display_with_causes())]
333 UploadPartError(#[from] SdkError<UploadPartError>),
334 #[error("{}", .0.display_with_causes())]
335 CompleteMultipartUploadError(#[from] SdkError<CompleteMultipartUploadError>),
336 #[error("{}", .0.display_with_causes())]
337 Other(#[from] anyhow::Error),
338}
339
340#[cfg(test)]
355mod tests {
356 use bytes::Bytes;
357 use uuid::Uuid;
358
359 use super::*;
360 use crate::{defaults, s3};
361
362 fn s3_bucket_key_for_test() -> Option<(String, String)> {
363 let bucket = match std::env::var("MZ_S3_UPLOADER_TEST_S3_BUCKET") {
364 Ok(bucket) => bucket,
365 Err(_) => {
366 if mz_ore::env::is_var_truthy("CI") {
367 panic!("CI is supposed to run this test but something has gone wrong!");
368 }
369 return None;
370 }
371 };
372
373 let prefix = Uuid::new_v4().to_string();
374 let key = format!("cargo_test/{}/file", prefix);
375 Some((bucket, key))
376 }
377
378 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
379 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_success() -> Result<(), S3MultiPartUploadError> {
383 let sdk_config = defaults().load().await;
384 let (bucket, key) = match s3_bucket_key_for_test() {
385 Some(tuple) => tuple,
386 None => return Ok(()),
387 };
388
389 let config = S3MultiPartUploaderConfig::default();
390 let mut uploader =
391 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
392
393 let expected_data = "onetwothree";
394 uploader.buffer_chunk(b"one")?;
395 uploader.buffer_chunk(b"two")?;
396 uploader.buffer_chunk(b"three")?;
397
398 let CompletedUpload {
400 part_count,
401 total_bytes_uploaded,
402 bucket: _,
403 key: _,
404 } = uploader.finish().await?;
405
406 let s3_client = s3::new_client(&sdk_config);
408 let uploaded_object = s3_client
409 .get_object()
410 .bucket(bucket)
411 .key(key)
412 .part_number(1) .send()
414 .await
415 .unwrap();
416
417 let uploaded_parts_count: u32 = uploaded_object.parts_count().unwrap().try_into().unwrap();
418 assert_eq!(uploaded_parts_count, part_count);
419 assert_eq!(part_count, 1);
420
421 let body = uploaded_object.body.collect().await.unwrap().into_bytes();
422 assert_eq!(body, expected_data);
423
424 let expected_bytes: u64 = Bytes::from(expected_data).len().try_into().unwrap();
425 assert_eq!(total_bytes_uploaded, expected_bytes);
426
427 Ok(())
428 }
429
430 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
431 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_buffer() -> Result<(), S3MultiPartUploadError> {
435 let sdk_config = defaults().load().await;
436 let (bucket, key) = match s3_bucket_key_for_test() {
437 Some(tuple) => tuple,
438 None => return Ok(()),
439 };
440
441 let config = S3MultiPartUploaderConfig {
442 part_size_limit: ByteSize::mib(5).as_u64(),
443 file_size_limit: ByteSize::mib(10).as_u64(),
444 };
445 let mut uploader =
446 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
447
448 let expected_data = vec![97; 6291456]; let expected_bytes: u64 = u64::cast_from(expected_data.len());
451 uploader.buffer_chunk(&expected_data)?;
452
453 assert_eq!(uploader.remaining_bytes_limit(), ByteSize::mib(4).as_u64());
454
455 let error = uploader.buffer_chunk(&expected_data).unwrap_err();
457 assert!(matches!(
458 error,
459 S3MultiPartUploadError::UploadExceedsMaxFileLimit(_)
460 ));
461
462 let CompletedUpload {
463 part_count,
464 total_bytes_uploaded,
465 bucket: _,
466 key: _,
467 } = uploader.finish().await?;
468
469 let s3_client = s3::new_client(&sdk_config);
471 let uploaded_object = s3_client
472 .get_object()
473 .bucket(bucket)
474 .key(key)
475 .send()
476 .await
477 .unwrap();
478
479 assert_eq!(part_count, 2); let body = uploaded_object.body.collect().await.unwrap().into_bytes();
482 assert_eq!(body, *expected_data);
483
484 assert_eq!(total_bytes_uploaded, expected_bytes);
485
486 Ok(())
487 }
488
489 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
490 #[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] #[ignore] async fn multi_part_upload_no_data() -> Result<(), S3MultiPartUploadError> {
494 let sdk_config = defaults().load().await;
495 let (bucket, key) = match s3_bucket_key_for_test() {
496 Some(tuple) => tuple,
497 None => return Ok(()),
498 };
499
500 let config = Default::default();
501 let uploader =
502 S3MultiPartUploader::try_new(&sdk_config, bucket.clone(), key.clone(), config).await?;
503
504 uploader.finish().await.unwrap();
506
507 let s3_client = s3::new_client(&sdk_config);
509 let uploaded_object = s3_client
510 .get_object()
511 .bucket(bucket)
512 .key(key)
513 .send()
514 .await
515 .unwrap();
516
517 assert_eq!(uploaded_object.content_length(), Some(0));
518
519 Ok(())
520 }
521
522 #[mz_ore::test]
523 fn test_invalid_configs() {
524 let config = S3MultiPartUploaderConfig {
525 part_size_limit: ByteSize::mib(5).as_u64() - 1,
526 file_size_limit: ByteSize::gib(5).as_u64(),
527 };
528 let error = config.validate().unwrap_err();
529
530 assert_eq!(
531 error.to_string(),
532 "invalid part size: 5242879, should be between 5242880 and 5368709120 bytes"
533 );
534
535 let config = S3MultiPartUploaderConfig {
536 part_size_limit: ByteSize::mib(5).as_u64(),
537 file_size_limit: (ByteSize::mib(5).as_u64() * 10001) - 1,
540 };
541 let error = config.validate().unwrap_err();
542 assert_eq!(
543 error.to_string(),
544 "total number of possible parts (file_size_limit / part_size_limit): 10001, cannot exceed 10000",
545 );
546 }
547}