1use std::fmt::Formatter;
13use std::str::FromStr;
14use std::sync::Arc;
15use std::time::Duration;
16
17use anyhow::anyhow;
18use async_stream::try_stream;
19use async_trait::async_trait;
20use bytes::Bytes;
21use deadpool_postgres::tokio_postgres::Config;
22use deadpool_postgres::tokio_postgres::types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
23use deadpool_postgres::{Object, PoolError};
24use futures_util::StreamExt;
25use mz_dyncfg::ConfigSet;
26use mz_ore::cast::CastFrom;
27use mz_ore::metrics::MetricsRegistry;
28use mz_ore::url::SensitiveUrl;
29use mz_postgres_client::metrics::PostgresClientMetrics;
30use mz_postgres_client::{PostgresClient, PostgresClientConfig, PostgresClientKnobs};
31use postgres_protocol::escape::escape_identifier;
32use tokio_postgres::error::SqlState;
33use tracing::{info, warn};
34
35use crate::error::Error;
36use crate::location::{CaSResult, Consensus, ExternalError, ResultStream, SeqNo, VersionedData};
37
38pub const USE_POSTGRES_TUNED_QUERIES: mz_dyncfg::Config<bool> = mz_dyncfg::Config::new(
40 "persist_use_postgres_tuned_queries",
41 false,
42 "Use a set of queries for consensus that have specifically been tuned against
43 Postgres to ensure we acquire a minimal number of locks.",
44);
45
46const SCHEMA: &str = "
47CREATE TABLE IF NOT EXISTS consensus (
48 shard text NOT NULL,
49 sequence_number bigint NOT NULL,
50 data bytea NOT NULL,
51 PRIMARY KEY(shard, sequence_number)
52)
53";
54
55const CRDB_SCHEMA_OPTIONS: &str = "WITH (sql_stats_automatic_collection_enabled = false)";
62const CRDB_CONFIGURE_ZONE: &str = "ALTER TABLE consensus CONFIGURE ZONE USING gc.ttlseconds = 600";
70
71impl ToSql for SeqNo {
72 fn to_sql(
73 &self,
74 ty: &Type,
75 w: &mut bytes::BytesMut,
76 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
77 let value = i64::try_from(self.0)?;
79 <i64 as ToSql>::to_sql(&value, ty, w)
80 }
81
82 fn accepts(ty: &Type) -> bool {
83 <i64 as ToSql>::accepts(ty)
84 }
85
86 to_sql_checked!();
87}
88
89impl<'a> FromSql<'a> for SeqNo {
90 fn from_sql(
91 ty: &Type,
92 raw: &'a [u8],
93 ) -> Result<SeqNo, Box<dyn std::error::Error + Sync + Send>> {
94 let sequence_number = <i64 as FromSql>::from_sql(ty, raw)?;
95
96 let sequence_number = u64::try_from(sequence_number)?;
99 Ok(SeqNo(sequence_number))
100 }
101
102 fn accepts(ty: &Type) -> bool {
103 <i64 as FromSql>::accepts(ty)
104 }
105}
106
107#[derive(Clone, Debug)]
109pub struct PostgresConsensusConfig {
110 url: SensitiveUrl,
111 knobs: Arc<dyn PostgresClientKnobs>,
112 metrics: PostgresClientMetrics,
113 dyncfg: Arc<ConfigSet>,
114}
115
116impl From<PostgresConsensusConfig> for PostgresClientConfig {
117 fn from(config: PostgresConsensusConfig) -> Self {
118 PostgresClientConfig::new(config.url, config.knobs, config.metrics)
119 }
120}
121
122impl PostgresConsensusConfig {
123 const EXTERNAL_TESTS_POSTGRES_URL: &'static str =
124 "MZ_PERSIST_EXTERNAL_STORAGE_TEST_POSTGRES_URL";
125
126 pub fn new(
128 url: &SensitiveUrl,
129 knobs: Box<dyn PostgresClientKnobs>,
130 metrics: PostgresClientMetrics,
131 dyncfg: Arc<ConfigSet>,
132 ) -> Result<Self, Error> {
133 Ok(PostgresConsensusConfig {
134 url: url.clone(),
135 knobs: Arc::from(knobs),
136 metrics,
137 dyncfg,
138 })
139 }
140
141 pub fn new_for_test() -> Result<Option<Self>, Error> {
151 let url = match std::env::var(Self::EXTERNAL_TESTS_POSTGRES_URL) {
152 Ok(url) => SensitiveUrl::from_str(&url).map_err(|e| e.to_string())?,
153 Err(_) => {
154 if mz_ore::env::is_var_truthy("CI") {
155 panic!("CI is supposed to run this test but something has gone wrong!");
156 }
157 return Ok(None);
158 }
159 };
160
161 struct TestConsensusKnobs;
162 impl std::fmt::Debug for TestConsensusKnobs {
163 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("TestConsensusKnobs").finish_non_exhaustive()
165 }
166 }
167 impl PostgresClientKnobs for TestConsensusKnobs {
168 fn connection_pool_max_size(&self) -> usize {
169 2
170 }
171
172 fn connection_pool_max_wait(&self) -> Option<Duration> {
173 Some(Duration::from_secs(1))
174 }
175
176 fn connection_pool_ttl(&self) -> Duration {
177 Duration::MAX
178 }
179 fn connection_pool_ttl_stagger(&self) -> Duration {
180 Duration::MAX
181 }
182 fn connect_timeout(&self) -> Duration {
183 Duration::MAX
184 }
185 fn tcp_user_timeout(&self) -> Duration {
186 Duration::ZERO
187 }
188
189 fn keepalives_idle(&self) -> Duration {
190 Duration::from_secs(10)
191 }
192
193 fn keepalives_interval(&self) -> Duration {
194 Duration::from_secs(5)
195 }
196
197 fn keepalives_retries(&self) -> u32 {
198 5
199 }
200 }
201
202 let dyncfg = ConfigSet::default().add(&USE_POSTGRES_TUNED_QUERIES);
203 let config = PostgresConsensusConfig::new(
204 &url,
205 Box::new(TestConsensusKnobs),
206 PostgresClientMetrics::new(&MetricsRegistry::new(), "mz_persist"),
207 Arc::new(dyncfg),
208 )?;
209 Ok(Some(config))
210 }
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215enum PostgresMode {
216 CockroachDB,
218 Postgres,
220}
221
222pub struct PostgresConsensus {
224 postgres_client: PostgresClient,
225 dyncfg: Arc<ConfigSet>,
226 mode: PostgresMode,
227}
228
229impl std::fmt::Debug for PostgresConsensus {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("PostgresConsensus").finish_non_exhaustive()
232 }
233}
234
235impl PostgresConsensus {
236 pub async fn open(config: PostgresConsensusConfig) -> Result<Self, ExternalError> {
239 let pg_config: Config = config.url.to_string().parse()?;
241 let role = pg_config.get_user().unwrap();
242 let create_schema = format!(
243 "CREATE SCHEMA IF NOT EXISTS consensus AUTHORIZATION {}",
244 escape_identifier(role),
245 );
246
247 let dyncfg = Arc::clone(&config.dyncfg);
248 let postgres_client = PostgresClient::open(config.into())?;
249
250 let client = postgres_client.get_connection().await?;
251
252 let mode = match client
253 .batch_execute(&format!(
254 "{}; {}{}; {};",
255 create_schema, SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE,
256 ))
257 .await
258 {
259 Ok(()) => PostgresMode::CockroachDB,
260 Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => {
261 warn!(
262 "unable to ALTER TABLE consensus, this is expected and OK when connecting with a read-only user"
263 );
264 PostgresMode::CockroachDB
265 }
266 Err(e)
269 if e.code() == Some(&SqlState::INVALID_PARAMETER_VALUE)
270 || e.code() == Some(&SqlState::SYNTAX_ERROR) =>
271 {
272 info!(
273 "unable to initiate consensus with CRDB params, this is expected and OK when running against Postgres: {:?}",
274 e
275 );
276 PostgresMode::Postgres
277 }
278 Err(e) => return Err(e.into()),
279 };
280
281 if mode != PostgresMode::CockroachDB {
282 client
283 .batch_execute(&format!("{}; {};", create_schema, SCHEMA))
284 .await?;
285 }
286
287 Ok(PostgresConsensus {
288 postgres_client,
289 dyncfg,
290 mode,
291 })
292 }
293
294 pub async fn drop_and_recreate(&self) -> Result<(), ExternalError> {
298 let client = self.get_connection().await?;
300 client.execute("DROP TABLE consensus", &[]).await?;
301 let crdb_mode = match client
302 .batch_execute(&format!(
303 "{}{}; {}",
304 SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE,
305 ))
306 .await
307 {
308 Ok(()) => true,
309 Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => {
310 warn!(
311 "unable to ALTER TABLE consensus, this is expected and OK when connecting with a read-only user"
312 );
313 true
314 }
315 Err(e)
316 if e.code() == Some(&SqlState::INVALID_PARAMETER_VALUE)
317 || e.code() == Some(&SqlState::SYNTAX_ERROR) =>
318 {
319 info!(
320 "unable to initiate consensus with CRDB params, this is expected and OK when running against Postgres: {:?}",
321 e
322 );
323 false
324 }
325 Err(e) => return Err(e.into()),
326 };
327
328 if !crdb_mode {
329 client.execute(SCHEMA, &[]).await?;
330 }
331 Ok(())
332 }
333
334 async fn get_connection(&self) -> Result<Object, PoolError> {
335 self.postgres_client.get_connection().await
336 }
337}
338
339#[async_trait]
340impl Consensus for PostgresConsensus {
341 fn list_keys(&self) -> ResultStream<'_, String> {
342 let q = "SELECT DISTINCT shard FROM consensus";
343
344 Box::pin(try_stream! {
345 let client = self.get_connection().await?;
348 let statement = client.prepare_cached(q).await?;
349 let params: &[String] = &[];
350 let mut rows = Box::pin(client.query_raw(&statement, params).await?);
351 while let Some(row) = rows.next().await {
352 let shard: String = row?.try_get("shard")?;
353 yield shard;
354 }
355 })
356 }
357
358 async fn head(&self, key: &str) -> Result<Option<VersionedData>, ExternalError> {
359 let q = "SELECT sequence_number, data FROM consensus
360 WHERE shard = $1 ORDER BY sequence_number DESC LIMIT 1";
361 let row = {
362 let client = self.get_connection().await?;
363 let statement = client.prepare_cached(q).await?;
364 client.query_opt(&statement, &[&key]).await?
365 };
366 let row = match row {
367 None => return Ok(None),
368 Some(row) => row,
369 };
370
371 let seqno: SeqNo = row.try_get("sequence_number")?;
372
373 let data: Vec<u8> = row.try_get("data")?;
374 Ok(Some(VersionedData {
375 seqno,
376 data: Bytes::from(data),
377 }))
378 }
379
380 async fn compare_and_set(
381 &self,
382 key: &str,
383 expected: Option<SeqNo>,
384 new: VersionedData,
385 ) -> Result<CaSResult, ExternalError> {
386 if let Some(expected) = expected {
387 if new.seqno <= expected {
388 return Err(Error::from(
389 format!("new seqno must be strictly greater than expected. Got new: {:?} expected: {:?}",
390 new.seqno, expected)).into());
391 }
392 }
393
394 let result = if let Some(expected) = expected {
395 static CRDB_CAS_QUERY: &str = "
402 INSERT INTO consensus (shard, sequence_number, data)
403 SELECT $1, $2, $3
404 WHERE (SELECT sequence_number FROM consensus
405 WHERE shard = $1
406 ORDER BY sequence_number DESC LIMIT 1) = $4;
407 ";
408
409 static POSTGRES_CAS_QUERY: &str = "
414 WITH last_seq AS (
415 SELECT sequence_number FROM consensus
416 WHERE shard = $1
417 ORDER BY sequence_number DESC
418 LIMIT 1
419 FOR UPDATE
420 )
421 INSERT INTO consensus (shard, sequence_number, data)
422 SELECT $1, $2, $3
423 FROM last_seq
424 WHERE last_seq.sequence_number = $4;
425 ";
426
427 let q = if USE_POSTGRES_TUNED_QUERIES.get(&self.dyncfg)
428 && self.mode == PostgresMode::Postgres
429 {
430 POSTGRES_CAS_QUERY
431 } else {
432 CRDB_CAS_QUERY
433 };
434 let client = self.get_connection().await?;
435 let statement = client.prepare_cached(q).await?;
436 client
437 .execute(
438 &statement,
439 &[&key, &new.seqno, &new.data.as_ref(), &expected],
440 )
441 .await?
442 } else {
443 let q = "INSERT INTO consensus SELECT $1, $2, $3 WHERE
445 NOT EXISTS (
446 SELECT * FROM consensus WHERE shard = $1
447 )
448 ON CONFLICT DO NOTHING";
449 let client = self.get_connection().await?;
450 let statement = client.prepare_cached(q).await?;
451 client
452 .execute(&statement, &[&key, &new.seqno, &new.data.as_ref()])
453 .await?
454 };
455
456 if result == 1 {
457 Ok(CaSResult::Committed)
458 } else {
459 Ok(CaSResult::ExpectationMismatch)
460 }
461 }
462
463 async fn scan(
464 &self,
465 key: &str,
466 from: SeqNo,
467 limit: usize,
468 ) -> Result<Vec<VersionedData>, ExternalError> {
469 let q = "SELECT sequence_number, data FROM consensus
470 WHERE shard = $1 AND sequence_number >= $2
471 ORDER BY sequence_number ASC LIMIT $3";
472 let Ok(limit) = i64::try_from(limit) else {
473 return Err(ExternalError::from(anyhow!(
474 "limit must be [0, i64::MAX]. was: {:?}",
475 limit
476 )));
477 };
478 let rows = {
479 let client = self.get_connection().await?;
480 let statement = client.prepare_cached(q).await?;
481 client.query(&statement, &[&key, &from, &limit]).await?
482 };
483 let mut results = Vec::with_capacity(rows.len());
484
485 for row in rows {
486 let seqno: SeqNo = row.try_get("sequence_number")?;
487 let data: Vec<u8> = row.try_get("data")?;
488 results.push(VersionedData {
489 seqno,
490 data: Bytes::from(data),
491 });
492 }
493 Ok(results)
494 }
495
496 async fn truncate(&self, key: &str, seqno: SeqNo) -> Result<Option<usize>, ExternalError> {
497 static CRDB_TRUNCATE_QUERY: &str = "
498 DELETE FROM consensus
499 WHERE shard = $1 AND sequence_number < $2 AND
500 EXISTS (
501 SELECT * FROM consensus WHERE shard = $1 AND sequence_number >= $2
502 )
503 ";
504
505 static POSTGRES_TRUNCATE_QUERY: &str = "
517 WITH newer_exists AS (
518 SELECT * FROM consensus
519 WHERE shard = $1
520 AND sequence_number >= $2
521 ORDER BY sequence_number ASC
522 LIMIT 1
523 FOR UPDATE
524 ),
525 to_lock AS (
526 SELECT ctid FROM consensus
527 WHERE shard = $1
528 AND sequence_number < $2
529 AND EXISTS (SELECT * FROM newer_exists)
530 ORDER BY sequence_number DESC
531 FOR UPDATE
532 )
533 DELETE FROM consensus
534 USING to_lock
535 WHERE consensus.ctid = to_lock.ctid;
536 ";
537
538 let q = if USE_POSTGRES_TUNED_QUERIES.get(&self.dyncfg)
539 && self.mode == PostgresMode::Postgres
540 {
541 POSTGRES_TRUNCATE_QUERY
542 } else {
543 CRDB_TRUNCATE_QUERY
544 };
545 let result = {
546 let client = self.get_connection().await?;
547 let statement = client.prepare_cached(q).await?;
548 client.execute(&statement, &[&key, &seqno]).await?
549 };
550 if result == 0 {
551 let current = self.head(key).await?;
563 if current.map_or(true, |data| data.seqno < seqno) {
564 return Err(ExternalError::from(anyhow!(
565 "upper bound too high for truncate: {:?}",
566 seqno
567 )));
568 }
569 }
570
571 Ok(Some(usize::cast_from(result)))
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use mz_ore::assert_err;
578 use tracing::info;
579 use uuid::Uuid;
580
581 use crate::location::tests::consensus_impl_test;
582
583 use super::*;
584
585 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
586 #[cfg_attr(miri, ignore)] async fn postgres_consensus() -> Result<(), ExternalError> {
588 let config = match PostgresConsensusConfig::new_for_test()? {
589 Some(config) => config,
590 None => {
591 info!(
592 "{} env not set: skipping test that uses external service",
593 PostgresConsensusConfig::EXTERNAL_TESTS_POSTGRES_URL
594 );
595 return Ok(());
596 }
597 };
598
599 consensus_impl_test(|| PostgresConsensus::open(config.clone())).await?;
600
601 let consensus = PostgresConsensus::open(config.clone()).await?;
603 let key = Uuid::new_v4().to_string();
604 let state = VersionedData {
605 seqno: SeqNo(5),
606 data: Bytes::from("abc"),
607 };
608
609 assert_eq!(
610 consensus.compare_and_set(&key, None, state.clone()).await,
611 Ok(CaSResult::Committed),
612 );
613
614 assert_eq!(consensus.head(&key).await, Ok(Some(state.clone())));
615
616 consensus.drop_and_recreate().await?;
617
618 assert_eq!(consensus.head(&key).await, Ok(None));
619
620 let config = match PostgresConsensusConfig::new_for_test()? {
624 Some(config) => config,
625 None => {
626 info!(
627 "{} env not set: skipping test that uses external service",
628 PostgresConsensusConfig::EXTERNAL_TESTS_POSTGRES_URL
629 );
630 return Ok(());
631 }
632 };
633
634 let consensus: PostgresConsensus = PostgresConsensus::open(config.clone()).await?;
635 let _conn1 = consensus.get_connection().await?;
637 let _conn2 = consensus.get_connection().await?;
638
639 let conn3 = consensus.get_connection().await;
641
642 assert_err!(conn3);
643
644 Ok(())
645 }
646}