mz_storage/source/postgres/snapshot.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Renders the table snapshot side of the [`PostgresSourceConnection`] ingestion dataflow.
11//!
12//! # Snapshot reading
13//!
14//! Depending on the resumption LSNs the table reader decides which tables need to be snapshotted.
15//! Each table is partitioned across all workers using PostgreSQL's `ctid` (tuple identifier)
16//! column, which identifies the physical location of each row. This allows parallel snapshotting
17//! of large tables across all available workers.
18//!
19//! There are a few subtle points about this operation, described in the following sections.
20//!
21//! ## Consistent LSN point for snapshot transactions
22//!
23//! Given that all our ingestion is based on correctly timestamping updates with the LSN they
24//! happened at it is important that we run the `COPY` query at a specific LSN point that is
25//! relatable with the LSN numbers we receive from the replication stream. Such point does not
26//! necessarily exist for a normal SQL transaction. To achieve this we must force postgres to
27//! produce a consistent point and let us know of the LSN number of that by creating a replication
28//! slot as the first statement in a transaction.
29//!
30//! This is a temporary dummy slot that is only used to put our snapshot transaction on a
31//! consistent LSN point. Unfortunately no lighterweight method exists for doing this. See this
32//! [postgres thread] for more details.
33//!
34//! One might wonder why we don't use the actual real slot to provide us with the snapshot point
35//! which would automatically be at the correct LSN. The answer is that it's possible that we crash
36//! and restart after having already created the slot but before having finished the snapshot. In
37//! that case the restarting process will have lost its opportunity to run queries at the slot's
38//! consistent point as that opportunity only exists in the ephemeral transaction that created the
39//! slot and that is long gone. Additionally there are good reasons of why we'd like to move the
40//! slot creation much earlier, e.g during purification, in which case the slot will always be
41//! pre-created.
42//!
43//! [postgres thread]: https://www.postgresql.org/message-id/flat/CAMN0T-vzzNy6TV1Jvh4xzNQdAvCLBQK_kh6_U7kAXgGU3ZFg-Q%40mail.gmail.com
44//!
45//! ## Reusing the consistent point among all workers
46//!
47//! Creating replication slots is potentially expensive so the code makes is such that all workers
48//! cooperate and reuse one consistent snapshot among them. In order to do so we make use the
49//! "export transaction" feature of postgres. This feature allows one SQL session to create an
50//! identifier for the transaction (a string identifier) it is currently in, which can be used by
51//! other sessions to enter the same "snapshot".
52//!
53//! We accomplish this by picking one worker at random to function as the transaction leader. The
54//! transaction leader is responsible for starting a SQL session, creating a temporary replication
55//! slot in a transaction, exporting the transaction id, and broadcasting the transaction
56//! information to all other workers via a broadcasted feedback edge.
57//!
58//! During this phase the follower workers are simply waiting to hear on the feedback edge,
59//! effectively synchronizing with the leader. Once all workers have received the snapshot
60//! information they can all start to perform their assigned COPY queries.
61//!
62//! The leader and follower steps described above are accomplished by the [`export_snapshot`] and
63//! [`use_snapshot`] functions respectively.
64//!
65//! ## Coordinated transaction COMMIT
66//!
67//! When follower workers are done with snapshotting they commit their transaction, close their
68//! session, and then drop their snapshot feedback capability. When the leader worker is done with
69//! snapshotting it drops its snapshot feedback capability and waits until it observes the
70//! snapshot input advancing to the empty frontier. This allows the leader to COMMIT its
71//! transaction last, which is the transaction that exported the snapshot.
72//!
73//! It's unclear if this is strictly necessary, but having the frontiers made it easy enough that I
74//! added the synchronization.
75//!
76//! ## Snapshot rewinding
77//!
78//! Ingestion dataflows must produce definite data, including the snapshot. What this means
79//! practically is that whenever we deem it necessary to snapshot a table we must do so at the same
80//! LSN. However, the method for running a transaction described above doesn't let us choose the
81//! LSN, it could be an LSN in the future chosen by PostgresSQL while it creates the temporary
82//! replication slot.
83//!
84//! The definition of differential collections states that a collection at some time `t_snapshot`
85//! is defined to be the accumulation of all updates that happen at `t <= t_snapshot`, where `<=`
86//! is the partial order. In this case we are faced with the problem of knowing the state of a
87//! table at `t_snapshot` but actually wanting to know the snapshot at `t_slot <= t_snapshot`.
88//!
89//! From the definition we can see that the snapshot at `t_slot` is related to the snapshot at
90//! `t_snapshot` with the following equations:
91//!
92//!```text
93//! sum(update: t <= t_snapshot) = sum(update: t <= t_slot) + sum(update: t_slot <= t <= t_snapshot)
94//! |
95//! V
96//! sum(update: t <= t_slot) = sum(update: t <= snapshot) - sum(update: t_slot <= t <= t_snapshot)
97//! ```
98//!
99//! Therefore, if we manage to recover the `sum(update: t_slot <= t <= t_snapshot)` term we will be
100//! able to "rewind" the snapshot we obtained at `t_snapshot` to `t_slot` by emitting all updates
101//! that happen between these two points with their diffs negated.
102//!
103//! It turns out that this term is exactly what the main replication slot provides us with and we
104//! can rewind snapshot at arbitrary points! In order to do this the snapshot dataflow emits rewind
105//! requests to the replication reader which informs it that a certain range of updates must be
106//! emitted at LSN 0 (by convention) with their diffs negated. These negated diffs are consolidated
107//! with the diffs taken at `t_snapshot` that were also emitted at LSN 0 (by convention) and we end
108//! up with a TVC that at LSN 0 contains the snapshot at `t_slot`.
109//!
110//! # Parallel table snapshotting with ctid ranges
111//!
112//! Each table is partitioned across workers using PostgreSQL's `ctid` column. The `ctid` is a
113//! tuple identifier of the form `(block_number, tuple_index)` that represents the physical
114//! location of a row on disk. By partitioning the ctid range, each worker can independently
115//! fetch a portion of the table.
116//!
117//! The partitioning works as follows:
118//! 1. The snapshot leader queries `pg_class.relpages` to estimate the number of blocks for each
119//! table. This is much faster than querying `max(ctid)` which would require a sequential scan.
120//! 2. The leader broadcasts the block count estimates along with the snapshot transaction ID
121//! to all workers, ensuring all workers use consistent estimates for partitioning.
122//! 3. Each worker calculates its assigned block range and fetches rows using a `COPY` query
123//! with a `SELECT` that filters by `ctid >= start AND ctid < end`.
124//! 4. The last worker uses an open-ended range (`ctid >= start`) to capture any rows beyond
125//! the estimated block count (handles cases where statistics are stale or table has grown).
126//!
127//! This approach efficiently parallelizes large table snapshots while maintaining the benefits
128//! of the `COPY` protocol for bulk data transfer.
129//!
130//! ## PostgreSQL version requirements
131//!
132//! Ctid range scans are only efficient on PostgreSQL >= 14 due to TID range scan optimizations
133//! introduced in that version. For older PostgreSQL versions, the snapshot falls back to the
134//! single-worker-per-table mode where each table is assigned to one worker based on consistent
135//! hashing. This is implemented by having the leader broadcast all-zero block counts when
136//! PostgreSQL version < 14.
137//!
138//! # Snapshot decoding
139//!
140//! Each worker fetches its ctid range directly and decodes the COPY stream locally.
141//!
142//! ```text
143//! ╭──────────────────╮
144//! ┏━━━━━━━━━━━━v━┓ │ exported
145//! ┃ table ┃ ╭─────────╮ │ snapshot id
146//! ┃ readers ┠─>─┤broadcast├──╯
147//! ┃ (parallel) ┃ ╰─────────╯
148//! ┗━┯━━━━━━━━━━┯━┛
149//! raw│ │
150//! COPY│ │
151//! data│ │
152//! ┏━━━━┷━━━━┓ │
153//! ┃ COPY ┃ │
154//! ┃ decoder ┃ │
155//! ┗━━━━┯━━━━┛ │
156//! │ snapshot │rewind
157//! │ updates │requests
158//! v v
159//! ```
160
161use std::collections::BTreeMap;
162use std::convert::Infallible;
163use std::pin::pin;
164use std::rc::Rc;
165use std::sync::Arc;
166use std::time::Duration;
167
168use anyhow::bail;
169use differential_dataflow::AsCollection;
170use futures::{StreamExt as _, TryStreamExt};
171use itertools::Itertools;
172use mz_ore::cast::CastFrom;
173use mz_ore::future::InTask;
174use mz_postgres_util::desc::PostgresTableDesc;
175use mz_postgres_util::schemas::get_pg_major_version;
176use mz_postgres_util::{Client, Config, PostgresError, simple_query_opt};
177use mz_repr::{Datum, DatumVec, Diff, Row};
178use mz_sql_parser::ast::{
179 Ident,
180 display::{AstDisplay, escaped_string_literal},
181};
182use mz_storage_types::connections::ConnectionContext;
183use mz_storage_types::errors::DataflowError;
184use mz_storage_types::parameters::PgSourceSnapshotConfig;
185use mz_storage_types::sources::{MzOffset, PostgresSourceConnection};
186use mz_timely_util::builder_async::{
187 Event as AsyncEvent, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
188};
189use timely::container::CapacityContainerBuilder;
190use timely::container::DrainContainer;
191use timely::dataflow::channels::pact::Pipeline;
192use timely::dataflow::operators::core::Map;
193use timely::dataflow::operators::vec::Broadcast;
194use timely::dataflow::operators::{CapabilitySet, Concat, ConnectLoop, Feedback, Operator};
195use timely::dataflow::{Scope, StreamVec};
196use timely::progress::Timestamp;
197use tokio_postgres::error::SqlState;
198use tokio_postgres::types::{Oid, PgLsn};
199use tracing::trace;
200
201use crate::metrics::source::postgres::PgSnapshotMetrics;
202use crate::source::RawSourceCreationConfig;
203use crate::source::postgres::replication::RewindRequest;
204use crate::source::postgres::{
205 DefiniteError, ReplicationError, SourceOutputInfo, TransientError, verify_schema,
206};
207use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
208use crate::statistics::SourceStatistics;
209
210/// Information broadcasted from the snapshot leader to all workers.
211/// This includes the transaction snapshot ID, LSN, and estimated block counts for each table.
212#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
213struct SnapshotInfo {
214 /// The exported transaction snapshot identifier.
215 snapshot_id: String,
216 /// The LSN at which the snapshot was taken.
217 snapshot_lsn: MzOffset,
218 /// Estimated number of blocks (pages) for each table, keyed by OID.
219 /// This is derived from `pg_class.relpages` and used to partition ctid ranges.
220 table_block_counts: BTreeMap<u32, u64>,
221 /// The current upstream schema of each table.
222 upstream_info: BTreeMap<u32, PostgresTableDesc>,
223}
224
225/// Represents a ctid range that a worker should snapshot.
226/// The range is [start_block, end_block) where end_block is optional (None means unbounded).
227#[derive(Debug)]
228struct CtidRange {
229 /// The starting block number (inclusive).
230 start_block: u64,
231 /// The ending block number (exclusive). None means unbounded (open-ended range).
232 end_block: Option<u64>,
233}
234
235/// Calculate the ctid range for a given worker based on estimated block count.
236///
237/// The table is partitioned by block number across all workers. Each worker gets a contiguous
238/// range of blocks. The last worker gets an open-ended range to handle any rows beyond the
239/// estimated block count.
240///
241/// When `estimated_blocks` is 0 (either because statistics are unavailable, the table appears
242/// empty, or PostgreSQL version < 14 doesn't support ctid range scans), the table is assigned
243/// to a single worker determined by `config.responsible_for(oid)` and that worker scans the
244/// full table.
245///
246/// Returns None if this worker has no work to do.
247fn worker_ctid_range(
248 config: &RawSourceCreationConfig,
249 estimated_blocks: u64,
250 oid: u32,
251) -> Option<CtidRange> {
252 // If estimated_blocks is 0, fall back to single-worker mode for this table.
253 // This handles:
254 // - PostgreSQL < 14 (ctid range scans not supported)
255 // - Tables that appear empty in statistics
256 // - Tables with stale/missing statistics
257 // The responsible worker scans the full table with an open-ended range.
258 if estimated_blocks == 0 {
259 let fallback = if config.responsible_for(oid) {
260 Some(CtidRange {
261 start_block: 0,
262 end_block: None,
263 })
264 } else {
265 None
266 };
267 return fallback;
268 }
269
270 let worker_id = u64::cast_from(config.worker_id);
271 let worker_count = u64::cast_from(config.worker_count);
272
273 // If there are more workers than blocks, only assign work to workers with id < estimated_blocks
274 // The last assigned worker still gets an open range.
275 let effective_worker_count = std::cmp::min(worker_count, estimated_blocks);
276
277 if worker_id >= effective_worker_count {
278 // This worker has no work to do
279 return None;
280 }
281
282 // Calculate start block for this worker (integer division distributes blocks evenly)
283 let start_block = worker_id * estimated_blocks / effective_worker_count;
284
285 // The last effective worker gets an open-ended range
286 let is_last_effective_worker = worker_id == effective_worker_count - 1;
287 if is_last_effective_worker {
288 Some(CtidRange {
289 start_block,
290 end_block: None,
291 })
292 } else {
293 let end_block = (worker_id + 1) * estimated_blocks / effective_worker_count;
294 Some(CtidRange {
295 start_block,
296 end_block: Some(end_block),
297 })
298 }
299}
300
301/// Estimate the number of blocks for each table from pg_class statistics.
302/// This is used to partition ctid ranges across workers.
303async fn estimate_table_block_counts(
304 client: &Client,
305 table_oids: &[u32],
306) -> Result<BTreeMap<u32, u64>, TransientError> {
307 if table_oids.is_empty() {
308 return Ok(BTreeMap::new());
309 }
310
311 // Query relpages for all tables at once
312 let oid_list = table_oids
313 .iter()
314 .map(|oid| oid.to_string())
315 .collect::<Vec<_>>()
316 .join(",");
317 let query = format!(
318 "SELECT oid, relpages FROM pg_class WHERE oid IN ({})",
319 oid_list
320 );
321
322 let mut block_counts = BTreeMap::new();
323 // Initialize all tables with 0 blocks (in case they're not in pg_class)
324 for &oid in table_oids {
325 block_counts.insert(oid, 0);
326 }
327
328 // Execute the query and collect results
329 let rows = client.simple_query(&query).await?;
330 for msg in rows {
331 if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
332 let oid: u32 = row.get("oid").unwrap().parse().unwrap();
333 let relpages: i64 = row.get("relpages").unwrap().parse().unwrap_or(0);
334 // relpages can be -1 if never analyzed, treat as 0
335 let relpages = std::cmp::max(0, relpages).try_into().unwrap();
336 block_counts.insert(oid, relpages);
337 }
338 }
339
340 Ok(block_counts)
341}
342
343/// Renders the snapshot dataflow. See the module documentation for more information.
344pub(crate) fn render<G: Scope<Timestamp = MzOffset>>(
345 mut scope: G,
346 config: RawSourceCreationConfig,
347 connection: PostgresSourceConnection,
348 table_info: BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
349 metrics: PgSnapshotMetrics,
350) -> (
351 StackedCollection<G, (usize, Result<SourceMessage, DataflowError>)>,
352 StreamVec<G, RewindRequest>,
353 StreamVec<G, Infallible>,
354 StreamVec<G, ReplicationError>,
355 PressOnDropButton,
356) {
357 let op_name = format!("TableReader({})", config.id);
358 let mut builder = AsyncOperatorBuilder::new(op_name, scope.clone());
359
360 let (feedback_handle, feedback_data) = scope.feedback(Default::default());
361
362 let (raw_handle, raw_data) = builder.new_output();
363 let (rewinds_handle, rewinds) = builder.new_output::<CapacityContainerBuilder<_>>();
364 // This output is used to signal to the replication operator that the replication slot has been
365 // created. With the current state of execution serialization there isn't a lot of benefit
366 // of splitting the snapshot and replication phases into two operators.
367 // TODO(petrosagg): merge the two operators in one (while still maintaining separation as
368 // functions/modules)
369 let (_, slot_ready) = builder.new_output::<CapacityContainerBuilder<_>>();
370 let (snapshot_handle, snapshot) = builder.new_output::<CapacityContainerBuilder<_>>();
371 let (definite_error_handle, definite_errors) =
372 builder.new_output::<CapacityContainerBuilder<_>>();
373
374 // This operator needs to broadcast data to itself in order to synchronize the transaction
375 // snapshot. However, none of the feedback capabilities result in output messages and for the
376 // feedback edge specifically having a default conncetion would result in a loop.
377 let mut snapshot_input = builder.new_disconnected_input(feedback_data, Pipeline);
378
379 // The export id must be sent to all workers, so we broadcast the feedback connection
380 snapshot.broadcast().connect_loop(feedback_handle);
381
382 let is_snapshot_leader = config.responsible_for("snapshot_leader");
383
384 // A global view of all outputs that will be snapshot by all workers.
385 let mut all_outputs = vec![];
386 // Table info for tables that need snapshotting. All workers will snapshot all tables,
387 // but each worker will handle a different ctid range within each table.
388 let mut tables_to_snapshot = BTreeMap::new();
389 // A collection of `SourceStatistics` to update for a given Oid. Same info exists in table_info,
390 // but this avoids having to iterate + map each time the statistics are needed.
391 let mut export_statistics = BTreeMap::new();
392 for (table, outputs) in table_info.iter() {
393 for (&output_index, output) in outputs {
394 if *output.resume_upper != [MzOffset::minimum()] {
395 // Already has been snapshotted.
396 continue;
397 }
398 all_outputs.push(output_index);
399 tables_to_snapshot
400 .entry(*table)
401 .or_insert_with(BTreeMap::new)
402 .insert(output_index, output.clone());
403 let statistics = config
404 .statistics
405 .get(&output.export_id)
406 .expect("statistics are initialized")
407 .clone();
408 export_statistics.insert((*table, output_index), statistics);
409 }
410 }
411
412 let (button, transient_errors) = builder.build_fallible(move |caps| {
413 let busy_signal = Arc::clone(&config.busy_signal);
414 Box::pin(SignaledFuture::new(busy_signal, async move {
415 let id = config.id;
416 let worker_id = config.worker_id;
417 let [
418 data_cap_set,
419 rewind_cap_set,
420 slot_ready_cap_set,
421 snapshot_cap_set,
422 definite_error_cap_set,
423 ]: &mut [_; 5] = caps.try_into().unwrap();
424
425 trace!(
426 %id,
427 "timely-{worker_id} initializing table reader \
428 with {} tables to snapshot",
429 tables_to_snapshot.len()
430 );
431
432 let connection_config = connection
433 .connection
434 .config(
435 &config.config.connection_context.secrets_reader,
436 &config.config,
437 InTask::Yes,
438 )
439 .await?;
440
441
442 // The snapshot operator is responsible for creating the replication slot(s).
443 // This first slot is the permanent slot that will be used for reading the replication
444 // stream. A temporary slot is created further on to capture table snapshots.
445 let replication_client = if is_snapshot_leader {
446 let client = connection_config
447 .connect_replication(&config.config.connection_context.ssh_tunnel_manager)
448 .await?;
449 let main_slot = &connection.publication_details.slot;
450
451 tracing::info!(%id, "ensuring replication slot {main_slot} exists");
452 super::ensure_replication_slot(&client, main_slot).await?;
453 Some(client)
454 } else {
455 None
456 };
457 *slot_ready_cap_set = CapabilitySet::new();
458
459 // Nothing needs to be snapshot.
460 if all_outputs.is_empty() {
461 trace!(%id, "no exports to snapshot");
462 // Note we do not emit a `ProgressStatisticsUpdate::Snapshot` update here,
463 // as we do not want to attempt to override the current value with 0. We
464 // just leave it null.
465 return Ok(());
466 }
467
468 // A worker *must* emit a count even if not responsible for snapshotting a table
469 // as statistic summarization will return null if any worker hasn't set a value.
470 // This will also reset snapshot stats for any exports not snapshotting.
471 // If no workers need to snapshot, then avoid emitting these as they will clear
472 // previous stats.
473 for statistics in config.statistics.values() {
474 statistics.set_snapshot_records_known(0);
475 statistics.set_snapshot_records_staged(0);
476 }
477
478 // Collect table OIDs for block count estimation
479 let table_oids: Vec<u32> = tables_to_snapshot.keys().copied().collect();
480
481 // replication client is only set if this worker is the snapshot leader
482 let client = match replication_client {
483 Some(client) => {
484 let tmp_slot = format!("mzsnapshot_{}", uuid::Uuid::new_v4()).replace('-', "");
485 let (snapshot_id, snapshot_lsn) =
486 export_snapshot(&client, &tmp_slot, true).await?;
487
488 // Check PostgreSQL version. Ctid range scans are only efficient on PG >= 14
489 // due to improvements in TID range scan support.
490 let pg_version = get_pg_major_version(&client).await?;
491
492 // Estimate block counts for all tables from pg_class statistics.
493 // This must be done by the leader and broadcasted to ensure all workers
494 // use the same estimates for ctid range partitioning.
495 //
496 // For PostgreSQL < 14, we set all block counts to 0 to fall back to
497 // single-worker-per-table mode, as ctid range scans are not well supported.
498 let table_block_counts = if pg_version >= 14 {
499 estimate_table_block_counts(&client, &table_oids).await?
500 } else {
501 trace!(
502 %id,
503 "timely-{worker_id} PostgreSQL version {pg_version} < 14, \
504 falling back to single-worker-per-table snapshot mode"
505 );
506 // Return all zeros to trigger fallback mode
507 table_oids.iter().map(|&oid| (oid, 0u64)).collect()
508 };
509
510 report_snapshot_size(
511 &client,
512 &tables_to_snapshot,
513 metrics,
514 &config,
515 &export_statistics,
516 )
517 .await?;
518
519 let upstream_info = {
520 // As part of retrieving the schema info, RLS policies are checked to ensure the
521 // snapshot can successfully read the tables. RLS policy errors are treated as
522 // transient, as the customer can simply add the BYPASSRLS to the PG account
523 // used by MZ.
524 match retrieve_schema_info(
525 &connection_config,
526 &config.config.connection_context,
527 &connection.publication,
528 &table_oids)
529 .await
530 {
531 // If the replication stream cannot be obtained in a definite way there is
532 // nothing else to do. These errors are not retractable.
533 Err(PostgresError::PublicationMissing(publication)) => {
534 let err = DefiniteError::PublicationDropped(publication);
535 for (oid, outputs) in tables_to_snapshot.iter() {
536 // Produce a definite error here and then exit to ensure
537 // a missing publication doesn't generate a transient
538 // error and restart this dataflow indefinitely.
539 //
540 // We pick `u64::MAX` as the LSN which will (in
541 // practice) never conflict any previously revealed
542 // portions of the TVC.
543 for output_index in outputs.keys() {
544 let update = (
545 (*oid, *output_index, Err(err.clone().into())),
546 MzOffset::from(u64::MAX),
547 Diff::ONE,
548 );
549 raw_handle.give_fueled(&data_cap_set[0], update).await;
550 }
551 }
552
553 definite_error_handle.give(
554 &definite_error_cap_set[0],
555 ReplicationError::Definite(Rc::new(err)),
556 );
557 return Ok(());
558 },
559 Err(e) => Err(TransientError::from(e))?,
560 Ok(i) => i,
561 }
562 };
563
564 let snapshot_info = SnapshotInfo {
565 snapshot_id,
566 snapshot_lsn,
567 upstream_info,
568 table_block_counts,
569 };
570 trace!(
571 %id,
572 "timely-{worker_id} exporting snapshot info {snapshot_info:?}");
573 snapshot_handle.give(&snapshot_cap_set[0], snapshot_info);
574
575 client
576 }
577 None => {
578 // Only the snapshot leader needs a replication connection.
579 let task_name = format!("timely-{worker_id} PG snapshotter");
580 connection_config
581 .connect(
582 &task_name,
583 &config.config.connection_context.ssh_tunnel_manager,
584 )
585 .await?
586 }
587 };
588
589 // Configure statement_timeout based on param. We want to be able to
590 // override the server value here in case it's set too low,
591 // respective to the size of the data we need to copy.
592 set_statement_timeout(
593 &client,
594 config
595 .config
596 .parameters
597 .pg_source_snapshot_statement_timeout,
598 )
599 .await?;
600
601 let snapshot_info = loop {
602 match snapshot_input.next().await {
603 Some(AsyncEvent::Data(_, mut data)) => {
604 break data.pop().expect("snapshot sent above")
605 }
606 Some(AsyncEvent::Progress(_)) => continue,
607 None => panic!(
608 "feedback closed \
609 before sending snapshot info"
610 ),
611 }
612 };
613 let SnapshotInfo {
614 snapshot_id,
615 snapshot_lsn,
616 table_block_counts,
617 upstream_info,
618 } = snapshot_info;
619
620 // Snapshot leader is already in identified transaction but all other workers need to enter it.
621 if !is_snapshot_leader {
622 trace!(%id, "timely-{worker_id} using snapshot id {snapshot_id:?}");
623 use_snapshot(&client, &snapshot_id).await?;
624 }
625
626 for (&oid, outputs) in tables_to_snapshot.iter() {
627 for (&output_index, info) in outputs.iter() {
628 if let Err(err) = verify_schema(oid, info, &upstream_info) {
629 raw_handle
630 .give_fueled(
631 &data_cap_set[0],
632 (
633 (oid, output_index, Err(err.into())),
634 MzOffset::minimum(),
635 Diff::ONE,
636 ),
637 )
638 .await;
639 continue;
640 }
641
642 // Get estimated block count from the broadcasted table statistics
643 let block_count = table_block_counts.get(&oid).copied().unwrap_or(0);
644
645 // Calculate this worker's ctid range based on estimated blocks.
646 // When estimated_blocks is 0 (PG < 14 or empty table), fall back to
647 // single-worker mode using responsible_for to pick the worker.
648 let Some(ctid_range) = worker_ctid_range(&config, block_count, oid) else {
649 // This worker has no work for this table (more workers than blocks)
650 trace!(
651 %id,
652 "timely-{worker_id} no ctid range assigned for table {:?}({oid})",
653 info.desc.name
654 );
655 continue;
656 };
657
658 trace!(
659 %id,
660 "timely-{worker_id} snapshotting table {:?}({oid}) output {output_index} \
661 @ {snapshot_lsn} with ctid range {:?}",
662 info.desc.name,
663 ctid_range
664 );
665
666 // To handle quoted/keyword names, we can use `Ident`'s AST printing, which
667 // emulate's PG's rules for name formatting.
668 let namespace = Ident::new_unchecked(&info.desc.namespace)
669 .to_ast_string_stable();
670 let table = Ident::new_unchecked(&info.desc.name)
671 .to_ast_string_stable();
672 let column_list = info
673 .desc
674 .columns
675 .iter()
676 .map(|c| Ident::new_unchecked(&c.name).to_ast_string_stable())
677 .join(",");
678
679
680 let ctid_filter = match ctid_range.end_block {
681 Some(end) => format!(
682 "WHERE ctid >= '({},0)'::tid AND ctid < '({},0)'::tid",
683 ctid_range.start_block, end
684 ),
685 None => format!("WHERE ctid >= '({},0)'::tid", ctid_range.start_block),
686 };
687 let query = format!(
688 "COPY (SELECT {column_list} FROM {namespace}.{table} {ctid_filter}) \
689 TO STDOUT (FORMAT TEXT, DELIMITER '\t')"
690 );
691 let mut stream = pin!(client.copy_out_simple(&query).await?);
692
693 let mut snapshot_staged = 0;
694 let mut update =
695 ((oid, output_index, Ok(vec![])), MzOffset::minimum(), Diff::ONE);
696 while let Some(bytes) = stream.try_next().await? {
697 let data = update.0 .2.as_mut().unwrap();
698 data.clear();
699 data.extend_from_slice(&bytes);
700 raw_handle.give_fueled(&data_cap_set[0], &update).await;
701 snapshot_staged += 1;
702 if snapshot_staged % 1000 == 0 {
703 let stat = &export_statistics[&(oid, output_index)];
704 stat.set_snapshot_records_staged(snapshot_staged);
705 }
706 }
707 // final update for snapshot_staged, using the staged
708 // values as the total is an estimate
709 let stat = &export_statistics[&(oid, output_index)];
710 stat.set_snapshot_records_staged(snapshot_staged);
711 }
712 }
713
714 // We are done with the snapshot so now we will emit rewind requests. It is important
715 // that this happens after the snapshot has finished because this is what unblocks the
716 // replication operator and we want this to happen serially. It might seem like a good
717 // idea to read the replication stream concurrently with the snapshot but it actually
718 // leads to a lot of data being staged for the future, which needlessly consumed memory
719 // in the cluster.
720 //
721 // Since all workers now snapshot all tables (each with different ctid ranges), we only
722 // emit rewind requests from the worker responsible for each output to avoid duplicates.
723 for (&oid, output) in tables_to_snapshot.iter() {
724 for (output_index, info) in output {
725 // Only emit rewind request from one worker per output
726 if !config.responsible_for((oid, *output_index)) {
727 continue;
728 }
729 trace!(%id, "timely-{worker_id} producing rewind request for table {} output {output_index}", info.desc.name);
730 let req = RewindRequest { output_index: *output_index, snapshot_lsn };
731 rewinds_handle.give(&rewind_cap_set[0], req);
732 }
733 }
734 *rewind_cap_set = CapabilitySet::new();
735
736 // Failure scenario after we have produced the snapshot, but before a successful COMMIT
737 fail::fail_point!("pg_snapshot_failure", |_| Err(
738 TransientError::SyntheticError
739 ));
740
741 // The exporting worker should wait for all the other workers to commit before dropping
742 // its client since this is what holds the exported transaction alive.
743 if is_snapshot_leader {
744 trace!(%id, "timely-{worker_id} waiting for all workers to finish");
745 *snapshot_cap_set = CapabilitySet::new();
746 while snapshot_input.next().await.is_some() {}
747 trace!(%id, "timely-{worker_id} (leader) comitting COPY transaction");
748 client.simple_query("COMMIT").await?;
749 } else {
750 trace!(%id, "timely-{worker_id} comitting COPY transaction");
751 client.simple_query("COMMIT").await?;
752 *snapshot_cap_set = CapabilitySet::new();
753 }
754 drop(client);
755 Ok(())
756 }))
757 });
758
759 // We now decode the COPY protocol and apply the cast expressions
760 let mut text_row = Row::default();
761 let mut final_row = Row::default();
762 let mut datum_vec = DatumVec::new();
763 let snapshot_updates = raw_data
764 .unary(Pipeline, "PgCastSnapshotRows", |_, _| {
765 move |input, output| {
766 input.for_each_time(|time, data| {
767 let mut session = output.session(&time);
768 for ((oid, output_index, event), time, diff) in
769 data.flat_map(|data| data.drain())
770 {
771 let output = &table_info
772 .get(oid)
773 .and_then(|outputs| outputs.get(output_index))
774 .expect("table_info contains all outputs");
775
776 let event = event
777 .as_ref()
778 .map_err(|e: &DataflowError| e.clone())
779 .and_then(|bytes| {
780 decode_copy_row(bytes, output.casts.len(), &mut text_row)?;
781 let datums = datum_vec.borrow_with(&text_row);
782 super::cast_row(&output.casts, &datums, &mut final_row)?;
783 Ok(SourceMessage {
784 key: Row::default(),
785 value: final_row.clone(),
786 metadata: Row::default(),
787 })
788 });
789
790 session.give(((*output_index, event), *time, *diff));
791 }
792 });
793 }
794 })
795 .as_collection();
796
797 let errors = definite_errors.concat(transient_errors.map(ReplicationError::from));
798
799 (
800 snapshot_updates,
801 rewinds,
802 slot_ready,
803 errors,
804 button.press_on_drop(),
805 )
806}
807
808/// Starts a read-only transaction on the SQL session of `client` at a consistent LSN point by
809/// creating a replication slot. Returns a snapshot identifier that can be imported in
810/// other SQL session and the LSN of the consistent point.
811async fn export_snapshot(
812 client: &Client,
813 slot: &str,
814 temporary: bool,
815) -> Result<(String, MzOffset), TransientError> {
816 match export_snapshot_inner(client, slot, temporary).await {
817 Ok(ok) => Ok(ok),
818 Err(err) => {
819 // We don't want to leave the client inside a failed tx
820 client.simple_query("ROLLBACK;").await?;
821 Err(err)
822 }
823 }
824}
825
826async fn export_snapshot_inner(
827 client: &Client,
828 slot: &str,
829 temporary: bool,
830) -> Result<(String, MzOffset), TransientError> {
831 client
832 .simple_query("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;")
833 .await?;
834
835 // Note: Using unchecked here is okay because we're using it in a SQL query.
836 let slot = Ident::new_unchecked(slot).to_ast_string_simple();
837 let temporary_str = if temporary { " TEMPORARY" } else { "" };
838 let query =
839 format!("CREATE_REPLICATION_SLOT {slot}{temporary_str} LOGICAL \"pgoutput\" USE_SNAPSHOT");
840 let row = match simple_query_opt(client, &query).await {
841 Ok(row) => Ok(row.unwrap()),
842 Err(PostgresError::Postgres(err)) if err.code() == Some(&SqlState::DUPLICATE_OBJECT) => {
843 return Err(TransientError::ReplicationSlotAlreadyExists);
844 }
845 Err(err) => Err(err),
846 }?;
847
848 // When creating a replication slot postgres returns the LSN of its consistent point, which is
849 // the LSN that must be passed to `START_REPLICATION` to cleanly transition from the snapshot
850 // phase to the replication phase. `START_REPLICATION` includes all transactions that commit at
851 // LSNs *greater than or equal* to the passed LSN. Therefore the snapshot phase must happen at
852 // the greatest LSN that is not beyond the consistent point. That LSN is `consistent_point - 1`
853 let consistent_point: PgLsn = row.get("consistent_point").unwrap().parse().unwrap();
854 let consistent_point = u64::from(consistent_point)
855 .checked_sub(1)
856 .expect("consistent point is always non-zero");
857
858 let row = simple_query_opt(client, "SELECT pg_export_snapshot();")
859 .await?
860 .unwrap();
861 let snapshot = row.get("pg_export_snapshot").unwrap().to_owned();
862
863 Ok((snapshot, MzOffset::from(consistent_point)))
864}
865
866/// Starts a read-only transaction on the SQL session of `client` at a the consistent LSN point of
867/// `snapshot`.
868async fn use_snapshot(client: &Client, snapshot: &str) -> Result<(), TransientError> {
869 client
870 .simple_query("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;")
871 .await?;
872 let query = format!(
873 "SET TRANSACTION SNAPSHOT {};",
874 escaped_string_literal(snapshot)
875 );
876 client.simple_query(&query).await?;
877 Ok(())
878}
879
880async fn set_statement_timeout(client: &Client, timeout: Duration) -> Result<(), TransientError> {
881 // Value is known to accept milliseconds w/o units.
882 // https://www.postgresql.org/docs/current/runtime-config-client.html
883 client
884 .simple_query(&format!("SET statement_timeout = {}", timeout.as_millis()))
885 .await?;
886 Ok(())
887}
888
889/// Decodes a row of `col_len` columns obtained from a text encoded COPY query into `row`.
890fn decode_copy_row(data: &[u8], col_len: usize, row: &mut Row) -> Result<(), DefiniteError> {
891 let mut packer = row.packer();
892 let row_parser = mz_pgcopy::CopyTextFormatParser::new(data, b'\t', "\\N");
893 let mut column_iter = row_parser.iter_raw_truncating(col_len);
894 for _ in 0..col_len {
895 let value = match column_iter.next() {
896 Some(Ok(value)) => value,
897 Some(Err(_)) => return Err(DefiniteError::InvalidCopyInput),
898 None => return Err(DefiniteError::MissingColumn),
899 };
900 let datum = value.map(super::decode_utf8_text).transpose()?;
901 packer.push(datum.unwrap_or(Datum::Null));
902 }
903 Ok(())
904}
905
906/// Record the sizes of the tables being snapshotted in `PgSnapshotMetrics` and emit snapshot statistics for each export.
907async fn report_snapshot_size(
908 client: &Client,
909 tables_to_snapshot: &BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
910 metrics: PgSnapshotMetrics,
911 config: &RawSourceCreationConfig,
912 export_statistics: &BTreeMap<(u32, usize), SourceStatistics>,
913) -> Result<(), anyhow::Error> {
914 // TODO(guswynn): delete unused configs
915 let snapshot_config = config.config.parameters.pg_snapshot_config;
916
917 for (&oid, outputs) in tables_to_snapshot {
918 // Use the first output's desc to make the table name since it is the same for all outputs
919 let Some((_, info)) = outputs.first_key_value() else {
920 continue;
921 };
922 let table = format!(
923 "{}.{}",
924 Ident::new_unchecked(info.desc.namespace.clone()).to_ast_string_simple(),
925 Ident::new_unchecked(info.desc.name.clone()).to_ast_string_simple()
926 );
927 let stats = collect_table_statistics(
928 client,
929 snapshot_config,
930 &info.desc.namespace,
931 &info.desc.name,
932 info.desc.oid,
933 )
934 .await?;
935 metrics.record_table_count_latency(table, stats.count_latency);
936 for &output_index in outputs.keys() {
937 export_statistics[&(oid, output_index)].set_snapshot_records_known(stats.count);
938 export_statistics[&(oid, output_index)].set_snapshot_records_staged(0);
939 }
940 }
941 Ok(())
942}
943
944#[derive(Default)]
945struct TableStatistics {
946 count: u64,
947 count_latency: f64,
948}
949
950async fn collect_table_statistics(
951 client: &Client,
952 config: PgSourceSnapshotConfig,
953 namespace: &str,
954 table_name: &str,
955 oid: u32,
956) -> Result<TableStatistics, anyhow::Error> {
957 use mz_ore::metrics::MetricsFutureExt;
958 let mut stats = TableStatistics::default();
959 let table = format!(
960 "{}.{}",
961 Ident::new_unchecked(namespace).to_ast_string_simple(),
962 Ident::new_unchecked(table_name).to_ast_string_simple()
963 );
964
965 let estimate_row = simple_query_opt(
966 client,
967 &format!("SELECT reltuples::bigint AS estimate_count FROM pg_class WHERE oid = '{oid}'"),
968 )
969 .wall_time()
970 .set_at(&mut stats.count_latency)
971 .await?;
972 stats.count = match estimate_row {
973 Some(row) => row.get("estimate_count").unwrap().parse().unwrap_or(0),
974 None => bail!("failed to get estimate count for {table}"),
975 };
976
977 // If the estimate is low enough we can attempt to get an exact count. Note that not yet
978 // vacuumed tables will report zero rows here and there is a possibility that they are very
979 // large. We accept this risk and we offer the feature flag as an escape hatch if it becomes
980 // problematic.
981 if config.collect_strict_count && stats.count < 1_000_000 {
982 let count_row = simple_query_opt(client, &format!("SELECT count(*) as count from {table}"))
983 .wall_time()
984 .set_at(&mut stats.count_latency)
985 .await?;
986 stats.count = match count_row {
987 Some(row) => row.get("count").unwrap().parse().unwrap(),
988 None => bail!("failed to get count for {table}"),
989 }
990 }
991
992 Ok(stats)
993}
994
995/// Validates that there are no blocking RLS polcicies on the tables and retrieves table schemas
996/// for the given publication.
997async fn retrieve_schema_info(
998 connection_config: &Config,
999 connection_context: &ConnectionContext,
1000 publication: &str,
1001 table_oids: &[Oid],
1002) -> Result<BTreeMap<u32, PostgresTableDesc>, PostgresError> {
1003 let schema_client = connection_config
1004 .connect(
1005 "snapshot schema info",
1006 &connection_context.ssh_tunnel_manager,
1007 )
1008 .await?;
1009 mz_postgres_util::validate_no_rls_policies(&schema_client, table_oids).await?;
1010 mz_postgres_util::publication_info(&schema_client, publication, Some(table_oids)).await
1011}