1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! Code to render the ingestion dataflow of a [`PostgresSourceConnection`]. The dataflow consists
//! of multiple operators in order to take advantage of all the available workers.
//!
//! # Snapshot
//!
//! One part of the dataflow deals with snapshotting the tables involved in the ingestion. Each
//! table that needs a snapshot is assigned to a specific worker which performs a `COPY` query
//! and distributes the raw COPY bytes to all workers to decode the text encoded rows.
//!
//! For all tables that ended up being snapshotted the snapshot reader also emits a rewind request
//! to the replication reader which will ensure that the requested portion of the replication
//! stream is subtracted from the snapshot.
//!
//! See the [snapshot] module for more information on the snapshot strategy.
//!
//! # Replication
//!
//! The other part of the dataflow deals with reading the logical replication slot, which must
//! happen from a single worker. The minimum amount of processing is performed from that worker
//! and the data is then distributed among all workers for decoding.
//!
//! See the [replication] module for more information on the replication strategy.
//!
//! # Error handling
//!
//! There are two kinds of errors that can happen during ingestion that are represented as two
//! separate error types:
//!
//! [`DefiniteError`]s are errors that happen during processing of a specific
//! collection record at a specific LSN. These are the only errors that can ever end up in the
//! error collection of a subsource.
//!
//! Transient errors are any errors that can happen for reasons that are unrelated to the data
//! itself. This could be authentication failures, connection failures, etc. The only operators
//! that can emit such errors are the `TableReader` and the `ReplicationReader` operators, which
//! are the ones that talk to the external world. Both of these operators are built with the
//! `AsyncOperatorBuilder::build_fallible` method which allows transient errors to be propagated
//! upwards with the standard `?` operator without risking downgrading the capability and producing
//! bogus frontiers.
//!
//! The error streams from both of those operators are published to the source status and also
//! trigger a restart of the dataflow.
//!
//! ```text
//!    ┏━━━━━━━━━━━━━━┓
//!    ┃    table     ┃
//!    ┃    reader    ┃
//!    ┗━┯━━━━━━━━━━┯━┛
//!      │          │rewind
//!      │          │requests
//!      │          ╰────╮
//!      │             ┏━v━━━━━━━━━━━┓
//!      │             ┃ replication ┃
//!      │             ┃   reader    ┃
//!      │             ┗━┯━━━━━━━━━┯━┛
//!  COPY│           slot│         │
//!  data│           data│         │
//! ┏━━━━v━━━━━┓ ┏━━━━━━━v━━━━━┓   │
//! ┃  COPY    ┃ ┃ replication ┃   │
//! ┃ decoder  ┃ ┃   decoder   ┃   │
//! ┗━━━━┯━━━━━┛ ┗━━━━━┯━━━━━━━┛   │
//!      │snapshot     │replication│
//!      │updates      │updates    │
//!      ╰────╮    ╭───╯           │
//!          ╭┴────┴╮              │
//!          │concat│              │
//!          ╰──┬───╯              │
//!             │ data             │progress
//!             │ output           │output
//!             v                  v
//! ```

use std::collections::BTreeMap;
use std::convert::Infallible;
use std::rc::Rc;
use std::time::Duration;

use itertools::Itertools as _;
use mz_expr::{EvalError, MirScalarExpr};
use mz_ore::error::ErrorExt;
use mz_postgres_util::desc::PostgresTableDesc;
use mz_postgres_util::{simple_query_opt, Client, PostgresError};
use mz_repr::{Datum, Row};
use mz_sql_parser::ast::display::AstDisplay;
use mz_sql_parser::ast::Ident;
use mz_storage_types::errors::{DataflowError, SourceError, SourceErrorDetails};
use mz_storage_types::sources::postgres::CastType;
use mz_storage_types::sources::{
    IndexedSourceExport, MzOffset, PostgresSourceConnection, SourceExport, SourceExportDetails,
    SourceTimestamp,
};
use mz_timely_util::builder_async::PressOnDropButton;
use serde::{Deserialize, Serialize};
use timely::dataflow::operators::{Concat, Map, ToStream};
use timely::dataflow::{Scope, Stream};
use timely::progress::Antichain;
use tokio_postgres::error::SqlState;
use tokio_postgres::types::PgLsn;

use crate::healthcheck::{HealthStatusMessage, HealthStatusUpdate, StatusNamespace};
use crate::source::types::{Probe, ProgressStatisticsUpdate, SourceRender, StackedCollection};
use crate::source::{RawSourceCreationConfig, SourceMessage};

mod replication;
mod snapshot;

impl SourceRender for PostgresSourceConnection {
    type Time = MzOffset;

