Skip to main content

mz_postgres_util/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::str::FromStr;
11
12use tokio_postgres::{
13    Client,
14    types::{Oid, PgLsn},
15};
16
17use mz_ssh_util::tunnel_manager::SshTunnelManager;
18
19use crate::{Config, PostgresError, Sql, query, query_one, simple_query, simple_query_opt};
20
21#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
22pub enum WalLevel {
23    Minimal,
24    Replica,
25    Logical,
26}
27
28impl std::str::FromStr for WalLevel {
29    type Err = anyhow::Error;
30    fn from_str(s: &str) -> Result<Self, Self::Err> {
31        match s {
32            "minimal" => Ok(Self::Minimal),
33            "replica" => Ok(Self::Replica),
34            "logical" => Ok(Self::Logical),
35            o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
36        }
37    }
38}
39
40impl std::fmt::Display for WalLevel {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        let s = match self {
43            WalLevel::Minimal => "minimal",
44            WalLevel::Replica => "replica",
45            WalLevel::Logical => "logical",
46        };
47
48        f.write_str(s)
49    }
50}
51
52#[mz_ore::test]
53fn test_wal_level_max() {
54    // Ensure `WalLevel::Logical` is the max among all levels.
55    for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
56        assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
57    }
58}
59
60pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
61    let wal_level = query_one(client, crate::sql!("SHOW wal_level"), &[]).await?;
62    let wal_level: String = wal_level.get("wal_level");
63    Ok(WalLevel::from_str(&wal_level)?)
64}
65
66pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
67    let max_wal_senders = query_one(
68        client,
69        crate::sql!("SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders"),
70        &[],
71    )
72    .await?;
73    Ok(max_wal_senders.get("max_wal_senders"))
74}
75
76pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
77    let available_replication_slots = query_one(
78        client,
79        crate::sql!(
80            "SELECT
81            CAST(current_setting('max_replication_slots') AS int8)
82              - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
83              AS available_replication_slots;"
84        ),
85        &[],
86    )
87    .await?;
88
89    let available_replication_slots: i64 =
90        available_replication_slots.get("available_replication_slots");
91
92    Ok(available_replication_slots)
93}
94
95/// Returns true if BYPASSRLS is set for the current user, false otherwise.
96///
97/// See <https://www.postgresql.org/docs/current/ddl-rowsecurity.html>
98pub async fn bypass_rls_attribute(client: &Client) -> Result<bool, PostgresError> {
99    let rls_attribute = query_one(
100        client,
101        crate::sql!("SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;"),
102        &[],
103    )
104    .await?;
105    Ok(rls_attribute.get("rolbypassrls"))
106}
107
108/// Returns an error if the tables identified by the oid's have RLS policies which
109/// affect the current user. Two checks are made:
110///
111/// 1. Identify which tables, from the provided oid's, have RLS SELECT or ALL policies that affect
112///    the user or public.
113/// 2. If there are policies that affect the user, check if the BYPASSRLS attribute is set. If set,
114///    the role is unaffected by the policies.
115pub async fn validate_no_rls_policies(
116    client: &Client,
117    table_oids: &[Oid],
118) -> Result<(), PostgresError> {
119    if table_oids.is_empty() {
120        return Ok(());
121    }
122    let tables_with_rls_for_user = query(
123        client,
124        crate::sql!(
125            "SELECT
126                    format('%I.%I', n.nspname, pc.relname) AS qualified_name
127                FROM pg_policy pp
128                JOIN pg_class pc ON pc.oid = pp.polrelid
129                JOIN pg_namespace n ON n.oid = pc.relnamespace
130                WHERE
131                    pp.polrelid = ANY($1::oid[])
132                    AND pp.polcmd IN ('r', '*')
133                    AND (
134                        0 = ANY(pp.polroles)
135                        OR EXISTS (
136                            SELECT 1
137                            FROM unnest(pp.polroles) AS r(role_oid)
138                            WHERE pg_has_role(CURRENT_USER, r.role_oid, 'USAGE')
139                        )
140                    );"
141        ),
142        &[&table_oids],
143    )
144    .await?;
145
146    let mut tables_with_rls_for_user = tables_with_rls_for_user
147        .into_iter()
148        .map(|row| row.get("qualified_name"))
149        .collect::<Vec<String>>();
150
151    // If the user has the BYPASSRLS flag set, then the policies don't apply, so we can
152    // return success.
153    if tables_with_rls_for_user.is_empty() || bypass_rls_attribute(client).await? {
154        Ok(())
155    } else {
156        tables_with_rls_for_user.sort();
157        Err(PostgresError::BypassRLSRequired(tables_with_rls_for_user))
158    }
159}
160
161pub async fn drop_replication_slots(
162    ssh_tunnel_manager: &SshTunnelManager,
163    config: Config,
164    slots: &[(&str, bool)],
165) -> Result<(), PostgresError> {
166    let client = config
167        .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
168        .await?;
169    let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
170    for (slot, should_wait) in slots {
171        let rows = query(
172            &*client,
173            crate::sql!("SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT"),
174            &[&slot],
175        )
176        .await?;
177        match &*rows {
178            [] => {
179                // DROP_REPLICATION_SLOT will error if the slot does not exist
180                tracing::info!(
181                    "drop_replication_slots called on non-existent slot {}",
182                    slot
183                );
184                continue;
185            }
186            [row] => {
187                // The drop of a replication slot happens concurrently with an ingestion dataflow
188                // shutting down, therefore there is the possibility that the slot is still in use.
189                // We really don't want to leak the slot and not forcefully terminating the
190                // dataflow's connection risks timing out. For this reason we always kill the
191                // active backend and drop the slot.
192                let active_pid: Option<i32> = row.get("active_pid");
193                if let Some(active_pid) = active_pid {
194                    let query = crate::sql!("SELECT pg_terminate_backend({})", active_pid);
195                    simple_query(&*client, query).await?;
196                }
197
198                let wait = if *should_wait {
199                    crate::sql!(" WAIT")
200                } else {
201                    crate::sql!("")
202                };
203                let query = crate::sql!("DROP_REPLICATION_SLOT {}{}", Sql::ident(slot), wait);
204                simple_query(&*replication_client, query).await?;
205            }
206            _ => {
207                return Err(PostgresError::Generic(anyhow::anyhow!(
208                    "multiple pg_replication_slots entries for slot {}",
209                    &slot
210                )));
211            }
212        }
213    }
214    Ok(())
215}
216
217pub async fn get_timeline_id(client: &Client) -> Result<u64, PostgresError> {
218    if let Some(r) = simple_query_opt(
219        client,
220        crate::sql!("SELECT timeline_id FROM pg_control_checkpoint()"),
221    )
222    .await?
223    {
224        r.get("timeline_id")
225            .expect("Returns a row with a timeline ID")
226            .parse::<u64>()
227            .map_err(|err| {
228                PostgresError::Generic(anyhow::anyhow!(
229                    "Failed to parse timeline ID from pg_control_checkpoint(): {}",
230                    err
231                ))
232            })
233    } else {
234        Err(PostgresError::Generic(anyhow::anyhow!(
235            "pg_control_checkpoint() did not return a result row"
236        )))
237    }
238}
239
240pub async fn fetch_max_lsn(
241    client: &Client,
242    is_physical_standby: bool,
243) -> Result<PgLsn, PostgresError> {
244    if is_physical_standby {
245        // A physical standby will not support pg_current_wal_lsn, so we use pg_last_wal_replay_lsn. This reports
246        // the latest LSN that the replica has successfully applied from the WAL.
247        if let Some(lsn) = fetch_lsn(client, crate::sql!("SELECT pg_last_wal_replay_lsn()")).await?
248        {
249            return Ok(lsn);
250        }
251        return Err(PostgresError::Generic(anyhow::anyhow!(
252            "pg_last_wal_replay_lsn unexpectedly None, expected client in recovery, is_in_recovery: {}",
253            get_is_in_recovery(client).await?
254        )));
255    }
256    // Based on the documentation, it appears that `pg_current_wal_lsn` has
257    // the same "upper" semantics of `confirmed_flush_lsn`:
258    // <https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADMIN-BACKUP>
259    // We may need to revisit this and use `pg_current_wal_flush_lsn`.
260    if let Some(lsn) = fetch_lsn(client, crate::sql!("SELECT pg_current_wal_lsn()")).await? {
261        return Ok(lsn);
262    }
263
264    Err(PostgresError::Generic(anyhow::anyhow!(
265        "pg_current_wal_lsn mysteriously has no value"
266    )))
267}
268
269async fn fetch_lsn(client: &Client, query: Sql) -> Result<Option<PgLsn>, PostgresError> {
270    let row = query_one(client, query, &[]).await?;
271    let lsn: Option<PgLsn> = row.get(0);
272    Ok(lsn)
273}
274
275/// Returns whether the server is in recovery, i.e. is a physical replica.
276pub async fn get_is_in_recovery(client: &Client) -> Result<bool, PostgresError> {
277    let row = query_one(client, crate::sql!("SELECT pg_is_in_recovery()"), &[]).await?;
278    let is_in_recovery: Option<bool> = row.get(0);
279    if let Some(is_in_recovery) = is_in_recovery {
280        return Ok(is_in_recovery);
281    }
282
283    Err(PostgresError::Generic(anyhow::anyhow!(
284        "pg_is_in_recovery() mysteriously has no value"
285    )))
286}