use anyhow::anyhow;
use aws_types::sdk_config::SdkConfig;
use mz_aws_util::s3_uploader::{
CompletedUpload, S3MultiPartUploadError, S3MultiPartUploader, S3MultiPartUploaderConfig,
};
use mz_ore::assert_none;
use mz_ore::cast::CastFrom;
use mz_ore::task::JoinHandleExt;
use mz_pgcopy::{encode_copy_format, encode_copy_format_header, CopyFormatParams};
use mz_repr::{GlobalId, RelationDesc, Row};
use mz_storage_types::sinks::{S3SinkFormat, S3UploadInfo};
use tracing::info;
use super::{CopyToParameters, CopyToS3Uploader, S3KeyManager};
pub(super) struct PgCopyUploader {
desc: RelationDesc,
format: CopyFormatParams<'static>,
file_index: usize,
key_manager: S3KeyManager,
batch: u64,
max_file_size: u64,
sdk_config: Option<SdkConfig>,
current_file_uploader: Option<S3MultiPartUploader>,
params: CopyToParameters,
}
impl CopyToS3Uploader for PgCopyUploader {
fn new(
sdk_config: SdkConfig,
connection_details: S3UploadInfo,
sink_id: &GlobalId,
batch: u64,
params: CopyToParameters,
) -> Result<PgCopyUploader, anyhow::Error> {
match connection_details.format {
S3SinkFormat::PgCopy(format_params) => Ok(PgCopyUploader {
desc: connection_details.desc,
sdk_config: Some(sdk_config),
format: format_params,
key_manager: S3KeyManager::new(sink_id, &connection_details.uri),
batch,
max_file_size: connection_details.max_file_size,
file_index: 0,
current_file_uploader: None,
params,
}),
_ => anyhow::bail!("Expected PgCopy format"),
}
}
async fn finish(&mut self) -> Result<(), anyhow::Error> {
if let Some(uploader) = self.current_file_uploader.take() {
let handle =
mz_ore::task::spawn(|| "s3_uploader::finish", async { uploader.finish().await });
let CompletedUpload {
part_count,
total_bytes_uploaded,
bucket,
key,
} = handle.wait_and_assert_finished().await?;
info!(
"finished upload: bucket {}, key {}, bytes_uploaded {}, parts_uploaded {}",
bucket, key, total_bytes_uploaded, part_count
);
}
Ok(())
}
async fn append_row(&mut self, row: &Row) -> Result<(), anyhow::Error> {
let mut buf: Vec<u8> = vec![];
encode_copy_format(&self.format, row, self.desc.typ(), &mut buf)
.map_err(|_| anyhow!("error encoding row"))?;
if self.current_file_uploader.is_none() {
self.start_new_file_upload().await?;
}
let mut uploader = self.current_file_uploader.as_mut().expect("known exists");
match uploader.buffer_chunk(&buf) {
Ok(_) => Ok(()),
Err(S3MultiPartUploadError::UploadExceedsMaxFileLimit(_)) => {
self.start_new_file_upload().await?;
uploader = self.current_file_uploader.as_mut().expect("known exists");
uploader.buffer_chunk(&buf)?;
Ok(())
}
Err(e) => Err(e.into()),
}
}
}
impl PgCopyUploader {
async fn start_new_file_upload(&mut self) -> Result<(), anyhow::Error> {
self.finish().await?;
assert_none!(self.current_file_uploader);
self.file_index += 1;
let object_key =
self.key_manager
.data_key(self.batch, self.file_index, self.format.file_extension());
let bucket = self.key_manager.bucket.clone();
info!("starting upload: bucket {}, key {}", &bucket, &object_key);
let sdk_config = self
.sdk_config
.take()
.expect("sdk_config should always be present");
let max_file_size = self.max_file_size;
let part_size_limit = u64::cast_from(self.params.s3_multipart_part_size_bytes);
let handle = mz_ore::task::spawn(|| "s3_uploader::try_new", async move {
let uploader = S3MultiPartUploader::try_new(
&sdk_config,
bucket,
object_key,
S3MultiPartUploaderConfig {
part_size_limit,
file_size_limit: max_file_size,
},
)
.await;
(uploader, sdk_config)
});
let (uploader, sdk_config) = handle.wait_and_assert_finished().await;
self.sdk_config = Some(sdk_config);
let mut uploader = uploader?;
if self.format.requires_header() {
let mut buf: Vec<u8> = vec![];
encode_copy_format_header(&self.format, &self.desc, &mut buf)
.map_err(|_| anyhow!("error encoding header"))?;
uploader.buffer_chunk(&buf)?;
}
self.current_file_uploader = Some(uploader);
Ok(())
}
}
#[cfg(test)]
mod tests {
use bytesize::ByteSize;
use mz_pgcopy::CopyFormatParams;
use mz_repr::{ColumnName, ColumnType, Datum, RelationType};
use uuid::Uuid;
use super::*;
fn s3_bucket_path_for_test() -> Option<(String, String)> {
let bucket = match std::env::var("MZ_S3_UPLOADER_TEST_S3_BUCKET") {
Ok(bucket) => bucket,
Err(_) => {
if mz_ore::env::is_var_truthy("CI") {
panic!("CI is supposed to run this test but something has gone wrong!");
}
return None;
}
};
let prefix = Uuid::new_v4().to_string();
let path = format!("cargo_test/{}/file", prefix);
Some((bucket, path))
}
#[mz_ore::test(tokio::test(flavor = "multi_thread"))]
#[cfg_attr(coverage, ignore)] #[cfg_attr(miri, ignore)] async fn test_multiple_files() -> Result<(), anyhow::Error> {
let sdk_config = mz_aws_util::defaults().load().await;
let (bucket, path) = match s3_bucket_path_for_test() {
Some(tuple) => tuple,
None => return Ok(()),
};
let sink_id = GlobalId::User(123);
let batch = 456;
let typ: RelationType = RelationType::new(vec![ColumnType {
scalar_type: mz_repr::ScalarType::String,
nullable: true,
}]);
let column_names = vec![ColumnName::from("col1")];
let desc = RelationDesc::new(typ, column_names.into_iter());
let mut uploader = PgCopyUploader::new(
sdk_config.clone(),
S3UploadInfo {
uri: format!("s3://{}/{}", bucket, path),
max_file_size: ByteSize::b(6).as_u64(),
desc,
format: S3SinkFormat::PgCopy(CopyFormatParams::Csv(Default::default())),
},
&sink_id,
batch,
CopyToParameters {
s3_multipart_part_size_bytes: 10 * 1024 * 1024,
arrow_builder_buffer_ratio: 100,
parquet_row_group_ratio: 100,
},
)?;
let mut row = Row::default();
row.packer().push(Datum::from("1234567"));
uploader.append_row(&row).await?;
row.packer().push(Datum::Null);
uploader.append_row(&row).await?;
row.packer().push(Datum::from("5678"));
uploader.append_row(&row).await?;
uploader.finish().await?;
let s3_client = mz_aws_util::s3::new_client(&sdk_config);
let first_file = s3_client
.get_object()
.bucket(bucket.clone())
.key(format!(
"{}/mz-{}-batch-{:04}-0001.csv",
path, sink_id, batch
))
.send()
.await
.unwrap();
let body = first_file.body.collect().await.unwrap().into_bytes();
let expected_body: &[u8] = b"1234567\n";
assert_eq!(body, *expected_body);
let second_file = s3_client
.get_object()
.bucket(bucket)
.key(format!(
"{}/mz-{}-batch-{:04}-0002.csv",
path, sink_id, batch
))
.send()
.await
.unwrap();
let body = second_file.body.collect().await.unwrap().into_bytes();
let expected_body: &[u8] = b"\n5678\n";
assert_eq!(body, *expected_body);
Ok(())
}
}