Skip to main content

mz_storage/source/postgres/
snapshot.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Renders the table snapshot side of the [`PostgresSourceConnection`] ingestion dataflow.
11//!
12//! # Snapshot reading
13//!
14//! Depending on the resumption LSNs the table reader decides which tables need to be snapshotted.
15//! Each table is partitioned across all workers using PostgreSQL's `ctid` (tuple identifier)
16//! column, which identifies the physical location of each row. This allows parallel snapshotting
17//! of large tables across all available workers.
18//!
19//! There are a few subtle points about this operation, described in the following sections.
20//!
21//! ## Consistent LSN point for snapshot transactions
22//!
23//! Given that all our ingestion is based on correctly timestamping updates with the LSN they
24//! happened at it is important that we run the `COPY` query at a specific LSN point that is
25//! relatable with the LSN numbers we receive from the replication stream. Such point does not
26//! necessarily exist for a normal SQL transaction. To achieve this we must force postgres to
27//! produce a consistent point and let us know of the LSN number of that by creating a replication
28//! slot as the first statement in a transaction.
29//!
30//! This is a temporary dummy slot that is only used to put our snapshot transaction on a
31//! consistent LSN point. Unfortunately no lighterweight method exists for doing this. See this
32//! [postgres thread] for more details.
33//!
34//! One might wonder why we don't use the actual real slot to provide us with the snapshot point
35//! which would automatically be at the correct LSN. The answer is that it's possible that we crash
36//! and restart after having already created the slot but before having finished the snapshot. In
37//! that case the restarting process will have lost its opportunity to run queries at the slot's
38//! consistent point as that opportunity only exists in the ephemeral transaction that created the
39//! slot and that is long gone. Additionally there are good reasons of why we'd like to move the
40//! slot creation much earlier, e.g during purification, in which case the slot will always be
41//! pre-created.
42//!
43//! [postgres thread]: https://www.postgresql.org/message-id/flat/CAMN0T-vzzNy6TV1Jvh4xzNQdAvCLBQK_kh6_U7kAXgGU3ZFg-Q%40mail.gmail.com
44//!
45//! ## Reusing the consistent point among all workers
46//!
47//! Creating replication slots is potentially expensive so the code makes is such that all workers
48//! cooperate and reuse one consistent snapshot among them. In order to do so we make use the
49//! "export transaction" feature of postgres. This feature allows one SQL session to create an
50//! identifier for the transaction (a string identifier) it is currently in, which can be used by
51//! other sessions to enter the same "snapshot".
52//!
53//! We accomplish this by picking one worker at random to function as the transaction leader. The
54//! transaction leader is responsible for starting a SQL session, creating a temporary replication
55//! slot in a transaction, exporting the transaction id, and broadcasting the transaction
56//! information to all other workers via a broadcasted feedback edge.
57//!
58//! During this phase the follower workers are simply waiting to hear on the feedback edge,
59//! effectively synchronizing with the leader. Once all workers have received the snapshot
60//! information they can all start to perform their assigned COPY queries.
61//!
62//! The leader and follower steps described above are accomplished by the [`export_snapshot`] and
63//! [`use_snapshot`] functions respectively.
64//!
65//! ## Coordinated transaction COMMIT
66//!
67//! When follower workers are done with snapshotting they commit their transaction, close their
68//! session, and then drop their snapshot feedback capability. When the leader worker is done with
69//! snapshotting it drops its snapshot feedback capability and waits until it observes the
70//! snapshot input advancing to the empty frontier. This allows the leader to COMMIT its
71//! transaction last, which is the transaction that exported the snapshot.
72//!
73//! It's unclear if this is strictly necessary, but having the frontiers made it easy enough that I
74//! added the synchronization.
75//!
76//! ## Snapshot rewinding
77//!
78//! Ingestion dataflows must produce definite data, including the snapshot. What this means
79//! practically is that whenever we deem it necessary to snapshot a table we must do so at the same
80//! LSN. However, the method for running a transaction described above doesn't let us choose the
81//! LSN, it could be an LSN in the future chosen by PostgresSQL while it creates the temporary
82//! replication slot.
83//!
84//! The definition of differential collections states that a collection at some time `t_snapshot`
85//! is defined to be the accumulation of all updates that happen at `t <= t_snapshot`, where `<=`
86//! is the partial order. In this case we are faced with the problem of knowing the state of a
87//! table at `t_snapshot` but actually wanting to know the snapshot at `t_slot <= t_snapshot`.
88//!
89//! From the definition we can see that the snapshot at `t_slot` is related to the snapshot at
90//! `t_snapshot` with the following equations:
91//!
92//!```text
93//! sum(update: t <= t_snapshot) = sum(update: t <= t_slot) + sum(update: t_slot <= t <= t_snapshot)
94//!                                         |
95//!                                         V
96//! sum(update: t <= t_slot) = sum(update: t <= snapshot) - sum(update: t_slot <= t <= t_snapshot)
97//! ```
98//!
99//! Therefore, if we manage to recover the `sum(update: t_slot <= t <= t_snapshot)` term we will be
100//! able to "rewind" the snapshot we obtained at `t_snapshot` to `t_slot` by emitting all updates
101//! that happen between these two points with their diffs negated.
102//!
103//! It turns out that this term is exactly what the main replication slot provides us with and we
104//! can rewind snapshot at arbitrary points! In order to do this the snapshot dataflow emits rewind
105//! requests to the replication reader which informs it that a certain range of updates must be
106//! emitted at LSN 0 (by convention) with their diffs negated. These negated diffs are consolidated
107//! with the diffs taken at `t_snapshot` that were also emitted at LSN 0 (by convention) and we end
108//! up with a TVC that at LSN 0 contains the snapshot at `t_slot`.
109//!
110//! # Parallel table snapshotting with ctid ranges
111//!
112//! Each table is partitioned across workers using PostgreSQL's `ctid` column. The `ctid` is a
113//! tuple identifier of the form `(block_number, tuple_index)` that represents the physical
114//! location of a row on disk. By partitioning the ctid range, each worker can independently
115//! fetch a portion of the table.
116//!
117//! The partitioning works as follows:
118//! 1. The snapshot leader queries `pg_class.relpages` to estimate the number of blocks for each
119//!    table. This is much faster than querying `max(ctid)` which would require a sequential scan.
120//! 2. The leader broadcasts the block count estimates along with the snapshot transaction ID
121//!    to all workers, ensuring all workers use consistent estimates for partitioning.
122//! 3. Each worker calculates its assigned block range and fetches rows using a `COPY` query
123//!    with a `SELECT` that filters by `ctid >= start AND ctid < end`.
124//! 4. The last worker uses an open-ended range (`ctid >= start`) to capture any rows beyond
125//!    the estimated block count (handles cases where statistics are stale or table has grown).
126//!
127//! This approach efficiently parallelizes large table snapshots while maintaining the benefits
128//! of the `COPY` protocol for bulk data transfer.
129//!
130//! ## PostgreSQL version requirements
131//!
132//! Ctid range scans are only efficient on PostgreSQL >= 14 due to TID range scan optimizations
133//! introduced in that version. For older PostgreSQL versions, the snapshot falls back to the
134//! single-worker-per-table mode where each table is assigned to one worker based on consistent
135//! hashing. This is implemented by having the leader broadcast all-zero block counts when
136//! PostgreSQL version < 14.
137//!
138//! # Snapshot decoding
139//!
140//! Each worker fetches its ctid range directly and decodes the COPY stream locally.
141//!
142//! ```text
143//!                 ╭──────────────────╮
144//!    ┏━━━━━━━━━━━━v━┓                │ exported
145//!    ┃    table     ┃   ╭─────────╮  │ snapshot id
146//!    ┃   readers    ┠─>─┤broadcast├──╯
147//!    ┃  (parallel)  ┃   ╰─────────╯
148//!    ┗━┯━━━━━━━━━━┯━┛
149//!   raw│          │
150//!  COPY│          │
151//!  data│          │
152//! ┏━━━━┷━━━━┓     │
153//! ┃  COPY   ┃     │
154//! ┃ decoder ┃     │
155//! ┗━━━━┯━━━━┛     │
156//!      │ snapshot │rewind
157//!      │ updates  │requests
158//!      v          v
159//! ```
160
161use std::collections::BTreeMap;
162use std::convert::Infallible;
163use std::pin::pin;
164use std::rc::Rc;
165use std::sync::Arc;
166use std::time::Duration;
167
168use anyhow::bail;
169use differential_dataflow::AsCollection;
170use futures::{StreamExt as _, TryStreamExt};
171use itertools::Itertools;
172use mz_ore::cast::CastFrom;
173use mz_ore::future::InTask;
174use mz_postgres_util::desc::PostgresTableDesc;
175use mz_postgres_util::schemas::get_pg_major_version;
176use mz_postgres_util::{Client, Config, PostgresError, simple_query_opt};
177use mz_repr::{Datum, DatumVec, Diff, Row};
178use mz_sql_parser::ast::{
179    Ident,
180    display::{AstDisplay, escaped_string_literal},
181};
182use mz_storage_types::connections::ConnectionContext;
183use mz_storage_types::errors::DataflowError;
184use mz_storage_types::parameters::PgSourceSnapshotConfig;
185use mz_storage_types::sources::{MzOffset, PostgresSourceConnection};
186use mz_timely_util::builder_async::{
187    Event as AsyncEvent, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
188};
189use timely::container::CapacityContainerBuilder;
190use timely::dataflow::channels::pact::Pipeline;
191use timely::dataflow::operators::core::Map;
192use timely::dataflow::operators::vec::Broadcast;
193use timely::dataflow::operators::{CapabilitySet, Concat, ConnectLoop, Feedback, Operator};
194use timely::dataflow::{Scope, StreamVec};
195use timely::progress::Timestamp;
196use tokio_postgres::error::SqlState;
197use tokio_postgres::types::{Oid, PgLsn};
198use tracing::trace;
199
200use crate::metrics::source::postgres::PgSnapshotMetrics;
201use crate::source::RawSourceCreationConfig;
202use crate::source::postgres::replication::RewindRequest;
203use crate::source::postgres::{
204    DefiniteError, ReplicationError, SourceOutputInfo, TransientError, verify_schema,
205};
206use crate::source::types::{FuelSize, SignaledFuture, SourceMessage, StackedCollection};
207use crate::statistics::SourceStatistics;
208
209/// Information broadcasted from the snapshot leader to all workers.
210/// This includes the transaction snapshot ID, LSN, and estimated block counts for each table.
211#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
212struct SnapshotInfo {
213    /// The exported transaction snapshot identifier.
214    snapshot_id: String,
215    /// The LSN at which the snapshot was taken.
216    snapshot_lsn: MzOffset,
217    /// Estimated number of blocks (pages) for each table, keyed by OID.
218    /// This is derived from `pg_class.relpages` and used to partition ctid ranges.
219    table_block_counts: BTreeMap<u32, u64>,
220    /// The current upstream schema of each table.
221    upstream_info: BTreeMap<u32, PostgresTableDesc>,
222}
223
224/// Represents a ctid range that a worker should snapshot.
225/// The range is [start_block, end_block) where end_block is optional (None means unbounded).
226#[derive(Debug)]
227struct CtidRange {
228    /// The starting block number (inclusive).
229    start_block: u64,
230    /// The ending block number (exclusive). None means unbounded (open-ended range).
231    end_block: Option<u64>,
232}
233
234/// Calculate the ctid range for a given worker based on estimated block count.
235///
236/// The table is partitioned by block number across all workers. Each worker gets a contiguous
237/// range of blocks. The last worker gets an open-ended range to handle any rows beyond the
238/// estimated block count.
239///
240/// When `estimated_blocks` is 0 (either because statistics are unavailable, the table appears
241/// empty, or PostgreSQL version < 14 doesn't support ctid range scans), the table is assigned
242/// to a single worker determined by `config.responsible_for(oid)` and that worker scans the
243/// full table.
244///
245/// Returns None if this worker has no work to do.
246fn worker_ctid_range(
247    config: &RawSourceCreationConfig,
248    estimated_blocks: u64,
249    oid: u32,
250) -> Option<CtidRange> {
251    // If estimated_blocks is 0, fall back to single-worker mode for this table.
252    // This handles:
253    // - PostgreSQL < 14 (ctid range scans not supported)
254    // - Tables that appear empty in statistics
255    // - Tables with stale/missing statistics
256    // The responsible worker scans the full table with an open-ended range.
257    if estimated_blocks == 0 {
258        let fallback = if config.responsible_for(oid) {
259            Some(CtidRange {
260                start_block: 0,
261                end_block: None,
262            })
263        } else {
264            None
265        };
266        return fallback;
267    }
268
269    let worker_id = u64::cast_from(config.worker_id);
270    let worker_count = u64::cast_from(config.worker_count);
271
272    // If there are more workers than blocks, only assign work to workers with id < estimated_blocks
273    // The last assigned worker still gets an open range.
274    let effective_worker_count = std::cmp::min(worker_count, estimated_blocks);
275
276    if worker_id >= effective_worker_count {
277        // This worker has no work to do
278        return None;
279    }
280
281    // Calculate start block for this worker (integer division distributes blocks evenly)
282    let start_block = worker_id * estimated_blocks / effective_worker_count;
283
284    // The last effective worker gets an open-ended range
285    let is_last_effective_worker = worker_id == effective_worker_count - 1;
286    if is_last_effective_worker {
287        Some(CtidRange {
288            start_block,
289            end_block: None,
290        })
291    } else {
292        let end_block = (worker_id + 1) * estimated_blocks / effective_worker_count;
293        Some(CtidRange {
294            start_block,
295            end_block: Some(end_block),
296        })
297    }
298}
299
300/// Estimate the number of blocks for each table from pg_class statistics.
301/// This is used to partition ctid ranges across workers.
302async fn estimate_table_block_counts(
303    client: &Client,
304    table_oids: &[u32],
305) -> Result<BTreeMap<u32, u64>, TransientError> {
306    if table_oids.is_empty() {
307        return Ok(BTreeMap::new());
308    }
309
310    // Query relpages for all tables at once
311    let oid_list = table_oids
312        .iter()
313        .map(|oid| oid.to_string())
314        .collect::<Vec<_>>()
315        .join(",");
316    let query = format!(
317        "SELECT oid, relpages FROM pg_class WHERE oid IN ({})",
318        oid_list
319    );
320
321    let mut block_counts = BTreeMap::new();
322    // Initialize all tables with 0 blocks (in case they're not in pg_class)
323    for &oid in table_oids {
324        block_counts.insert(oid, 0);
325    }
326
327    // Execute the query and collect results
328    let rows = client.simple_query(&query).await?;
329    for msg in rows {
330        if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
331            let oid: u32 = row.get("oid").unwrap().parse().unwrap();
332            let relpages: i64 = row.get("relpages").unwrap().parse().unwrap_or(0);
333            // relpages can be -1 if never analyzed, treat as 0
334            let relpages = std::cmp::max(0, relpages).try_into().unwrap();
335            block_counts.insert(oid, relpages);
336        }
337    }
338
339    Ok(block_counts)
340}
341
342/// Renders the snapshot dataflow. See the module documentation for more information.
343pub(crate) fn render<'scope>(
344    scope: Scope<'scope, MzOffset>,
345    config: RawSourceCreationConfig,
346    connection: PostgresSourceConnection,
347    table_info: BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
348    metrics: PgSnapshotMetrics,
349) -> (
350    StackedCollection<'scope, MzOffset, (usize, Result<SourceMessage, DataflowError>)>,
351    StreamVec<'scope, MzOffset, RewindRequest>,
352    StreamVec<'scope, MzOffset, Infallible>,
353    StreamVec<'scope, MzOffset, ReplicationError>,
354    PressOnDropButton,
355) {
356    let op_name = format!("TableReader({})", config.id);
357    let mut builder = AsyncOperatorBuilder::new(op_name, scope.clone());
358
359    let (feedback_handle, feedback_data) = scope.feedback(Default::default());
360
361    let (raw_handle, raw_data) = builder.new_output();
362    let (rewinds_handle, rewinds) = builder.new_output::<CapacityContainerBuilder<_>>();
363    // This output is used to signal to the replication operator that the replication slot has been
364    // created. With the current state of execution serialization there isn't a lot of benefit
365    // of splitting the snapshot and replication phases into two operators.
366    // TODO(petrosagg): merge the two operators in one (while still maintaining separation as
367    // functions/modules)
368    let (_, slot_ready) = builder.new_output::<CapacityContainerBuilder<_>>();
369    let (snapshot_handle, snapshot) = builder.new_output::<CapacityContainerBuilder<_>>();
370    let (definite_error_handle, definite_errors) =
371        builder.new_output::<CapacityContainerBuilder<_>>();
372
373    // This operator needs to broadcast data to itself in order to synchronize the transaction
374    // snapshot. However, none of the feedback capabilities result in output messages and for the
375    // feedback edge specifically having a default conncetion would result in a loop.
376    let mut snapshot_input = builder.new_disconnected_input(feedback_data, Pipeline);
377
378    // The export id must be sent to all workers, so we broadcast the feedback connection
379    snapshot.broadcast().connect_loop(feedback_handle);
380
381    let is_snapshot_leader = config.responsible_for("snapshot_leader");
382
383    // A global view of all outputs that will be snapshot by all workers.
384    let mut all_outputs = vec![];
385    // Table info for tables that need snapshotting. All workers will snapshot all tables,
386    // but each worker will handle a different ctid range within each table.
387    let mut tables_to_snapshot = BTreeMap::new();
388    // A collection of `SourceStatistics` to update for a given Oid. Same info exists in table_info,
389    // but this avoids having to iterate + map each time the statistics are needed.
390    let mut export_statistics = BTreeMap::new();
391    for (table, outputs) in table_info.iter() {
392        for (&output_index, output) in outputs {
393            if *output.resume_upper != [MzOffset::minimum()] {
394                // Already has been snapshotted.
395                continue;
396            }
397            all_outputs.push(output_index);
398            tables_to_snapshot
399                .entry(*table)
400                .or_insert_with(BTreeMap::new)
401                .insert(output_index, output.clone());
402            let statistics = config
403                .statistics
404                .get(&output.export_id)
405                .expect("statistics are initialized")
406                .clone();
407            export_statistics.insert((*table, output_index), statistics);
408        }
409    }
410
411    let (button, transient_errors) = builder.build_fallible(move |caps| {
412        let busy_signal = Arc::clone(&config.busy_signal);
413        Box::pin(SignaledFuture::new(busy_signal, async move {
414            let id = config.id;
415            let worker_id = config.worker_id;
416            let [
417                data_cap_set,
418                rewind_cap_set,
419                slot_ready_cap_set,
420                snapshot_cap_set,
421                definite_error_cap_set,
422            ]: &mut [_; 5] = caps.try_into().unwrap();
423
424            trace!(
425                %id,
426                "timely-{worker_id} initializing table reader \
427                    with {} tables to snapshot",
428                    tables_to_snapshot.len()
429            );
430
431            let connection_config = connection
432                .connection
433                .config(
434                    &config.config.connection_context.secrets_reader,
435                    &config.config,
436                    InTask::Yes,
437                )
438                .await?;
439
440
441            // The snapshot operator is responsible for creating the replication slot(s).
442            // This first slot is the permanent slot that will be used for reading the replication
443            // stream.  A temporary slot is created further on to capture table snapshots.
444            let replication_client = if is_snapshot_leader {
445                let client = connection_config
446                    .connect_replication(&config.config.connection_context.ssh_tunnel_manager)
447                    .await?;
448                let main_slot = &connection.publication_details.slot;
449
450                tracing::info!(%id, "ensuring replication slot {main_slot} exists");
451                super::ensure_replication_slot(&client, main_slot).await?;
452                Some(client)
453            } else {
454                None
455            };
456            *slot_ready_cap_set = CapabilitySet::new();
457
458            // Nothing needs to be snapshot.
459            if all_outputs.is_empty() {
460                trace!(%id, "no exports to snapshot");
461                // Note we do not emit a `ProgressStatisticsUpdate::Snapshot` update here,
462                // as we do not want to attempt to override the current value with 0. We
463                // just leave it null.
464                return Ok(());
465            }
466
467            // A worker *must* emit a count even if not responsible for snapshotting a table
468            // as statistic summarization will return null if any worker hasn't set a value.
469            // This will also reset snapshot stats for any exports not snapshotting.
470            // If no workers need to snapshot, then avoid emitting these as they will clear
471            // previous stats.
472            for statistics in config.statistics.values() {
473                statistics.set_snapshot_records_known(0);
474                statistics.set_snapshot_records_staged(0);
475            }
476
477            // Collect table OIDs for block count estimation
478            let table_oids: Vec<u32> = tables_to_snapshot.keys().copied().collect();
479
480            // replication client is only set if this worker is the snapshot leader
481            let client = match replication_client {
482                Some(client) => {
483                    let tmp_slot = format!("mzsnapshot_{}", uuid::Uuid::new_v4()).replace('-', "");
484                    let (snapshot_id, snapshot_lsn) =
485                        export_snapshot(&client, &tmp_slot, true).await?;
486
487                    // Check PostgreSQL version. Ctid range scans are only efficient on PG >= 14
488                    // due to improvements in TID range scan support.
489                    let pg_version = get_pg_major_version(&client).await?;
490
491                    // Estimate block counts for all tables from pg_class statistics.
492                    // This must be done by the leader and broadcasted to ensure all workers
493                    // use the same estimates for ctid range partitioning.
494                    //
495                    // For PostgreSQL < 14, we set all block counts to 0 to fall back to
496                    // single-worker-per-table mode, as ctid range scans are not well supported.
497                    let table_block_counts = if pg_version >= 14 {
498                        estimate_table_block_counts(&client, &table_oids).await?
499                    } else {
500                        trace!(
501                            %id,
502                            "timely-{worker_id} PostgreSQL version {pg_version} < 14, \
503                             falling back to single-worker-per-table snapshot mode"
504                        );
505                        // Return all zeros to trigger fallback mode
506                        table_oids.iter().map(|&oid| (oid, 0u64)).collect()
507                    };
508
509                    report_snapshot_size(
510                        &client,
511                        &tables_to_snapshot,
512                        metrics,
513                        &config,
514                        &export_statistics,
515                    )
516                    .await?;
517
518                    let upstream_info = {
519                        // As part of retrieving the schema info, RLS policies are checked to ensure the
520                        // snapshot can successfully read the tables. RLS policy errors are treated as
521                        // transient, as the customer can simply add the BYPASSRLS to the PG account
522                        // used by MZ.
523                        match retrieve_schema_info(
524                            &connection_config,
525                            &config.config.connection_context,
526                            &connection.publication,
527                            &table_oids)
528                            .await
529                        {
530                            // If the replication stream cannot be obtained in a definite way there is
531                            // nothing else to do. These errors are not retractable.
532                            Err(PostgresError::PublicationMissing(publication)) => {
533                                let err = DefiniteError::PublicationDropped(publication);
534                                for (oid, outputs) in tables_to_snapshot.iter() {
535                                    // Produce a definite error here and then exit to ensure
536                                    // a missing publication doesn't generate a transient
537                                    // error and restart this dataflow indefinitely.
538                                    //
539                                    // We pick `u64::MAX` as the LSN which will (in
540                                    // practice) never conflict any previously revealed
541                                    // portions of the TVC.
542                                    for output_index in outputs.keys() {
543                                        let update = (
544                                            (*oid, *output_index, Err(err.clone().into())),
545                                            MzOffset::from(u64::MAX),
546                                            Diff::ONE,
547                                        );
548                                        let size = update.fuel_size();
549                                        raw_handle
550                                            .give_fueled(&data_cap_set[0], update, size)
551                                            .await;
552                                    }
553                                }
554
555                                definite_error_handle.give(
556                                    &definite_error_cap_set[0],
557                                    ReplicationError::Definite(Rc::new(err)),
558                                );
559                                return Ok(());
560                            },
561                            Err(e) => Err(TransientError::from(e))?,
562                            Ok(i) => i,
563                        }
564                    };
565
566                    let snapshot_info = SnapshotInfo {
567                        snapshot_id,
568                        snapshot_lsn,
569                        upstream_info,
570                        table_block_counts,
571                    };
572                    trace!(
573                        %id,
574                        "timely-{worker_id} exporting snapshot info {snapshot_info:?}");
575                    snapshot_handle.give(&snapshot_cap_set[0], snapshot_info);
576
577                    client
578                }
579                None => {
580                    // Only the snapshot leader needs a replication connection.
581                    let task_name = format!("timely-{worker_id} PG snapshotter");
582                    connection_config
583                        .connect(
584                            &task_name,
585                            &config.config.connection_context.ssh_tunnel_manager,
586                        )
587                        .await?
588                }
589            };
590
591            // Configure statement_timeout based on param. We want to be able to
592            // override the server value here in case it's set too low,
593            // respective to the size of the data we need to copy.
594            set_statement_timeout(
595                &client,
596                config
597                    .config
598                    .parameters
599                    .pg_source_snapshot_statement_timeout,
600            )
601            .await?;
602
603            let snapshot_info = loop {
604                match snapshot_input.next().await {
605                    Some(AsyncEvent::Data(_, mut data)) => {
606                        break data.pop().expect("snapshot sent above")
607                    }
608                    Some(AsyncEvent::Progress(_)) => continue,
609                    None => panic!(
610                        "feedback closed \
611                    before sending snapshot info"
612                    ),
613                }
614            };
615            let SnapshotInfo {
616                snapshot_id,
617                snapshot_lsn,
618                table_block_counts,
619                upstream_info,
620            } = snapshot_info;
621
622            // Snapshot leader is already in identified transaction but all other workers need to enter it.
623            if !is_snapshot_leader {
624                trace!(%id, "timely-{worker_id} using snapshot id {snapshot_id:?}");
625                use_snapshot(&client, &snapshot_id).await?;
626            }
627
628            for (&oid, outputs) in tables_to_snapshot.iter() {
629                for (&output_index, info) in outputs.iter() {
630                    if let Err(err) = verify_schema(oid, info, &upstream_info) {
631                        let update = (
632                            (oid, output_index, Err(err.into())),
633                            MzOffset::minimum(),
634                            Diff::ONE,
635                        );
636                        let size = update.fuel_size();
637                        raw_handle
638                            .give_fueled(&data_cap_set[0], update, size)
639                            .await;
640                        continue;
641                    }
642
643                    // Get estimated block count from the broadcasted table statistics
644                    let block_count = table_block_counts.get(&oid).copied().unwrap_or(0);
645
646                    // Calculate this worker's ctid range based on estimated blocks.
647                    // When estimated_blocks is 0 (PG < 14 or empty table), fall back to
648                    // single-worker mode using responsible_for to pick the worker.
649                    let Some(ctid_range) = worker_ctid_range(&config, block_count, oid) else {
650                        // This worker has no work for this table (more workers than blocks)
651                        trace!(
652                            %id,
653                            "timely-{worker_id} no ctid range assigned for table {:?}({oid})",
654                            info.desc.name
655                        );
656                        continue;
657                    };
658
659                    trace!(
660                        %id,
661                        "timely-{worker_id} snapshotting table {:?}({oid}) output {output_index} \
662                         @ {snapshot_lsn} with ctid range {:?}",
663                        info.desc.name,
664                        ctid_range
665                    );
666
667                    // To handle quoted/keyword names, we can use `Ident`'s AST printing, which
668                    // emulate's PG's rules for name formatting.
669                    let namespace = Ident::new_unchecked(&info.desc.namespace)
670                        .to_ast_string_stable();
671                    let table = Ident::new_unchecked(&info.desc.name)
672                        .to_ast_string_stable();
673                    let column_list = info
674                        .desc
675                        .columns
676                        .iter()
677                        .map(|c| Ident::new_unchecked(&c.name).to_ast_string_stable())
678                        .join(",");
679
680
681                    let ctid_filter = match ctid_range.end_block {
682                        Some(end) => format!(
683                            "WHERE ctid >= '({},0)'::tid AND ctid < '({},0)'::tid",
684                            ctid_range.start_block, end
685                        ),
686                        None => format!("WHERE ctid >= '({},0)'::tid", ctid_range.start_block),
687                    };
688                    let query = format!(
689                        "COPY (SELECT {column_list} FROM {namespace}.{table} {ctid_filter}) \
690                         TO STDOUT (FORMAT TEXT, DELIMITER '\t')"
691                    );
692                    let mut stream = pin!(client.copy_out_simple(&query).await?);
693
694                    let mut snapshot_staged = 0;
695                    while let Some(bytes) = stream.try_next().await? {
696                        let update = (
697                            (oid, output_index, Ok(bytes)),
698                            MzOffset::minimum(),
699                            Diff::ONE,
700                        );
701                        let size = update.fuel_size();
702                        raw_handle
703                            .give_fueled(&data_cap_set[0], update, size)
704                            .await;
705                        snapshot_staged += 1;
706                        if snapshot_staged % 1000 == 0 {
707                            let stat = &export_statistics[&(oid, output_index)];
708                            stat.set_snapshot_records_staged(snapshot_staged);
709                        }
710                    }
711                    // final update for snapshot_staged, using the staged
712                    // values as the total is an estimate
713                    let stat = &export_statistics[&(oid, output_index)];
714                    stat.set_snapshot_records_staged(snapshot_staged);
715                }
716            }
717
718            // We are done with the snapshot so now we will emit rewind requests. It is important
719            // that this happens after the snapshot has finished because this is what unblocks the
720            // replication operator and we want this to happen serially. It might seem like a good
721            // idea to read the replication stream concurrently with the snapshot but it actually
722            // leads to a lot of data being staged for the future, which needlessly consumed memory
723            // in the cluster.
724            //
725            // Since all workers now snapshot all tables (each with different ctid ranges), we only
726            // emit rewind requests from the worker responsible for each output to avoid duplicates.
727            for (&oid, output) in tables_to_snapshot.iter() {
728                for (output_index, info) in output {
729                    // Only emit rewind request from one worker per output
730                    if !config.responsible_for((oid, *output_index)) {
731                        continue;
732                    }
733                    trace!(%id, "timely-{worker_id} producing rewind request for table {} output {output_index}", info.desc.name);
734                    let req = RewindRequest { output_index: *output_index, snapshot_lsn };
735                    rewinds_handle.give(&rewind_cap_set[0], req);
736                }
737            }
738            *rewind_cap_set = CapabilitySet::new();
739
740            // Failure scenario after we have produced the snapshot, but before a successful COMMIT
741            fail::fail_point!("pg_snapshot_failure", |_| Err(
742                TransientError::SyntheticError
743            ));
744
745            // The exporting worker should wait for all the other workers to commit before dropping
746            // its client since this is what holds the exported transaction alive.
747            if is_snapshot_leader {
748                trace!(%id, "timely-{worker_id} waiting for all workers to finish");
749                *snapshot_cap_set = CapabilitySet::new();
750                while snapshot_input.next().await.is_some() {}
751                trace!(%id, "timely-{worker_id} (leader) comitting COPY transaction");
752                client.simple_query("COMMIT").await?;
753            } else {
754                trace!(%id, "timely-{worker_id} comitting COPY transaction");
755                client.simple_query("COMMIT").await?;
756                *snapshot_cap_set = CapabilitySet::new();
757            }
758            drop(client);
759            Ok(())
760        }))
761    });
762
763    // We now decode the COPY protocol and apply the cast expressions
764    let mut text_row = Row::default();
765    let mut final_row = Row::default();
766    let mut datum_vec = DatumVec::new();
767    let snapshot_updates = raw_data
768        .unary(Pipeline, "PgCastSnapshotRows", |_, _| {
769            move |input, output| {
770                input.for_each_time(|time, data| {
771                    let mut session = output.session(&time);
772                    for ((oid, output_index, event), time, diff) in
773                        data.flat_map(|data| data.drain(..))
774                    {
775                        let output = &table_info
776                            .get(&oid)
777                            .and_then(|outputs| outputs.get(&output_index))
778                            .expect("table_info contains all outputs");
779
780                        let event = event
781                            .as_ref()
782                            .map_err(|e: &DataflowError| e.clone())
783                            .and_then(|bytes| {
784                                decode_copy_row(bytes, output.casts.len(), &mut text_row)?;
785                                let datums = datum_vec.borrow_with(&text_row);
786                                super::cast_row(&output.casts, &datums, &mut final_row)?;
787                                Ok(SourceMessage {
788                                    key: Row::default(),
789                                    value: final_row.clone(),
790                                    metadata: Row::default(),
791                                })
792                            });
793
794                        session.give(((output_index, event), time, diff));
795                    }
796                });
797            }
798        })
799        .as_collection();
800
801    let errors = definite_errors.concat(transient_errors.map(ReplicationError::from));
802
803    (
804        snapshot_updates,
805        rewinds,
806        slot_ready,
807        errors,
808        button.press_on_drop(),
809    )
810}
811
812/// Starts a read-only transaction on the SQL session of `client` at a consistent LSN point by
813/// creating a replication slot. Returns a snapshot identifier that can be imported in
814/// other SQL session and the LSN of the consistent point.
815async fn export_snapshot(
816    client: &Client,
817    slot: &str,
818    temporary: bool,
819) -> Result<(String, MzOffset), TransientError> {
820    match export_snapshot_inner(client, slot, temporary).await {
821        Ok(ok) => Ok(ok),
822        Err(err) => {
823            // We don't want to leave the client inside a failed tx
824            client.simple_query("ROLLBACK;").await?;
825            Err(err)
826        }
827    }
828}
829
830async fn export_snapshot_inner(
831    client: &Client,
832    slot: &str,
833    temporary: bool,
834) -> Result<(String, MzOffset), TransientError> {
835    client
836        .simple_query("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;")
837        .await?;
838
839    // Note: Using unchecked here is okay because we're using it in a SQL query.
840    let slot = Ident::new_unchecked(slot).to_ast_string_simple();
841    let temporary_str = if temporary { " TEMPORARY" } else { "" };
842    let query =
843        format!("CREATE_REPLICATION_SLOT {slot}{temporary_str} LOGICAL \"pgoutput\" USE_SNAPSHOT");
844    let row = match simple_query_opt(client, &query).await {
845        Ok(row) => Ok(row.unwrap()),
846        Err(PostgresError::Postgres(err)) if err.code() == Some(&SqlState::DUPLICATE_OBJECT) => {
847            return Err(TransientError::ReplicationSlotAlreadyExists);
848        }
849        Err(err) => Err(err),
850    }?;
851
852    // When creating a replication slot postgres returns the LSN of its consistent point, which is
853    // the LSN that must be passed to `START_REPLICATION` to cleanly transition from the snapshot
854    // phase to the replication phase. `START_REPLICATION` includes all transactions that commit at
855    // LSNs *greater than or equal* to the passed LSN. Therefore the snapshot phase must happen at
856    // the greatest LSN that is not beyond the consistent point. That LSN is `consistent_point - 1`
857    let consistent_point: PgLsn = row.get("consistent_point").unwrap().parse().unwrap();
858    let consistent_point = u64::from(consistent_point)
859        .checked_sub(1)
860        .expect("consistent point is always non-zero");
861
862    let row = simple_query_opt(client, "SELECT pg_export_snapshot();")
863        .await?
864        .unwrap();
865    let snapshot = row.get("pg_export_snapshot").unwrap().to_owned();
866
867    Ok((snapshot, MzOffset::from(consistent_point)))
868}
869
870/// Starts a read-only transaction on the SQL session of `client` at a the consistent LSN point of
871/// `snapshot`.
872async fn use_snapshot(client: &Client, snapshot: &str) -> Result<(), TransientError> {
873    client
874        .simple_query("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;")
875        .await?;
876    let query = format!(
877        "SET TRANSACTION SNAPSHOT {};",
878        escaped_string_literal(snapshot)
879    );
880    client.simple_query(&query).await?;
881    Ok(())
882}
883
884async fn set_statement_timeout(client: &Client, timeout: Duration) -> Result<(), TransientError> {
885    // Value is known to accept milliseconds w/o units.
886    // https://www.postgresql.org/docs/current/runtime-config-client.html
887    client
888        .simple_query(&format!("SET statement_timeout = {}", timeout.as_millis()))
889        .await?;
890    Ok(())
891}
892
893/// Decodes a row of `col_len` columns obtained from a text encoded COPY query into `row`.
894fn decode_copy_row(data: &[u8], col_len: usize, row: &mut Row) -> Result<(), DefiniteError> {
895    let mut packer = row.packer();
896    let row_parser = mz_pgcopy::CopyTextFormatParser::new(data, b'\t', "\\N");
897    let mut column_iter = row_parser.iter_raw_truncating(col_len);
898    for _ in 0..col_len {
899        let value = match column_iter.next() {
900            Some(Ok(value)) => value,
901            Some(Err(_)) => return Err(DefiniteError::InvalidCopyInput),
902            None => return Err(DefiniteError::MissingColumn),
903        };
904        let datum = value.map(super::decode_utf8_text).transpose()?;
905        packer.push(datum.unwrap_or(Datum::Null));
906    }
907    Ok(())
908}
909
910/// Record the sizes of the tables being snapshotted in `PgSnapshotMetrics` and emit snapshot statistics for each export.
911async fn report_snapshot_size(
912    client: &Client,
913    tables_to_snapshot: &BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
914    metrics: PgSnapshotMetrics,
915    config: &RawSourceCreationConfig,
916    export_statistics: &BTreeMap<(u32, usize), SourceStatistics>,
917) -> Result<(), anyhow::Error> {
918    // TODO(guswynn): delete unused configs
919    let snapshot_config = config.config.parameters.pg_snapshot_config;
920
921    for (&oid, outputs) in tables_to_snapshot {
922        // Use the first output's desc to make the table name since it is the same for all outputs
923        let Some((_, info)) = outputs.first_key_value() else {
924            continue;
925        };
926        let table = format!(
927            "{}.{}",
928            Ident::new_unchecked(info.desc.namespace.clone()).to_ast_string_simple(),
929            Ident::new_unchecked(info.desc.name.clone()).to_ast_string_simple()
930        );
931        let stats = collect_table_statistics(
932            client,
933            snapshot_config,
934            &info.desc.namespace,
935            &info.desc.name,
936            info.desc.oid,
937        )
938        .await?;
939        metrics.record_table_count_latency(table, stats.count_latency);
940        for &output_index in outputs.keys() {
941            export_statistics[&(oid, output_index)].set_snapshot_records_known(stats.count);
942            export_statistics[&(oid, output_index)].set_snapshot_records_staged(0);
943        }
944    }
945    Ok(())
946}
947
948#[derive(Default)]
949struct TableStatistics {
950    count: u64,
951    count_latency: f64,
952}
953
954async fn collect_table_statistics(
955    client: &Client,
956    config: PgSourceSnapshotConfig,
957    namespace: &str,
958    table_name: &str,
959    oid: u32,
960) -> Result<TableStatistics, anyhow::Error> {
961    use mz_ore::metrics::MetricsFutureExt;
962    let mut stats = TableStatistics::default();
963    let table = format!(
964        "{}.{}",
965        Ident::new_unchecked(namespace).to_ast_string_simple(),
966        Ident::new_unchecked(table_name).to_ast_string_simple()
967    );
968
969    let estimate_row = simple_query_opt(
970        client,
971        &format!("SELECT reltuples::bigint AS estimate_count FROM pg_class WHERE oid = '{oid}'"),
972    )
973    .wall_time()
974    .set_at(&mut stats.count_latency)
975    .await?;
976    stats.count = match estimate_row {
977        Some(row) => row.get("estimate_count").unwrap().parse().unwrap_or(0),
978        None => bail!("failed to get estimate count for {table}"),
979    };
980
981    // If the estimate is low enough we can attempt to get an exact count. Note that not yet
982    // vacuumed tables will report zero rows here and there is a possibility that they are very
983    // large. We accept this risk and we offer the feature flag as an escape hatch if it becomes
984    // problematic.
985    if config.collect_strict_count && stats.count < 1_000_000 {
986        let count_row = simple_query_opt(client, &format!("SELECT count(*) as count from {table}"))
987            .wall_time()
988            .set_at(&mut stats.count_latency)
989            .await?;
990        stats.count = match count_row {
991            Some(row) => row.get("count").unwrap().parse().unwrap(),
992            None => bail!("failed to get count for {table}"),
993        }
994    }
995
996    Ok(stats)
997}
998
999/// Validates that there are no blocking RLS polcicies on the tables and retrieves table schemas
1000/// for the given publication.
1001async fn retrieve_schema_info(
1002    connection_config: &Config,
1003    connection_context: &ConnectionContext,
1004    publication: &str,
1005    table_oids: &[Oid],
1006) -> Result<BTreeMap<u32, PostgresTableDesc>, PostgresError> {
1007    let schema_client = connection_config
1008        .connect(
1009            "snapshot schema info",
1010            &connection_context.ssh_tunnel_manager,
1011        )
1012        .await?;
1013    mz_postgres_util::validate_no_rls_policies(&schema_client, table_oids).await?;
1014    mz_postgres_util::publication_info(&schema_client, publication, Some(table_oids)).await
1015}