Skip to main content

mz_persist/
s3.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An S3 implementation of [Blob] storage.
11
12use std::cmp;
13use std::fmt::{Debug, Formatter};
14use std::ops::Range;
15use std::sync::Arc;
16use std::sync::atomic::{self, AtomicU64};
17use std::time::{Duration, Instant};
18
19use anyhow::{Context, anyhow};
20use async_trait::async_trait;
21use aws_config::sts::AssumeRoleProvider;
22use aws_config::timeout::TimeoutConfig;
23use aws_credential_types::Credentials;
24use aws_sdk_s3::Client as S3Client;
25use aws_sdk_s3::config::{AsyncSleep, Sleep};
26use aws_sdk_s3::error::{ProvideErrorMetadata, SdkError};
27use aws_sdk_s3::primitives::ByteStream;
28use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
29use aws_types::region::Region;
30use bytes::{Bytes, BytesMut};
31use futures_util::stream::FuturesOrdered;
32use futures_util::{FutureExt, StreamExt};
33use mz_ore::bytes::SegmentedBytes;
34use mz_ore::cast::CastFrom;
35use mz_ore::metrics::MetricsRegistry;
36use mz_ore::task::RuntimeExt;
37use tokio::runtime::Handle as AsyncHandle;
38use tracing::{Instrument, debug, debug_span, trace, trace_span};
39use uuid::Uuid;
40
41use crate::cfg::BlobKnobs;
42use crate::error::Error;
43use crate::location::{Blob, BlobMetadata, Determinate, ExternalError};
44use crate::metrics::S3BlobMetrics;
45
46/// Configuration for opening an [S3Blob].
47#[derive(Clone, Debug)]
48pub struct S3BlobConfig {
49    metrics: S3BlobMetrics,
50    client: S3Client,
51    bucket: String,
52    prefix: String,
53}
54
55// There is no simple way to hook into the S3 client to capture when its various timeouts
56// are hit. Instead, we pass along marker values that inform our [MetricsSleep] impl which
57// type of timeout was requested so it can substitute in a dynamic value set by config
58// from the caller.
59const OPERATION_TIMEOUT_MARKER: Duration = Duration::new(111, 1111);
60const OPERATION_ATTEMPT_TIMEOUT_MARKER: Duration = Duration::new(222, 2222);
61const CONNECT_TIMEOUT_MARKER: Duration = Duration::new(333, 3333);
62const READ_TIMEOUT_MARKER: Duration = Duration::new(444, 4444);
63
64#[derive(Debug)]
65struct MetricsSleep {
66    knobs: Box<dyn BlobKnobs>,
67    metrics: S3BlobMetrics,
68}
69
70impl AsyncSleep for MetricsSleep {
71    fn sleep(&self, duration: Duration) -> Sleep {
72        let (duration, metric) = match duration {
73            OPERATION_TIMEOUT_MARKER => (
74                self.knobs.operation_timeout(),
75                Some(self.metrics.operation_timeouts.clone()),
76            ),
77            OPERATION_ATTEMPT_TIMEOUT_MARKER => (
78                self.knobs.operation_attempt_timeout(),
79                Some(self.metrics.operation_attempt_timeouts.clone()),
80            ),
81            CONNECT_TIMEOUT_MARKER => (
82                self.knobs.connect_timeout(),
83                Some(self.metrics.connect_timeouts.clone()),
84            ),
85            READ_TIMEOUT_MARKER => (
86                self.knobs.read_timeout(),
87                Some(self.metrics.read_timeouts.clone()),
88            ),
89            duration => (duration, None),
90        };
91
92        // the sleep future we return here will only be polled to
93        // completion if its corresponding http request to S3 times
94        // out, meaning we can chain incrementing the appropriate
95        // timeout counter to when it finishes
96        Sleep::new(tokio::time::sleep(duration).map(|x| {
97            if let Some(counter) = metric {
98                counter.inc();
99            }
100            x
101        }))
102    }
103}
104
105impl S3BlobConfig {
106    const EXTERNAL_TESTS_S3_BUCKET: &'static str = "MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET";
107
108    /// Returns a new [S3BlobConfig] for use in production.
109    ///
110    /// Stores objects in the given bucket prepended with the (possibly empty)
111    /// prefix. S3 credentials and region must be available in the process or
112    /// environment.
113    pub async fn new(
114        bucket: String,
115        prefix: String,
116        role_arn: Option<String>,
117        endpoint: Option<String>,
118        region: Option<String>,
119        credentials: Option<(String, String)>,
120        knobs: Box<dyn BlobKnobs>,
121        metrics: S3BlobMetrics,
122    ) -> Result<Self, Error> {
123        let mut loader = mz_aws_util::defaults();
124
125        if let Some(region) = region {
126            loader = loader.region(Region::new(region));
127        };
128
129        if let Some(role_arn) = role_arn {
130            let assume_role_sdk_config = mz_aws_util::defaults().load().await;
131            let role_provider = AssumeRoleProvider::builder(role_arn)
132                .configure(&assume_role_sdk_config)
133                .session_name("persist")
134                .build()
135                .await;
136            loader = loader.credentials_provider(role_provider);
137        }
138
139        if let Some((access_key_id, secret_access_key)) = credentials {
140            loader = loader.credentials_provider(Credentials::from_keys(
141                access_key_id,
142                secret_access_key,
143                None,
144            ));
145        }
146
147        if let Some(endpoint) = endpoint {
148            loader = loader.endpoint_url(endpoint)
149        }
150
151        // NB: we must always use the custom sleep impl if we use the timeout marker values
152        loader = loader.sleep_impl(MetricsSleep {
153            knobs,
154            metrics: metrics.clone(),
155        });
156        loader = loader.timeout_config(
157            TimeoutConfig::builder()
158                // maximum time allowed for a top-level S3 API call (including internal retries)
159                .operation_timeout(OPERATION_TIMEOUT_MARKER)
160                // maximum time allowed for a single network call
161                .operation_attempt_timeout(OPERATION_ATTEMPT_TIMEOUT_MARKER)
162                // maximum time until a connection succeeds
163                .connect_timeout(CONNECT_TIMEOUT_MARKER)
164                // maximum time to read the first byte of a response
165                .read_timeout(READ_TIMEOUT_MARKER)
166                .build(),
167        );
168
169        let client = mz_aws_util::s3::new_client(&loader.load().await);
170        Ok(S3BlobConfig {
171            metrics,
172            client,
173            bucket,
174            prefix,
175        })
176    }
177
178    /// Returns a new [S3BlobConfig] for use in unit tests.
179    ///
180    /// By default, persist tests that use external storage (like s3) are
181    /// no-ops, so that `cargo test` does the right thing without any
182    /// configuration. To activate the tests, set the
183    /// `MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET` environment variable and
184    /// ensure you have valid AWS credentials available in a location where the
185    /// AWS Rust SDK can discovery them.
186    ///
187    /// This intentionally uses the `MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET`
188    /// env as the switch for test no-op-ness instead of the presence of a valid
189    /// AWS authentication configuration envs because a developers might have
190    /// valid credentials present and this isn't an explicit enough signal from
191    /// a developer running `cargo test` that it's okay to use these
192    /// credentials. It also intentionally does not use the local drop-in s3
193    /// replacement to keep persist unit tests light.
194    ///
195    /// On CI, these tests are enabled by adding the scratch-aws-access plugin
196    /// to the `cargo-test` step in `ci/test/pipeline.template.yml` and setting
197    /// `MZ_PERSIST_EXTERNAL_STORAGE_TEST_S3_BUCKET` in
198    /// `ci/test/cargo-test/mzcompose.py`.
199    ///
200    /// For a Materialize developer, to opt in to these tests locally for
201    /// development, follow the AWS access guide:
202    ///
203    /// ```text
204    /// https://github.com/MaterializeInc/i2/blob/main/doc/aws-access.md
205    /// ```
206    ///
207    /// then running `source src/persist/s3_test_env_mz.sh`. You will also have
208    /// to run `aws sso login` if you haven't recently.
209    ///
210    /// Non-Materialize developers will have to set up their own auto-deleting
211    /// bucket and export the same env vars that s3_test_env_mz.sh does.
212    ///
213    /// Only public for use in src/benches.
214    pub async fn new_for_test() -> Result<Option<Self>, Error> {
215        let bucket = match std::env::var(Self::EXTERNAL_TESTS_S3_BUCKET) {
216            Ok(bucket) => bucket,
217            Err(_) => {
218                if mz_ore::env::is_var_truthy("CI") {
219                    panic!("CI is supposed to run this test but something has gone wrong!");
220                }
221                return Ok(None);
222            }
223        };
224
225        struct TestBlobKnobs;
226        impl std::fmt::Debug for TestBlobKnobs {
227            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
228                f.debug_struct("TestBlobKnobs").finish_non_exhaustive()
229            }
230        }
231        impl BlobKnobs for TestBlobKnobs {
232            fn operation_timeout(&self) -> Duration {
233                OPERATION_TIMEOUT_MARKER
234            }
235
236            fn operation_attempt_timeout(&self) -> Duration {
237                OPERATION_ATTEMPT_TIMEOUT_MARKER
238            }
239
240            fn connect_timeout(&self) -> Duration {
241                CONNECT_TIMEOUT_MARKER
242            }
243
244            fn read_timeout(&self) -> Duration {
245                READ_TIMEOUT_MARKER
246            }
247
248            fn is_cc_active(&self) -> bool {
249                false
250            }
251        }
252
253        // Give each test a unique prefix so they don't conflict. We don't have
254        // to worry about deleting any data that we create because the bucket is
255        // set to auto-delete after 1 day.
256        let prefix = Uuid::new_v4().to_string();
257        let role_arn = None;
258        let metrics = S3BlobMetrics::new(&MetricsRegistry::new());
259        let config = S3BlobConfig::new(
260            bucket,
261            prefix,
262            role_arn,
263            None,
264            None,
265            None,
266            Box::new(TestBlobKnobs),
267            metrics,
268        )
269        .await?;
270        Ok(Some(config))
271    }
272
273    /// Returns a clone of Self with a new v4 uuid prefix.
274    pub fn clone_with_new_uuid_prefix(&self) -> Self {
275        let mut ret = self.clone();
276        ret.prefix = Uuid::new_v4().to_string();
277        ret
278    }
279}
280
281/// Implementation of [Blob] backed by S3.
282#[derive(Debug)]
283pub struct S3Blob {
284    metrics: S3BlobMetrics,
285    client: S3Client,
286    bucket: String,
287    prefix: String,
288    // Maximum number of keys we get information about per list-objects request.
289    //
290    // Defaults to 1000 which is the current AWS max.
291    max_keys: i32,
292    multipart_config: MultipartConfig,
293}
294
295impl S3Blob {
296    /// Opens the given location for non-exclusive read-write access.
297    pub async fn open(config: S3BlobConfig) -> Result<Self, ExternalError> {
298        let ret = S3Blob {
299            metrics: config.metrics,
300            client: config.client,
301            bucket: config.bucket,
302            prefix: config.prefix,
303            max_keys: 1_000,
304            multipart_config: MultipartConfig::default(),
305        };
306        // Connect before returning success. We don't particularly care about
307        // what's stored in this blob (nothing writes to it, so presumably it's
308        // empty) just that we were able and allowed to fetch it.
309        let _ = ret.get("HEALTH_CHECK").await?;
310        Ok(ret)
311    }
312
313    fn get_path(&self, key: &str) -> String {
314        format!("{}/{}", self.prefix, key)
315    }
316}
317
318#[async_trait]
319impl Blob for S3Blob {
320    async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
321        let start_overall = Instant::now();
322        let path = self.get_path(key);
323
324        // S3 advises that it's fastest to download large objects along the part
325        // boundaries they were originally uploaded with [1].
326        //
327        // [1]: https://docs.aws.amazon.com/whitepapers/latest/s3-optimizing-performance-best-practices/use-byte-range-fetches.html
328        //
329        // One option is to run the same logic as multipart does and do the
330        // requests using the resulting byte ranges, but if we ever changed the
331        // multipart chunking logic, they wouldn't line up for old blobs written
332        // by a previous version.
333        //
334        // Another option is to store the part boundaries in the metadata we
335        // keep about the batch, but this would be large and wasteful.
336        //
337        // Luckily, s3 exposes a part_number param on GetObject requests that we
338        // can use. If an object was created with multipart, it allows
339        // requesting each part as they were originally uploaded by the part
340        // number index. With this, we can simply send off requests for part
341        // number 1..=num_parts and reassemble the results.
342        //
343        // We could roundtrip the number of parts through persist batch
344        // metadata, but with some cleverness, we can avoid even this. Turns
345        // out, if multipart upload wasn't used (it was just a normal PutObject
346        // request), s3 will still happily return it for a request specifying a
347        // part_number of 1. This lets us fire off a first request, which
348        // contains the metadata we need to determine how many additional parts
349        // we need, if any.
350        //
351        // So, the following call sends this first request. The SDK even returns
352        // the headers before the full data body has completed. This gives us
353        // the number of parts. We can then proceed to fetch the body of the
354        // first request concurrently with the rest of the parts of the object.
355
356        // For each header and body that we fetch, we track the fastest, and
357        // any large deviations from it.
358        let min_body_elapsed = Arc::new(MinElapsed::default());
359        let min_header_elapsed = Arc::new(MinElapsed::default());
360        self.metrics.get_part.inc();
361
362        // Fetch our first header, this tells us how many more are left.
363        let header_start = Instant::now();
364        let object = self
365            .client
366            .get_object()
367            .bucket(&self.bucket)
368            .key(&path)
369            .part_number(1)
370            .send()
371            .await;
372        let elapsed = header_start.elapsed();
373        min_header_elapsed.observe(elapsed, "s3 download first part header");
374
375        let first_part = match object {
376            Ok(object) => object,
377            Err(SdkError::ServiceError(err)) if err.err().is_no_such_key() => return Ok(None),
378            Err(err) => {
379                self.update_error_metrics("GetObject", &err);
380                Err(anyhow!(err).context("s3 get meta err"))?
381            }
382        };
383
384        // Get the remaining number of parts
385        let num_parts = match first_part.parts_count() {
386            // For a non-multipart upload, parts_count will be None. The rest of  the code works
387            // perfectly well if we just pretend this was a multipart upload of 1 part.
388            None => 1,
389            // For any positive value greater than 0, just return it.
390            Some(parts @ 1..) => parts,
391            // A non-positive value is invalid.
392            Some(bad) => {
393                assert!(bad <= 0);
394                return Err(anyhow!("unexpected number of s3 object parts: {}", bad).into());
395            }
396        };
397
398        trace!(
399            "s3 download first header took {:?} ({num_parts} parts)",
400            start_overall.elapsed(),
401        );
402
403        let mut body_futures = FuturesOrdered::new();
404        let mut first_part = Some(first_part);
405
406        // Fetch the headers of the rest of the parts. (Starting at part 2 because we already
407        // did part 1.)
408        for part_num in 1..=num_parts {
409            // Clone a handle to our MinElapsed trackers so we can give one to
410            // each download task.
411            let min_header_elapsed = Arc::clone(&min_header_elapsed);
412            let min_body_elapsed = Arc::clone(&min_body_elapsed);
413            let get_invalid_resp = self.metrics.get_invalid_resp.clone();
414            let first_part = first_part.take();
415            let path = &path;
416            let request_future = async move {
417                // Fetch the headers of the rest of the parts. (Using the existing headers
418                // for part 1.
419                let mut object = match first_part {
420                    Some(first_part) => {
421                        assert_eq!(part_num, 1, "only the first part should be prefetched");
422                        first_part
423                    }
424                    None => {
425                        assert_ne!(part_num, 1, "first part should be prefetched");
426                        // Request our headers.
427                        let header_start = Instant::now();
428                        let object = self
429                            .client
430                            .get_object()
431                            .bucket(&self.bucket)
432                            .key(path)
433                            .part_number(part_num)
434                            .send()
435                            .await
436                            .inspect_err(|err| self.update_error_metrics("GetObject", err))
437                            .context("s3 get meta err")?;
438                        min_header_elapsed
439                            .observe(header_start.elapsed(), "s3 download part header");
440                        object
441                    }
442                };
443
444                // Request the body.
445                let body_start = Instant::now();
446
447                // Coalesce all hyper chunks for this part into a single contiguous
448                // allocation. Pushing each SDK `Bytes` chunk separately into
449                // `SegmentedBytes` yields hundreds of segments per blob, which makes
450                // every parquet `ChunkReader::get_bytes` call O(N) and dominates CPU
451                // in `SegmentedBytes::advance`/`get_bytes` during decode. Copying
452                // also releases the hyper pool buffer so it doesn't stay pinned for
453                // the lifetime of the blob.
454                let mut buf = match object.content_length() {
455                    Some(len @ 1..) => BytesMut::with_capacity(usize::cast_from(
456                        u64::try_from(len).expect("positive integer"),
457                    )),
458                    Some(len @ ..=-1) => {
459                        tracing::trace!(?len, "found invalid content-length");
460                        get_invalid_resp.inc();
461                        BytesMut::new()
462                    }
463                    Some(0) | None => BytesMut::new(),
464                };
465
466                while let Some(data) = object.body.next().await {
467                    let data = data.context("s3 get body err")?;
468                    buf.extend_from_slice(&data);
469                }
470
471                let body_elapsed = body_start.elapsed();
472                min_body_elapsed.observe(body_elapsed, "s3 download part body");
473
474                let body_parts = if buf.is_empty() {
475                    Vec::new()
476                } else {
477                    vec![buf.freeze()]
478                };
479                Ok::<_, anyhow::Error>(body_parts)
480            };
481
482            body_futures.push_back(request_future);
483        }
484
485        // Await on all of our parts requests.
486        let mut segments = vec![];
487        while let Some(result) = body_futures.next().await {
488            // Download failure, we failed to fetch the body from S3.
489            let mut part_body = result
490                .inspect_err(|e| {
491                    self.metrics
492                        .error_counts
493                        .with_label_values(&["GetObjectStream", e.to_string().as_str()])
494                        .inc()
495                })
496                .context("s3 get body err")?;
497
498            // Collect all of our segments.
499            segments.append(&mut part_body);
500        }
501
502        debug!(
503            "s3 GetObject took {:?} ({} parts)",
504            start_overall.elapsed(),
505            num_parts
506        );
507        Ok(Some(SegmentedBytes::from(segments)))
508    }
509
510    async fn list_keys_and_metadata(
511        &self,
512        key_prefix: &str,
513        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
514    ) -> Result<(), ExternalError> {
515        let mut continuation_token = None;
516        // we only want to return keys that match the specified blob key prefix
517        let blob_key_prefix = self.get_path(key_prefix);
518        // but we want to exclude the shared root prefix from our returned keys,
519        // so only the blob key itself is passed in to `f`
520        let strippable_root_prefix = format!("{}/", self.prefix);
521
522        loop {
523            self.metrics.list_objects.inc();
524            let resp = self
525                .client
526                .list_objects_v2()
527                .bucket(&self.bucket)
528                .prefix(&blob_key_prefix)
529                .max_keys(self.max_keys)
530                .set_continuation_token(continuation_token)
531                .send()
532                .await
533                .inspect_err(|err| self.update_error_metrics("ListObjectsV2", err))
534                .context("list bucket error")?;
535            if let Some(contents) = resp.contents {
536                for object in contents.iter() {
537                    if let Some(key) = object.key.as_ref() {
538                        if let Some(key) = key.strip_prefix(&strippable_root_prefix) {
539                            let size_in_bytes = match object.size {
540                                None => {
541                                    return Err(ExternalError::from(anyhow!(
542                                        "object missing size: {key}"
543                                    )));
544                                }
545                                Some(size) => size
546                                    .try_into()
547                                    .expect("file in S3 cannot have negative size"),
548                            };
549                            f(BlobMetadata { key, size_in_bytes });
550                        } else {
551                            return Err(ExternalError::from(anyhow!(
552                                "found key with invalid prefix: {}",
553                                key
554                            )));
555                        }
556                    }
557                }
558            }
559
560            if resp.next_continuation_token.is_some() {
561                continuation_token = resp.next_continuation_token;
562            } else {
563                break;
564            }
565        }
566
567        Ok(())
568    }
569
570    async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
571        let value_len = value.len();
572        if self
573            .multipart_config
574            .should_multipart(value_len)
575            .map_err(anyhow::Error::msg)?
576        {
577            self.set_multi_part(key, value)
578                .instrument(debug_span!("s3set_multi", payload_len = value_len))
579                .await
580        } else {
581            self.set_single_part(key, value).await
582        }
583    }
584
585    async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
586        // There is a race condition here where, if two delete calls for the
587        // same key occur simultaneously, both might think they did the actual
588        // deletion. This return value is only used for metrics, so it's
589        // unfortunate, but fine.
590        let path = self.get_path(key);
591        self.metrics.delete_head.inc();
592        let head_res = self
593            .client
594            .head_object()
595            .bucket(&self.bucket)
596            .key(&path)
597            .send()
598            .await;
599        let size_bytes = match head_res {
600            Ok(x) => match x.content_length {
601                None => {
602                    return Err(ExternalError::from(anyhow!(
603                        "s3 delete content length was none"
604                    )));
605                }
606                Some(content_length) => {
607                    u64::try_from(content_length).expect("file in S3 cannot have negative size")
608                }
609            },
610            Err(SdkError::ServiceError(err)) if err.err().is_not_found() => return Ok(None),
611            Err(err) => {
612                self.update_error_metrics("HeadObject", &err);
613                return Err(ExternalError::from(
614                    anyhow!(err).context("s3 delete head err"),
615                ));
616            }
617        };
618        self.metrics.delete_object.inc();
619        let _ = self
620            .client
621            .delete_object()
622            .bucket(&self.bucket)
623            .key(&path)
624            .send()
625            .await
626            .inspect_err(|err| self.update_error_metrics("DeleteObject", err))
627            .context("s3 delete object err")?;
628        Ok(Some(usize::cast_from(size_bytes)))
629    }
630
631    async fn restore(&self, key: &str) -> Result<(), ExternalError> {
632        let path = self.get_path(key);
633        // Fetch the latest version of the object. If it's a normal version, return true;
634        // if it's a delete marker, delete it and loop; if there is no such version,
635        // return false.
636        // TODO: limit the number of delete markers we'll peel back?
637        loop {
638            // S3 only lets us fetch the versions of an object with a list requests.
639            // Seems a bit wasteful to just fetch one at a time, but otherwise we can only
640            // guess the order of versions via the timestamp, and that feels brittle.
641            let list_res = self
642                .client
643                .list_object_versions()
644                .bucket(&self.bucket)
645                .prefix(&path)
646                .max_keys(1)
647                .send()
648                .await
649                .inspect_err(|err| self.update_error_metrics("ListObjectVersions", err))
650                .context("listing object versions during restore")?;
651
652            let current_delete = list_res
653                .delete_markers()
654                .into_iter()
655                .filter(|d| {
656                    // We need to check that any versions we're looking at have the right key,
657                    // not just a key with our key as a prefix.
658                    d.key() == Some(path.as_str())
659                })
660                .find(|d| d.is_latest().unwrap_or(false))
661                .and_then(|d| d.version_id());
662
663            if let Some(version) = current_delete {
664                let deleted = self
665                    .client
666                    .delete_object()
667                    .bucket(&self.bucket)
668                    .key(&path)
669                    .version_id(version)
670                    .send()
671                    .await
672                    .inspect_err(|err| self.update_error_metrics("DeleteObject", err))
673                    .context("deleting a delete marker")?;
674                assert!(
675                    deleted.delete_marker().unwrap_or(false),
676                    "deleting a delete marker"
677                );
678            } else {
679                let has_current_version = list_res
680                    .versions()
681                    .into_iter()
682                    .filter(|d| d.key() == Some(path.as_str()))
683                    .any(|v| v.is_latest().unwrap_or(false));
684
685                if !has_current_version {
686                    return Err(Determinate::new(anyhow!(
687                        "unable to restore {key} in s3: no valid version exists"
688                    ))
689                    .into());
690                }
691                return Ok(());
692            }
693        }
694    }
695}
696
697impl S3Blob {
698    async fn set_single_part(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
699        let start_overall = Instant::now();
700        let path = self.get_path(key);
701
702        let value_len = value.len();
703        let part_span = trace_span!("s3set_single", payload_len = value_len);
704        self.metrics.set_single.inc();
705        self.client
706            .put_object()
707            .bucket(&self.bucket)
708            .key(path)
709            .body(ByteStream::from(value))
710            .send()
711            .instrument(part_span)
712            .await
713            .inspect_err(|err| self.update_error_metrics("PutObject", err))
714            .context("set single part")?;
715        debug!(
716            "s3 PutObject single done {}b / {:?}",
717            value_len,
718            start_overall.elapsed()
719        );
720        Ok(())
721    }
722
723    // TODO(benesch): remove this once this function no longer makes use of
724    // potentially dangerous `as` conversions.
725    #[allow(clippy::as_conversions)]
726    async fn set_multi_part(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
727        let start_overall = Instant::now();
728        let path = self.get_path(key);
729
730        // Start the multi part request and get an upload id.
731        trace!("s3 PutObject multi start {}b", value.len());
732        self.metrics.set_multi_create.inc();
733        let upload_res = self
734            .client
735            .create_multipart_upload()
736            .bucket(&self.bucket)
737            .key(&path)
738            .customize()
739            .mutate_request(|req| {
740                // By default the Rust AWS SDK does not set the Content-Length
741                // header on POST calls with empty bodies. This is fine for S3,
742                // but when running against GCS's S3 interop mode these calls
743                // will be rejected unless we set this header manually.
744                req.headers_mut().insert("Content-Length", "0");
745            })
746            .send()
747            .instrument(debug_span!("s3set_multi_start"))
748            .await
749            .inspect_err(|err| self.update_error_metrics("CreateMultipartUpload", err))
750            .context("create_multipart_upload err")?;
751        let upload_id = upload_res
752            .upload_id()
753            .ok_or_else(|| anyhow!("create_multipart_upload response missing upload_id"))?;
754        trace!(
755            "s3 create_multipart_upload took {:?}",
756            start_overall.elapsed()
757        );
758
759        let async_runtime = AsyncHandle::try_current().map_err(anyhow::Error::new)?;
760
761        // Fire off all the individual parts.
762        //
763        // TODO: The aws cli throttles how many of these are outstanding at any
764        // given point. We'll likely want to do the same at some point.
765        let start_parts = Instant::now();
766        let mut part_futs = Vec::new();
767        for (part_num, part_range) in self.multipart_config.part_iter(value.len()) {
768            // NB: Without this spawn, these will execute serially. This is rust
769            // async 101 stuff, but there isn't much async in the persist
770            // codebase (yet?) so I thought it worth calling out.
771            let part_span = debug_span!("s3set_multi_part", payload_len = part_range.len());
772            let part_fut = async_runtime.spawn_named(
773                // TODO: Add the key and part number once this can be annotated
774                // with metadata.
775                || "persist_s3blob_put_part",
776                {
777                    self.metrics.set_multi_part.inc();
778                    self.client
779                        .upload_part()
780                        .bucket(&self.bucket)
781                        .key(&path)
782                        .upload_id(upload_id)
783                        .part_number(part_num as i32)
784                        .body(ByteStream::from(value.slice(part_range)))
785                        .send()
786                        .instrument(part_span)
787                        .map(move |res| (start_parts.elapsed(), res))
788                },
789            );
790            part_futs.push((part_num, part_fut));
791        }
792        let parts_len = part_futs.len();
793
794        // Wait on all the parts to finish. This is done in part order, no need
795        // for joining them in the order they finish.
796        //
797        // TODO: Consider using something like futures::future::join_all() for
798        // this. That would cancel outstanding requests for us if any of them
799        // fails. However, it might not play well with using retries for tail
800        // latencies. Investigate.
801        let min_part_elapsed = MinElapsed::default();
802        let mut parts = Vec::with_capacity(parts_len);
803        for (part_num, part_fut) in part_futs.into_iter() {
804            let (this_part_elapsed, part_res) = part_fut.await;
805            let part_res = part_res
806                .inspect_err(|err| self.update_error_metrics("UploadPart", err))
807                .context("s3 upload_part err")?;
808            let part_e_tag = part_res.e_tag().ok_or_else(|| {
809                self.metrics
810                    .error_counts
811                    .with_label_values(&["UploadPart", "MissingEtag"])
812                    .inc();
813                anyhow!("s3 upload part missing e_tag")
814            })?;
815            parts.push(
816                CompletedPart::builder()
817                    .e_tag(part_e_tag)
818                    .part_number(part_num as i32)
819                    .build(),
820            );
821            min_part_elapsed.observe(this_part_elapsed, "s3 upload_part took");
822        }
823        trace!(
824            "s3 upload_parts overall took {:?} ({} parts)",
825            start_parts.elapsed(),
826            parts_len
827        );
828
829        // Complete the upload.
830        //
831        // Currently, we early return if any of the individual parts fail. This
832        // permanently orphans any parts that succeeded. One fix is to call
833        // abort_multipart_upload, which deletes them. However, there's also an
834        // option for an s3 bucket to auto-delete parts that haven't been
835        // completed or aborted after a given amount of time. This latter is
836        // simpler and also resilient to ill-timed mz restarts, so we use it for
837        // now. We could likely add the accounting necessary to make
838        // abort_multipart_upload work, but it would be complex and affect perf.
839        // Let's see how far we can get without it.
840        let start_complete = Instant::now();
841        self.metrics.set_multi_complete.inc();
842        self.client
843            .complete_multipart_upload()
844            .bucket(&self.bucket)
845            .key(&path)
846            .upload_id(upload_id)
847            .multipart_upload(
848                CompletedMultipartUpload::builder()
849                    .set_parts(Some(parts))
850                    .build(),
851            )
852            .send()
853            .instrument(debug_span!("s3set_multi_complete", num_parts = parts_len))
854            .await
855            .inspect_err(|err| self.update_error_metrics("CompleteMultipartUpload", err))
856            .context("complete_multipart_upload err")?;
857        trace!(
858            "s3 complete_multipart_upload took {:?}",
859            start_complete.elapsed()
860        );
861
862        debug!(
863            "s3 PutObject multi done {}b / {:?} ({} parts)",
864            value.len(),
865            start_overall.elapsed(),
866            parts_len
867        );
868        Ok(())
869    }
870
871    fn update_error_metrics<E, R>(&self, op: &str, err: &SdkError<E, R>)
872    where
873        E: ProvideErrorMetadata,
874    {
875        let code = match err {
876            SdkError::ServiceError(e) => match e.err().code() {
877                Some(code) => code,
878                None => "UnknownServiceError",
879            },
880            SdkError::DispatchFailure(e) => {
881                if let Some(other_error) = e.as_other() {
882                    match other_error {
883                        aws_config::retry::ErrorKind::TransientError => "TransientError",
884                        aws_config::retry::ErrorKind::ThrottlingError => "ThrottlingError",
885                        aws_config::retry::ErrorKind::ServerError => "ServerError",
886                        aws_config::retry::ErrorKind::ClientError => "ClientError",
887                        _ => "UnknownDispatchFailure",
888                    }
889                } else if e.is_timeout() {
890                    "TimeoutError"
891                } else if e.is_io() {
892                    "IOError"
893                } else if e.is_user() {
894                    "UserError"
895                } else {
896                    "UnknownDispathFailure"
897                }
898            }
899            SdkError::ResponseError(_) => "ResponseError",
900            SdkError::ConstructionFailure(_) => "ConstructionFailure",
901            // There is some overlap with MetricsSleep. MetricsSleep is more granular
902            // but does not contain the operation.
903            SdkError::TimeoutError(_) => "TimeoutError",
904            // an error was added at some point in the future
905            _ => "UnknownSdkError",
906        };
907        self.metrics
908            .error_counts
909            .with_label_values(&[op, code])
910            .inc();
911    }
912}
913
914#[derive(Clone, Debug)]
915struct MultipartConfig {
916    multipart_threshold: usize,
917    multipart_chunk_size: usize,
918}
919
920impl Default for MultipartConfig {
921    fn default() -> Self {
922        Self {
923            multipart_threshold: Self::DEFAULT_MULTIPART_THRESHOLD,
924            multipart_chunk_size: Self::DEFAULT_MULTIPART_CHUNK_SIZE,
925        }
926    }
927}
928
929const MB: usize = 1024 * 1024;
930const TB: usize = 1024 * 1024 * MB;
931
932impl MultipartConfig {
933    /// The minimum object size for which we start using multipart upload.
934    ///
935    /// From the official `aws cli` tool implementation:
936    ///
937    /// <https://github.com/aws/aws-cli/blob/2.4.14/awscli/customizations/s3/transferconfig.py#L18-L29>
938    const DEFAULT_MULTIPART_THRESHOLD: usize = 8 * MB;
939    /// The size of each part (except the last) in a multipart upload.
940    ///
941    /// From the official `aws cli` tool implementation:
942    ///
943    /// <https://github.com/aws/aws-cli/blob/2.4.14/awscli/customizations/s3/transferconfig.py#L18-L29>
944    const DEFAULT_MULTIPART_CHUNK_SIZE: usize = 8 * MB;
945
946    /// The largest size object creatable in S3.
947    ///
948    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
949    const MAX_SINGLE_UPLOAD_SIZE: usize = 5 * TB;
950    /// The minimum size of a part in a multipart upload.
951    ///
952    /// This minimum doesn't apply to the last chunk, which can be any size.
953    ///
954    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
955    const MIN_UPLOAD_CHUNK_SIZE: usize = 5 * MB;
956    /// The smallest allowable part number (inclusive).
957    ///
958    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
959    const MIN_PART_NUM: u32 = 1;
960    /// The largest allowable part number (inclusive).
961    ///
962    /// From <https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html>
963    const MAX_PART_NUM: u32 = 10_000;
964
965    fn should_multipart(&self, blob_len: usize) -> Result<bool, String> {
966        if blob_len > Self::MAX_SINGLE_UPLOAD_SIZE {
967            return Err(format!(
968                "S3 does not support blobs larger than {} bytes got: {}",
969                Self::MAX_SINGLE_UPLOAD_SIZE,
970                blob_len
971            ));
972        }
973        Ok(blob_len > self.multipart_threshold)
974    }
975
976    fn part_iter(&self, blob_len: usize) -> MultipartChunkIter {
977        debug_assert!(self.multipart_chunk_size >= MultipartConfig::MIN_UPLOAD_CHUNK_SIZE);
978        MultipartChunkIter::new(self.multipart_chunk_size, blob_len)
979    }
980}
981
982#[derive(Clone, Debug)]
983struct MultipartChunkIter {
984    total_len: usize,
985    part_size: usize,
986    part_idx: u32,
987}
988
989impl MultipartChunkIter {
990    fn new(default_part_size: usize, blob_len: usize) -> Self {
991        let max_parts: usize = usize::cast_from(MultipartConfig::MAX_PART_NUM);
992
993        // Compute the minimum part size we can use without going over the max
994        // number of parts that S3 allows: `ceil(blob_len / max_parts)`.This
995        // will end up getting thrown away by the `cmp::max` for anything
996        // smaller than `max_parts * default_part_size = 80GiB`.
997        let min_part_size = (blob_len + max_parts - 1) / max_parts;
998        let part_size = cmp::max(min_part_size, default_part_size);
999
1000        // Part nums are 1-indexed in S3. Convert back to 0-indexed to make the
1001        // range math easier to follow.
1002        let part_idx = MultipartConfig::MIN_PART_NUM - 1;
1003        MultipartChunkIter {
1004            total_len: blob_len,
1005            part_size,
1006            part_idx,
1007        }
1008    }
1009}
1010
1011impl Iterator for MultipartChunkIter {
1012    type Item = (u32, Range<usize>);
1013
1014    fn next(&mut self) -> Option<Self::Item> {
1015        let part_idx = self.part_idx;
1016        self.part_idx += 1;
1017
1018        let start = usize::cast_from(part_idx) * self.part_size;
1019        if start >= self.total_len {
1020            return None;
1021        }
1022        let end = cmp::min(start + self.part_size, self.total_len);
1023        let part_num = part_idx + 1;
1024        Some((part_num, start..end))
1025    }
1026}
1027
1028/// A helper for tracking the minimum of a set of Durations.
1029#[derive(Debug)]
1030struct MinElapsed {
1031    min: AtomicU64,
1032    alert_factor: u64,
1033}
1034
1035impl Default for MinElapsed {
1036    fn default() -> Self {
1037        MinElapsed {
1038            min: AtomicU64::new(u64::MAX),
1039            alert_factor: 8,
1040        }
1041    }
1042}
1043
1044impl MinElapsed {
1045    fn observe(&self, x: Duration, msg: &'static str) {
1046        let nanos = x.as_nanos();
1047        let nanos = u64::try_from(nanos).unwrap_or(u64::MAX);
1048
1049        // Possibly set a new minimum.
1050        let prev_min = self.min.fetch_min(nanos, atomic::Ordering::SeqCst);
1051
1052        // Trace if our provided duration was much larger than our minimum.
1053        let new_min = std::cmp::min(prev_min, nanos);
1054        if nanos > new_min.saturating_mul(self.alert_factor) {
1055            let min_duration = Duration::from_nanos(new_min);
1056            let factor = self.alert_factor;
1057            debug!("{msg} took {x:?} more than {factor}x the min {min_duration:?}");
1058        } else {
1059            trace!("{msg} took {x:?}");
1060        }
1061    }
1062}
1063
1064// Make sure the "vendored" feature of the openssl_sys crate makes it into the
1065// transitive dep graph of persist, so that we don't attempt to link against the
1066// system OpenSSL library. Fake a usage of the crate here so that a good
1067// samaritan doesn't remove our unused dep.
1068#[allow(dead_code)]
1069fn openssl_sys_hack() {
1070    openssl_sys::init();
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075    use tracing::info;
1076
1077    use crate::location::tests::blob_impl_test;
1078
1079    use super::*;
1080
1081    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1082    #[cfg_attr(coverage, ignore)] // https://github.com/MaterializeInc/database-issues/issues/5586
1083    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `TLS_method` on OS `linux`
1084    #[ignore] // TODO: Reenable against minio so it can run locally
1085    async fn s3_blob() -> Result<(), ExternalError> {
1086        let config = match S3BlobConfig::new_for_test().await? {
1087            Some(client) => client,
1088            None => {
1089                info!(
1090                    "{} env not set: skipping test that uses external service",
1091                    S3BlobConfig::EXTERNAL_TESTS_S3_BUCKET
1092                );
1093                return Ok(());
1094            }
1095        };
1096        let config_multipart = config.clone_with_new_uuid_prefix();
1097
1098        blob_impl_test(move |path| {
1099            let path = path.to_owned();
1100            let config = config.clone();
1101            async move {
1102                let config = S3BlobConfig {
1103                    metrics: config.metrics.clone(),
1104                    client: config.client.clone(),
1105                    bucket: config.bucket.clone(),
1106                    prefix: format!("{}/s3_blob_impl_test/{}", config.prefix, path),
1107                };
1108                let mut blob = S3Blob::open(config).await?;
1109                blob.max_keys = 2;
1110                Ok(blob)
1111            }
1112        })
1113        .await?;
1114
1115        // Also specifically test multipart. S3 requires all parts but the last
1116        // to be at least 5MB, which we don't want to do from a test, so this
1117        // uses the multipart code path but only writes a single part.
1118        {
1119            let blob = S3Blob::open(config_multipart).await?;
1120            blob.set_multi_part("multipart", "foobar".into()).await?;
1121            assert_eq!(
1122                blob.get("multipart").await?,
1123                Some(b"foobar".to_vec().into())
1124            );
1125        }
1126
1127        Ok(())
1128    }
1129
1130    #[mz_ore::test]
1131    fn should_multipart() {
1132        let config = MultipartConfig::default();
1133        assert_eq!(config.should_multipart(0), Ok(false));
1134        assert_eq!(config.should_multipart(1), Ok(false));
1135        assert_eq!(
1136            config.should_multipart(MultipartConfig::DEFAULT_MULTIPART_THRESHOLD),
1137            Ok(false)
1138        );
1139        assert_eq!(
1140            config.should_multipart(MultipartConfig::DEFAULT_MULTIPART_THRESHOLD + 1),
1141            Ok(true)
1142        );
1143        assert_eq!(
1144            config.should_multipart(MultipartConfig::DEFAULT_MULTIPART_THRESHOLD * 2),
1145            Ok(true)
1146        );
1147        assert_eq!(
1148            config.should_multipart(MultipartConfig::MAX_SINGLE_UPLOAD_SIZE),
1149            Ok(true)
1150        );
1151        assert_eq!(
1152            config.should_multipart(MultipartConfig::MAX_SINGLE_UPLOAD_SIZE + 1),
1153            Err(
1154                "S3 does not support blobs larger than 5497558138880 bytes got: 5497558138881"
1155                    .into()
1156            )
1157        );
1158    }
1159
1160    #[mz_ore::test]
1161    fn multipart_iter() {
1162        let iter = MultipartChunkIter::new(10, 0);
1163        assert_eq!(iter.collect::<Vec<_>>(), vec![]);
1164
1165        let iter = MultipartChunkIter::new(10, 9);
1166        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..9)]);
1167
1168        let iter = MultipartChunkIter::new(10, 10);
1169        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10)]);
1170
1171        let iter = MultipartChunkIter::new(10, 11);
1172        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10), (2, 10..11)]);
1173
1174        let iter = MultipartChunkIter::new(10, 19);
1175        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10), (2, 10..19)]);
1176
1177        let iter = MultipartChunkIter::new(10, 20);
1178        assert_eq!(iter.collect::<Vec<_>>(), vec![(1, 0..10), (2, 10..20)]);
1179
1180        let iter = MultipartChunkIter::new(10, 21);
1181        assert_eq!(
1182            iter.collect::<Vec<_>>(),
1183            vec![(1, 0..10), (2, 10..20), (3, 20..21)]
1184        );
1185    }
1186}