use std::collections::BTreeMap;
use std::convert::Infallible;
use std::rc::Rc;
use std::time::Duration;
use itertools::Itertools as _;
use mz_expr::{EvalError, MirScalarExpr};
use mz_ore::error::ErrorExt;
use mz_postgres_util::desc::PostgresTableDesc;
use mz_postgres_util::{simple_query_opt, Client, PostgresError};
use mz_repr::{Datum, Row};
use mz_sql_parser::ast::display::AstDisplay;
use mz_sql_parser::ast::Ident;
use mz_storage_types::errors::{DataflowError, SourceError, SourceErrorDetails};
use mz_storage_types::sources::postgres::CastType;
use mz_storage_types::sources::{
IndexedSourceExport, MzOffset, PostgresSourceConnection, SourceExport, SourceExportDetails,
SourceTimestamp,
};
use mz_timely_util::builder_async::PressOnDropButton;
use serde::{Deserialize, Serialize};
use timely::dataflow::operators::{Concat, Map, ToStream};
use timely::dataflow::{Scope, Stream};
use timely::progress::Antichain;
use tokio_postgres::error::SqlState;
use tokio_postgres::types::PgLsn;
use crate::healthcheck::{HealthStatusMessage, HealthStatusUpdate, StatusNamespace};
use crate::source::types::{Probe, ProgressStatisticsUpdate, SourceRender, StackedCollection};
use crate::source::{RawSourceCreationConfig, SourceMessage};
mod replication;
mod snapshot;
impl SourceRender for PostgresSourceConnection {
type Time = MzOffset;
const STATUS_NAMESPACE: StatusNamespace = StatusNamespace::Postgres;
fn render<G: Scope<Timestamp = MzOffset>>(
self,
scope: &mut G,
config: RawSourceCreationConfig,
resume_uppers: impl futures::Stream<Item = Antichain<MzOffset>> + 'static,
_start_signal: impl std::future::Future<Output = ()> + 'static,
) -> (
StackedCollection<G, (usize, Result<SourceMessage, DataflowError>)>,
Option<Stream<G, Infallible>>,
Stream<G, HealthStatusMessage>,
Stream<G, ProgressStatisticsUpdate>,
Option<Stream<G, Probe<MzOffset>>>,
Vec<PressOnDropButton>,
) {
let mut table_info = BTreeMap::new();
for (
id,
IndexedSourceExport {
ingestion_output,
export:
SourceExport {
details,
storage_metadata: _,
data_config: _,
},
},
) in &config.source_exports
{
let details = match details {
SourceExportDetails::Postgres(details) => details,
SourceExportDetails::None => continue,
_ => panic!("unexpected source export details: {:?}", details),
};
let desc = details.table.clone();
let casts = details.column_casts.clone();
let resume_upper = Antichain::from_iter(
config
.source_resume_uppers
.get(id)
.expect("all source exports must be present in source resume uppers")
.iter()
.map(MzOffset::decode_row),
);
let output = SourceOutputInfo {
desc,
casts,
resume_upper,
};
table_info
.entry(output.desc.oid)
.or_insert_with(BTreeMap::new)
.insert(*ingestion_output, output);
}
let metrics = config.metrics.get_postgres_source_metrics(config.id);
let (snapshot_updates, rewinds, slot_ready, snapshot_stats, snapshot_err, snapshot_token) =
snapshot::render(
scope.clone(),
config.clone(),
self.clone(),
table_info.clone(),
metrics.snapshot_metrics.clone(),
);
let (repl_updates, uppers, stats_stream, probe_stream, repl_err, repl_token) =
replication::render(
scope.clone(),
config,
self,
table_info,
&rewinds,
&slot_ready,
resume_uppers,
metrics,
);
let stats_stream = stats_stream.concat(&snapshot_stats);
let updates = snapshot_updates.concat(&repl_updates);
let init = std::iter::once(HealthStatusMessage {
index: 0,
namespace: Self::STATUS_NAMESPACE,
update: HealthStatusUpdate::Running,
})
.to_stream(scope);
let errs = snapshot_err.concat(&repl_err).map(move |err| {
let err_string = err.display_with_causes().to_string();
let update = HealthStatusUpdate::halting(err_string.clone(), None);
let namespace = match err {
ReplicationError::Transient(err)
if matches!(
&*err,
TransientError::PostgresError(PostgresError::Ssh(_))
| TransientError::PostgresError(PostgresError::SshIo(_))
) =>
{
StatusNamespace::Ssh
}
_ => Self::STATUS_NAMESPACE,
};
HealthStatusMessage {
index: 0,
namespace: namespace.clone(),
update,
}
});
let health = init.concat(&errs);
(
updates,
Some(uppers),
health,
stats_stream,
probe_stream,
vec![snapshot_token, repl_token],
)
}
}
#[derive(Clone, Debug)]
struct SourceOutputInfo {
desc: PostgresTableDesc,
casts: Vec<(CastType, MirScalarExpr)>,
resume_upper: Antichain<MzOffset>,
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum ReplicationError {
#[error(transparent)]
Transient(#[from] Rc<TransientError>),
#[error(transparent)]
Definite(#[from] Rc<DefiniteError>),
}
#[derive(Debug, thiserror::Error)]
pub enum TransientError {
#[error("replication slot mysteriously missing")]
MissingReplicationSlot,
#[error("slot overcompacted. Requested LSN {requested_lsn} but only LSNs >= {available_lsn} are available")]
OvercompactedReplicationSlot {
requested_lsn: MzOffset,
available_lsn: MzOffset,
},
#[error("replication slot already exists")]
ReplicationSlotAlreadyExists,
#[error("stream ended prematurely")]
ReplicationEOF,
#[error("unexpected replication message")]
UnknownReplicationMessage,
#[error("unexpected logical replication message")]
UnknownLogicalReplicationMessage,
#[error("received replication event outside of transaction")]
BareTransactionEvent,
#[error("lsn mismatch between BEGIN and COMMIT")]
InvalidTransaction,
#[error("BEGIN within existing BEGIN stream")]
NestedTransaction,
#[error("recoverable errors should crash the process during snapshots")]
SyntheticError,
#[error("sql client error")]
SQLClient(#[from] tokio_postgres::Error),
#[error(transparent)]
PostgresError(#[from] PostgresError),
#[error(transparent)]
Generic(#[from] anyhow::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
pub enum DefiniteError {
#[error("slot compacted past snapshot point. snapshot consistent point={0} resume_lsn={1}")]
SlotCompactedPastResumePoint(MzOffset, MzOffset),
#[error("table was truncated")]
TableTruncated,
#[error("table was dropped")]
TableDropped,
#[error("publication {0:?} does not exist")]
PublicationDropped(String),
#[error("replication slot has been invalidated because it exceeded the maximum reserved size")]
InvalidReplicationSlot,
#[error("unexpected number of columns while parsing COPY output")]
MissingColumn,
#[error("failed to parse COPY protocol")]
InvalidCopyInput,
#[error("invalid timeline ID from PostgreSQL server. Expected {expected} but got {actual}")]
InvalidTimelineId { expected: u64, actual: u64 },
#[error("TOASTed value missing from old row. Did you forget to set REPLICA IDENTITY to FULL for your table?")]
MissingToast,
#[error("old row missing from replication stream. Did you forget to set REPLICA IDENTITY to FULL for your table?")]
DefaultReplicaIdentity,
#[error("incompatible schema change: {0}")]
IncompatibleSchema(String),
#[error("invalid UTF8 string: {0:?}")]
InvalidUTF8(Vec<u8>),
#[error("failed to cast raw column: {0}")]
CastError(#[source] EvalError),
}
impl From<DefiniteError> for DataflowError {
fn from(err: DefiniteError) -> Self {
let m = err.to_string().into();
DataflowError::SourceError(Box::new(SourceError {
error: match &err {
DefiniteError::SlotCompactedPastResumePoint(_, _) => SourceErrorDetails::Other(m),
DefiniteError::TableTruncated => SourceErrorDetails::Other(m),
DefiniteError::TableDropped => SourceErrorDetails::Other(m),
DefiniteError::PublicationDropped(_) => SourceErrorDetails::Initialization(m),
DefiniteError::InvalidReplicationSlot => SourceErrorDetails::Initialization(m),
DefiniteError::MissingColumn => SourceErrorDetails::Other(m),
DefiniteError::InvalidCopyInput => SourceErrorDetails::Other(m),
DefiniteError::InvalidTimelineId { .. } => SourceErrorDetails::Initialization(m),
DefiniteError::MissingToast => SourceErrorDetails::Other(m),
DefiniteError::DefaultReplicaIdentity => SourceErrorDetails::Other(m),
DefiniteError::IncompatibleSchema(_) => SourceErrorDetails::Other(m),
DefiniteError::InvalidUTF8(_) => SourceErrorDetails::Other(m),
DefiniteError::CastError(_) => SourceErrorDetails::Other(m),
},
}))
}
}
async fn ensure_replication_slot(client: &Client, slot: &str) -> Result<(), TransientError> {
let slot = Ident::new_unchecked(slot).to_ast_string();
let query = format!("CREATE_REPLICATION_SLOT {slot} LOGICAL \"pgoutput\" NOEXPORT_SNAPSHOT");
match simple_query_opt(client, &query).await {
Ok(_) => Ok(()),
Err(PostgresError::Postgres(err)) if err.code() == Some(&SqlState::DUPLICATE_OBJECT) => {
tracing::trace!("replication slot {slot} already existed");
Ok(())
}
Err(err) => Err(TransientError::PostgresError(err)),
}
}
struct SlotMetadata {
active_pid: Option<i32>,
confirmed_flush_lsn: MzOffset,
}
async fn fetch_slot_metadata(
client: &Client,
slot: &str,
interval: Duration,
) -> Result<SlotMetadata, TransientError> {
loop {
let query = "SELECT active_pid, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = $1";
let Some(row) = client.query_opt(query, &[&slot]).await? else {
return Err(TransientError::MissingReplicationSlot);
};
match row.get::<_, Option<PgLsn>>("confirmed_flush_lsn") {
Some(lsn) => {
return Ok(SlotMetadata {
confirmed_flush_lsn: MzOffset::from(lsn),
active_pid: row.get("active_pid"),
})
}
None => tokio::time::sleep(interval).await,
};
}
}
async fn fetch_max_lsn(client: &Client) -> Result<MzOffset, TransientError> {
let query = "SELECT pg_current_wal_lsn()";
let row = simple_query_opt(client, query).await?;
match row.and_then(|row| {
row.get("pg_current_wal_lsn")
.map(|lsn| lsn.parse::<PgLsn>().unwrap())
}) {
Some(lsn) => Ok(MzOffset::from(lsn)),
None => Err(TransientError::Generic(anyhow::anyhow!(
"pg_current_wal_lsn() mysteriously has no value"
))),
}
}
fn verify_schema(
oid: u32,
expected_desc: &PostgresTableDesc,
upstream_info: &BTreeMap<u32, PostgresTableDesc>,
casts: &[(CastType, MirScalarExpr)],
) -> Result<(), DefiniteError> {
let current_desc = upstream_info.get(&oid).ok_or(DefiniteError::TableDropped)?;
let allow_oids_to_change_by_col_num = expected_desc
.columns
.iter()
.zip_eq(casts.iter())
.flat_map(|(col, (cast_type, _))| match cast_type {
CastType::Text => Some(col.col_num),
CastType::Natural => None,
})
.collect();
match expected_desc.determine_compatibility(current_desc, &allow_oids_to_change_by_col_num) {
Ok(()) => Ok(()),
Err(err) => Err(DefiniteError::IncompatibleSchema(err.to_string())),
}
}
fn cast_row(
casts: &[(CastType, MirScalarExpr)],
datums: &[Datum<'_>],
row: &mut Row,
) -> Result<(), DefiniteError> {
let arena = mz_repr::RowArena::new();
let mut packer = row.packer();
for (_, column_cast) in casts {
let datum = column_cast
.eval(datums, &arena)
.map_err(DefiniteError::CastError)?;
packer.push(datum);
}
Ok(())
}
fn decode_utf8_text(bytes: &[u8]) -> Result<Datum<'_>, DefiniteError> {
match std::str::from_utf8(bytes) {
Ok(text) => Ok(Datum::String(text)),
Err(_) => Err(DefiniteError::InvalidUTF8(bytes.to_vec())),
}
}