use std::any::Any;
use std::cell::RefCell;
use std::rc::Rc;
use differential_dataflow::{AsCollection, Collection, Hashable};
use mz_compute_client::protocol::response::CopyToResponse;
use mz_compute_types::dyncfgs::{
COPY_TO_S3_ARROW_BUILDER_BUFFER_RATIO, COPY_TO_S3_MULTIPART_PART_SIZE_BYTES,
COPY_TO_S3_PARQUET_ROW_GROUP_FILE_RATIO,
};
use mz_compute_types::sinks::{ComputeSinkDesc, CopyToS3OneshotSinkConnection};
use mz_repr::{Diff, GlobalId, Row, Timestamp};
use mz_storage_types::controller::CollectionMetadata;
use mz_storage_types::errors::DataflowError;
use mz_timely_util::operator::consolidate_pact;
use timely::dataflow::channels::pact::{Exchange, Pipeline};
use timely::dataflow::operators::Operator;
use timely::dataflow::Scope;
use timely::progress::Antichain;
use crate::render::sinks::SinkRender;
use crate::render::StartSignal;
use crate::typedefs::KeyBatcher;
impl<G> SinkRender<G> for CopyToS3OneshotSinkConnection
where
G: Scope<Timestamp = Timestamp>,
{
fn render_sink(
&self,
compute_state: &mut crate::compute_state::ComputeState,
sink: &ComputeSinkDesc<CollectionMetadata>,
sink_id: GlobalId,
_as_of: Antichain<Timestamp>,
_start_signal: StartSignal,
sinked_collection: Collection<G, Row, Diff>,
err_collection: Collection<G, DataflowError, Diff>,
_ct_times: Option<Collection<G, (), Diff>>,
) -> Option<Rc<dyn Any>> {
let mut response_protocol = ResponseProtocol {
sink_id,
response_buffer: Some(Rc::clone(&compute_state.copy_to_response_buffer)),
};
let result_callback = move |count: Result<u64, String>| {
response_protocol.send(count);
};
let batch_count = self.output_batch_count;
let input = consolidate_pact::<KeyBatcher<_, _, _>, _, _, _, _>(
&sinked_collection
.map(move |row| {
let batch = row.hashed() % batch_count;
((row, batch), ())
})
.inner,
Exchange::new(move |(((_, batch), _), _, _)| *batch),
"Consolidated COPY TO S3 input",
)
.as_collection();
let error = consolidate_pact::<KeyBatcher<_, _, _>, _, _, _, _>(
&err_collection
.map(move |row| {
let batch = row.hashed() % batch_count;
((row, batch), ())
})
.inner,
Exchange::new(move |(((_, batch), _), _, _)| *batch),
"Consolidated COPY TO S3 errors",
)
.container::<Vec<_>>();
let error_stream =
error.unary_frontier(Pipeline, "COPY TO S3 error filtering", |_cap, _info| {
let up_to = sink.up_to.clone();
let mut received_one = false;
move |input, output| {
while let Some((time, data)) = input.next() {
if !up_to.less_equal(time.time()) && !received_one {
received_one = true;
output.session(&time).give_iterator(data.drain(..1));
}
}
}
});
let params = mz_storage_operators::s3_oneshot_sink::CopyToParameters {
parquet_row_group_ratio: COPY_TO_S3_PARQUET_ROW_GROUP_FILE_RATIO
.get(&compute_state.worker_config),
arrow_builder_buffer_ratio: COPY_TO_S3_ARROW_BUILDER_BUFFER_RATIO
.get(&compute_state.worker_config),
s3_multipart_part_size_bytes: COPY_TO_S3_MULTIPART_PART_SIZE_BYTES
.get(&compute_state.worker_config),
};
let token = mz_storage_operators::s3_oneshot_sink::copy_to(
input,
error_stream,
sink.up_to.clone(),
self.upload_info.clone(),
compute_state.context.connection_context.clone(),
self.aws_connection.clone(),
sink_id,
self.connection_id,
params,
result_callback,
);
Some(token)
}
}
struct ResponseProtocol {
pub sink_id: GlobalId,
pub response_buffer: Option<Rc<RefCell<Vec<(GlobalId, CopyToResponse)>>>>,
}
impl ResponseProtocol {
fn send(&mut self, count: Result<u64, String>) {
let buffer = self.response_buffer.take().expect("expect response buffer");
let response = match count {
Ok(count) => CopyToResponse::RowCount(count),
Err(error) => CopyToResponse::Error(error),
};
buffer.borrow_mut().push((self.sink_id, response));
}
}
impl Drop for ResponseProtocol {
fn drop(&mut self) {
if let Some(buffer) = self.response_buffer.take() {
buffer
.borrow_mut()
.push((self.sink_id, CopyToResponse::Dropped));
}
}
}