Skip to main content

mz_postgres_util/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::str::FromStr;
11
12use mz_sql_parser::ast::{Ident, display::AstDisplay};
13use tokio_postgres::{
14    Client,
15    types::{Oid, PgLsn},
16};
17
18use mz_ssh_util::tunnel_manager::SshTunnelManager;
19
20use crate::{Config, PostgresError, simple_query_opt};
21
22#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
23pub enum WalLevel {
24    Minimal,
25    Replica,
26    Logical,
27}
28
29impl std::str::FromStr for WalLevel {
30    type Err = anyhow::Error;
31    fn from_str(s: &str) -> Result<Self, Self::Err> {
32        match s {
33            "minimal" => Ok(Self::Minimal),
34            "replica" => Ok(Self::Replica),
35            "logical" => Ok(Self::Logical),
36            o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
37        }
38    }
39}
40
41impl std::fmt::Display for WalLevel {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        let s = match self {
44            WalLevel::Minimal => "minimal",
45            WalLevel::Replica => "replica",
46            WalLevel::Logical => "logical",
47        };
48
49        f.write_str(s)
50    }
51}
52
53#[mz_ore::test]
54fn test_wal_level_max() {
55    // Ensure `WalLevel::Logical` is the max among all levels.
56    for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
57        assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
58    }
59}
60
61pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
62    let wal_level = client.query_one("SHOW wal_level", &[]).await?;
63    let wal_level: String = wal_level.get("wal_level");
64    Ok(WalLevel::from_str(&wal_level)?)
65}
66
67pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
68    let max_wal_senders = client
69        .query_one(
70            "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders",
71            &[],
72        )
73        .await?;
74    Ok(max_wal_senders.get("max_wal_senders"))
75}
76
77pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
78    let available_replication_slots = client
79        .query_one(
80            "SELECT
81            CAST(current_setting('max_replication_slots') AS int8)
82              - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
83              AS available_replication_slots;",
84            &[],
85        )
86        .await?;
87
88    let available_replication_slots: i64 =
89        available_replication_slots.get("available_replication_slots");
90
91    Ok(available_replication_slots)
92}
93
94/// Returns true if BYPASSRLS is set for the current user, false otherwise.
95///
96/// See <https://www.postgresql.org/docs/current/ddl-rowsecurity.html>
97pub async fn bypass_rls_attribute(client: &Client) -> Result<bool, PostgresError> {
98    let rls_attribute = client
99        .query_one(
100            "SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;",
101            &[],
102        )
103        .await?;
104    Ok(rls_attribute.get("rolbypassrls"))
105}
106
107/// Returns an error if the tables identified by the oid's have RLS policies which
108/// affect the current user. Two checks are made:
109///
110/// 1. Identify which tables, from the provided oid's, have RLS SELECT or ALL policies that affect
111///    the user or public.
112/// 2. If there are policies that affect the user, check if the BYPASSRLS attribute is set. If set,
113///    the role is unaffected by the policies.
114pub async fn validate_no_rls_policies(
115    client: &Client,
116    table_oids: &[Oid],
117) -> Result<(), PostgresError> {
118    if table_oids.is_empty() {
119        return Ok(());
120    }
121    let tables_with_rls_for_user = client
122        .query(
123            "SELECT
124                    format('%I.%I', n.nspname, pc.relname) AS qualified_name
125                FROM pg_policy pp
126                JOIN pg_class pc ON pc.oid = pp.polrelid
127                JOIN pg_namespace n ON n.oid = pc.relnamespace
128                WHERE
129                    pp.polrelid = ANY($1::oid[])
130                    AND pp.polcmd IN ('r', '*')
131                    AND (
132                        0 = ANY(pp.polroles)
133                        OR EXISTS (
134                            SELECT 1
135                            FROM unnest(pp.polroles) AS r(role_oid)
136                            WHERE pg_has_role(CURRENT_USER, r.role_oid, 'USAGE')
137                        )
138                    );",
139            &[&table_oids],
140        )
141        .await
142        .map_err(PostgresError::from)?;
143
144    let mut tables_with_rls_for_user = tables_with_rls_for_user
145        .into_iter()
146        .map(|row| row.get("qualified_name"))
147        .collect::<Vec<String>>();
148
149    // If the user has the BYPASSRLS flag set, then the policies don't apply, so we can
150    // return success.
151    if tables_with_rls_for_user.is_empty() || bypass_rls_attribute(client).await? {
152        Ok(())
153    } else {
154        tables_with_rls_for_user.sort();
155        Err(PostgresError::BypassRLSRequired(tables_with_rls_for_user))
156    }
157}
158
159pub async fn drop_replication_slots(
160    ssh_tunnel_manager: &SshTunnelManager,
161    config: Config,
162    slots: &[(&str, bool)],
163) -> Result<(), PostgresError> {
164    let client = config
165        .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
166        .await?;
167    let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
168    for (slot, should_wait) in slots {
169        let rows = client
170            .query(
171                "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT",
172                &[&slot],
173            )
174            .await?;
175        match &*rows {
176            [] => {
177                // DROP_REPLICATION_SLOT will error if the slot does not exist
178                tracing::info!(
179                    "drop_replication_slots called on non-existent slot {}",
180                    slot
181                );
182                continue;
183            }
184            [row] => {
185                // The drop of a replication slot happens concurrently with an ingestion dataflow
186                // shutting down, therefore there is the possibility that the slot is still in use.
187                // We really don't want to leak the slot and not forcefully terminating the
188                // dataflow's connection risks timing out. For this reason we always kill the
189                // active backend and drop the slot.
190                let active_pid: Option<i32> = row.get("active_pid");
191                if let Some(active_pid) = active_pid {
192                    client
193                        .simple_query(&format!("SELECT pg_terminate_backend({active_pid})"))
194                        .await?;
195                }
196
197                let wait_str = if *should_wait { " WAIT" } else { "" };
198                let slot = Ident::new_unchecked(*slot).to_ast_string_simple();
199                replication_client
200                    .simple_query(&format!("DROP_REPLICATION_SLOT {slot}{wait_str}"))
201                    .await?;
202            }
203            _ => {
204                return Err(PostgresError::Generic(anyhow::anyhow!(
205                    "multiple pg_replication_slots entries for slot {}",
206                    &slot
207                )));
208            }
209        }
210    }
211    Ok(())
212}
213
214pub async fn get_timeline_id(client: &Client) -> Result<u64, PostgresError> {
215    if let Some(r) =
216        simple_query_opt(client, "SELECT timeline_id FROM pg_control_checkpoint()").await?
217    {
218        r.get("timeline_id")
219            .expect("Returns a row with a timeline ID")
220            .parse::<u64>()
221            .map_err(|err| {
222                PostgresError::Generic(anyhow::anyhow!(
223                    "Failed to parse timeline ID from IDENTIFY_SYSTEM: {}",
224                    err
225                ))
226            })
227    } else {
228        Err(PostgresError::Generic(anyhow::anyhow!(
229            "IDENTIFY_SYSTEM did not return a result row"
230        )))
231    }
232}
233
234pub async fn get_current_wal_lsn(client: &Client) -> Result<PgLsn, PostgresError> {
235    let row = client.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
236    let lsn: PgLsn = row.get(0);
237
238    Ok(lsn)
239}