Skip to main content

mz_adapter/coord/sequencer/inner/
copy_from.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::str::FromStr;
11use std::sync::Arc;
12
13use mz_adapter_types::connection::ConnectionId;
14use mz_expr::Eval;
15use mz_ore::cast::CastInto;
16use mz_persist_client::Diagnostics;
17use mz_persist_client::batch::ProtoBatch;
18use mz_persist_types::codec_impls::UnitSchema;
19use mz_pgcopy::CopyFormatParams;
20use mz_repr::{CatalogItemId, ColumnIndex, Datum, RelationDesc, Row, RowArena};
21use mz_sql::catalog::SessionCatalog;
22use mz_sql::plan::{self, CopyFromFilter, CopyFromSource, HirScalarExpr};
23use mz_sql::session::metadata::SessionMetadata;
24use mz_storage_client::client::TableData;
25use mz_storage_types::StorageDiff;
26use mz_storage_types::oneshot_sources::{ContentShape, OneshotIngestionRequest};
27use mz_storage_types::sources::SourceData;
28use smallvec::SmallVec;
29use timely::progress::Antichain;
30use tokio::sync::{mpsc, oneshot};
31use url::Url;
32use uuid::Uuid;
33
34use crate::command::CopyFromStdinWriter;
35use crate::coord::sequencer::inner::return_if_err;
36use crate::coord::{ActiveCopyFrom, Coordinator, TargetCluster};
37use crate::optimize;
38use crate::optimize::dataflows::{EvalTime, ExprPrep, ExprPrepOneShot};
39use crate::session::{Session, TransactionOps, WriteOp};
40use crate::{AdapterError, ExecuteContext, ExecuteResponse};
41
42/// Finalize persist batches periodically during COPY FROM STDIN to avoid
43/// unbounded in-memory growth in a single giant batch.
44const COPY_FROM_STDIN_MAX_BATCH_BYTES: usize = 32 * 1024 * 1024;
45
46/// Cap on the number of parallel decode workers spawned per COPY FROM STDIN.
47/// A single network-bound stream sees marginal gains past a handful of
48/// decoders, and capping bounds how much of the blocking pool any one COPY can
49/// occupy while actively decoding.
50const COPY_FROM_STDIN_MAX_WORKERS: usize = 8;
51
52impl Coordinator {
53    pub(crate) async fn sequence_copy_from(
54        &mut self,
55        ctx: ExecuteContext,
56        plan: plan::CopyFromPlan,
57        target_cluster: TargetCluster,
58    ) {
59        if ctx
60            .session()
61            .vars()
62            .transaction_isolation()
63            .is_bounded_staleness()
64        {
65            return ctx.retire(Err(AdapterError::BoundedStalenessReadOnly));
66        }
67
68        // STDIN is sequenced by handing control back to pgwire, which drives the
69        // CopyData/CopyDone exchange. URL/S3 sources stage a one-shot ingestion
70        // server-side and fall through to the rest of this function.
71        if let CopyFromSource::Stdin = plan.source {
72            let (tx, _, session, ctx_extra) = ctx.into_parts();
73            tx.send(
74                Ok(ExecuteResponse::CopyFrom {
75                    target_id: plan.target_id,
76                    target_name: plan.target_name,
77                    columns: plan.columns,
78                    params: plan.params,
79                    ctx_extra,
80                }),
81                session,
82            );
83            return;
84        }
85
86        let plan::CopyFromPlan {
87            target_name: _,
88            target_id,
89            source,
90            columns: _,
91            source_desc,
92            mfp,
93            params,
94            filter,
95        } = plan;
96
97        let eval_uri = |from: HirScalarExpr| -> Result<String, AdapterError> {
98            let style = ExprPrepOneShot {
99                logical_time: EvalTime::NotAvailable,
100                session: ctx.session(),
101                catalog_state: self.catalog().state(),
102            };
103            let mut from = from.lower_uncorrelated(self.catalog().state().system_config())?;
104            style.prep_scalar_expr(&mut from)?;
105
106            // TODO(cf3): Add structured errors for the below uses of `coord_bail!`
107            // and AdapterError::Unstructured.
108            let temp_storage = RowArena::new();
109            let eval_result = from.eval(&[], &temp_storage)?;
110            let eval_string = match eval_result {
111                Datum::Null => coord_bail!("COPY FROM target value cannot be NULL"),
112                Datum::String(url_str) => url_str,
113                other => coord_bail!("programming error! COPY FROM target cannot be {other}"),
114            };
115
116            Ok(eval_string.to_string())
117        };
118
119        // We check in planning that we're copying into a Table, but be defensive.
120        let Some(entry) = self.catalog().try_get_entry(&target_id) else {
121            return ctx.retire(Err(AdapterError::ConcurrentDependencyDrop {
122                dependency_kind: "table",
123                dependency_id: target_id.to_string(),
124            }));
125        };
126        let Some(dest_table) = entry.table() else {
127            let typ = entry.item().typ();
128            let msg = format!("programming error: expected a Table found {typ:?}");
129            return ctx.retire(Err(AdapterError::Unstructured(anyhow::anyhow!(msg))));
130        };
131
132        // Generate a unique UUID for our ingestion.
133        let ingestion_id = Uuid::new_v4();
134        let collection_id = dest_table.global_id_writes();
135
136        let format = match params {
137            CopyFormatParams::Csv(csv) => {
138                mz_storage_types::oneshot_sources::ContentFormat::Csv(csv.to_owned())
139            }
140            CopyFormatParams::Parquet => mz_storage_types::oneshot_sources::ContentFormat::Parquet,
141            CopyFormatParams::Text(_) | CopyFormatParams::Binary => {
142                mz_ore::soft_panic_or_log!("unsupported formats should be rejected in planning");
143                ctx.retire(Err(AdapterError::Unsupported("COPY FROM URL/S3 format")));
144                return;
145            }
146        };
147
148        let source = match source {
149            CopyFromSource::Url(from_expr) => {
150                let url = return_if_err!(eval_uri(from_expr), ctx);
151                // TODO(cf2): Structured errors.
152                let result = Url::parse(&url)
153                    .map_err(|err| AdapterError::Unstructured(anyhow::anyhow!("{err}")));
154                let url = return_if_err!(result, ctx);
155
156                // Only allow http(s) schemes. Technically we would fail later, as the current
157                // crate (reqwest) doesn't support other schemes.
158                // Prefer to fail early and explicitly in case the downstream ever changes.
159                // DNS resolution for hostnames is performed at execution time to avoid stalls
160                // during sequencing; IP-literal hosts are validated here because reqwest's
161                // custom DNS resolver is only invoked for hostnames.
162                match url.scheme() {
163                    "http" | "https" => {}
164                    other => {
165                        return ctx.retire(Err(AdapterError::Unstructured(anyhow::anyhow!(
166                            "only 'http://' and 'https://' urls are supported as COPY FROM \
167                             target, got '{other}://'"
168                        ))));
169                    }
170                }
171                let enforce_external_addresses =
172                    mz_storage_types::dyncfgs::ENFORCE_EXTERNAL_ADDRESSES
173                        .get(self.controller.storage.config().config_set());
174                if enforce_external_addresses {
175                    if let Err(err) = mz_ore::netio::ensure_url_ip_global(&url) {
176                        return ctx
177                            .retire(Err(AdapterError::Unstructured(anyhow::anyhow!("{err}"))));
178                    }
179                }
180                mz_storage_types::oneshot_sources::ContentSource::Http {
181                    url: mz_ore::url::SensitiveUrl(url),
182                }
183            }
184            CopyFromSource::AwsS3 {
185                uri,
186                connection,
187                connection_id,
188            } => {
189                let uri = return_if_err!(eval_uri(uri), ctx);
190
191                // Validate the URI is an S3 URI, with a bucket name. We rely on validating here
192                // and expect it in clusterd.
193                //
194                // TODO(cf2): Structured errors.
195                let result = http::Uri::from_str(&uri)
196                    .map_err(|err| {
197                        AdapterError::Unstructured(anyhow::anyhow!("expected S3 uri: {err}"))
198                    })
199                    .and_then(|uri| {
200                        if uri.scheme_str() != Some("s3") && uri.scheme_str() != Some("gs") {
201                            coord_bail!("only 's3://...' and 'gs://...' urls are supported as COPY FROM target");
202                        }
203                        Ok(uri)
204                    })
205                    .and_then(|uri| {
206                        if uri.host().is_none() {
207                            coord_bail!("missing bucket name from 's3://...' url");
208                        }
209                        Ok(uri)
210                    });
211                let uri = return_if_err!(result, ctx);
212
213                mz_storage_types::oneshot_sources::ContentSource::AwsS3 {
214                    connection,
215                    connection_id,
216                    uri: uri.to_string(),
217                }
218            }
219            CopyFromSource::Stdin => {
220                unreachable!("STDIN handled by the early return above")
221            }
222        };
223
224        let filter = match filter {
225            None => mz_storage_types::oneshot_sources::ContentFilter::None,
226            Some(CopyFromFilter::Files(files)) => {
227                mz_storage_types::oneshot_sources::ContentFilter::Files(files)
228            }
229            Some(CopyFromFilter::Pattern(pattern)) => {
230                mz_storage_types::oneshot_sources::ContentFilter::Pattern(pattern)
231            }
232        };
233
234        let source_mfp = mfp
235            .into_plan()
236            .map_err(|s| AdapterError::internal("copy_from", s))
237            .and_then(|mfp| {
238                mfp.into_nontemporal().map_err(|_| {
239                    AdapterError::internal("copy_from", "temporal MFP not allowed in copy from")
240                })
241            });
242        let source_mfp = return_if_err!(source_mfp, ctx);
243
244        let shape = ContentShape {
245            source_desc,
246            source_mfp,
247        };
248
249        let request = OneshotIngestionRequest {
250            source,
251            format,
252            filter,
253            shape,
254        };
255
256        let target_cluster = match self
257            .catalog()
258            .resolve_target_cluster(target_cluster, ctx.session())
259        {
260            Ok(cluster) => cluster,
261            Err(err) => {
262                return ctx.retire(Err(err));
263            }
264        };
265        let cluster_id = target_cluster.id;
266
267        // When we finish staging the Batches in Persist, we'll send a command
268        // to the Coordinator.
269        let command_tx = self.internal_cmd_tx.clone();
270        let conn_id = ctx.session().conn_id().clone();
271        let closure = Box::new(move |batches| {
272            let _ = command_tx.send(crate::coord::Message::StagedBatches {
273                conn_id,
274                table_id: target_id,
275                batches,
276            });
277        });
278        // Stash the execute context so we can cancel the COPY.
279        let conn_id = ctx.session().conn_id().clone();
280        self.active_copies.insert(
281            conn_id,
282            ActiveCopyFrom {
283                ingestion_id,
284                cluster_id,
285                table_id: target_id,
286                ctx,
287            },
288        );
289
290        let _result = self
291            .controller
292            .storage
293            .create_oneshot_ingestion(ingestion_id, collection_id, cluster_id, request, closure)
294            .await;
295    }
296
297    /// Sets up a streaming COPY FROM STDIN operation.
298    ///
299    /// Spawns N parallel background batch builder tasks that each receive
300    /// raw byte chunks, decode them, apply column defaults/reordering,
301    /// and build persist batches. Returns a [`CopyFromStdinWriter`] for
302    /// pgwire to distribute raw byte chunks across the workers.
303    pub(crate) fn setup_copy_from_stdin(
304        &self,
305        session: &Session,
306        target_id: CatalogItemId,
307        target_name: String,
308        columns: Vec<ColumnIndex>,
309        row_desc: RelationDesc,
310        params: CopyFormatParams<'static>,
311    ) -> Result<CopyFromStdinWriter, AdapterError> {
312        // Look up the table and its persist shard metadata.
313        let Some(entry) = self.catalog().try_get_entry(&target_id) else {
314            return Err(AdapterError::ConcurrentDependencyDrop {
315                dependency_kind: "table",
316                dependency_id: target_id.to_string(),
317            });
318        };
319        let Some(dest_table) = entry.table() else {
320            let typ = entry.item().typ();
321            return Err(AdapterError::Unstructured(anyhow::anyhow!(
322                "programming error: expected a Table found {typ:?}"
323            )));
324        };
325        let collection_id = dest_table.global_id_writes();
326
327        let collection_meta = self
328            .controller
329            .storage
330            .collection_metadata(collection_id)
331            .map_err(|e| AdapterError::Unstructured(anyhow::anyhow!("{e}")))?;
332        let shard_id = collection_meta.data_shard;
333        let collection_desc = collection_meta.relation_desc.clone();
334
335        // Pre-compute the column transformation.
336        let pcx = session.pcx().clone();
337        let session_meta = session.meta();
338        let catalog = self.catalog().clone();
339        let conn_catalog = catalog.for_session(session);
340        let catalog_state = conn_catalog.state();
341        let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
342
343        // Determine if we need column rewriting (defaults/reordering).
344        let target_desc = catalog
345            .try_get_entry(&target_id)
346            .expect("table must exist")
347            .relation_desc_latest()
348            .expect("table has desc")
349            .into_owned();
350        let all_columns_in_order = columns.len() == target_desc.arity()
351            && columns.iter().enumerate().all(|(i, c)| c.to_raw() == i);
352
353        // If we need column rewriting, pre-compute the transform by running
354        // plan_copy_from with a single dummy row through the optimizer.
355        let column_transform = if all_columns_in_order {
356            None
357        } else {
358            let dummy_datums: Vec<Datum> = columns.iter().map(|_| Datum::Null).collect();
359            let dummy_row = Row::pack(&dummy_datums);
360
361            let prep = ExprPrepOneShot {
362                logical_time: EvalTime::NotAvailable,
363                session: &session_meta,
364                catalog_state,
365            };
366            let mut optimizer = optimize::view::Optimizer::new_with_prep_no_limit(
367                optimizer_config.clone(),
368                None,
369                prep,
370            );
371
372            let hir = mz_sql::plan::plan_copy_from(
373                &pcx,
374                &conn_catalog,
375                target_id,
376                target_name.clone(),
377                columns.clone(),
378                vec![dummy_row],
379            )?;
380            let mir = optimize::Optimize::optimize(&mut optimizer, hir)?;
381            let mir_expr = mir.into_inner();
382            let (result_ref, _) = mir_expr
383                .as_const()
384                .expect("optimizer should produce constant");
385            let result_rows = result_ref
386                .clone()
387                .map_err(|e| AdapterError::Unstructured(anyhow::anyhow!("eval error: {e}")))?;
388
389            let (full_row, _) = result_rows.into_iter().next().expect("should have one row");
390            let full_datums: Vec<Datum> = full_row.unpack();
391
392            let col_to_source: std::collections::BTreeMap<ColumnIndex, usize> =
393                columns.iter().enumerate().map(|(a, b)| (*b, a)).collect();
394
395            let mut sources: Vec<ColumnSource> = Vec::with_capacity(target_desc.arity());
396            let mut default_datums: Vec<Datum> = Vec::new();
397
398            for i in 0..target_desc.arity() {
399                let col_idx = ColumnIndex::from_raw(i);
400                if let Some(&src_idx) = col_to_source.get(&col_idx) {
401                    sources.push(ColumnSource::Input(src_idx));
402                } else {
403                    sources.push(ColumnSource::Default(default_datums.len()));
404                    default_datums.push(full_datums[i]);
405                }
406            }
407
408            let defaults_row = Row::pack(&default_datums);
409
410            Some(ColumnTransform {
411                sources,
412                defaults_row,
413            })
414        };
415
416        // Compute column types for decoding (same logic as pgwire used to do).
417        let column_types: Arc<[mz_pgrepr::Type]> = row_desc
418            .typ()
419            .column_types
420            .iter()
421            .map(|x| &x.scalar_type)
422            .map(mz_pgrepr::Type::from)
423            .collect::<Vec<_>>()
424            .into();
425
426        // Determine number of parallel workers, capped so that a single COPY
427        // cannot reserve an unbounded share of the shared blocking pool.
428        let num_workers = std::cmp::min(
429            std::thread::available_parallelism()
430                .map(|n| n.get())
431                .unwrap_or(1),
432            COPY_FROM_STDIN_MAX_WORKERS,
433        );
434        tracing::info!(
435            %target_id, num_workers,
436            "starting parallel COPY FROM STDIN batch builders"
437        );
438
439        // Shared state across workers.
440        let column_transform = Arc::new(column_transform);
441        let target_desc = Arc::new(target_desc);
442        let collection_desc = Arc::new(collection_desc);
443        let persist_client = self.persist_client.clone();
444
445        // Create per-worker channels and spawn one async task per worker. Each
446        // worker offloads the CPU-intensive processing of a chunk (decode plus
447        // the per-row transform/constraint-check/columnar encode) to the
448        // blocking pool for the duration of that chunk (see
449        // `copy_from_stdin_batch_builder`), so workers run in parallel while
450        // doing CPU work but hold no thread while idle between chunks.
451        let mut batch_txs = Vec::with_capacity(num_workers);
452        let mut worker_handles = Vec::with_capacity(num_workers);
453
454        // When COPY FROM uses CSV with HEADER, only the very first chunk in
455        // the stream contains the real header line. The pgwire handler splits
456        // data into ~32MB chunks distributed round-robin across workers, so
457        // subsequent chunks' first rows are data, not headers. We must only
458        // skip the header on the first chunk of worker 0.
459        let first_chunk_has_header = params.requires_header();
460        let mut worker_params = params;
461        if let CopyFormatParams::Csv(ref mut csv) = worker_params {
462            csv.header = false;
463        }
464
465        for worker_id in 0..num_workers {
466            // Keep in-flight buffering tight: at most one chunk queued per
467            // worker in addition to the currently-processed chunk.
468            let (batch_tx, batch_rx) = mpsc::channel::<Vec<u8>>(1);
469            batch_txs.push(batch_tx);
470
471            let persist_client = persist_client.clone();
472            let column_types = Arc::clone(&column_types);
473            let column_transform = Arc::clone(&column_transform);
474            let target_desc = Arc::clone(&target_desc);
475            let collection_desc = Arc::clone(&collection_desc);
476            let params = worker_params.clone();
477            // Only worker 0 receives the first chunk (round-robin), so only
478            // it needs to skip the CSV header on its first chunk.
479            let skip_header_on_first_chunk = worker_id == 0 && first_chunk_has_header;
480
481            let handle = mz_ore::task::spawn(
482                || format!("copy_from_stdin_worker:{target_id}:{worker_id}"),
483                Self::copy_from_stdin_batch_builder(
484                    persist_client,
485                    shard_id,
486                    collection_id,
487                    collection_desc,
488                    target_desc,
489                    column_transform,
490                    column_types,
491                    params,
492                    skip_header_on_first_chunk,
493                    batch_rx,
494                ),
495            );
496            worker_handles.push(handle);
497        }
498
499        // Spawn a collector task that waits for all workers.
500        let (completion_tx, completion_rx) = oneshot::channel();
501        mz_ore::task::spawn(
502            || format!("copy_from_stdin_collector:{target_id}"),
503            async move {
504                let mut all_batches = Vec::with_capacity(num_workers);
505                let mut total_rows: u64 = 0;
506
507                for handle in worker_handles {
508                    match handle.await {
509                        Ok((proto_batches, count)) => {
510                            all_batches.extend(proto_batches);
511                            total_rows += count;
512                        }
513                        Err(e) => {
514                            let _ = completion_tx.send(Err(e));
515                            return;
516                        }
517                    }
518                }
519
520                let _ = completion_tx.send(Ok((all_batches, total_rows)));
521            },
522        );
523
524        Ok(CopyFromStdinWriter {
525            batch_txs,
526            completion_rx,
527        })
528    }
529
530    /// Background task: receives raw byte chunks, decodes rows, and builds
531    /// persist batches. One instance runs per parallel worker.
532    async fn copy_from_stdin_batch_builder(
533        persist_client: mz_persist_client::PersistClient,
534        shard_id: mz_persist_client::ShardId,
535        collection_id: mz_repr::GlobalId,
536        collection_desc: Arc<RelationDesc>,
537        target_desc: Arc<RelationDesc>,
538        column_transform: Arc<Option<ColumnTransform>>,
539        column_types: Arc<[mz_pgrepr::Type]>,
540        params: CopyFormatParams<'static>,
541        skip_header_on_first_chunk: bool,
542        mut batch_rx: mpsc::Receiver<Vec<u8>>,
543    ) -> Result<(Vec<ProtoBatch>, u64), AdapterError> {
544        let persist_diagnostics = Diagnostics {
545            shard_name: collection_id.to_string(),
546            handle_purpose: "CopyFromStdin::batch_builder".to_string(),
547        };
548        let write_handle = persist_client
549            .open_writer::<SourceData, (), mz_repr::Timestamp, StorageDiff>(
550                shard_id,
551                collection_desc,
552                Arc::new(UnitSchema),
553                persist_diagnostics,
554            )
555            .await
556            .map_err(|e| AdapterError::Unstructured(anyhow::anyhow!("persist open: {e}")))?;
557
558        // Build a batch at the minimum timestamp. The coordinator will
559        // re-timestamp it during commit.
560        let lower = mz_repr::Timestamp::MIN;
561        let upper = Antichain::from_elem(lower.step_forward());
562        let mut batch_builder = write_handle.builder(Antichain::from_elem(lower));
563        let mut row_count: u64 = 0;
564        let mut row_count_in_batch: u64 = 0;
565        let mut batch_bytes: usize = 0;
566        let mut proto_batches = Vec::new();
567
568        let rt = tokio::runtime::Handle::current();
569        let mut is_first_chunk = true;
570        while let Some(raw_bytes) = batch_rx.recv().await {
571            // For the first chunk of worker 0, re-enable header skipping so the
572            // real CSV header line is skipped.
573            let chunk_params = if is_first_chunk && skip_header_on_first_chunk {
574                let mut p = params.clone();
575                if let CopyFormatParams::Csv(ref mut csv) = p {
576                    csv.header = true;
577                }
578                p
579            } else {
580                params.clone()
581            };
582            is_first_chunk = false;
583            let raw_len = raw_bytes.len();
584
585            // Offload the entire CPU-bound per-chunk pipeline -- decode, column
586            // transform, constraint checks, and the columnar persist encode
587            // (`BatchBuilder::add` -> `PartBuilder::push`) -- to the blocking
588            // pool. There is no yield point in the row loop until a batch fills
589            // (`add` only awaits `flush_part`, and only once an *encoded* part
590            // reaches `blob_target_size`, far beyond the 32 MiB *raw* batch
591            // boundary), so left on the async runtime each chunk's rows would
592            // run as one uninterrupted burst on a shared runtime worker thread,
593            // starving other connections. The blocking thread is held only
594            // while a chunk is in flight and released back to the pool between
595            // chunks (during `recv().await`), so idle workers still hold no
596            // thread. `block_on` is invoked once per chunk -- not per row -- to
597            // drive the row loop and the rare `flush_part` it may await.
598            let chunk_column_types = Arc::clone(&column_types);
599            let chunk_transform = Arc::clone(&column_transform);
600            let chunk_target_desc = Arc::clone(&target_desc);
601            let chunk_rt = rt.clone();
602            let (returned_builder, added_rows) = mz_ore::task::spawn_blocking(
603                || "copy_from_stdin_process_chunk",
604                move || {
605                    let rows = mz_pgcopy::decode_copy_format(
606                        &raw_bytes,
607                        &chunk_column_types,
608                        chunk_params,
609                    )
610                    .map_err(|e| AdapterError::CopyFormatError(e.to_string()))?;
611
612                    chunk_rt.block_on(async move {
613                        let mut added: u64 = 0;
614                        for row in rows {
615                            // Apply column transform if needed (add defaults, reorder).
616                            let full_row = if let Some(ref transform) = *chunk_transform {
617                                transform.apply(&row)
618                            } else {
619                                row
620                            };
621
622                            // Check constraints.
623                            for (i, datum) in full_row.iter().enumerate() {
624                                chunk_target_desc.constraints_met(i, &datum).map_err(|e| {
625                                    AdapterError::Unstructured(anyhow::anyhow!(
626                                        "constraint violation: {e}"
627                                    ))
628                                })?;
629                            }
630
631                            let data = SourceData(Ok(full_row));
632                            batch_builder
633                                .add(&data, &(), &lower, &1)
634                                .await
635                                .map_err(|e| {
636                                    AdapterError::Unstructured(anyhow::anyhow!("persist add: {e}"))
637                                })?;
638                            added += 1;
639                        }
640                        Ok::<_, AdapterError>((batch_builder, added))
641                    })
642                },
643            )
644            .await?;
645            batch_builder = returned_builder;
646            row_count += added_rows;
647            row_count_in_batch += added_rows;
648
649            batch_bytes = batch_bytes.saturating_add(raw_len);
650            if batch_bytes >= COPY_FROM_STDIN_MAX_BATCH_BYTES {
651                let batch = batch_builder.finish(upper.clone()).await.map_err(|e| {
652                    AdapterError::Unstructured(anyhow::anyhow!("persist finish: {e}"))
653                })?;
654                proto_batches.push(batch.into_transmittable_batch());
655
656                batch_builder = write_handle.builder(Antichain::from_elem(lower));
657                row_count_in_batch = 0;
658                batch_bytes = 0;
659            }
660        }
661
662        if row_count_in_batch > 0 || proto_batches.is_empty() {
663            let batch = batch_builder
664                .finish(upper)
665                .await
666                .map_err(|e| AdapterError::Unstructured(anyhow::anyhow!("persist finish: {e}")))?;
667            proto_batches.push(batch.into_transmittable_batch());
668        }
669
670        Ok((proto_batches, row_count))
671    }
672
673    pub(crate) fn commit_staged_batches(
674        &mut self,
675        conn_id: ConnectionId,
676        table_id: CatalogItemId,
677        batches: Vec<Result<ProtoBatch, String>>,
678    ) {
679        let Some(active_copy) = self.active_copies.remove(&conn_id) else {
680            // Getting a successful response for a cancel COPY FROM is unexpected.
681            tracing::warn!(%conn_id, ?batches, "got response for canceled COPY FROM");
682            return;
683        };
684
685        let ActiveCopyFrom {
686            ingestion_id,
687            cluster_id: _,
688            table_id: _,
689            mut ctx,
690        } = active_copy;
691        tracing::info!(%ingestion_id, num_batches = ?batches.len(), "received batches to append");
692
693        let mut all_batches = SmallVec::with_capacity(batches.len());
694        let mut all_errors = SmallVec::<[String; 1]>::with_capacity(batches.len());
695        let mut row_count = 0u64;
696
697        for maybe_batch in batches {
698            match maybe_batch {
699                Ok(batch) => {
700                    let count = batch.batch.as_ref().map(|b| b.len).unwrap_or(0);
701                    all_batches.push(batch);
702                    row_count = row_count.saturating_add(count);
703                }
704                Err(err) => all_errors.push(err),
705            }
706        }
707
708        // If we got any errors we need to fail the whole operation.
709        if let Some(error) = all_errors.pop() {
710            tracing::warn!(?error, ?all_errors, "failed COPY FROM");
711
712            // TODO(cf1): Cleanup the existing ProtoBatches to prevent leaking them.
713            // TODO(cf2): Carry structured errors all the way through.
714
715            ctx.retire(Err(AdapterError::Unstructured(anyhow::anyhow!(
716                "COPY FROM: {error}"
717            ))));
718
719            return;
720        }
721
722        // Stage a WriteOp, then when the Session is retired we complete the
723        // transaction, which handles acquiring the write lock for `table_id`,
724        // advancing the timestamps of the staged batches, and waiting for
725        // everything to complete before sending a response to the client.
726        let stage_write = ctx
727            .session_mut()
728            .add_transaction_ops(TransactionOps::Writes(vec![WriteOp {
729                id: table_id,
730                rows: TableData::Batches(all_batches),
731            }]));
732
733        if let Err(err) = stage_write {
734            ctx.retire(Err(err));
735        } else {
736            ctx.retire(Ok(ExecuteResponse::Copied(row_count.cast_into())));
737        }
738    }
739
740    /// Cancel any active `COPY FROM` statements/oneshot ingestions.
741    #[mz_ore::instrument(level = "debug")]
742    pub(crate) fn cancel_pending_copy(&mut self, conn_id: &ConnectionId) {
743        if let Some(ActiveCopyFrom {
744            ingestion_id,
745            cluster_id: _,
746            table_id: _,
747            ctx,
748        }) = self.active_copies.remove(conn_id)
749        {
750            let cancel_result = self
751                .controller
752                .storage
753                .cancel_oneshot_ingestion(ingestion_id);
754            if let Err(err) = cancel_result {
755                tracing::error!(?err, "failed to cancel OneshotIngestion");
756            }
757
758            ctx.retire(Err(AdapterError::Canceled));
759        }
760    }
761}
762
763/// Describes how to transform a partial row (with only specified columns)
764/// into a full row matching the table schema.
765struct ColumnTransform {
766    /// For each column in the target table, where to get the value.
767    sources: Vec<ColumnSource>,
768    /// Pre-computed default values for columns not in the COPY column list.
769    /// Packed as a Row; indexed by the `Default(idx)` variant.
770    defaults_row: Row,
771}
772
773enum ColumnSource {
774    /// Take the value from the input row at this position.
775    Input(usize),
776    /// Use the pre-computed default at this index in `defaults_row`.
777    Default(usize),
778}
779
780impl ColumnTransform {
781    /// Apply the transform to produce a full row from a partial input row.
782    fn apply(&self, input: &Row) -> Row {
783        let input_datums: Vec<Datum> = input.unpack();
784        let default_datums: Vec<Datum> = self.defaults_row.unpack();
785        let mut output_datums = Vec::with_capacity(self.sources.len());
786        for source in &self.sources {
787            match source {
788                ColumnSource::Input(idx) => output_datums.push(input_datums[*idx]),
789                ColumnSource::Default(idx) => output_datums.push(default_datums[*idx]),
790            }
791        }
792        Row::pack(&output_datums)
793    }
794}