mz_postgres_util/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::str::FromStr;
11use tokio_postgres::{
12    Client,
13    types::{Oid, PgLsn},
14};
15
16use mz_ssh_util::tunnel_manager::SshTunnelManager;
17
18use crate::{Config, PostgresError, simple_query_opt};
19
20#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
21pub enum WalLevel {
22    Minimal,
23    Replica,
24    Logical,
25}
26
27impl std::str::FromStr for WalLevel {
28    type Err = anyhow::Error;
29    fn from_str(s: &str) -> Result<Self, Self::Err> {
30        match s {
31            "minimal" => Ok(Self::Minimal),
32            "replica" => Ok(Self::Replica),
33            "logical" => Ok(Self::Logical),
34            o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
35        }
36    }
37}
38
39impl std::fmt::Display for WalLevel {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        let s = match self {
42            WalLevel::Minimal => "minimal",
43            WalLevel::Replica => "replica",
44            WalLevel::Logical => "logical",
45        };
46
47        f.write_str(s)
48    }
49}
50
51#[mz_ore::test]
52fn test_wal_level_max() {
53    // Ensure `WalLevel::Logical` is the max among all levels.
54    for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
55        assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
56    }
57}
58
59pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
60    let wal_level = client.query_one("SHOW wal_level", &[]).await?;
61    let wal_level: String = wal_level.get("wal_level");
62    Ok(WalLevel::from_str(&wal_level)?)
63}
64
65pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
66    let max_wal_senders = client
67        .query_one(
68            "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders",
69            &[],
70        )
71        .await?;
72    Ok(max_wal_senders.get("max_wal_senders"))
73}
74
75pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
76    let available_replication_slots = client
77        .query_one(
78            "SELECT
79            CAST(current_setting('max_replication_slots') AS int8)
80              - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
81              AS available_replication_slots;",
82            &[],
83        )
84        .await?;
85
86    let available_replication_slots: i64 =
87        available_replication_slots.get("available_replication_slots");
88
89    Ok(available_replication_slots)
90}
91
92/// Returns true if BYPASSRLS is set for the current user, false otherwise.
93///
94/// See <https://www.postgresql.org/docs/current/ddl-rowsecurity.html>
95pub async fn bypass_rls_attribute(client: &Client) -> Result<bool, PostgresError> {
96    let rls_attribute = client
97        .query_one(
98            "SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;",
99            &[],
100        )
101        .await?;
102    Ok(rls_attribute.get("rolbypassrls"))
103}
104
105/// Returns an error if the tables identified by the oid's have RLS policies which
106/// affect the current user. Two checks are made:
107///
108/// 1. Identify which tables, from the provided oid's, have RLS policies that affecct the user or
109///    public.
110/// 2. If there are policies that affect the user, check if the BYPASSRLS attribute is set. If set,
111///    the role is unaffected by the policies.
112pub async fn validate_no_rls_policies(
113    client: &Client,
114    table_oids: &[Oid],
115) -> Result<(), PostgresError> {
116    if table_oids.is_empty() {
117        return Ok(());
118    }
119    let tables_with_rls_for_user = client
120        .query(
121            "SELECT
122                    format('%I.%I', pc.relnamespace::regnamespace, pc.relname) AS qualified_name
123                FROM pg_policy pp
124                JOIN pg_class pc ON pc.oid = polrelid
125                WHERE
126                    polrelid = ANY($1::oid[])
127                    AND
128                    (0 = ANY(polroles) OR CURRENT_USER::regrole::oid = ANY(polroles));",
129            &[&table_oids],
130        )
131        .await
132        .map_err(PostgresError::from)?;
133
134    let mut tables_with_rls_for_user = tables_with_rls_for_user
135        .into_iter()
136        .map(|row| row.get("qualified_name"))
137        .collect::<Vec<String>>();
138
139    // If the user has the BYPASSRLS flag set, then the policies don't apply, so we can
140    // return success.
141    if tables_with_rls_for_user.is_empty() || bypass_rls_attribute(client).await? {
142        Ok(())
143    } else {
144        tables_with_rls_for_user.sort();
145        Err(PostgresError::BypassRLSRequired(tables_with_rls_for_user))
146    }
147}
148
149pub async fn drop_replication_slots(
150    ssh_tunnel_manager: &SshTunnelManager,
151    config: Config,
152    slots: &[(&str, bool)],
153) -> Result<(), PostgresError> {
154    let client = config
155        .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
156        .await?;
157    let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
158    for (slot, should_wait) in slots {
159        let rows = client
160            .query(
161                "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT",
162                &[&slot],
163            )
164            .await?;
165        match &*rows {
166            [] => {
167                // DROP_REPLICATION_SLOT will error if the slot does not exist
168                tracing::info!(
169                    "drop_replication_slots called on non-existent slot {}",
170                    slot
171                );
172                continue;
173            }
174            [row] => {
175                // The drop of a replication slot happens concurrently with an ingestion dataflow
176                // shutting down, therefore there is the possibility that the slot is still in use.
177                // We really don't want to leak the slot and not forcefully terminating the
178                // dataflow's connection risks timing out. For this reason we always kill the
179                // active backend and drop the slot.
180                let active_pid: Option<i32> = row.get("active_pid");
181                if let Some(active_pid) = active_pid {
182                    client
183                        .simple_query(&format!("SELECT pg_terminate_backend({active_pid})"))
184                        .await?;
185                }
186
187                let wait_str = if *should_wait { " WAIT" } else { "" };
188                replication_client
189                    .simple_query(&format!("DROP_REPLICATION_SLOT {slot}{wait_str}"))
190                    .await?;
191            }
192            _ => {
193                return Err(PostgresError::Generic(anyhow::anyhow!(
194                    "multiple pg_replication_slots entries for slot {}",
195                    &slot
196                )));
197            }
198        }
199    }
200    Ok(())
201}
202
203pub async fn get_timeline_id(client: &Client) -> Result<u64, PostgresError> {
204    if let Some(r) =
205        simple_query_opt(client, "SELECT timeline_id FROM pg_control_checkpoint()").await?
206    {
207        r.get("timeline_id")
208            .expect("Returns a row with a timeline ID")
209            .parse::<u64>()
210            .map_err(|err| {
211                PostgresError::Generic(anyhow::anyhow!(
212                    "Failed to parse timeline ID from IDENTIFY_SYSTEM: {}",
213                    err
214                ))
215            })
216    } else {
217        Err(PostgresError::Generic(anyhow::anyhow!(
218            "IDENTIFY_SYSTEM did not return a result row"
219        )))
220    }
221}
222
223pub async fn get_current_wal_lsn(client: &Client) -> Result<PgLsn, PostgresError> {
224    let row = client.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
225    let lsn: PgLsn = row.get(0);
226
227    Ok(lsn)
228}