    const STATUS_NAMESPACE: StatusNamespace = StatusNamespace::Postgres;

    /// Render the ingestion dataflow. This function only connects things together and contains no
    /// actual processing logic.
    fn render<G: Scope<Timestamp = MzOffset>>(
        self,
        scope: &mut G,
        config: RawSourceCreationConfig,
        resume_uppers: impl futures::Stream<Item = Antichain<MzOffset>> + 'static,
        _start_signal: impl std::future::Future<Output = ()> + 'static,
    ) -> (
        StackedCollection<G, (usize, Result<SourceMessage, DataflowError>)>,
        Option<Stream<G, Infallible>>,
        Stream<G, HealthStatusMessage>,
        Stream<G, ProgressStatisticsUpdate>,
        Option<Stream<G, Probe<MzOffset>>>,
        Vec<PressOnDropButton>,
    ) {
        // Collect the source outputs that we will be exporting into a per-table map.
        let mut table_info = BTreeMap::new();
        for (
            id,
            IndexedSourceExport {
                ingestion_output,
                export:
                    SourceExport {
                        details,
                        storage_metadata: _,
                        data_config: _,
                    },
            },
        ) in &config.source_exports
        {
            let details = match details {
                SourceExportDetails::Postgres(details) => details,
                // This is an export that doesn't need any data output to it.
                SourceExportDetails::None => continue,
                _ => panic!("unexpected source export details: {:?}", details),
            };
            let desc = details.table.clone();
            let casts = details.column_casts.clone();
            let resume_upper = Antichain::from_iter(
                config
                    .source_resume_uppers
                    .get(id)
                    .expect("all source exports must be present in source resume uppers")
                    .iter()
                    .map(MzOffset::decode_row),
            );
            let output = SourceOutputInfo {
                desc,
                casts,
                resume_upper,
            };
            table_info
                .entry(output.desc.oid)
                .or_insert_with(BTreeMap::new)
                .insert(*ingestion_output, output);
        }

        let metrics = config.metrics.get_postgres_source_metrics(config.id);

        let (snapshot_updates, rewinds, slot_ready, snapshot_stats, snapshot_err, snapshot_token) =
            snapshot::render(
                scope.clone(),
                config.clone(),
                self.clone(),
                table_info.clone(),
                metrics.snapshot_metrics.clone(),
            );

        let (repl_updates, uppers, stats_stream, probe_stream, repl_err, repl_token) =
            replication::render(
                scope.clone(),
                config,
                self,
                table_info,
                &rewinds,
                &slot_ready,
                resume_uppers,
                metrics,
            );

        let stats_stream = stats_stream.concat(&snapshot_stats);

        let updates = snapshot_updates.concat(&repl_updates);

        let init = std::iter::once(HealthStatusMessage {
            index: 0,
            namespace: Self::STATUS_NAMESPACE,
            update: HealthStatusUpdate::Running,
        })
        .to_stream(scope);

        // N.B. Note that we don't check ssh tunnel statuses here. We could, but immediately on
        // restart we are going to set the status to an ssh error correctly, so we don't do this
        // extra work.
        let errs = snapshot_err.concat(&repl_err).map(move |err| {
            // This update will cause the dataflow to restart
            let err_string = err.display_with_causes().to_string();
            let update = HealthStatusUpdate::halting(err_string.clone(), None);

            let namespace = match err {
                ReplicationError::Transient(err)
                    if matches!(
                        &*err,
                        TransientError::PostgresError(PostgresError::Ssh(_))
                            | TransientError::PostgresError(PostgresError::SshIo(_))
                    ) =>
                {
                    StatusNamespace::Ssh
                }
                _ => Self::STATUS_NAMESPACE,
            };

            HealthStatusMessage {
                index: 0,
                namespace: namespace.clone(),
                update,
            }
        });

        let health = init.concat(&errs);

        (
            updates,
            Some(uppers),
            health,
            stats_stream,
            probe_stream,
            vec![snapshot_token, repl_token],
        )
    }
}

#[derive(Clone, Debug)]
struct SourceOutputInfo {
    desc: PostgresTableDesc,
    casts: Vec<(CastType, MirScalarExpr)>,
    resume_upper: Antichain<MzOffset>,
}

