1use std::str::FromStr;
11
12use tokio_postgres::{
13 Client,
14 types::{Oid, PgLsn},
15};
16
17use mz_ssh_util::tunnel_manager::SshTunnelManager;
18
19use crate::{Config, PostgresError, Sql, query, query_one, simple_query, simple_query_opt};
20
21#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
22pub enum WalLevel {
23 Minimal,
24 Replica,
25 Logical,
26}
27
28impl std::str::FromStr for WalLevel {
29 type Err = anyhow::Error;
30 fn from_str(s: &str) -> Result<Self, Self::Err> {
31 match s {
32 "minimal" => Ok(Self::Minimal),
33 "replica" => Ok(Self::Replica),
34 "logical" => Ok(Self::Logical),
35 o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
36 }
37 }
38}
39
40impl std::fmt::Display for WalLevel {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 let s = match self {
43 WalLevel::Minimal => "minimal",
44 WalLevel::Replica => "replica",
45 WalLevel::Logical => "logical",
46 };
47
48 f.write_str(s)
49 }
50}
51
52#[mz_ore::test]
53fn test_wal_level_max() {
54 for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
56 assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
57 }
58}
59
60pub async fn get_wal_level(client: &Client) -> Result<WalLevel, PostgresError> {
61 let wal_level = query_one(client, crate::sql!("SHOW wal_level"), &[]).await?;
62 let wal_level: String = wal_level.get("wal_level");
63 Ok(WalLevel::from_str(&wal_level)?)
64}
65
66pub async fn get_max_wal_senders(client: &Client) -> Result<i64, PostgresError> {
67 let max_wal_senders = query_one(
68 client,
69 crate::sql!("SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders"),
70 &[],
71 )
72 .await?;
73 Ok(max_wal_senders.get("max_wal_senders"))
74}
75
76pub async fn available_replication_slots(client: &Client) -> Result<i64, PostgresError> {
77 let available_replication_slots = query_one(
78 client,
79 crate::sql!(
80 "SELECT
81 CAST(current_setting('max_replication_slots') AS int8)
82 - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
83 AS available_replication_slots;"
84 ),
85 &[],
86 )
87 .await?;
88
89 let available_replication_slots: i64 =
90 available_replication_slots.get("available_replication_slots");
91
92 Ok(available_replication_slots)
93}
94
95pub async fn bypass_rls_attribute(client: &Client) -> Result<bool, PostgresError> {
99 let rls_attribute = query_one(
100 client,
101 crate::sql!("SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;"),
102 &[],
103 )
104 .await?;
105 Ok(rls_attribute.get("rolbypassrls"))
106}
107
108pub async fn validate_no_rls_policies(
116 client: &Client,
117 table_oids: &[Oid],
118) -> Result<(), PostgresError> {
119 if table_oids.is_empty() {
120 return Ok(());
121 }
122 let tables_with_rls_for_user = query(
123 client,
124 crate::sql!(
125 "SELECT
126 format('%I.%I', n.nspname, pc.relname) AS qualified_name
127 FROM pg_policy pp
128 JOIN pg_class pc ON pc.oid = pp.polrelid
129 JOIN pg_namespace n ON n.oid = pc.relnamespace
130 WHERE
131 pp.polrelid = ANY($1::oid[])
132 AND pp.polcmd IN ('r', '*')
133 AND (
134 0 = ANY(pp.polroles)
135 OR EXISTS (
136 SELECT 1
137 FROM unnest(pp.polroles) AS r(role_oid)
138 WHERE pg_has_role(CURRENT_USER, r.role_oid, 'USAGE')
139 )
140 );"
141 ),
142 &[&table_oids],
143 )
144 .await?;
145
146 let mut tables_with_rls_for_user = tables_with_rls_for_user
147 .into_iter()
148 .map(|row| row.get("qualified_name"))
149 .collect::<Vec<String>>();
150
151 if tables_with_rls_for_user.is_empty() || bypass_rls_attribute(client).await? {
154 Ok(())
155 } else {
156 tables_with_rls_for_user.sort();
157 Err(PostgresError::BypassRLSRequired(tables_with_rls_for_user))
158 }
159}
160
161pub async fn drop_replication_slots(
162 ssh_tunnel_manager: &SshTunnelManager,
163 config: Config,
164 slots: &[(&str, bool)],
165) -> Result<(), PostgresError> {
166 let client = config
167 .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
168 .await?;
169 let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
170 for (slot, should_wait) in slots {
171 let rows = query(
172 &*client,
173 crate::sql!("SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT"),
174 &[&slot],
175 )
176 .await?;
177 match &*rows {
178 [] => {
179 tracing::info!(
181 "drop_replication_slots called on non-existent slot {}",
182 slot
183 );
184 continue;
185 }
186 [row] => {
187 let active_pid: Option<i32> = row.get("active_pid");
193 if let Some(active_pid) = active_pid {
194 let query = crate::sql!("SELECT pg_terminate_backend({})", active_pid);
195 simple_query(&*client, query).await?;
196 }
197
198 let wait = if *should_wait {
199 crate::sql!(" WAIT")
200 } else {
201 crate::sql!("")
202 };
203 let query = crate::sql!("DROP_REPLICATION_SLOT {}{}", Sql::ident(slot), wait);
204 simple_query(&*replication_client, query).await?;
205 }
206 _ => {
207 return Err(PostgresError::Generic(anyhow::anyhow!(
208 "multiple pg_replication_slots entries for slot {}",
209 &slot
210 )));
211 }
212 }
213 }
214 Ok(())
215}
216
217pub async fn get_timeline_id(client: &Client) -> Result<u64, PostgresError> {
218 if let Some(r) = simple_query_opt(
219 client,
220 crate::sql!("SELECT timeline_id FROM pg_control_checkpoint()"),
221 )
222 .await?
223 {
224 r.get("timeline_id")
225 .expect("Returns a row with a timeline ID")
226 .parse::<u64>()
227 .map_err(|err| {
228 PostgresError::Generic(anyhow::anyhow!(
229 "Failed to parse timeline ID from pg_control_checkpoint(): {}",
230 err
231 ))
232 })
233 } else {
234 Err(PostgresError::Generic(anyhow::anyhow!(
235 "pg_control_checkpoint() did not return a result row"
236 )))
237 }
238}
239
240pub async fn fetch_max_lsn(
241 client: &Client,
242 is_physical_standby: bool,
243) -> Result<PgLsn, PostgresError> {
244 if is_physical_standby {
245 if let Some(lsn) = fetch_lsn(client, crate::sql!("SELECT pg_last_wal_replay_lsn()")).await?
248 {
249 return Ok(lsn);
250 }
251 return Err(PostgresError::Generic(anyhow::anyhow!(
252 "pg_last_wal_replay_lsn unexpectedly None, expected client in recovery, is_in_recovery: {}",
253 get_is_in_recovery(client).await?
254 )));
255 }
256 if let Some(lsn) = fetch_lsn(client, crate::sql!("SELECT pg_current_wal_lsn()")).await? {
261 return Ok(lsn);
262 }
263
264 Err(PostgresError::Generic(anyhow::anyhow!(
265 "pg_current_wal_lsn mysteriously has no value"
266 )))
267}
268
269async fn fetch_lsn(client: &Client, query: Sql) -> Result<Option<PgLsn>, PostgresError> {
270 let row = query_one(client, query, &[]).await?;
271 let lsn: Option<PgLsn> = row.get(0);
272 Ok(lsn)
273}
274
275pub async fn get_is_in_recovery(client: &Client) -> Result<bool, PostgresError> {
277 let row = query_one(client, crate::sql!("SELECT pg_is_in_recovery()"), &[]).await?;
278 let is_in_recovery: Option<bool> = row.get(0);
279 if let Some(is_in_recovery) = is_in_recovery {
280 return Ok(is_in_recovery);
281 }
282
283 Err(PostgresError::Generic(anyhow::anyhow!(
284 "pg_is_in_recovery() mysteriously has no value"
285 )))
286}