mz_postgres_util/
replication.rs1use std::str::FromStr;
11use tokio_postgres::{
12 Client,
13 types::{Oid, PgLsn},
14};
15
16use mz_ssh_util::tunnel_manager::SshTunnelManager;
17
18use crate::{Config, PostgresError, simple_query_opt};
19
20#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
21pub enum WalLevel {
22 Minimal,
23 Replica,
24 Logical,
25}
26
27impl std::str::FromStr for WalLevel {
28 type Err = anyhow::Error;
29 fn from_str(s: &str) -> Result<Self, Self::Err> {
30 match s {
31 "minimal" => Ok(Self::Minimal),
32 "replica" => Ok(Self::Replica),
33 "logical" => Ok(Self::Logical),
34 o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
35 }
36 }
37}
38
39impl std::fmt::Display for WalLevel {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 let s = match self {
42 WalLevel::Minimal => "minimal",
43 WalLevel::Replica => "replica",
44 WalLevel::Logical => "logical",
45 };
46
47 f.write_str(s)
48 }
49}
50
51#[mz_ore::test]
52fn test_wal_level_max() {
53 for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
55 assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
56 }
57}
58
59pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
60 let wal_level = client.query_one("SHOW wal_level", &[]).await?;
61 let wal_level: String = wal_level.get("wal_level");
62 Ok(WalLevel::from_str(&wal_level)?)
63}
64
65pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
66 let max_wal_senders = client
67 .query_one(
68 "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders",
69 &[],
70 )
71 .await?;
72 Ok(max_wal_senders.get("max_wal_senders"))
73}
74
75pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
76 let available_replication_slots = client
77 .query_one(
78 "SELECT
79 CAST(current_setting('max_replication_slots') AS int8)
80 - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
81 AS available_replication_slots;",
82 &[],
83 )
84 .await?;
85
86 let available_replication_slots: i64 =
87 available_replication_slots.get("available_replication_slots");
88
89 Ok(available_replication_slots)
90}
91
92pub async fn bypass_rls_attribute(client: &Client) -> Result<bool, PostgresError> {
96 let rls_attribute = client
97 .query_one(
98 "SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;",
99 &[],
100 )
101 .await?;
102 Ok(rls_attribute.get("rolbypassrls"))
103}
104
105pub async fn validate_no_rls_policies(
113 client: &Client,
114 table_oids: &[Oid],
115) -> Result<(), PostgresError> {
116 if table_oids.is_empty() {
117 return Ok(());
118 }
119 let tables_with_rls_for_user = client
120 .query(
121 "SELECT
122 format('%I.%I', pc.relnamespace::regnamespace, pc.relname) AS qualified_name
123 FROM pg_policy pp
124 JOIN pg_class pc ON pc.oid = polrelid
125 WHERE
126 polrelid = ANY($1::oid[])
127 AND
128 (0 = ANY(polroles) OR CURRENT_USER::regrole::oid = ANY(polroles));",
129 &[&table_oids],
130 )
131 .await
132 .map_err(PostgresError::from)?;
133
134 let mut tables_with_rls_for_user = tables_with_rls_for_user
135 .into_iter()
136 .map(|row| row.get("qualified_name"))
137 .collect::<Vec<String>>();
138
139 if tables_with_rls_for_user.is_empty() || bypass_rls_attribute(client).await? {
142 Ok(())
143 } else {
144 tables_with_rls_for_user.sort();
145 Err(PostgresError::BypassRLSRequired(tables_with_rls_for_user))
146 }
147}
148
149pub async fn drop_replication_slots(
150 ssh_tunnel_manager: &SshTunnelManager,
151 config: Config,
152 slots: &[(&str, bool)],
153) -> Result<(), PostgresError> {
154 let client = config
155 .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
156 .await?;
157 let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
158 for (slot, should_wait) in slots {
159 let rows = client
160 .query(
161 "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT",
162 &[&slot],
163 )
164 .await?;
165 match &*rows {
166 [] => {
167 tracing::info!(
169 "drop_replication_slots called on non-existent slot {}",
170 slot
171 );
172 continue;
173 }
174 [row] => {
175 let active_pid: Option<i32> = row.get("active_pid");
181 if let Some(active_pid) = active_pid {
182 client
183 .simple_query(&format!("SELECT pg_terminate_backend({active_pid})"))
184 .await?;
185 }
186
187 let wait_str = if *should_wait { " WAIT" } else { "" };
188 replication_client
189 .simple_query(&format!("DROP_REPLICATION_SLOT {slot}{wait_str}"))
190 .await?;
191 }
192 _ => {
193 return Err(PostgresError::Generic(anyhow::anyhow!(
194 "multiple pg_replication_slots entries for slot {}",
195 &slot
196 )));
197 }
198 }
199 }
200 Ok(())
201}
202
203pub async fn get_timeline_id(client: &Client) -> Result<u64, PostgresError> {
204 if let Some(r) =
205 simple_query_opt(client, "SELECT timeline_id FROM pg_control_checkpoint()").await?
206 {
207 r.get("timeline_id")
208 .expect("Returns a row with a timeline ID")
209 .parse::<u64>()
210 .map_err(|err| {
211 PostgresError::Generic(anyhow::anyhow!(
212 "Failed to parse timeline ID from IDENTIFY_SYSTEM: {}",
213 err
214 ))
215 })
216 } else {
217 Err(PostgresError::Generic(anyhow::anyhow!(
218 "IDENTIFY_SYSTEM did not return a result row"
219 )))
220 }
221}
222
223pub async fn get_current_wal_lsn(client: &Client) -> Result<PgLsn, PostgresError> {
224 let row = client.query_one("SELECT pg_current_wal_lsn()", &[]).await?;
225 let lsn: PgLsn = row.get(0);
226
227 Ok(lsn)
228}