mz_postgres_util/
replication.rs1use std::str::FromStr;
11
12use mz_sql_parser::ast::{Ident, display::AstDisplay};
13use tokio_postgres::{
14 Client,
15 types::{Oid, PgLsn},
16};
17
18use mz_ssh_util::tunnel_manager::SshTunnelManager;
19
20use crate::{Config, PostgresError, simple_query_opt};
21
22#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
23pub enum WalLevel {
24 Minimal,
25 Replica,
26 Logical,
27}
28
29impl std::str::FromStr for WalLevel {
30 type Err = anyhow::Error;
31 fn from_str(s: &str) -> Result<Self, Self::Err> {
32 match s {
33 "minimal" => Ok(Self::Minimal),
34 "replica" => Ok(Self::Replica),
35 "logical" => Ok(Self::Logical),
36 o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
37 }
38 }
39}
40
41impl std::fmt::Display for WalLevel {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 let s = match self {
44 WalLevel::Minimal => "minimal",
45 WalLevel::Replica => "replica",
46 WalLevel::Logical => "logical",
47 };
48
49 f.write_str(s)
50 }
51}
52
53#[mz_ore::test]
54fn test_wal_level_max() {
55 for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
57 assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
58 }
59}
60
61pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
62 let wal_level = client.query_one("SHOW wal_level", &[]).await?;
63 let wal_level: String = wal_level.get("wal_level");
64 Ok(WalLevel::from_str(&wal_level)?)
65}
66
67pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
68 let max_wal_senders = client
69 .query_one(
70 "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders",
71 &[],
72 )
73 .await?;
74 Ok(max_wal_senders.get("max_wal_senders"))
75}
76
77pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
78 let available_replication_slots = client
79 .query_one(
80 "SELECT
81 CAST(current_setting('max_replication_slots') AS int8)
82 - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
83 AS available_replication_slots;",
84 &[],
85 )
86 .await?;
87
88 let available_replication_slots: i64 =
89 available_replication_slots.get("available_replication_slots");
90
91 Ok(available_replication_slots)
92}
93
94pub async fn bypass_rls_attribute(client: &Client) -> Result<bool, PostgresError> {
98 let rls_attribute = client
99 .query_one(
100 "SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;",
101 &[],
102 )
103 .await?;
104 Ok(rls_attribute.get("rolbypassrls"))
105}
106
107pub async fn validate_no_rls_policies(
115 client: &Client,
116 table_oids: &[Oid],
117) -> Result<(), PostgresError> {
118 if table_oids.is_empty() {
119 return Ok(());
120 }
121 let tables_with_rls_for_user = client
122 .query(
123 "SELECT
124 format('%I.%I', n.nspname, pc.relname) AS qualified_name
125 FROM pg_policy pp
126 JOIN pg_class pc ON pc.oid = pp.polrelid
127 JOIN pg_namespace n ON n.oid = pc.relnamespace
128 WHERE
129 pp.polrelid = ANY($1::oid[])
130 AND pp.polcmd IN ('r', '*')
131 AND (
132 0 = ANY(pp.polroles)
133 OR EXISTS (
134 SELECT 1
135 FROM unnest(pp.polroles) AS r(role_oid)
136 WHERE pg_has_role(CURRENT_USER, r.role_oid, 'USAGE')
137 )
138 );",
139 &[&table_oids],
140 )
141 .await
142 .map_err(PostgresError::from)?;
143
144 let mut tables_with_rls_for_user = tables_with_rls_for_user
145 .into_iter()
146 .map(|row| row.get("qualified_name"))
147 .collect::<Vec<String>>();
148
149 if tables_with_rls_for_user.is_empty() || bypass_rls_attribute(client).await? {
152 Ok(())
153 } else {
154 tables_with_rls_for_user.sort();
155 Err(PostgresError::BypassRLSRequired(tables_with_rls_for_user))
156 }
157}
158
159pub async fn drop_replication_slots(
160 ssh_tunnel_manager: &SshTunnelManager,
161 config: Config,
162 slots: &[(&str, bool)],
163) -> Result<(), PostgresError> {
164 let client = config
165 .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
166 .await?;
167 let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
168 for (slot, should_wait) in slots {
169 let rows = client
170 .query(
171 "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT",
172 &[&slot],
173 )
174 .await?;
175 match &*rows {
176 [] => {
177 tracing::info!(
179 "drop_replication_slots called on non-existent slot {}",
180 slot
181 );
182 continue;
183 }
184 [row] => {
185 let active_pid: Option<i32> = row.get("active_pid");
191 if let Some(active_pid) = active_pid {
192 client
193 .simple_query(&format!("SELECT pg_terminate_backend({active_pid})"))
194 .await?;
195 }
196
197 let wait_str = if *should_wait { " WAIT" } else { "" };
198 let slot = Ident::new_unchecked(*slot).to_ast_string_simple();
199 replication_client
200 .simple_query(&format!("DROP_REPLICATION_SLOT {slot}{wait_str}"))
201 .await?;
202 }
203 _ => {
204 return Err(PostgresError::Generic(anyhow::anyhow!(
205 "multiple pg_replication_slots entries for slot {}",
206 &slot
207 )));
208 }
209 }
210 }
211 Ok(())
212}
213
214pub async fn get_timeline_id(client: &Client) -> Result<u64, PostgresError> {
215 if let Some(r) =
216 simple_query_opt(client, "SELECT timeline_id FROM pg_control_checkpoint()").await?
217 {
218 r.get("timeline_id")
219 .expect("Returns a row with a timeline ID")
220 .parse::<u64>()
221 .map_err(|err| {
222 PostgresError::Generic(anyhow::anyhow!(
223 "Failed to parse timeline ID from IDENTIFY_SYSTEM: {}",
224 err
225 ))
226 })
227 } else {
228 Err(PostgresError::Generic(anyhow::anyhow!(
229 "IDENTIFY_SYSTEM did not return a result row"
230 )))
231 }
232}
233
234pub async fn get_current_wal_lsn(client: &Client) -> Result<PgLsn, PostgresError> {
235 let row = client.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
236 let lsn: PgLsn = row.get(0);
237
238 Ok(lsn)
239}