mz_postgres_util/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::str::FromStr;
11use tokio_postgres::{Client, types::PgLsn};
12
13use mz_ssh_util::tunnel_manager::SshTunnelManager;
14
15use crate::{Config, PostgresError, simple_query_opt};
16
17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
18pub enum WalLevel {
19    Minimal,
20    Replica,
21    Logical,
22}
23
24impl std::str::FromStr for WalLevel {
25    type Err = anyhow::Error;
26    fn from_str(s: &str) -> Result<Self, Self::Err> {
27        match s {
28            "minimal" => Ok(Self::Minimal),
29            "replica" => Ok(Self::Replica),
30            "logical" => Ok(Self::Logical),
31            o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
32        }
33    }
34}
35
36impl std::fmt::Display for WalLevel {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        let s = match self {
39            WalLevel::Minimal => "minimal",
40            WalLevel::Replica => "replica",
41            WalLevel::Logical => "logical",
42        };
43
44        f.write_str(s)
45    }
46}
47
48#[mz_ore::test]
49fn test_wal_level_max() {
50    // Ensure `WalLevel::Logical` is the max among all levels.
51    for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
52        assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
53    }
54}
55
56pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
57    let wal_level = client.query_one("SHOW wal_level", &[]).await?;
58    let wal_level: String = wal_level.get("wal_level");
59    Ok(WalLevel::from_str(&wal_level)?)
60}
61
62pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
63    let max_wal_senders = client
64        .query_one(
65            "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders",
66            &[],
67        )
68        .await?;
69    Ok(max_wal_senders.get("max_wal_senders"))
70}
71
72pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
73    let available_replication_slots = client
74        .query_one(
75            "SELECT
76            CAST(current_setting('max_replication_slots') AS int8)
77              - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
78              AS available_replication_slots;",
79            &[],
80        )
81        .await?;
82
83    let available_replication_slots: i64 =
84        available_replication_slots.get("available_replication_slots");
85
86    Ok(available_replication_slots)
87}
88
89pub async fn drop_replication_slots(
90    ssh_tunnel_manager: &SshTunnelManager,
91    config: Config,
92    slots: &[(&str, bool)],
93) -> Result<(), PostgresError> {
94    let client = config
95        .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
96        .await?;
97    let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
98    for (slot, should_wait) in slots {
99        let rows = client
100            .query(
101                "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT",
102                &[&slot],
103            )
104            .await?;
105        match &*rows {
106            [] => {
107                // DROP_REPLICATION_SLOT will error if the slot does not exist
108                tracing::info!(
109                    "drop_replication_slots called on non-existent slot {}",
110                    slot
111                );
112                continue;
113            }
114            [row] => {
115                // The drop of a replication slot happens concurrently with an ingestion dataflow
116                // shutting down, therefore there is the possibility that the slot is still in use.
117                // We really don't want to leak the slot and not forcefully terminating the
118                // dataflow's connection risks timing out. For this reason we always kill the
119                // active backend and drop the slot.
120                let active_pid: Option<i32> = row.get("active_pid");
121                if let Some(active_pid) = active_pid {
122                    client
123                        .simple_query(&format!("SELECT pg_terminate_backend({active_pid})"))
124                        .await?;
125                }
126
127                let wait_str = if *should_wait { " WAIT" } else { "" };
128                replication_client
129                    .simple_query(&format!("DROP_REPLICATION_SLOT {slot}{wait_str}"))
130                    .await?;
131            }
132            _ => {
133                return Err(PostgresError::Generic(anyhow::anyhow!(
134                    "multiple pg_replication_slots entries for slot {}",
135                    &slot
136                )));
137            }
138        }
139    }
140    Ok(())
141}
142
143pub async fn get_timeline_id(replication_client: &Client) -> Result<u64, PostgresError> {
144    if let Some(r) = simple_query_opt(replication_client, "IDENTIFY_SYSTEM").await? {
145        r.get("timeline")
146            .expect("Returns a row with a timeline ID")
147            .parse::<u64>()
148            .map_err(|err| {
149                PostgresError::Generic(anyhow::anyhow!(
150                    "Failed to parse timeline ID from IDENTIFY_SYSTEM: {}",
151                    err
152                ))
153            })
154    } else {
155        Err(PostgresError::Generic(anyhow::anyhow!(
156            "IDENTIFY_SYSTEM did not return a result row"
157        )))
158    }
159}
160
161pub async fn get_current_wal_lsn(client: &Client) -> Result<PgLsn, PostgresError> {
162    let row = client.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
163    let lsn: PgLsn = row.get(0);
164
165    Ok(lsn)
166}