Skip to main content

mz_persist/
s3.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An S3 implementation of [Blob] storage.
11
12use std::cmp;
13use std::fmt::{Debug, Formatter};
14use std::ops::Range;
15use std::sync::Arc;
16use std::sync::atomic::{self, AtomicU64};
17use std::time::{Duration, Instant};
18
19use anyhow::{Context, anyhow};
20use async_trait::async_trait;
21use aws_config::sts::AssumeRoleProvider;
22use aws_config::timeout::TimeoutConfig;
23use aws_credential_types::Credentials;
24use aws_sdk_s3::Client as S3Client;
25use aws_sdk_s3::config::{AsyncSleep, Sleep};
26use aws_sdk_s3::error::{ProvideErrorMetadata, SdkError};
27use aws_sdk_s3::primitives::ByteStream;
28use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
29use aws_types::region::Region;
30use bytes::Bytes;
31use futures_util::stream::FuturesOrdered;
32use futures_util::{FutureExt, StreamExt};
33use mz_ore::bytes::SegmentedBytes;
34use mz_ore::cast::CastFrom;
35use mz_ore::metrics::MetricsRegistry;
36use mz_ore::task::RuntimeExt;
37use tokio::runtime::Handle as AsyncHandle;
38use tracing::{Instrument, debug, debug_span, trace, trace_span};
39use uuid::Uuid;
40
41use crate::cfg::BlobKnobs;
42use crate::error::Error;
43use crate::location::{Blob, BlobMetadata, Determinate, ExternalError};
44use crate::metrics::S3BlobMetrics;
45
46/// Configuration for opening an [S3Blob].
47#[derive(Clone, Debug)]
48pub struct S3BlobConfig {
49    metrics: S3BlobMetrics,
50    client: S3Client,
51    bucket: String,
52    prefix: String,
53}
54
55// There is no simple way to hook into the S3 client to capture when its various timeouts
56// are hit. Instead, we pass along marker values that inform our [MetricsSleep] impl which
57// type of timeout was requested so it can substitute in a dynamic value set by config
58// from the caller.
59const OPERATION_TIMEOUT_MARKER: Duration = Duration::new(111, 1111);
60const OPERATION_ATTEMPT_TIMEOUT_MARKER: Duration = Duration::new(222, 2222);
61const CONNECT_TIMEOUT_MARKER: Duration = Duration::new(333, 3333);
62const READ_TIMEOUT_MARKER: Duration = Duration::new(444, 4444);
63
64#[derive(Debug)]
65struct MetricsSleep {
66    knobs: Box<dyn BlobKnobs>,
67    metrics: S3BlobMetrics,
68}
69
70impl AsyncSleep for MetricsSleep {
71    fn sleep(&self, duration: Duration) -> Sleep {
72        let (duration, metric) = match duration {
73            OPERATION_TIMEOUT_MARKER => (
74                self.knobs.operation_timeout(),
75                Some(self.metrics.operation_timeouts.clone()),
76            ),
77            OPERATION_ATTEMPT_TIMEOUT_MARKER => (
78                self.knobs.operation_attempt_timeout(),
79                Some(self.metrics.operation_attempt_timeouts.clone()),
80            ),
81            CONNECT_TIMEOUT_MARKER => (
82                self.knobs.connect_timeout(),
83                Some(self.metrics.connect_timeouts.clone()),
84            ),
85            READ_TIMEOUT_MARKER => (
86                self.knobs.read_timeout(),
87                Some(self.metrics.read_timeouts.clone()),
88            ),
89            duration => (duration, None),
90        };
91
92        // the sleep future we return here will only be polled to
93        // completion if its corresponding http request to S3 times
94        // out, meaning we can chain incrementing the appropriate
95        // timeout counter to when it finishes
96        Sleep::new(tokio::time::sleep(duration).map(|x| {
97            if let Some(counter) = metric {
98                counter.inc();
99            }
100            x
101        }))
102    }
103}
104
105impl S3BlobConfig {
106    const EXTERNAL_TESTS_S3_BUCKET: &'static str = "MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET";
107
108    /// Returns a new [S3BlobConfig] for use in production.
109    ///
110    /// Stores objects in the given bucket prepended with the (possibly empty)
111    /// prefix. S3 credentials and region must be available in the process or
112    /// environment.
113    pub async fn new(
114        bucket: String,
115        prefix: String,
116        role_arn: Option<String>,
117        endpoint: Option<String>,
118        region: Option<String>,
119        credentials: Option<(String, String)>,
120        knobs: Box<dyn BlobKnobs>,
121        metrics: S3BlobMetrics,
122    ) -> Result<Self, Error> {
123        let mut loader = mz_aws_util::defaults();
124
125        if let Some(region) = region {
126            loader = loader.region(Region::new(region));
127        };
128
129        if let Some(role_arn) = role_arn {
130            let assume_role_sdk_config = mz_aws_util::defaults().load().await;
131            let role_provider = AssumeRoleProvider::builder(role_arn)
132                .configure(&assume_role_sdk_config)
133                .session_name("persist")
134                .build()
135                .await;
136            loader = loader.credentials_provider(role_provider);
137        }
138
139        if let Some((access_key_id, secret_access_key)) = credentials {
140            loader = loader.credentials_provider(Credentials::from_keys(
141                access_key_id,
142                secret_access_key,
143                None,
144            ));
145        }
146
147        if let Some(endpoint) = endpoint {
148            loader = loader.endpoint_url(endpoint)
149        }
150
151        // NB: we must always use the custom sleep impl if we use the timeout marker values
152        loader = loader.sleep_impl(MetricsSleep {
153            knobs,
154            metrics: metrics.clone(),
155        });
156        loader = loader.timeout_config(
157            TimeoutConfig::builder()
158                // maximum time allowed for a top-level S3 API call (including internal retries)
159                .operation_timeout(OPERATION_TIMEOUT_MARKER)
160                // maximum time allowed for a single network call
161                .operation_attempt_timeout(OPERATION_ATTEMPT_TIMEOUT_MARKER)
162                // maximum time until a connection succeeds
163                .connect_timeout(CONNECT_TIMEOUT_MARKER)
164                // maximum time to read the first byte of a response
165                .read_timeout(READ_TIMEOUT_MARKER)
166                .build(),
167        );
168
169        let client = mz_aws_util::s3::new_client(&loader.load().await);
170        Ok(S3BlobConfig {
171            metrics,
172            client,
173            bucket,
174            prefix,
175        })
176    }
177
178    /// Returns a new [S3BlobConfig] for use in unit tests.
179    ///
180    /// By default, persist tests that use external storage (like s3) are
181    /// no-ops, so that `cargo test` does the right thing without any
182    /// configuration. To activate the tests, set the
183    /// `MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET` environment variable and
184    /// ensure you have valid AWS credentials available in a location where the
185    /// AWS Rust SDK can discovery them.
186    ///
187    /// This intentionally uses the `MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET`
188    /// env as the switch for test no-op-ness instead of the presence of a valid
189    /// AWS authentication configuration envs because a developers might have
190    /// valid credentials present and this isn't an explicit enough signal from
191    /// a developer running `cargo test` that it's okay to use these
192    /// credentials. It also intentionally does not use the local drop-in s3
193    /// replacement to keep persist unit tests light.
194    ///
195    /// On CI, these tests are enabled by adding the scratch-aws-access plugin
196    /// to the `cargo-test` step in `ci/test/pipeline.template.yml` and setting
197    /// `MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET` in
198    /// `ci/test/cargo-test/mzcompose.py`.
199    ///
200    /// For a Materialize developer, to opt in to these tests locally for
201    /// development, follow the AWS access guide:
202    ///
203    /// ```text
204    /// https://github.com/MaterializeInc/i2/blob/main/doc/aws-access.md
205    /// ```
206    ///
207    /// then running `source src/persist/s3_test_env_mz.sh`. You will also have
208    /// to run `aws sso login` if you haven't recently.
209    ///
210    /// Non-Materialize developers will have to set up their own auto-deleting
211    /// bucket and export the same env vars that s3_test_env_mz.sh does.
212    ///
213    /// Only public for use in src/benches.
214    pub async fn new_for_test() -> Result<Option<Self>, Error> {
215        let bucket = match std::env::var(Self::EXTERNAL_TESTS_S3_BUCKET) {
216            Ok(bucket) => bucket,
217            Err(_) => {
218                if mz_ore::env::is_var_truthy("CI") {
219                    panic!("CI is supposed to run this test but something has gone wrong!");
220                }
221                return Ok(None);
222            }
223        };
224
225        struct TestBlobKnobs;
226        impl std::fmt::Debug for TestBlobKnobs {
227            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
228                f.debug_struct("TestBlobKnobs").finish_non_exhaustive()
229            }
230        }
231        impl BlobKnobs for TestBlobKnobs {
232            fn operation_timeout(&self) -> Duration {
233                OPERATION_TIMEOUT_MARKER
234            }
235
236            fn operation_attempt_timeout(&self) -> Duration {
237                OPERATION_ATTEMPT_TIMEOUT_MARKER
238            }
239
240            fn connect_timeout(&self) -> Duration {
241                CONNECT_TIMEOUT_MARKER
242            }
243
244            fn read_timeout(&self) -> Duration {
245                READ_TIMEOUT_MARKER
246            }
247
248            fn is_cc_active(&self) -> bool {
249                false
250            }
251        }
252
253        // Give each test a unique prefix so they don't conflict. We don't have
254        // to worry about deleting any data that we create because the bucket is
255        // set to auto-delete after 1 day.
256        let prefix = Uuid::new_v4().to_string();
257        let role_arn = None;
258        let metrics = S3BlobMetrics::new(&MetricsRegistry::new());
259        let config = S3BlobConfig::new(
260            bucket,
261            prefix,
262            role_arn,
263            None,
264            None,
265            None,
266            Box::new(TestBlobKnobs),
267            metrics,
268        )
269        .await?;
270        Ok(Some(config))
271    }
272
273    /// Returns a clone of Self with a new v4 uuid prefix.
274    pub fn clone_with_new_uuid_prefix(&self) -> Self {
275        let mut ret = self.clone();
276        ret.prefix = Uuid::new_v4().to_string();
277        ret
278    }
279}
280
281/// Implementation of [Blob] backed by S3.
282#[derive(Debug)]
283pub struct S3Blob {
284    metrics: S3BlobMetrics,
285    client: S3Client,
286    bucket: String,
287    prefix: String,
288    // Maximum number of keys we get information about per list-objects request.
289    //
290    // Defaults to 1000 which is the current AWS max.
291    max_keys: i32,
292    multipart_config: MultipartConfig,
293}
294
295impl S3Blob {
296    /// Opens the given location for non-exclusive read-write access.
297    pub async fn open(config: S3BlobConfig) -> Result<Self, ExternalError> {
298        let ret = S3Blob {
299            metrics: config.metrics,
300            client: config.client,
301            bucket: config.bucket,
302            prefix: config.prefix,
303            max_keys: 1_000,
304            multipart_config: MultipartConfig::default(),
305        };
306        // Connect before returning success. We don't particularly care about
307        // what's stored in this blob (nothing writes to it, so presumably it's
308        // empty) just that we were able and allowed to fetch it.
309        let _ = ret.get("HEALTH_CHECK").await?;
310        Ok(ret)
311    }
312
313    fn get_path(&self, key: &str) -> String {
314        format!("{}/{}", self.prefix, key)
315    }
316}
317
318#[async_trait]
319impl Blob for S3Blob {
320    async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
321        let start_overall = Instant::now();
322        let path = self.get_path(key);
323
324        // S3 advises that it's fastest to download large objects along the part
325        // boundaries they were originally uploaded with [1].
326        //
327        // [1]: https://docs.aws.amazon.com/whitepapers/latest/s3-optimizing-performance-best-practices/use-byte-range-fetches.html
328        //
329        // One option is to run the same logic as multipart does and do the
330        // requests using the resulting byte ranges, but if we ever changed the
331        // multipart chunking logic, they wouldn't line up for old blobs written
332        // by a previous version.
333        //
334        // Another option is to store the part boundaries in the metadata we
335        // keep about the batch, but this would be large and wasteful.
336        //
337        // Luckily, s3 exposes a part_number param on GetObject requests that we
338        // can use. If an object was created with multipart, it allows
339        // requesting each part as they were originally uploaded by the part
340        // number index. With this, we can simply send off requests for part
341        // number 1..=num_parts and reassemble the results.
342        //
343        // We could roundtrip the number of parts through persist batch
344        // metadata, but with some cleverness, we can avoid even this. Turns
345        // out, if multipart upload wasn't used (it was just a normal PutObject
346        // request), s3 will still happily return it for a request specifying a
347        // part_number of 1. This lets us fire off a first request, which
348        // contains the metadata we need to determine how many additional parts
349        // we need, if any.
350        //
351        // So, the following call sends this first request. The SDK even returns
352        // the headers before the full data body has completed. This gives us
353        // the number of parts. We can then proceed to fetch the body of the
354        // first request concurrently with the rest of the parts of the object.
355
356        // For each header and body that we fetch, we track the fastest, and
357        // any large deviations from it.
358        let min_body_elapsed = Arc::new(MinElapsed::default());
359        let min_header_elapsed = Arc::new(MinElapsed::default());
360        self.metrics.get_part.inc();
361
362        // Fetch our first header, this tells us how many more are left.
363        let header_start = Instant::now();
364        let object = self
365            .client
366            .get_object()
367            .bucket(&self.bucket)
368            .key(&path)
369            .part_number(1)
370            .send()
371            .await;
372        let elapsed = header_start.elapsed();
373        min_header_elapsed.observe(elapsed, "s3 download first part header");
374
375        let first_part = match object {
376            Ok(object) => object,
377            Err(SdkError::ServiceError(err)) if err.err().is_no_such_key() => return Ok(None),
378            Err(err) => {
379                self.update_error_metrics("GetObject", &err);
380                Err(anyhow!(err).context("s3 get meta err"))?
381            }
382        };
383
384        // Get the remaining number of parts
385        let num_parts = match first_part.parts_count() {
386            // For a non-multipart upload, parts_count will be None. The rest of  the code works
387            // perfectly well if we just pretend this was a multipart upload of 1 part.
388            None => 1,
389            // For any positive value greater than 0, just return it.
390            Some(parts @ 1..) => parts,
391            // A non-positive value is invalid.
392            Some(bad) => {
393                assert!(bad <= 0);
394                return Err(anyhow!("unexpected number of s3 object parts: {}", bad).into());
395            }
396        };
397
398        trace!(
399            "s3 download first header took {:?} ({num_parts} parts)",
400            start_overall.elapsed(),
401        );
402
403        let mut body_futures = FuturesOrdered::new();
404        let mut first_part = Some(first_part);
405
406        // Fetch the headers of the rest of the parts. (Starting at part 2 because we already
407        // did part 1.)
408        for part_num in 1..=num_parts {
409            // Clone a handle to our MinElapsed trackers so we can give one to
410            // each download task.
411            let min_header_elapsed = Arc::clone(&min_header_elapsed);
412            let min_body_elapsed = Arc::clone(&min_body_elapsed);
413            let get_invalid_resp = self.metrics.get_invalid_resp.clone();
414            let first_part = first_part.take();
415            let path = &path;
416            let request_future = async move {
417                // Fetch the headers of the rest of the parts. (Using the existing headers
418                // for part 1.
419                let mut object = match first_part {
420                    Some(first_part) => {
421                        assert_eq!(part_num, 1, "only the first part should be prefetched");
422                        first_part
423                    }
424                    None => {
425                        assert_ne!(part_num, 1, "first part should be prefetched");
426                        // Request our headers.
427                        let header_start = Instant::now();
428                        let object = self
429                            .client
430                            .get_object()
431                            .bucket(&self.bucket)
432                            .key(path)
433                            .part_number(part_num)
434                            .send()
435                            .await
436                            .inspect_err(|err| self.update_error_metrics("GetObject", err))
437                            .context("s3 get meta err")?;
438                        min_header_elapsed
439                            .observe(header_start.elapsed(), "s3 download part header");
440                        object
441                    }
442                };
443
444                // Request the body.
445                let body_start = Instant::now();
446                let mut body_parts: Vec<Bytes> = Vec::new();
447
448                if let Some(len @ ..=-1) = object.content_length() {
449                    tracing::trace!(?len, "found invalid content-length");
450                    get_invalid_resp.inc();
451                }
452
453                while let Some(data) = object.body.next().await {
454                    body_parts.push(data.context("s3 get body err")?);
455                }
456
457                let body_elapsed = body_start.elapsed();
458                min_body_elapsed.observe(body_elapsed, "s3 download part body");
459
460                Ok::<_, anyhow::Error>(body_parts)
461            };
462
463            body_futures.push_back(request_future);
464        }
465
466        // Await on all of our parts requests.
467        let mut segments = vec![];
468        while let Some(result) = body_futures.next().await {
469            // Download failure, we failed to fetch the body from S3.
470            let mut part_body = result
471                .inspect_err(|e| {
472                    self.metrics
473                        .error_counts
474                        .with_label_values(&["GetObjectStream", e.to_string().as_str()])
475                        .inc()
476                })
477                .context("s3 get body err")?;
478
479            // Collect all of our segments.
480            segments.append(&mut part_body);
481        }
482
483        debug!(
484            "s3 GetObject took {:?} ({} parts)",
485            start_overall.elapsed(),
486            num_parts
487        );
488        Ok(Some(SegmentedBytes::from(segments)))
489    }
490
491    async fn list_keys_and_metadata(
492        &self,
493        key_prefix: &str,
494        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
495    ) -> Result<(), ExternalError> {
496        let mut continuation_token = None;
497        // we only want to return keys that match the specified blob key prefix
498        let blob_key_prefix = self.get_path(key_prefix);
499        // but we want to exclude the shared root prefix from our returned keys,
500        // so only the blob key itself is passed in to `f`
501        let strippable_root_prefix = format!("{}/", self.prefix);
502
503        loop {
504            self.metrics.list_objects.inc();
505            let resp = self
506                .client
507                .list_objects_v2()
508                .bucket(&self.bucket)
509                .prefix(&blob_key_prefix)
510                .max_keys(self.max_keys)
511                .set_continuation_token(continuation_token)
512                .send()
513                .await
514                .inspect_err(|err| self.update_error_metrics("ListObjectsV2", err))
515                .context("list bucket error")?;
516            if let Some(contents) = resp.contents {
517                for object in contents.iter() {
518                    if let Some(key) = object.key.as_ref() {
519                        if let Some(key) = key.strip_prefix(&strippable_root_prefix) {
520                            let size_in_bytes = match object.size {
521                                None => {
522                                    return Err(ExternalError::from(anyhow!(
523                                        "object missing size: {key}"
524                                    )));
525                                }
526                                Some(size) => size
527                                    .try_into()
528                                    .expect("file in S3 cannot have negative size"),
529                            };
530                            f(BlobMetadata { key, size_in_bytes });
531                        } else {
532                            return Err(ExternalError::from(anyhow!(
533                                "found key with invalid prefix: {}",
534                                key
535                            )));
536                        }
537                    }
538                }
539            }
540
541            if resp.next_continuation_token.is_some() {
542                continuation_token = resp.next_continuation_token;
543            } else {
544                break;
545            }
546        }
547
548        Ok(())
549    }
550
551    async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
552        let value_len = value.len();
553        if self
554            .multipart_config
555            .should_multipart(value_len)
556            .map_err(anyhow::Error::msg)?
557        {
558            self.set_multi_part(key, value)
559                .instrument(debug_span!("s3set_multi", payload_len = value_len))
560                .await
561        } else {
562            self.set_single_part(key, value).await
563        }
564    }
565
566    async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
567        // There is a race condition here where, if two delete calls for the
568        // same key occur simultaneously, both might think they did the actual
569        // deletion. This return value is only used for metrics, so it's
570        // unfortunate, but fine.
571        let path = self.get_path(key);
572        self.metrics.delete_head.inc();
573        let head_res = self
574            .client
575            .head_object()
576            .bucket(&self.bucket)
577            .key(&path)
578            .send()
579            .await;
580        let size_bytes = match head_res {
581            Ok(x) => match x.content_length {
582                None => {
583                    return Err(ExternalError::from(anyhow!(
584                        "s3 delete content length was none"
585                    )));
586                }
587                Some(content_length) => {
588                    u64::try_from(content_length).expect("file in S3 cannot have negative size")
589                }
590            },
591            Err(SdkError::ServiceError(err)) if err.err().is_not_found() => return Ok(None),
592            Err(err) => {
593                self.update_error_metrics("HeadObject", &err);
594                return Err(ExternalError::from(
595                    anyhow!(err).context("s3 delete head err"),
596                ));
597            }
598        };
599        self.metrics.delete_object.inc();
600        let _ = self
601            .client
602            .delete_object()
603            .bucket(&self.bucket)
604            .key(&path)
605            .send()
606            .await
607            .inspect_err(|err| self.update_error_metrics("DeleteObject", err))
608            .context("s3 delete object err")?;
609        Ok(Some(usize::cast_from(size_bytes)))
610    }
611
612    async fn restore(&self, key: &str) -> Result<(), ExternalError> {
613        let path = self.get_path(key);
614        // Fetch the latest version of the object. If it's a normal version, return true;
615        // if it's a delete marker, delete it and loop; if there is no such version,
616        // return false.
617        // TODO: limit the number of delete markers we'll peel back?
618        loop {
619            // S3 only lets us fetch the versions of an object with a list requests.
620            // Seems a bit wasteful to just fetch one at a time, but otherwise we can only
621            // guess the order of versions via the timestamp, and that feels brittle.
622            let list_res = self
623                .client
624                .list_object_versions()
625                .bucket(&self.bucket)
626                .prefix(&path)
627                .max_keys(1)
628                .send()
629                .await
630                .inspect_err(|err| self.update_error_metrics("ListObjectVersions", err))
631                .context("listing object versions during restore")?;
632
633            let current_delete = list_res
634                .delete_markers()
635                .into_iter()
636                .filter(|d| {
637                    // We need to check that any versions we're looking at have the right key,
638                    // not just a key with our key as a prefix.
639                    d.key() == Some(path.as_str())
640                })
641                .find(|d| d.is_latest().unwrap_or(false))
642                .and_then(|d| d.version_id());
643
644            if let Some(version) = current_delete {
645                let deleted = self
646                    .client
647                    .delete_object()
648                    .bucket(&self.bucket)
649                    .key(&path)
650                    .version_id(version)
651                    .send()
652                    .await
653                    .inspect_err(|err| self.update_error_metrics("DeleteObject", err))
654                    .context("deleting a delete marker")?;
655                assert!(
656                    deleted.delete_marker().unwrap_or(false),
657                    "deleting a delete marker"
658                );
659            } else {
660                let has_current_version = list_res
661                    .versions()
662                    .into_iter()
663                    .filter(|d| d.key() == Some(path.as_str()))
664                    .any(|v| v.is_latest().unwrap_or(false));
665
666                if !has_current_version {
667                    return Err(Determinate::new(anyhow!(
668                        "unable to restore {key} in s3: no valid version exists"
669                    ))
670                    .into());
671                }
672                return Ok(());
673            }
674        }
675    }
676}
677
678impl S3Blob {
679    async fn set_single_part(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
680        let start_overall = Instant::now();
681        let path = self.get_path(key);
682
683        let value_len = value.len();
684        let part_span = trace_span!("s3set_single", payload_len = value_len);
685        self.metrics.set_single.inc();
686        self.client
687            .put_object()
688            .bucket(&self.bucket)
689            .key(path)
690            .body(ByteStream::from(value))
691            .send()
692            .instrument(part_span)
693            .await
694            .inspect_err(|err| self.update_error_metrics("PutObject", err))
695            .context("set single part")?;
696        debug!(
697            "s3 PutObject single done {}b / {:?}",
698            value_len,
699            start_overall.elapsed()
700        );
701        Ok(())
702    }
703
704    // TODO(benesch): remove this once this function no longer makes use of
705    // potentially dangerous `as` conversions.
706    #[allow(clippy::as_conversions)]
707    async fn set_multi_part(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
708        let start_overall = Instant::now();
709        let path = self.get_path(key);
710
711        // Start the multi part request and get an upload id.
712        trace!("s3 PutObject multi start {}b", value.len());
713        self.metrics.set_multi_create.inc();
714        let upload_res = self
715            .client
716            .create_multipart_upload()
717            .bucket(&self.bucket)
718            .key(&path)
719            .customize()
720            .mutate_request(|req| {
721                // By default the Rust AWS SDK does not set the Content-Length
722                // header on POST calls with empty bodies. This is fine for S3,
723                // but when running against GCS's S3 interop mode these calls
724                // will be rejected unless we set this header manually.
725                req.headers_mut().insert("Content-Length", "0");
726            })
727            .send()
728            .instrument(debug_span!("s3set_multi_start"))
729            .await
730            .inspect_err(|err| self.update_error_metrics("CreateMultipartUpload", err))
731            .context("create_multipart_upload err")?;
732        let upload_id = upload_res
733            .upload_id()
734            .ok_or_else(|| anyhow!("create_multipart_upload response missing upload_id"))?;
735        trace!(
736            "s3 create_multipart_upload took {:?}",
737            start_overall.elapsed()
738        );
739
740        let async_runtime = AsyncHandle::try_current().map_err(anyhow::Error::new)?;
741
742        // Fire off all the individual parts.
743        //
744        // TODO: The aws cli throttles how many of these are outstanding at any
745        // given point. We'll likely want to do the same at some point.
746        let start_parts = Instant::now();
747        let mut part_futs = Vec::new();
748        for (part_num, part_range) in self.multipart_config.part_iter(value.len()) {
749            // NB: Without this spawn, these will execute serially. This is rust
750            // async 101 stuff, but there isn't much async in the persist
751            // codebase (yet?) so I thought it worth calling out.
752            let part_span = debug_span!("s3set_multi_part", payload_len = part_range.len());
753            let part_fut = async_runtime.spawn_named(
754                // TODO: Add the key and part number once this can be annotated
755                // with metadata.
756                || "persist_s3blob_put_part",
757                {
758                    self.metrics.set_multi_part.inc();
759                    self.client
760                        .upload_part()
761                        .bucket(&self.bucket)
762                        .key(&path)
763                        .upload_id(upload_id)
764                        .part_number(part_num as i32)
765                        .body(ByteStream::from(value.slice(part_range)))
766                        .send()
767                        .instrument(part_span)
768                        .map(move |res| (start_parts.elapsed(), res))
769                },
770            );
771            part_futs.push((part_num, part_fut));
772        }
773        let parts_len = part_futs.len();
774
775        // Wait on all the parts to finish. This is done in part order, no need
776        // for joining them in the order they finish.
777        //
778        // TODO: Consider using something like futures::future::join_all() for
779        // this. That would cancel outstanding requests for us if any of them
780        // fails. However, it might not play well with using retries for tail
781        // latencies. Investigate.
782        let min_part_elapsed = MinElapsed::default();
783        let mut parts = Vec::with_capacity(parts_len);
784        for (part_num, part_fut) in part_futs.into_iter() {
785            let (this_part_elapsed, part_res) = part_fut.await;
786            let part_res = part_res
787                .inspect_err(|err| self.update_error_metrics("UploadPart", err))
788                .context("s3 upload_part err")?;
789            let part_e_tag = part_res.e_tag().ok_or_else(|| {
790                self.metrics
791                    .error_counts
792                    .with_label_values(&["UploadPart", "MissingEtag"])
793                    .inc();
794                anyhow!("s3 upload part missing e_tag")
795            })?;
796            parts.push(
797                CompletedPart::builder()
798                    .e_tag(part_e_tag)
799                    .part_number(part_num as i32)
800                    .build(),
801            );
802            min_part_elapsed.observe(this_part_elapsed, "s3 upload_part took");
803        }
804        trace!(
805            "s3 upload_parts overall took {:?} ({} parts)",
806            start_parts.elapsed(),
807            parts_len
808        );
809
810        // Complete the upload.
811        //
812        // Currently, we early return if any of the individual parts fail. This
813        // permanently orphans any parts that succeeded. One fix is to call
814        // abort_multipart_upload, which deletes them. However, there's also an
815        // option for an s3 bucket to auto-delete parts that haven't been
816        // completed or aborted after a given amount of time. This latter is
817        // simpler and also resilient to ill-timed mz restarts, so we use it for
818        // now. We could likely add the accounting necessary to make
819        // abort_multipart_upload work, but it would be complex and affect perf.
820        // Let's see how far we can get without it.
821        let start_complete = Instant::now();
822        self.metrics.set_multi_complete.inc();
823        self.client
824            .complete_multipart_upload()
825            .bucket(&self.bucket)
826            .key(&path)
827            .upload_id(upload_id)
828            .multipart_upload(
829                CompletedMultipartUpload::builder()
830                    .set_parts(Some(parts))
831                    .build(),
832            )
833            .send()
834            .instrument(debug_span!("s3set_multi_complete", num_parts = parts_len))
835            .await
836            .inspect_err(|err| self.update_error_metrics("CompleteMultipartUpload", err))
837            .context("complete_multipart_upload err")?;
838        trace!(
839            "s3 complete_multipart_upload took {:?}",
840            start_complete.elapsed()
841        );
842
843        debug!(
844            "s3 PutObject multi done {}b / {:?} ({} parts)",
845            value.len(),
846            start_overall.elapsed(),
847            parts_len
848        );
849        Ok(())
850    }
851
852    fn update_error_metrics<E, R>(&self, op: &str, err: &SdkError<E, R>)
853    where
854        E: ProvideErrorMetadata,
855    {
856        let code = match err {
857            SdkError::ServiceError(e) => match e.err().code() {
858                Some(code) => code,
859                None => "UnknownServiceError",
860            },
861            SdkError::DispatchFailure(e) => {
862                if let Some(other_error) = e.as_other() {
863                    match other_error {
864                        aws_config::retry::ErrorKind::TransientError => "TransientError",
865                        aws_config::retry::ErrorKind::ThrottlingError => "ThrottlingError",
866                        aws_config::retry::ErrorKind::ServerError => "ServerError",
867                        aws_config::retry::ErrorKind::ClientError => "ClientError",
868                        _ => "UnknownDispatchFailure",
869                    }
870                } else if e.is_timeout() {
871                    "TimeoutError"
872                } else if e.is_io() {
873                    "IOError"
874                } else if e.is_user() {
875                    "UserError"
876                } else {
877                    "UnknownDispathFailure"
878                }
879            }
880            SdkError::ResponseError(_) => "ResponseError",
881            SdkError::ConstructionFailure(_) => "ConstructionFailure",
882            // There is some overlap with MetricsSleep. MetricsSleep is more granular
883            // but does not contain the operation.
884            SdkError::TimeoutError(_) => "TimeoutError",
885            // an error was added at some point in the future
886            _ => "UnknownSdkError",
887        };
888        self.metrics
889            .error_counts
890            .with_label_values(&[op, code])
891            .inc();
892    }
893}
894
895#[derive(Clone, Debug)]
896struct MultipartConfig {
897    multipart_threshold: usize,
898    multipart_chunk_size: usize,
899}
900
901impl Default for MultipartConfig {
902    fn default() -> Self {
903        Self {
904            multipart_threshold: Self::DEFAULT_MULTIPART_THRESHOLD,
905            multipart_chunk_size: Self::DEFAULT_MULTIPART_CHUNK_SIZE,
906        }
907    }
908}
909
910const MB: usize = 1024 * 1024;
911const TB: usize = 1024 * 1024 * MB;
912
913impl MultipartConfig {
914    /// The minimum object size for which we start using multipart upload.
915    ///
916    /// From the official `aws cli` tool implementation:
917    ///
918    /// <https://github.com/aws/aws-cli/blob/2.4.14/awscli/customizations/s3/transferconfig.py#L18-L29>
919    const DEFAULT_MULTIPART_THRESHOLD: usize = 8 * MB;
920    /// The size of each part (except the last) in a multipart upload.
921    ///
922    /// From the official `aws cli` tool implementation:
923    ///
924    /// <https://github.com/aws/aws-cli/blob/2.4.14/awscli/customizations/s3/transferconfig.py#L18-L29>
925    const DEFAULT_MULTIPART_CHUNK_SIZE: usize = 8 * MB;
926
927    /// The largest size object creatable in S3.
928    ///
929    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
930    const MAX_SINGLE_UPLOAD_SIZE: usize = 5 * TB;
931    /// The minimum size of a part in a multipart upload.
932    ///
933    /// This minimum doesn't apply to the last chunk, which can be any size.
934    ///
935    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
936    const MIN_UPLOAD_CHUNK_SIZE: usize = 5 * MB;
937    /// The smallest allowable part number (inclusive).
938    ///
939    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
940    const MIN_PART_NUM: u32 = 1;
941    /// The largest allowable part number (inclusive).
942    ///
943    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
944    const MAX_PART_NUM: u32 = 10_000;
945
946    fn should_multipart(&self, blob_len: usize) -> Result<bool, String> {
947        if blob_len > Self::MAX_SINGLE_UPLOAD_SIZE {
948            return Err(format!(
949                "S3 does not support blobs larger than {} bytes got: {}",
950                Self::MAX_SINGLE_UPLOAD_SIZE,
951                blob_len
952            ));
953        }
954        Ok(blob_len > self.multipart_threshold)
955    }
956
957    fn part_iter(&self, blob_len: usize) -> MultipartChunkIter {
958        debug_assert!(self.multipart_chunk_size >= MultipartConfig::MIN_UPLOAD_CHUNK_SIZE);
959        MultipartChunkIter::new(self.multipart_chunk_size, blob_len)
960    }
961}
962
963#[derive(Clone, Debug)]
964struct MultipartChunkIter {
965    total_len: usize,
966    part_size: usize,
967    part_idx: u32,
968}
969
970impl MultipartChunkIter {
971    fn new(default_part_size: usize, blob_len: usize) -> Self {
972        let max_parts: usize = usize::cast_from(MultipartConfig::MAX_PART_NUM);
973
974        // Compute the minimum part size we can use without going over the max
975        // number of parts that S3 allows: `ceil(blob_len / max_parts)`.This
976        // will end up getting thrown away by the `cmp::max` for anything
977        // smaller than `max_parts * default_part_size = 80GiB`.
978        let min_part_size = (blob_len + max_parts - 1) / max_parts;
979        let part_size = cmp::max(min_part_size, default_part_size);
980
981        // Part nums are 1-indexed in S3. Convert back to 0-indexed to make the
982        // range math easier to follow.
983        let part_idx = MultipartConfig::MIN_PART_NUM - 1;
984        MultipartChunkIter {
985            total_len: blob_len,
986            part_size,
987            part_idx,
988        }
989    }
990}
991
992impl Iterator for MultipartChunkIter {
993    type Item = (u32, Range<usize>);
994
995    fn next(&mut self) -> Option<Self::Item> {
996        let part_idx = self.part_idx;
997        self.part_idx += 1;
998
999        let start = usize::cast_from(part_idx) * self.part_size;
1000        if start >= self.total_len {
1001            return None;
1002        }
1003        let end = cmp::min(start + self.part_size, self.total_len);
1004        let part_num = part_idx + 1;
1005        Some((part_num, start..end))
1006    }
1007}
1008
1009/// A helper for tracking the minimum of a set of Durations.
1010#[derive(Debug)]
1011struct MinElapsed {
1012    min: AtomicU64,
1013    alert_factor: u64,
1014}
1015
1016impl Default for MinElapsed {
1017    fn default() -> Self {
1018        MinElapsed {
1019            min: AtomicU64::new(u64::MAX),
1020            alert_factor: 8,
1021        }
1022    }
1023}
1024
1025impl MinElapsed {
1026    fn observe(&self, x: Duration, msg: &'static str) {
1027        let nanos = x.as_nanos();
1028        let nanos = u64::try_from(nanos).unwrap_or(u64::MAX);
1029
1030        // Possibly set a new minimum.
1031        let prev_min = self.min.fetch_min(nanos, atomic::Ordering::SeqCst);
1032
1033        // Trace if our provided duration was much larger than our minimum.
1034        let new_min = std::cmp::min(prev_min, nanos);
1035        if nanos > new_min.saturating_mul(self.alert_factor) {
1036            let min_duration = Duration::from_nanos(new_min);
1037            let factor = self.alert_factor;
1038            debug!("{msg} took {x:?} more than {factor}x the min {min_duration:?}");
1039        } else {
1040            trace!("{msg} took {x:?}");
1041        }
1042    }
1043}
1044
1045// Make sure the "vendored" feature of the openssl_sys crate makes it into the
1046// transitive dep graph of persist, so that we don't attempt to link against the
1047// system OpenSSL library. Fake a usage of the crate here so that a good
1048// samaritan doesn't remove our unused dep.
1049#[allow(dead_code)]
1050fn openssl_sys_hack() {
1051    openssl_sys::init();
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use tracing::info;
1057
1058    use crate::location::tests::blob_impl_test;
1059
1060    use super::*;
1061
1062    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1063    #[cfg_attr(coverage, ignore)] // https://github.com/MaterializeInc/database-issues/issues/5586
1064    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `TLS_method` on OS `linux`
1065    #[ignore] // TODO: Reenable against minio so it can run locally
1066    async fn s3_blob() -> Result<(), ExternalError> {
1067        let config = match S3BlobConfig::new_for_test().await? {
1068            Some(client) => client,
1069            None => {
1070                info!(
1071                    "{} env not set: skipping test that uses external service",
1072                    S3BlobConfig::EXTERNAL_TESTS_S3_BUCKET
1073                );
1074                return Ok(());
1075            }
1076        };
1077        let config_multipart = config.clone_with_new_uuid_prefix();
1078
1079        blob_impl_test(move |path| {
1080            let path = path.to_owned();
1081            let config = config.clone();
1082            async move {
1083                let config = S3BlobConfig {
1084                    metrics: config.metrics.clone(),
1085                    client: config.client.clone(),
1086                    bucket: config.bucket.clone(),
1087                    prefix: format!("{}/s3_blob_impl_test/{}", config.prefix, path),
1088                };
1089                let mut blob = S3Blob::open(config).await?;
1090                blob.max_keys = 2;
1091                Ok(blob)
1092            }
1093        })
1094        .await?;
1095
1096        // Also specifically test multipart. S3 requires all parts but the last
1097        // to be at least 5MB, which we don't want to do from a test, so this
1098        // uses the multipart code path but only writes a single part.
1099        {
1100            let blob = S3Blob::open(config_multipart).await?;
1101            blob.set_multi_part("multipart", "foobar".into()).await?;
1102            assert_eq!(
1103                blob.get("multipart").await?,
1104                Some(b"foobar".to_vec().into())
1105            );
1106        }
1107
1108        Ok(())
1109    }
1110
1111    #[mz_ore::test]
1112    fn should_multipart() {
1113        let config = MultipartConfig::default();
1114        assert_eq!(config.should_multipart(0), Ok(false));
1115        assert_eq!(config.should_multipart(1), Ok(false));
1116        assert_eq!(
1117            config.should_multipart(MultipartConfig::DEFAULT_MULTIPART_THRESHOLD),
1118            Ok(false)
1119        );
1120        assert_eq!(
1121            config.should_multipart(MultipartConfig::DEFAULT_MULTIPART_THRESHOLD + 1),
1122            Ok(true)
1123        );
1124        assert_eq!(
1125            config.should_multipart(MultipartConfig::DEFAULT_MULTIPART_THRESHOLD * 2),
1126            Ok(true)
1127        );
1128        assert_eq!(
1129            config.should_multipart(MultipartConfig::MAX_SINGLE_UPLOAD_SIZE),
1130            Ok(true)
1131        );
1132        assert_eq!(
1133            config.should_multipart(MultipartConfig::MAX_SINGLE_UPLOAD_SIZE + 1),
1134            Err(
1135                "S3 does not support blobs larger than 5497558138880 bytes got: 5497558138881"
1136                    .into()
1137            )
1138        );
1139    }
1140
1141    #[mz_ore::test]
1142    fn multipart_iter() {
1143        let iter = MultipartChunkIter::new(10, 0);
1144        assert_eq!(iter.collect::<Vec<_>>(), vec![]);
1145
1146        let iter = MultipartChunkIter::new(10, 9);
1147        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..9)]);
1148
1149        let iter = MultipartChunkIter::new(10, 10);
1150        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10)]);
1151
1152        let iter = MultipartChunkIter::new(10, 11);
1153        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10), (2, 10..11)]);
1154
1155        let iter = MultipartChunkIter::new(10, 19);
1156        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10), (2, 10..19)]);
1157
1158        let iter = MultipartChunkIter::new(10, 20);
1159        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10), (2, 10..20)]);
1160
1161        let iter = MultipartChunkIter::new(10, 21);
1162        assert_eq!(
1163            iter.collect::<Vec<_>>(),
1164            vec![(1, 0..10), (2, 10..20), (3, 20..21)]
1165        );
1166    }
1167}