mz_storage/source/postgres/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Renders the logical replication side of the [`PostgresSourceConnection`] ingestion dataflow.
11//!
12//! ```text
13//!              o
14//!              │rewind
15//!              │requests
16//!          ╭───┴────╮
17//!          │exchange│ (collect all requests to one worker)
18//!          ╰───┬────╯
19//!           ┏━━v━━━━━━━━━━┓
20//!           ┃ replication ┃ (single worker)
21//!           ┃   reader    ┃
22//!           ┗━┯━━━━━━━━┯━━┛
23//!             │raw     │
24//!             │data    │
25//!        ╭────┴─────╮  │
26//!        │distribute│  │ (distribute to all workers)
27//!        ╰────┬─────╯  │
28//! ┏━━━━━━━━━━━┷━┓      │
29//! ┃ replication ┃      │ (parallel decode)
30//! ┃   decoder   ┃      │
31//! ┗━━━━━┯━━━━━━━┛      │
32//!       │ replication  │ progress
33//!       │ updates      │ output
34//!       v              v
35//! ```
36//!
37//! # Progress tracking
38//!
39//! In order to avoid causing excessive resource usage in the upstream server it's important to
40//! track the LSN that we have successfully committed to persist and communicate that back to
41//! PostgreSQL. Under normal operation this gauge of progress is provided by the presence of
42//! transactions themselves. Since at a given LSN offset there can be only a single message, when a
43//! transaction is received and processed we can infer that we have seen all the messages that are
44//! not beyond `commit_lsn + 1`.
45//!
46//! Things are a bit more complicated in the absence of transactions though because even though we
47//! don't receive any the server might very well be generating WAL records. This can happen if
48//! there is a separate logical database performing writes (which is the case for RDS databases),
49//! or, in servers running PostgreSQL version 15 or greater, the logical replication process
50//! includes an optimization that omits empty transactions, which can happen if you're only
51//! replicating a subset of the tables and there writes going to the other ones.
52//!
53//! If we fail to detect this situation and don't send LSN feedback in a timely manner the server
54//! will be forced to keep around WAL data that can eventually lead to disk space exhaustion.
55//!
56//! In the absence of transactions the only available piece of information in the replication
57//! stream are keepalive messages. Keepalive messages are documented[1] to contain the current end
58//! of WAL on the server. That is a useless number when it comes to progress tracking because there
59//! might be pending messages at LSNs between the last received commit_lsn and the current end of
60//! WAL.
61//!
62//! Fortunately for us, the documentation for PrimaryKeepalive messages is wrong and it actually
63//! contains the last *sent* LSN[2]. Here sent doesn't necessarily mean sent over the wire, but
64//! sent to the upstream process that is handling producing the logical stream. Therefore, if we
65//! receive a keepalive with a particular LSN we can be certain that there are no other replication
66//! messages at previous LSNs, because they would have been already generated and received. We
67//! therefore connect the keepalive messages directly to our capability.
68//!
69//! [1]: https://www.postgresql.org/docs/15/protocol-replication.html#PROTOCOL-REPLICATION-START-REPLICATION
70//! [2]: https://www.postgresql.org/message-id/CAFPTHDZS9O9WG02EfayBd6oONzK%2BqfUxS6AbVLJ7W%2BKECza2gg%40mail.gmail.com
71
72use std::collections::BTreeMap;
73use std::convert::Infallible;
74use std::pin::pin;
75use std::rc::Rc;
76use std::str::FromStr;
77use std::sync::Arc;
78use std::sync::LazyLock;
79use std::time::Instant;
80use std::time::{Duration, SystemTime, UNIX_EPOCH};
81
82use differential_dataflow::AsCollection;
83use futures::{FutureExt, Stream as AsyncStream, StreamExt, TryStreamExt};
84use mz_ore::cast::CastFrom;
85use mz_ore::collections::HashSet;
86use mz_ore::future::InTask;
87use mz_ore::iter::IteratorExt;
88use mz_postgres_util::PostgresError;
89use mz_postgres_util::tunnel::PostgresFlavor;
90use mz_postgres_util::{Client, simple_query_opt};
91use mz_repr::{Datum, DatumVec, Diff, Row};
92use mz_sql_parser::ast::{Ident, display::AstDisplay};
93use mz_storage_types::dyncfgs::{PG_OFFSET_KNOWN_INTERVAL, PG_SCHEMA_VALIDATION_INTERVAL};
94use mz_storage_types::errors::DataflowError;
95use mz_storage_types::sources::SourceTimestamp;
96use mz_storage_types::sources::{MzOffset, PostgresSourceConnection};
97use mz_timely_util::builder_async::{
98    AsyncOutputHandle, Event as AsyncEvent, OperatorBuilder as AsyncOperatorBuilder,
99    PressOnDropButton,
100};
101use postgres_replication::LogicalReplicationStream;
102use postgres_replication::protocol::{LogicalReplicationMessage, ReplicationMessage, TupleData};
103use serde::{Deserialize, Serialize};
104use timely::container::CapacityContainerBuilder;
105use timely::dataflow::channels::pact::{Exchange, Pipeline};
106use timely::dataflow::channels::pushers::Tee;
107use timely::dataflow::operators::Capability;
108use timely::dataflow::operators::Concat;
109use timely::dataflow::operators::Operator;
110use timely::dataflow::operators::core::Map;
111use timely::dataflow::{Scope, Stream};
112use timely::progress::Antichain;
113use tokio::sync::{mpsc, watch};
114use tokio_postgres::error::SqlState;
115use tokio_postgres::types::PgLsn;
116use tracing::{error, trace};
117
118use crate::metrics::source::postgres::PgSourceMetrics;
119use crate::source::RawSourceCreationConfig;
120use crate::source::postgres::verify_schema;
121use crate::source::postgres::{DefiniteError, ReplicationError, SourceOutputInfo, TransientError};
122use crate::source::probe;
123use crate::source::types::{
124    Probe, ProgressStatisticsUpdate, SignaledFuture, SourceMessage, StackedCollection,
125};
126
127/// Postgres epoch is 2000-01-01T00:00:00Z
128static PG_EPOCH: LazyLock<SystemTime> =
129    LazyLock::new(|| UNIX_EPOCH + Duration::from_secs(946_684_800));
130
131// A request to rewind a snapshot taken at `snapshot_lsn` to the initial LSN of the replication
132// slot. This is accomplished by emitting `(data, 0, -diff)` for all updates `(data, lsn, diff)`
133// whose `lsn <= snapshot_lsn`. By convention the snapshot is always emitted at LSN 0.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub(crate) struct RewindRequest {
136    /// The output index that should be rewound.
137    pub(crate) output_index: usize,
138    /// The LSN that the snapshot was taken at.
139    pub(crate) snapshot_lsn: MzOffset,
140}
141
142/// Renders the replication dataflow. See the module documentation for more information.
143pub(crate) fn render<G: Scope<Timestamp = MzOffset>>(
144    scope: G,
145    config: RawSourceCreationConfig,
146    connection: PostgresSourceConnection,
147    table_info: BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
148    rewind_stream: &Stream<G, RewindRequest>,
149    slot_ready_stream: &Stream<G, Infallible>,
150    committed_uppers: impl futures::Stream<Item = Antichain<MzOffset>> + 'static,
151    metrics: PgSourceMetrics,
152) -> (
153    StackedCollection<G, (usize, Result<SourceMessage, DataflowError>)>,
154    Stream<G, Infallible>,
155    Stream<G, ProgressStatisticsUpdate>,
156    Option<Stream<G, Probe<MzOffset>>>,
157    Stream<G, ReplicationError>,
158    PressOnDropButton,
159) {
160    let op_name = format!("ReplicationReader({})", config.id);
161    let mut builder = AsyncOperatorBuilder::new(op_name, scope.clone());
162
163    let slot_reader = u64::cast_from(config.responsible_worker("slot"));
164    let (data_output, data_stream) = builder.new_output();
165    let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
166    let (definite_error_handle, definite_errors) = builder.new_output();
167
168    let (stats_output, stats_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
169    let (probe_output, probe_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
170
171    // Yugabyte doesn't support LSN probing currently.
172    let probe_stream = match connection.connection.flavor {
173        PostgresFlavor::Vanilla => Some(probe_stream),
174        PostgresFlavor::Yugabyte => None,
175    };
176
177    let mut rewind_input =
178        builder.new_disconnected_input(rewind_stream, Exchange::new(move |_| slot_reader));
179    let mut slot_ready_input = builder.new_disconnected_input(slot_ready_stream, Pipeline);
180    let mut output_uppers = table_info
181        .iter()
182        .flat_map(|(_, outputs)| outputs.values().map(|o| o.resume_upper.clone()))
183        .collect::<Vec<_>>();
184    metrics.tables.set(u64::cast_from(output_uppers.len()));
185
186    // Include the upper of the main source output for use in calculating the initial
187    // resume upper.
188    output_uppers.push(Antichain::from_iter(
189        config
190            .source_resume_uppers
191            .get(&config.id)
192            .expect("id exists")
193            .iter()
194            .map(MzOffset::decode_row),
195    ));
196
197    let reader_table_info = table_info.clone();
198    let (button, transient_errors) = builder.build_fallible(move |caps| {
199        let table_info = reader_table_info;
200        let busy_signal = Arc::clone(&config.busy_signal);
201        Box::pin(SignaledFuture::new(busy_signal, async move {
202            let (id, worker_id) = (config.id, config.worker_id);
203            let [
204                data_cap_set,
205                upper_cap_set,
206                definite_error_cap_set,
207                stats_cap,
208                probe_cap,
209            ]: &mut [_; 5] = caps.try_into().unwrap();
210
211            if !config.responsible_for("slot") {
212                // Emit 0, to mark this worker as having started up correctly.
213                stats_output.give(
214                    &stats_cap[0],
215                    ProgressStatisticsUpdate::SteadyState {
216                        offset_known: 0,
217                        offset_committed: 0,
218                    },
219                );
220                return Ok(());
221            }
222
223            // Determine the slot lsn.
224            let connection_config = connection
225                .connection
226                .config(
227                    &config.config.connection_context.secrets_reader,
228                    &config.config,
229                    InTask::Yes,
230                )
231                .await?;
232
233            let slot = &connection.publication_details.slot;
234            let replication_client = connection_config
235                .connect_replication(&config.config.connection_context.ssh_tunnel_manager)
236                .await?;
237
238            let metadata_client = connection_config
239                .connect(
240                    "replication metadata",
241                    &config.config.connection_context.ssh_tunnel_manager,
242                )
243                .await?;
244            let metadata_client = Arc::new(metadata_client);
245
246            while let Some(_) = slot_ready_input.next().await {
247                // Wait for the slot to be created
248            }
249            tracing::info!(%id, "ensuring replication slot {slot} exists");
250            super::ensure_replication_slot(&replication_client, slot).await?;
251            let slot_metadata = super::fetch_slot_metadata(
252                &*metadata_client,
253                slot,
254                mz_storage_types::dyncfgs::PG_FETCH_SLOT_RESUME_LSN_INTERVAL
255                    .get(config.config.config_set()),
256            )
257            .await?;
258
259            // We're the only application that should be using this replication
260            // slot. The only way that there can be another connection using
261            // this slot under normal operation is if there's a stale TCP
262            // connection from a prior incarnation of the source holding on to
263            // the slot. We don't want to wait for the WAL sender timeout and/or
264            // TCP keepalives to time out that connection, because these values
265            // are generally under the control of the DBA and may not time out
266            // the connection for multiple minutes, or at all. Instead we just
267            // force kill the connection that's using the slot.
268            //
269            // Note that there's a small risk that *we're* the zombie cluster
270            // that should not be using the replication slot. Kubernetes cannot
271            // 100% guarantee that only one cluster is alive at a time. However,
272            // this situation should not last long, and the worst that can
273            // happen is a bit of transient thrashing over ownership of the
274            // replication slot.
275            if let Some(active_pid) = slot_metadata.active_pid {
276                tracing::warn!(
277                    %id, %active_pid,
278                    "replication slot already in use; will attempt to kill existing connection",
279                );
280
281                match metadata_client
282                    .execute("SELECT pg_terminate_backend($1)", &[&active_pid])
283                    .await
284                {
285                    Ok(_) => {
286                        tracing::info!(
287                            "successfully killed existing connection; \
288                            starting replication is likely to succeed"
289                        );
290                        // Note that `pg_terminate_backend` does not wait for
291                        // the termination of the targeted connection to
292                        // complete. We may try to start replication before the
293                        // targeted connection has cleaned up its state. That's
294                        // okay. If that happens we'll just try again from the
295                        // top via the suspend-and-restart flow.
296                    }
297                    Err(e) => {
298                        tracing::warn!(
299                            %e,
300                            "failed to kill existing replication connection; \
301                            replication will likely fail to start"
302                        );
303                        // Continue on anyway, just in case the replication slot
304                        // is actually available. Maybe PostgreSQL has some
305                        // staleness when it reports `active_pid`, for example.
306                    }
307                }
308            }
309
310            // The overall resumption point for this source is the minimum of the resumption points
311            // contributed by each of the outputs.
312            let resume_lsn = output_uppers
313                .iter()
314                .flat_map(|f| f.elements())
315                .map(|&lsn| {
316                    // An output is either an output that has never had data committed to it or one
317                    // that has and needs to resume. We differentiate between the two by checking
318                    // whether an output wishes to "resume" from the minimum timestamp. In that case
319                    // its contribution to the overal resumption point is the earliest point available
320                    // in the slot. This information would normally be something that the storage
321                    // controller figures out in the form of an as-of frontier, but at the moment the
322                    // storage controller does not have visibility into what the replication slot is
323                    // doing.
324                    if lsn == MzOffset::from(0) {
325                        slot_metadata.confirmed_flush_lsn
326                    } else {
327                        lsn
328                    }
329                })
330                .min();
331            let Some(resume_lsn) = resume_lsn else {
332                return Ok(());
333            };
334            upper_cap_set.downgrade([&resume_lsn]);
335            trace!(%id, "timely-{worker_id} replication reader started lsn={resume_lsn}");
336
337            // Emitting an initial probe before we start waiting for rewinds ensures that we will
338            // have a timestamp binding in the remap collection while the snapshot is processed.
339            // This is important because otherwise the snapshot updates would need to be buffered
340            // in the reclock operator, instead of being spilled to S3 in the persist sink.
341            //
342            // Note that we need to fetch the probe LSN _after_ having created the replication
343            // slot, to make sure the fetched LSN will be included in the replication stream.
344            let probe_ts = (config.now_fn)().into();
345            let max_lsn = super::fetch_max_lsn(&*metadata_client).await?;
346            let probe = Probe {
347                probe_ts,
348                upstream_frontier: Antichain::from_elem(max_lsn),
349            };
350            probe_output.give(&probe_cap[0], probe);
351
352            let mut rewinds = BTreeMap::new();
353            while let Some(event) = rewind_input.next().await {
354                if let AsyncEvent::Data(_, data) = event {
355                    for req in data {
356                        if resume_lsn > req.snapshot_lsn + 1 {
357                            let err = DefiniteError::SlotCompactedPastResumePoint(
358                                req.snapshot_lsn + 1,
359                                resume_lsn,
360                            );
361                            // If the replication stream cannot be obtained from the resume point there is nothing
362                            // else to do. These errors are not retractable.
363                            for (oid, outputs) in table_info.iter() {
364                                for output_index in outputs.keys() {
365                                    // We pick `u64::MAX` as the LSN which will (in practice) never conflict
366                                    // any previously revealed portions of the TVC.
367                                    let update = (
368                                        (
369                                            *oid,
370                                            *output_index,
371                                            Err(DataflowError::from(err.clone())),
372                                        ),
373                                        MzOffset::from(u64::MAX),
374                                        Diff::ONE,
375                                    );
376                                    data_output.give_fueled(&data_cap_set[0], update).await;
377                                }
378                            }
379                            definite_error_handle.give(
380                                &definite_error_cap_set[0],
381                                ReplicationError::Definite(Rc::new(err)),
382                            );
383                            return Ok(());
384                        }
385                        rewinds.insert(req.output_index, req);
386                    }
387                }
388            }
389            trace!(%id, "timely-{worker_id} pending rewinds {rewinds:?}");
390
391            let mut committed_uppers = pin!(committed_uppers);
392
393            let stream_result = raw_stream(
394                &config,
395                replication_client,
396                Arc::clone(&metadata_client),
397                &connection.publication_details.slot,
398                &connection.publication_details.timeline_id,
399                &connection.publication,
400                resume_lsn,
401                committed_uppers.as_mut(),
402                &stats_output,
403                &stats_cap[0],
404                &probe_output,
405                &probe_cap[0],
406            )
407            .await?;
408
409            let stream = match stream_result {
410                Ok(stream) => stream,
411                Err(err) => {
412                    // If the replication stream cannot be obtained in a definite way there is
413                    // nothing else to do. These errors are not retractable.
414                    for (oid, outputs) in table_info.iter() {
415                        for output_index in outputs.keys() {
416                            // We pick `u64::MAX` as the LSN which will (in practice) never conflict
417                            // any previously revealed portions of the TVC.
418                            let update = (
419                                (*oid, *output_index, Err(DataflowError::from(err.clone()))),
420                                MzOffset::from(u64::MAX),
421                                Diff::ONE,
422                            );
423                            data_output.give_fueled(&data_cap_set[0], update).await;
424                        }
425                    }
426
427                    definite_error_handle.give(
428                        &definite_error_cap_set[0],
429                        ReplicationError::Definite(Rc::new(err)),
430                    );
431                    return Ok(());
432                }
433            };
434            let mut stream = pin!(stream.peekable());
435
436            // Run the periodic schema validation on a separate task using a separate client,
437            // to prevent it from blocking the replication reading progress.
438            let ssh_tunnel_manager = &config.config.connection_context.ssh_tunnel_manager;
439            let client = connection_config
440                .connect("schema validation", ssh_tunnel_manager)
441                .await?;
442            let mut schema_errors = spawn_schema_validator(
443                client,
444                &config,
445                connection.publication.clone(),
446                table_info.clone(),
447            );
448
449            let mut errored = HashSet::new();
450            // Instead of downgrading the capability for every transaction we process we only do it
451            // if we're about to yield, which is checked at the bottom of the loop. This avoids
452            // creating excessive progress tracking traffic when there are multiple small
453            // transactions ready to go.
454            let mut data_upper = resume_lsn;
455            // A stash of reusable vectors to convert from bytes::Bytes based data, which is not
456            // compatible with `columnation`, to Vec<u8> data that is.
457            while let Some(event) = stream.as_mut().next().await {
458                use LogicalReplicationMessage::*;
459                use ReplicationMessage::*;
460                match event {
461                    Ok(XLogData(data)) => match data.data() {
462                        Begin(begin) => {
463                            let commit_lsn = MzOffset::from(begin.final_lsn());
464
465                            let mut tx = pin!(extract_transaction(
466                                stream.by_ref(),
467                                &*metadata_client,
468                                commit_lsn,
469                                &table_info,
470                                &metrics,
471                                &connection.publication,
472                                &mut errored
473                            ));
474
475                            trace!(
476                                %id,
477                                "timely-{worker_id} extracting transaction \
478                                    at {commit_lsn}"
479                            );
480                            assert!(
481                                data_upper <= commit_lsn,
482                                "new_upper={data_upper} tx_lsn={commit_lsn}",
483                            );
484                            data_upper = commit_lsn + 1;
485                            // We are about to ingest a transaction which has the possiblity to be
486                            // very big and we certainly don't want to hold the data in memory. For
487                            // this reason we eagerly downgrade the upper capability in order for
488                            // the reclocking machinery to mint a binding that includes
489                            // this transaction and therefore be able to pass the data of the
490                            // transaction through as we stream it.
491                            upper_cap_set.downgrade([&data_upper]);
492                            while let Some((oid, output_index, event, diff)) = tx.try_next().await?
493                            {
494                                let event = event.map_err(Into::into);
495                                let mut data = (oid, output_index, event);
496                                if let Some(req) = rewinds.get(&output_index) {
497                                    if commit_lsn <= req.snapshot_lsn {
498                                        let update = (data, MzOffset::from(0), -diff);
499                                        data_output.give_fueled(&data_cap_set[0], &update).await;
500                                        data = update.0;
501                                    }
502                                }
503                                let update = (data, commit_lsn, diff);
504                                data_output.give_fueled(&data_cap_set[0], &update).await;
505                            }
506                        }
507                        _ => return Err(TransientError::BareTransactionEvent),
508                    },
509                    Ok(PrimaryKeepAlive(keepalive)) => {
510                        trace!( %id,
511                            "timely-{worker_id} received keepalive lsn={}",
512                            keepalive.wal_end()
513                        );
514
515                        // Take the opportunity to report any schema validation errors.
516                        while let Ok(error) = schema_errors.try_recv() {
517                            use SchemaValidationError::*;
518                            match error {
519                                Postgres(PostgresError::PublicationMissing(publication)) => {
520                                    let err = DefiniteError::PublicationDropped(publication);
521                                    for (oid, outputs) in table_info.iter() {
522                                        for output_index in outputs.keys() {
523                                            let update = (
524                                                (
525                                                    *oid,
526                                                    *output_index,
527                                                    Err(DataflowError::from(err.clone())),
528                                                ),
529                                                data_cap_set[0].time().clone(),
530                                                Diff::ONE,
531                                            );
532                                            data_output.give_fueled(&data_cap_set[0], update).await;
533                                        }
534                                    }
535                                    definite_error_handle.give(
536                                        &definite_error_cap_set[0],
537                                        ReplicationError::Definite(Rc::new(err)),
538                                    );
539                                    return Ok(());
540                                }
541                                Postgres(pg_error) => Err(TransientError::from(pg_error))?,
542                                Schema {
543                                    oid,
544                                    output_index,
545                                    error,
546                                } => {
547                                    if errored.contains(&output_index) {
548                                        continue;
549                                    }
550
551                                    let update = (
552                                        (oid, output_index, Err(error.into())),
553                                        data_cap_set[0].time().clone(),
554                                        Diff::ONE,
555                                    );
556                                    data_output.give_fueled(&data_cap_set[0], update).await;
557                                    errored.insert(output_index);
558                                }
559                            }
560                        }
561                        data_upper = std::cmp::max(data_upper, keepalive.wal_end().into());
562                    }
563                    Ok(_) => return Err(TransientError::UnknownReplicationMessage),
564                    Err(err) => return Err(err),
565                }
566
567                let will_yield = stream.as_mut().peek().now_or_never().is_none();
568                if will_yield {
569                    trace!(%id, "timely-{worker_id} yielding at lsn={data_upper}");
570                    rewinds.retain(|_, req| data_upper <= req.snapshot_lsn);
571                    // As long as there are pending rewinds we can't downgrade our data capability
572                    // since we must be able to produce data at offset 0.
573                    if rewinds.is_empty() {
574                        data_cap_set.downgrade([&data_upper]);
575                    }
576                    upper_cap_set.downgrade([&data_upper]);
577                }
578            }
579            // We never expect the replication stream to gracefully end
580            Err(TransientError::ReplicationEOF)
581        }))
582    });
583
584    // We now process the slot updates and apply the cast expressions
585    let mut final_row = Row::default();
586    let mut datum_vec = DatumVec::new();
587    let mut next_worker = (0..u64::cast_from(scope.peers()))
588        // Round robin on 1000-records basis to avoid creating tiny containers when there are a
589        // small number of updates and a large number of workers.
590        .flat_map(|w| std::iter::repeat_n(w, 1000))
591        .cycle();
592    let round_robin = Exchange::new(move |_| next_worker.next().unwrap());
593    let replication_updates = data_stream
594        .map::<Vec<_>, _, _>(Clone::clone)
595        .unary(round_robin, "PgCastReplicationRows", |_, _| {
596            move |input, output| {
597                while let Some((time, data)) = input.next() {
598                    let mut session = output.session(&time);
599                    for ((oid, output_index, event), time, diff) in data.drain(..) {
600                        let output = &table_info
601                            .get(&oid)
602                            .and_then(|outputs| outputs.get(&output_index))
603                            .expect("table_info contains all outputs");
604                        let event = event.and_then(|row| {
605                            let datums = datum_vec.borrow_with(&row);
606                            super::cast_row(&output.casts, &datums, &mut final_row)?;
607                            Ok(SourceMessage {
608                                key: Row::default(),
609                                value: final_row.clone(),
610                                metadata: Row::default(),
611                            })
612                        });
613
614                        session.give(((output_index, event), time, diff));
615                    }
616                }
617            }
618        })
619        .as_collection();
620
621    let errors = definite_errors.concat(&transient_errors.map(ReplicationError::from));
622
623    (
624        replication_updates,
625        upper_stream,
626        stats_stream,
627        probe_stream,
628        errors,
629        button.press_on_drop(),
630    )
631}
632
633/// Produces the logical replication stream while taking care of regularly sending standby
634/// keepalive messages with the provided `uppers` stream.
635///
636/// The returned stream will contain all transactions that whose commit LSN is beyond `resume_lsn`.
637async fn raw_stream<'a>(
638    config: &'a RawSourceCreationConfig,
639    replication_client: Client,
640    metadata_client: Arc<Client>,
641    slot: &'a str,
642    timeline_id: &'a Option<u64>,
643    publication: &'a str,
644    resume_lsn: MzOffset,
645    uppers: impl futures::Stream<Item = Antichain<MzOffset>> + 'a,
646    stats_output: &'a AsyncOutputHandle<
647        MzOffset,
648        CapacityContainerBuilder<Vec<ProgressStatisticsUpdate>>,
649        Tee<MzOffset, Vec<ProgressStatisticsUpdate>>,
650    >,
651    stats_cap: &'a Capability<MzOffset>,
652    probe_output: &'a AsyncOutputHandle<
653        MzOffset,
654        CapacityContainerBuilder<Vec<Probe<MzOffset>>>,
655        Tee<MzOffset, Vec<Probe<MzOffset>>>,
656    >,
657    probe_cap: &'a Capability<MzOffset>,
658) -> Result<
659    Result<
660        impl AsyncStream<Item = Result<ReplicationMessage<LogicalReplicationMessage>, TransientError>>
661        + 'a,
662        DefiniteError,
663    >,
664    TransientError,
665> {
666    if let Err(err) = ensure_publication_exists(&*metadata_client, publication).await? {
667        // If the publication gets deleted there is nothing else to do. These errors
668        // are not retractable.
669        return Ok(Err(err));
670    }
671
672    // Skip the timeline ID check for sources without a known timeline ID
673    // (sources created before the timeline ID was added to the source details)
674    if let Some(expected_timeline_id) = timeline_id {
675        if let Err(err) =
676            ensure_replication_timeline_id(&replication_client, expected_timeline_id).await?
677        {
678            return Ok(Err(err));
679        }
680    }
681
682    // How often a proactive standby status update message should be sent to the server.
683    //
684    // The upstream will periodically request status updates by setting the keepalive's reply field
685    // value to 1. However, we cannot rely on these messages arriving on time. For example, when
686    // the upstream is sending a big transaction its keepalive messages are queued and can be
687    // delayed arbitrarily.
688    //
689    // See: <https://www.postgresql.org/message-id/CAMsr+YE2dSfHVr7iEv1GSPZihitWX-PMkD9QALEGcTYa+sdsgg@mail.gmail.com>
690    //
691    // For this reason we query the server's timeout value and proactively send a keepalive at
692    // twice the frequency to have a healthy margin from the deadline.
693    //
694    // Note: We must use the metadata client here which is NOT in replication mode. Some Aurora
695    // Postgres versions disallow SHOW commands from within replication connection.
696    // See: https://github.com/readysettech/readyset/discussions/28#discussioncomment-4405671
697    let row = simple_query_opt(&*metadata_client, "SHOW wal_sender_timeout;")
698        .await?
699        .unwrap();
700    let wal_sender_timeout = match row.get("wal_sender_timeout") {
701        // When this parameter is zero the timeout mechanism is disabled
702        Some("0") => None,
703        Some(value) => Some(
704            mz_repr::adt::interval::Interval::from_str(value)
705                .unwrap()
706                .duration()
707                .unwrap(),
708        ),
709        None => panic!("ubiquitous parameter missing"),
710    };
711
712    // This interval controls the cadence at which we send back status updates and, crucially,
713    // request PrimaryKeepAlive messages. PrimaryKeepAlive messages drive the frontier forward in
714    // the absence of data updates and we don't want a large `wal_sender_timeout` value to slow us
715    // down. For this reason the feedback interval is set to one second, or less if the
716    // wal_sender_timeout is less than 2 seconds.
717    let feedback_interval = match wal_sender_timeout {
718        Some(t) => std::cmp::min(Duration::from_secs(1), t.checked_div(2).unwrap()),
719        None => Duration::from_secs(1),
720    };
721
722    let mut feedback_timer = tokio::time::interval(feedback_interval);
723    // 'Delay' ensures we always tick at least 'feedback_interval'.
724    feedback_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
725
726    // Postgres will return all transactions that commit *at or after* after the provided LSN,
727    // following the timely upper semantics.
728    let lsn = PgLsn::from(resume_lsn.offset);
729    let query = format!(
730        r#"START_REPLICATION SLOT "{}" LOGICAL {} ("proto_version" '1', "publication_names" '{}')"#,
731        Ident::new_unchecked(slot).to_ast_string_simple(),
732        lsn,
733        publication,
734    );
735    let copy_stream = match replication_client.copy_both_simple(&query).await {
736        Ok(copy_stream) => copy_stream,
737        Err(err) if err.code() == Some(&SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE) => {
738            return Ok(Err(DefiniteError::InvalidReplicationSlot));
739        }
740        Err(err) => return Err(err.into()),
741    };
742
743    // According to the documentation [1] we must check that the slot LSN matches our
744    // expectations otherwise we risk getting silently fast-forwarded to a future LSN. In order
745    // to avoid a TOCTOU issue we must do this check after starting the replication stream. We
746    // cannot use the replication client to do that because it's already in CopyBoth mode.
747    // [1] https://www.postgresql.org/docs/15/protocol-replication.html#PROTOCOL-REPLICATION-START-REPLICATION-SLOT-LOGICAL
748    let slot_metadata = super::fetch_slot_metadata(
749        &*metadata_client,
750        slot,
751        mz_storage_types::dyncfgs::PG_FETCH_SLOT_RESUME_LSN_INTERVAL
752            .get(config.config.config_set()),
753    )
754    .await?;
755    let min_resume_lsn = slot_metadata.confirmed_flush_lsn;
756    tracing::info!(
757        %config.id,
758        "started replication using backend PID={:?}. wal_sender_timeout={:?}",
759        slot_metadata.active_pid, wal_sender_timeout
760    );
761
762    let (probe_tx, mut probe_rx) = watch::channel(None);
763    let config_set = Arc::clone(config.config.config_set());
764    let now_fn = config.now_fn.clone();
765    let max_lsn_task_handle =
766        mz_ore::task::spawn(|| format!("pg_current_wal_lsn:{}", config.id), async move {
767            let mut probe_ticker =
768                probe::Ticker::new(|| PG_OFFSET_KNOWN_INTERVAL.get(&config_set), now_fn);
769
770            while !probe_tx.is_closed() {
771                let probe_ts = probe_ticker.tick().await;
772                let probe_or_err = super::fetch_max_lsn(&*metadata_client)
773                    .await
774                    .map(|lsn| Probe {
775                        probe_ts,
776                        upstream_frontier: Antichain::from_elem(lsn),
777                    });
778                let _ = probe_tx.send(Some(probe_or_err));
779            }
780        })
781        .abort_on_drop();
782
783    let stream = async_stream::try_stream!({
784        // Ensure we don't pre-drop the task
785        let _max_lsn_task_handle = max_lsn_task_handle;
786
787        let mut uppers = pin!(uppers);
788        let mut last_committed_upper = resume_lsn;
789
790        let mut stream = pin!(LogicalReplicationStream::new(copy_stream));
791
792        if !(resume_lsn == MzOffset::from(0) || min_resume_lsn <= resume_lsn) {
793            let err = TransientError::OvercompactedReplicationSlot {
794                available_lsn: min_resume_lsn,
795                requested_lsn: resume_lsn,
796            };
797            error!("timely-{} ({}) {err}", config.worker_id, config.id);
798            Err(err)?;
799        }
800
801        loop {
802            tokio::select! {
803                Some(next_message) = stream.next() => match next_message {
804                    Ok(ReplicationMessage::XLogData(data)) => {
805                        yield ReplicationMessage::XLogData(data);
806                        Ok(())
807                    }
808                    Ok(ReplicationMessage::PrimaryKeepAlive(keepalive)) => {
809                        yield ReplicationMessage::PrimaryKeepAlive(keepalive);
810                        Ok(())
811                    }
812                    Err(err) => Err(err.into()),
813                    _ => Err(TransientError::UnknownReplicationMessage),
814                },
815                _ = feedback_timer.tick() => {
816                    let ts: i64 = PG_EPOCH.elapsed().unwrap().as_micros().try_into().unwrap();
817                    let lsn = PgLsn::from(last_committed_upper.offset);
818                    trace!("timely-{} ({}) sending keepalive {lsn:?}", config.worker_id, config.id);
819                    // Postgres only sends PrimaryKeepAlive messages when *it* wants a reply, which
820                    // happens when out status update is late. Since we send them proactively this
821                    // may never happen. It is therefore *crucial* that we set the last parameter
822                    // (the reply flag) to 1 here. This will cause the upstream server to send us a
823                    // PrimaryKeepAlive message promptly which will give us frontier advancement
824                    // information in the absence of data updates.
825                    let res = stream.as_mut().standby_status_update(lsn, lsn, lsn, ts, 1).await;
826                    res.map_err(|e| e.into())
827                },
828                Some(upper) = uppers.next() => match upper.into_option() {
829                    Some(lsn) => {
830                        last_committed_upper = std::cmp::max(last_committed_upper, lsn);
831                        Ok(())
832                    }
833                    None => Ok(()),
834                },
835                Ok(()) = probe_rx.changed() => match &*probe_rx.borrow() {
836                    Some(Ok(probe)) => {
837                        if let Some(offset_known) = probe.upstream_frontier.as_option() {
838                            stats_output.give(
839                                stats_cap,
840                                ProgressStatisticsUpdate::SteadyState {
841                                    // Similar to the kafka source, we don't subtract 1 from the
842                                    // upper as we want to report the _number of bytes_ we have
843                                    // processed/in upstream.
844                                    offset_known: offset_known.offset,
845                                    offset_committed: last_committed_upper.offset,
846                                },
847                            );
848                        }
849                        probe_output.give(probe_cap, probe.clone());
850                        Ok(())
851                    },
852                    Some(Err(err)) => Err(anyhow::anyhow!("{err}").into()),
853                    None => Ok(()),
854                },
855                else => return
856            }?;
857        }
858    });
859    Ok(Ok(stream))
860}
861
862/// Extracts a single transaction from the replication stream delimited by a BEGIN and COMMIT
863/// message. The BEGIN message must have already been consumed from the stream before calling this
864/// function.
865fn extract_transaction<'a>(
866    stream: impl AsyncStream<
867        Item = Result<ReplicationMessage<LogicalReplicationMessage>, TransientError>,
868    > + 'a,
869    metadata_client: &'a Client,
870    commit_lsn: MzOffset,
871    table_info: &'a BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
872    metrics: &'a PgSourceMetrics,
873    publication: &'a str,
874    errored_outputs: &'a mut HashSet<usize>,
875) -> impl AsyncStream<Item = Result<(u32, usize, Result<Row, DefiniteError>, Diff), TransientError>> + 'a
876{
877    use LogicalReplicationMessage::*;
878    let mut row = Row::default();
879    async_stream::try_stream!({
880        let mut stream = pin!(stream);
881        metrics.transactions.inc();
882        metrics.lsn.set(commit_lsn.offset);
883        while let Some(event) = stream.try_next().await? {
884            // We can ignore keepalive messages while processing a transaction because the
885            // commit_lsn will drive progress.
886            let message = match event {
887                ReplicationMessage::XLogData(data) => data.into_data(),
888                ReplicationMessage::PrimaryKeepAlive(_) => continue,
889                _ => Err(TransientError::UnknownReplicationMessage)?,
890            };
891            metrics.total.inc();
892            match message {
893                Insert(body) if !table_info.contains_key(&body.rel_id()) => continue,
894                Update(body) if !table_info.contains_key(&body.rel_id()) => continue,
895                Delete(body) if !table_info.contains_key(&body.rel_id()) => continue,
896                Relation(body) if !table_info.contains_key(&body.rel_id()) => continue,
897                Insert(body) => {
898                    metrics.inserts.inc();
899                    let row = unpack_tuple(body.tuple().tuple_data(), &mut row);
900                    let rel = body.rel_id();
901                    for ((output, _), row) in table_info
902                        .get(&rel)
903                        .map(|o| o.iter().filter(|(o, _)| !errored_outputs.contains(o)))
904                        .into_iter()
905                        .flatten()
906                        .repeat_clone(row)
907                    {
908                        yield (rel, *output, row, Diff::ONE);
909                    }
910                }
911                Update(body) => match body.old_tuple() {
912                    Some(old_tuple) => {
913                        metrics.updates.inc();
914                        // If the new tuple contains unchanged toast values we reference the old ones
915                        let new_tuple =
916                            std::iter::zip(body.new_tuple().tuple_data(), old_tuple.tuple_data())
917                                .map(|(new, old)| match new {
918                                    TupleData::UnchangedToast => old,
919                                    _ => new,
920                                });
921                        let old_row = unpack_tuple(old_tuple.tuple_data(), &mut row);
922                        let new_row = unpack_tuple(new_tuple, &mut row);
923                        let rel = body.rel_id();
924                        for ((output, _), (old_row, new_row)) in table_info
925                            .get(&rel)
926                            .map(|o| o.iter().filter(|(o, _)| !errored_outputs.contains(o)))
927                            .into_iter()
928                            .flatten()
929                            .repeat_clone((old_row, new_row))
930                        {
931                            yield (rel, *output, old_row, Diff::MINUS_ONE);
932                            yield (rel, *output, new_row, Diff::ONE);
933                        }
934                    }
935                    None => {
936                        let rel = body.rel_id();
937                        for (output, _) in table_info
938                            .get(&rel)
939                            .map(|o| o.iter().filter(|(o, _)| !errored_outputs.contains(o)))
940                            .into_iter()
941                            .flatten()
942                        {
943                            yield (
944                                rel,
945                                *output,
946                                Err(DefiniteError::DefaultReplicaIdentity),
947                                Diff::ONE,
948                            );
949                        }
950                    }
951                },
952                Delete(body) => match body.old_tuple() {
953                    Some(old_tuple) => {
954                        metrics.deletes.inc();
955                        let row = unpack_tuple(old_tuple.tuple_data(), &mut row);
956                        let rel = body.rel_id();
957                        for ((output, _), row) in table_info
958                            .get(&rel)
959                            .map(|o| o.iter().filter(|(o, _)| !errored_outputs.contains(o)))
960                            .into_iter()
961                            .flatten()
962                            .repeat_clone(row)
963                        {
964                            yield (rel, *output, row, Diff::MINUS_ONE);
965                        }
966                    }
967                    None => {
968                        let rel = body.rel_id();
969                        for (output, _) in table_info
970                            .get(&rel)
971                            .map(|o| o.iter().filter(|(o, _)| !errored_outputs.contains(o)))
972                            .into_iter()
973                            .flatten()
974                        {
975                            yield (
976                                rel,
977                                *output,
978                                Err(DefiniteError::DefaultReplicaIdentity),
979                                Diff::ONE,
980                            );
981                        }
982                    }
983                },
984                Relation(body) => {
985                    let rel_id = body.rel_id();
986                    let valid_outputs = table_info
987                        .get(&rel_id)
988                        .map(|o| o.iter().filter(|(o, _)| !errored_outputs.contains(o)))
989                        .into_iter()
990                        .flatten()
991                        .collect::<Vec<_>>();
992                    if valid_outputs.len() > 0 {
993                        // Because the replication stream doesn't include columns' attnums, we need
994                        // to check the current local schema against the current remote schema to
995                        // ensure e.g. we haven't received a schema update with the same terminal
996                        // column name which is actually a different column.
997                        let oids = std::iter::once(rel_id)
998                            .chain(table_info.keys().copied())
999                            .collect::<Vec<_>>();
1000                        let upstream_info = mz_postgres_util::publication_info(
1001                            metadata_client,
1002                            publication,
1003                            Some(&oids),
1004                        )
1005                        .await?;
1006
1007                        for (output_index, output) in valid_outputs {
1008                            if let Err(err) =
1009                                verify_schema(rel_id, &output.desc, &upstream_info, &output.casts)
1010                            {
1011                                errored_outputs.insert(*output_index);
1012                                yield (rel_id, *output_index, Err(err), Diff::ONE);
1013                            }
1014
1015                            // Error any dropped tables.
1016                            for (oid, outputs) in table_info {
1017                                if !upstream_info.contains_key(oid) {
1018                                    for output in outputs.keys() {
1019                                        if errored_outputs.insert(*output) {
1020                                            // Minimize the number of excessive errors
1021                                            // this will generate.
1022                                            yield (
1023                                                *oid,
1024                                                *output,
1025                                                Err(DefiniteError::TableDropped),
1026                                                Diff::ONE,
1027                                            );
1028                                        }
1029                                    }
1030                                }
1031                            }
1032                        }
1033                    }
1034                }
1035                Truncate(body) => {
1036                    for &rel_id in body.rel_ids() {
1037                        if let Some(outputs) = table_info.get(&rel_id) {
1038                            for (output, _) in outputs {
1039                                if errored_outputs.insert(*output) {
1040                                    yield (
1041                                        rel_id,
1042                                        *output,
1043                                        Err(DefiniteError::TableTruncated),
1044                                        Diff::ONE,
1045                                    );
1046                                }
1047                            }
1048                        }
1049                    }
1050                }
1051                Commit(body) => {
1052                    if commit_lsn != body.commit_lsn().into() {
1053                        Err(TransientError::InvalidTransaction)?
1054                    }
1055                    return;
1056                }
1057                // TODO: We should handle origin messages and emit an error as they indicate that
1058                // the upstream performed a point in time restore so all bets are off about the
1059                // continuity of the stream.
1060                Origin(_) | Type(_) => metrics.ignored.inc(),
1061                Begin(_) => Err(TransientError::NestedTransaction)?,
1062                // The enum is marked as non_exhaustive. Better to be conservative
1063                _ => Err(TransientError::UnknownLogicalReplicationMessage)?,
1064            }
1065        }
1066        Err(TransientError::ReplicationEOF)?;
1067    })
1068}
1069
1070/// Unpacks an iterator of TupleData into a list of nullable bytes or an error if this can't be
1071/// done.
1072#[inline]
1073fn unpack_tuple<'a, I>(tuple_data: I, row: &mut Row) -> Result<Row, DefiniteError>
1074where
1075    I: IntoIterator<Item = &'a TupleData>,
1076    I::IntoIter: ExactSizeIterator,
1077{
1078    let iter = tuple_data.into_iter();
1079    let mut packer = row.packer();
1080    for data in iter {
1081        let datum = match data {
1082            TupleData::Text(bytes) => super::decode_utf8_text(bytes)?,
1083            TupleData::Null => Datum::Null,
1084            TupleData::UnchangedToast => return Err(DefiniteError::MissingToast),
1085            TupleData::Binary(_) => return Err(DefiniteError::UnexpectedBinaryData),
1086        };
1087        packer.push(datum);
1088    }
1089    Ok(row.clone())
1090}
1091
1092/// Ensures the publication exists on the server. It returns an outer transient error in case of
1093/// connection issues and an inner definite error if the publication is dropped.
1094async fn ensure_publication_exists(
1095    client: &Client,
1096    publication: &str,
1097) -> Result<Result<(), DefiniteError>, TransientError> {
1098    // Figure out the last written LSN and then add one to convert it into an upper.
1099    let result = client
1100        .query_opt(
1101            "SELECT 1 FROM pg_publication WHERE pubname = $1;",
1102            &[&publication],
1103        )
1104        .await?;
1105    match result {
1106        Some(_) => Ok(Ok(())),
1107        None => Ok(Err(DefiniteError::PublicationDropped(
1108            publication.to_owned(),
1109        ))),
1110    }
1111}
1112
1113/// Ensure the active replication timeline_id matches the one we expect such that we can safely
1114/// resume replication. It returns an outer transient error in case of
1115/// connection issues and an inner definite error if the timeline id does not match.
1116async fn ensure_replication_timeline_id(
1117    replication_client: &Client,
1118    expected_timeline_id: &u64,
1119) -> Result<Result<(), DefiniteError>, TransientError> {
1120    let timeline_id = mz_postgres_util::get_timeline_id(replication_client).await?;
1121    if timeline_id == *expected_timeline_id {
1122        Ok(Ok(()))
1123    } else {
1124        Ok(Err(DefiniteError::InvalidTimelineId {
1125            expected: *expected_timeline_id,
1126            actual: timeline_id,
1127        }))
1128    }
1129}
1130
1131enum SchemaValidationError {
1132    Postgres(PostgresError),
1133    Schema {
1134        oid: u32,
1135        output_index: usize,
1136        error: DefiniteError,
1137    },
1138}
1139
1140fn spawn_schema_validator(
1141    client: Client,
1142    config: &RawSourceCreationConfig,
1143    publication: String,
1144    table_info: BTreeMap<u32, BTreeMap<usize, SourceOutputInfo>>,
1145) -> mpsc::UnboundedReceiver<SchemaValidationError> {
1146    let (tx, rx) = mpsc::unbounded_channel();
1147    let source_id = config.id;
1148    let config_set = Arc::clone(config.config.config_set());
1149
1150    mz_ore::task::spawn(|| format!("schema-validator:{}", source_id), async move {
1151        while !tx.is_closed() {
1152            trace!(%source_id, "validating schemas");
1153
1154            let validation_start = Instant::now();
1155
1156            let upstream_info = match mz_postgres_util::publication_info(
1157                &*client,
1158                &publication,
1159                Some(&table_info.keys().copied().collect::<Vec<_>>()),
1160            )
1161            .await
1162            {
1163                Ok(info) => info,
1164                Err(error) => {
1165                    let _ = tx.send(SchemaValidationError::Postgres(error));
1166                    continue;
1167                }
1168            };
1169
1170            for (&oid, outputs) in table_info.iter() {
1171                for (&output_index, output_info) in outputs {
1172                    let expected_desc = &output_info.desc;
1173                    let casts = &output_info.casts;
1174                    if let Err(error) = verify_schema(oid, expected_desc, &upstream_info, casts) {
1175                        trace!(
1176                            %source_id,
1177                            "schema of output index {output_index} for oid {oid} invalid",
1178                        );
1179                        let _ = tx.send(SchemaValidationError::Schema {
1180                            oid,
1181                            output_index,
1182                            error,
1183                        });
1184                    } else {
1185                        trace!(
1186                            %source_id,
1187                            "schema of output index {output_index} for oid {oid} valid",
1188                        );
1189                    }
1190                }
1191            }
1192
1193            let interval = PG_SCHEMA_VALIDATION_INTERVAL.get(&config_set);
1194            let elapsed = validation_start.elapsed();
1195            let wait = interval.saturating_sub(elapsed);
1196            tokio::time::sleep(wait).await;
1197        }
1198    });
1199
1200    rx
1201}