//! AWS S3 [`OneshotSource`].
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use derivative::Derivative;
use futures::stream::{BoxStream, TryStreamExt};
use futures::StreamExt;
use mz_ore::future::InTask;
use mz_repr::CatalogItemId;
use mz_storage_types::connections::aws::AwsConnection;
use mz_storage_types::connections::ConnectionContext;
use serde::{Deserialize, Serialize};
use crate::oneshot_source::util::IntoRangeHeaderValue;
use crate::oneshot_source::{
OneshotObject, OneshotSource, StorageErrorX, StorageErrorXContext, StorageErrorXKind,
#[derive(Clone, Derivative)]
pub struct AwsS3Source {
// Only used for initialization.
#[derivative(Debug = "ignore")]
connection: Arc<AwsConnection>,
connection_id: CatalogItemId,
#[derivative(Debug = "ignore")]
context: Arc<ConnectionContext>,
/// Name of the S3 bucket we'll list from.
bucket: String,
/// Optional prefix that can be specified via an S3 URI.
prefix: Option<String>,
/// S3 client that is lazily initialized.
#[derivative(Debug = "ignore")]
client: std::sync::OnceLock<mz_aws_util::s3::Client>,
impl AwsS3Source {
pub fn new(
connection: AwsConnection,
connection_id: CatalogItemId,
context: ConnectionContext,
uri: String,
) -> Self {
let uri = http::Uri::from_str(&uri).expect("validated URI in sequencing");
let bucket = uri
.expect("validated host in sequencing")
let prefix = if uri.path().is_empty() || uri.path() == "/" {
} else {
// The S3 client expects a trailing `/` but no leading `/`.
let mut prefix = uri.path().to_string();
if let Some(suffix) = prefix.strip_prefix('/') {
prefix = suffix.to_string();
if !prefix.ends_with('/') {
prefix = format!("{prefix}/");
AwsS3Source {
connection: Arc::new(connection),
context: Arc::new(context),
client: std::sync::OnceLock::new(),
pub async fn initialize(&self) -> Result<mz_aws_util::s3::Client, anyhow::Error> {
let sdk_config = self
.load_sdk_config(&self.context, self.connection_id, InTask::Yes)
let s3_client = mz_aws_util::s3::new_client(&sdk_config);
pub async fn client(&self) -> Result<&mz_aws_util::s3::Client, anyhow::Error> {
if self.client.get().is_none() {
let client = self.initialize().await?;
let _ = self.client.set(client);
Ok(self.client.get().expect("just initialized"))
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct S3Object {
/// Key from S3 list operation.
key: String,
/// Name of the object, generally the last component of the key.
name: String,
/// Size of the object in bytes.
size: usize,
impl OneshotObject for S3Object {
fn name(&self) -> &str {
fn path(&self) -> &str {
fn size(&self) -> usize {
fn encodings(&self) -> &[super::Encoding] {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct S3Checksum {
e_tag: Option<String>,
impl OneshotSource for AwsS3Source {
type Object = S3Object;
type Checksum = S3Checksum;
async fn list<'a>(
&'a self,
) -> Result<Vec<(Self::Object, Self::Checksum)>, super::StorageErrorX> {
let client = self.client().await.map_err(StorageErrorXKind::generic)?;
let mut objects_request = client.list_objects_v2().bucket(&self.bucket);
// Users can optionally specify a prefix via the S3 uri they originally specify.
if let Some(prefix) = &self.prefix {
objects_request = objects_request.prefix(prefix);
let objects = objects_request
// TODO(cf1): Pagination.
let objects: Vec<_> = objects
.map(|o| {
let key = o
.ok_or_else(|| StorageErrorXKind::MissingField("key".into()))?
let name = Path::new(&key)
.and_then(|os_name| os_name.to_str())
.ok_or_else(|| StorageErrorXKind::Generic(format!("malformed key: {key}")))?
let size = o
.ok_or_else(|| StorageErrorXKind::MissingField("size".into()))?;
let size: usize = size.try_into().map_err(StorageErrorXKind::generic)?;
let object = S3Object { key, name, size };
let checksum = S3Checksum {
e_tag: o.e_tag().map(|x| x.to_owned()),
Ok::<_, StorageErrorXKind>((object, checksum))
.collect::<Result<_, _>>()
fn get<'s>(
&'s self,
object: Self::Object,
_checksum: Self::Checksum,
range: Option<std::ops::RangeInclusive<usize>>,
) -> BoxStream<'s, Result<bytes::Bytes, StorageErrorX>> {
let initial_response = async move {
tracing::info!(name = %object.name(), ?range, "fetching object");
// TODO(cf1): Validate our checksum.
let client = self.client().await.map_err(StorageErrorXKind::generic)?;
let mut request = client.get_object().bucket(&self.bucket).key(&object.key);
if let Some(range) = range {
let value = range.into_range_header_value();
request = request.range(value);
let object = request
.map_err(|err| StorageErrorXKind::AwsS3Request(err.to_string()))?;
// AWS's ByteStream doesn't implement the Stream trait.
let stream = mz_aws_util::s3::ByteStreamAdapter::new(object.body)
Ok::<_, StorageErrorX>(stream)