#[derive(Clone, Debug, thiserror::Error)]
pub enum ReplicationError {
    #[error(transparent)]
    Transient(#[from] Rc<TransientError>),
    #[error(transparent)]
    Definite(#[from] Rc<DefiniteError>),
}

/// A transient error that never ends up in the collection of a specific table.
#[derive(Debug, thiserror::Error)]
pub enum TransientError {
    #[error("replication slot mysteriously missing")]
    MissingReplicationSlot,
    #[error("slot overcompacted. Requested LSN {requested_lsn} but only LSNs >= {available_lsn} are available")]
    OvercompactedReplicationSlot {
        requested_lsn: MzOffset,
        available_lsn: MzOffset,
    },
    #[error("replication slot already exists")]
    ReplicationSlotAlreadyExists,
    #[error("stream ended prematurely")]
    ReplicationEOF,
    #[error("unexpected replication message")]
    UnknownReplicationMessage,
    #[error("unexpected logical replication message")]
    UnknownLogicalReplicationMessage,
    #[error("received replication event outside of transaction")]
    BareTransactionEvent,
    #[error("lsn mismatch between BEGIN and COMMIT")]
    InvalidTransaction,
    #[error("BEGIN within existing BEGIN stream")]
    NestedTransaction,
    #[error("recoverable errors should crash the process during snapshots")]
    SyntheticError,
    #[error("sql client error")]
    SQLClient(#[from] tokio_postgres::Error),
    #[error(transparent)]
    PostgresError(#[from] PostgresError),
    #[error(transparent)]
    Generic(#[from] anyhow::Error),
}

/// A definite error that always ends up in the collection of a specific table.
#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
pub enum DefiniteError {
    #[error("slot compacted past snapshot point. snapshot consistent point={0} resume_lsn={1}")]
    SlotCompactedPastResumePoint(MzOffset, MzOffset),
    #[error("table was truncated")]
    TableTruncated,
    #[error("table was dropped")]
    TableDropped,
    #[error("publication {0:?} does not exist")]
    PublicationDropped(String),
    #[error("replication slot has been invalidated because it exceeded the maximum reserved size")]
    InvalidReplicationSlot,
    #[error("unexpected number of columns while parsing COPY output")]
    MissingColumn,
    #[error("failed to parse COPY protocol")]
    InvalidCopyInput,
    #[error("invalid timeline ID from PostgreSQL server. Expected {expected} but got {actual}")]
    InvalidTimelineId { expected: u64, actual: u64 },
    #[error("TOASTed value missing from old row. Did you forget to set REPLICA IDENTITY to FULL for your table?")]
    MissingToast,
    #[error("old row missing from replication stream. Did you forget to set REPLICA IDENTITY to FULL for your table?")]
    DefaultReplicaIdentity,
    #[error("incompatible schema change: {0}")]
    // TODO: proper error variants for all the expected schema violations
    IncompatibleSchema(String),
    #[error("invalid UTF8 string: {0:?}")]
    InvalidUTF8(Vec<u8>),
    #[error("failed to cast raw column: {0}")]
    CastError(#[source] EvalError),
}

impl From<DefiniteError> for DataflowError {
    fn from(err: DefiniteError) -> Self {
        let m = err.to_string().into();
        DataflowError::SourceError(Box::new(SourceError {
            error: match &err {
                DefiniteError::SlotCompactedPastResumePoint(_, _) => SourceErrorDetails::Other(m),
                DefiniteError::TableTruncated => SourceErrorDetails::Other(m),
                DefiniteError::TableDropped => SourceErrorDetails::Other(m),
                DefiniteError::PublicationDropped(_) => SourceErrorDetails::Initialization(m),
                DefiniteError::InvalidReplicationSlot => SourceErrorDetails::Initialization(m),
                DefiniteError::MissingColumn => SourceErrorDetails::Other(m),
                DefiniteError::InvalidCopyInput => SourceErrorDetails::Other(m),
                DefiniteError::InvalidTimelineId { .. } => SourceErrorDetails::Initialization(m),
                DefiniteError::MissingToast => SourceErrorDetails::Other(m),
                DefiniteError::DefaultReplicaIdentity => SourceErrorDetails::Other(m),
                DefiniteError::IncompatibleSchema(_) => SourceErrorDetails::Other(m),
                DefiniteError::InvalidUTF8(_) => SourceErrorDetails::Other(m),
                DefiniteError::CastError(_) => SourceErrorDetails::Other(m),
            },
        }))
    }
}

async fn ensure_replication_slot(client: &Client, slot: &str) -> Result<(), TransientError> {
    // Note: Using unchecked here is okay because we're using it in a SQL query.
    let slot = Ident::new_unchecked(slot).to_ast_string();
    let query = format!("CREATE_REPLICATION_SLOT {slot} LOGICAL \"pgoutput\" NOEXPORT_SNAPSHOT");
    match simple_query_opt(client, &query).await {
        Ok(_) => Ok(()),
        // If the slot already exists that's still ok
        Err(PostgresError::Postgres(err)) if err.code() == Some(&SqlState::DUPLICATE_OBJECT) => {
            tracing::trace!("replication slot {slot} already existed");
            Ok(())
        }
        Err(err) => Err(TransientError::PostgresError(err)),
    }
}

/// The state of a replication slot.
struct SlotMetadata {
    /// The process ID of the session using this slot if the slot is currently actively being used.
    /// None if inactive.
    active_pid: Option<i32>,
    /// The address (LSN) up to which the logical slot's consumer has confirmed receiving data.
    /// Data corresponding to the transactions committed before this LSN is not available anymore.
    confirmed_flush_lsn: MzOffset,
}

/// Fetches the minimum LSN at which this slot can safely resume.
async fn fetch_slot_metadata(
    client: &Client,
    slot: &str,
    interval: Duration,
) -> Result<SlotMetadata, TransientError> {
    loop {
        let query = "SELECT active_pid, confirmed_flush_lsn
                FROM pg_replication_slots WHERE slot_name = $1";
        let Some(row) = client.query_opt(query, &[&slot]).await? else {
            return Err(TransientError::MissingReplicationSlot);
        };

        match row.get::<_, Option<PgLsn>>("confirmed_flush_lsn") {
            // For postgres, `confirmed_flush_lsn` means that the slot is able to produce
            // all transactions that happen at tx_lsn >= confirmed_flush_lsn. Therefore this value
            // already has "upper" semantics.
            Some(lsn) => {
                return Ok(SlotMetadata {
                    confirmed_flush_lsn: MzOffset::from(lsn),
                    active_pid: row.get("active_pid"),
                })
            }
            // It can happen that confirmed_flush_lsn is NULL as the slot initializes
            // This could probably be a `tokio::time::interval`, but its only is called twice,
            // so its fine like this.
            None => tokio::time::sleep(interval).await,
        };
    }
}

/// Fetch the `pg_current_wal_lsn`, used to report metrics.
async fn fetch_max_lsn(client: &Client) -> Result<MzOffset, TransientError> {
    let query = "SELECT pg_current_wal_lsn()";
    let row = simple_query_opt(client, query).await?;

    match row.and_then(|row| {
        row.get("pg_current_wal_lsn")
            .map(|lsn| lsn.parse::<PgLsn>().unwrap())
    }) {
        // Based on the documentation, it appears that `pg_current_wal_lsn` has
        // the same "upper" semantics of `confirmed_flush_lsn`:
        // <https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADMIN-BACKUP>
        // We may need to revisit this and use `pg_current_wal_flush_lsn`.
        Some(lsn) => Ok(MzOffset::from(lsn)),
        None => Err(TransientError::Generic(anyhow::anyhow!(
            "pg_current_wal_lsn() mysteriously has no value"
        ))),
    }
}

// Ensures that the table with oid `oid` and expected schema `expected_schema` is still compatible
// with the current upstream schema `upstream_info`.
fn verify_schema(
    oid: u32,
    expected_desc: &PostgresTableDesc,
    upstream_info: &BTreeMap<u32, PostgresTableDesc>,
    casts: &[(CastType, MirScalarExpr)],
) -> Result<(), DefiniteError> {
    let current_desc = upstream_info.get(&oid).ok_or(DefiniteError::TableDropped)?;

    let allow_oids_to_change_by_col_num = expected_desc
        .columns
        .iter()
        .zip_eq(casts.iter())
        .flat_map(|(col, (cast_type, _))| match cast_type {
            CastType::Text => Some(col.col_num),
            CastType::Natural => None,
        })
        .collect();

    match expected_desc.determine_compatibility(current_desc, &allow_oids_to_change_by_col_num) {
        Ok(()) => Ok(()),
        Err(err) => Err(DefiniteError::IncompatibleSchema(err.to_string())),
    }
}

/// Casts a text row into the target types
fn cast_row(
    casts: &[(CastType, MirScalarExpr)],
    datums: &[Datum<'_>],
    row: &mut Row,
) -> Result<(), DefiniteError> {
    let arena = mz_repr::RowArena::new();
    let mut packer = row.packer();
    for (_, column_cast) in casts {
        let datum = column_cast
            .eval(datums, &arena)
            .map_err(DefiniteError::CastError)?;
        packer.push(datum);
    }
    Ok(())
}

/// Converts raw bytes that are expected to be UTF8 encoded into a `Datum::String`
fn decode_utf8_text(bytes: &[u8]) -> Result<Datum<'_>, DefiniteError> {
    match std::str::from_utf8(bytes) {
        Ok(text) => Ok(Datum::String(text)),
        Err(_) => Err(DefiniteError::InvalidUTF8(bytes.to_vec())),
    }
}