1use std::fmt::Formatter;
13use std::str::FromStr;
14use std::sync::Arc;
15use std::time::Duration;
16
17use anyhow::anyhow;
18use async_stream::try_stream;
19use async_trait::async_trait;
20use bytes::Bytes;
21use deadpool_postgres::tokio_postgres::Config;
22use deadpool_postgres::tokio_postgres::types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
23use deadpool_postgres::{Object, PoolError};
24use futures_util::StreamExt;
25use mz_ore::cast::CastFrom;
26use mz_ore::metrics::MetricsRegistry;
27use mz_ore::url::SensitiveUrl;
28use mz_postgres_client::metrics::PostgresClientMetrics;
29use mz_postgres_client::{PostgresClient, PostgresClientConfig, PostgresClientKnobs};
30use postgres_protocol::escape::escape_identifier;
31use tokio_postgres::error::SqlState;
32use tracing::{info, warn};
33
34use crate::error::Error;
35use crate::location::{CaSResult, Consensus, ExternalError, ResultStream, SeqNo, VersionedData};
36
37const SCHEMA: &str = "
38CREATE TABLE IF NOT EXISTS consensus (
39 shard text NOT NULL,
40 sequence_number bigint NOT NULL,
41 data bytea NOT NULL,
42 PRIMARY KEY(shard, sequence_number)
43)
44";
45
46const CRDB_SCHEMA_OPTIONS: &str = "WITH (sql_stats_automatic_collection_enabled = false)";
53const CRDB_CONFIGURE_ZONE: &str = "ALTER TABLE consensus CONFIGURE ZONE USING gc.ttlseconds = 600";
61
62impl ToSql for SeqNo {
63 fn to_sql(
64 &self,
65 ty: &Type,
66 w: &mut bytes::BytesMut,
67 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
68 let value = i64::try_from(self.0)?;
70 <i64 as ToSql>::to_sql(&value, ty, w)
71 }
72
73 fn accepts(ty: &Type) -> bool {
74 <i64 as ToSql>::accepts(ty)
75 }
76
77 to_sql_checked!();
78}
79
80impl<'a> FromSql<'a> for SeqNo {
81 fn from_sql(
82 ty: &Type,
83 raw: &'a [u8],
84 ) -> Result<SeqNo, Box<dyn std::error::Error + Sync + Send>> {
85 let sequence_number = <i64 as FromSql>::from_sql(ty, raw)?;
86
87 let sequence_number = u64::try_from(sequence_number)?;
90 Ok(SeqNo(sequence_number))
91 }
92
93 fn accepts(ty: &Type) -> bool {
94 <i64 as FromSql>::accepts(ty)
95 }
96}
97
98#[derive(Clone, Debug)]
100pub struct PostgresConsensusConfig {
101 url: SensitiveUrl,
102 knobs: Arc<dyn PostgresClientKnobs>,
103 metrics: PostgresClientMetrics,
104}
105
106impl From<PostgresConsensusConfig> for PostgresClientConfig {
107 fn from(config: PostgresConsensusConfig) -> Self {
108 PostgresClientConfig::new(config.url, config.knobs, config.metrics)
109 }
110}
111
112impl PostgresConsensusConfig {
113 const EXTERNAL_TESTS_POSTGRES_URL: &'static str =
114 "MZ_PERSIST_EXTERNAL_STORAGE_TEST_POSTGRES_URL";
115
116 pub fn new(
118 url: &SensitiveUrl,
119 knobs: Box<dyn PostgresClientKnobs>,
120 metrics: PostgresClientMetrics,
121 ) -> Result<Self, Error> {
122 Ok(PostgresConsensusConfig {
123 url: url.clone(),
124 knobs: Arc::from(knobs),
125 metrics,
126 })
127 }
128
129 pub fn new_for_test() -> Result<Option<Self>, Error> {
139 let url = match std::env::var(Self::EXTERNAL_TESTS_POSTGRES_URL) {
140 Ok(url) => SensitiveUrl::from_str(&url).map_err(|e| e.to_string())?,
141 Err(_) => {
142 if mz_ore::env::is_var_truthy("CI") {
143 panic!("CI is supposed to run this test but something has gone wrong!");
144 }
145 return Ok(None);
146 }
147 };
148
149 struct TestConsensusKnobs;
150 impl std::fmt::Debug for TestConsensusKnobs {
151 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("TestConsensusKnobs").finish_non_exhaustive()
153 }
154 }
155 impl PostgresClientKnobs for TestConsensusKnobs {
156 fn connection_pool_max_size(&self) -> usize {
157 2
158 }
159
160 fn connection_pool_max_wait(&self) -> Option<Duration> {
161 Some(Duration::from_secs(1))
162 }
163
164 fn connection_pool_ttl(&self) -> Duration {
165 Duration::MAX
166 }
167 fn connection_pool_ttl_stagger(&self) -> Duration {
168 Duration::MAX
169 }
170 fn connect_timeout(&self) -> Duration {
171 Duration::MAX
172 }
173 fn tcp_user_timeout(&self) -> Duration {
174 Duration::ZERO
175 }
176 }
177
178 let config = PostgresConsensusConfig::new(
179 &url,
180 Box::new(TestConsensusKnobs),
181 PostgresClientMetrics::new(&MetricsRegistry::new(), "mz_persist"),
182 )?;
183 Ok(Some(config))
184 }
185}
186
187pub struct PostgresConsensus {
189 postgres_client: PostgresClient,
190}
191
192impl std::fmt::Debug for PostgresConsensus {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("PostgresConsensus").finish_non_exhaustive()
195 }
196}
197
198impl PostgresConsensus {
199 pub async fn open(config: PostgresConsensusConfig) -> Result<Self, ExternalError> {
202 let pg_config: Config = config.url.to_string().parse()?;
204 let role = pg_config.get_user().unwrap();
205 let create_schema = format!(
206 "CREATE SCHEMA IF NOT EXISTS consensus AUTHORIZATION {}",
207 escape_identifier(role),
208 );
209
210 let postgres_client = PostgresClient::open(config.into())?;
211
212 let client = postgres_client.get_connection().await?;
213
214 let crdb_mode = match client
215 .batch_execute(&format!(
216 "{}; {}{}; {};",
217 create_schema, SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE,
218 ))
219 .await
220 {
221 Ok(()) => true,
222 Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => {
223 warn!(
224 "unable to ALTER TABLE consensus, this is expected and OK when connecting with a read-only user"
225 );
226 true
227 }
228 Err(e)
229 if e.code() == Some(&SqlState::INVALID_PARAMETER_VALUE)
230 || e.code() == Some(&SqlState::SYNTAX_ERROR) =>
231 {
232 info!(
233 "unable to initiate consensus with CRDB params, this is expected and OK when running against Postgres: {:?}",
234 e
235 );
236 false
237 }
238 Err(e) => return Err(e.into()),
239 };
240
241 if !crdb_mode {
242 client
243 .batch_execute(&format!("{}; {};", create_schema, SCHEMA))
244 .await?;
245 }
246
247 Ok(PostgresConsensus { postgres_client })
248 }
249
250 pub async fn drop_and_recreate(&self) -> Result<(), ExternalError> {
254 let client = self.get_connection().await?;
256 client.execute("DROP TABLE consensus", &[]).await?;
257 let crdb_mode = match client
258 .batch_execute(&format!(
259 "{}{}; {}",
260 SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE,
261 ))
262 .await
263 {
264 Ok(()) => true,
265 Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => {
266 warn!(
267 "unable to ALTER TABLE consensus, this is expected and OK when connecting with a read-only user"
268 );
269 true
270 }
271 Err(e)
272 if e.code() == Some(&SqlState::INVALID_PARAMETER_VALUE)
273 || e.code() == Some(&SqlState::SYNTAX_ERROR) =>
274 {
275 info!(
276 "unable to initiate consensus with CRDB params, this is expected and OK when running against Postgres: {:?}",
277 e
278 );
279 false
280 }
281 Err(e) => return Err(e.into()),
282 };
283
284 if !crdb_mode {
285 client.execute(SCHEMA, &[]).await?;
286 }
287 Ok(())
288 }
289
290 async fn get_connection(&self) -> Result<Object, PoolError> {
291 self.postgres_client.get_connection().await
292 }
293}
294
295#[async_trait]
296impl Consensus for PostgresConsensus {
297 fn list_keys(&self) -> ResultStream<String> {
298 let q = "SELECT DISTINCT shard FROM consensus";
299
300 Box::pin(try_stream! {
301 let client = self.get_connection().await?;
304 let statement = client.prepare_cached(q).await?;
305 let params: &[String] = &[];
306 let mut rows = Box::pin(client.query_raw(&statement, params).await?);
307 while let Some(row) = rows.next().await {
308 let shard: String = row?.try_get("shard")?;
309 yield shard;
310 }
311 })
312 }
313
314 async fn head(&self, key: &str) -> Result<Option<VersionedData>, ExternalError> {
315 let q = "SELECT sequence_number, data FROM consensus
316 WHERE shard = $1 ORDER BY sequence_number DESC LIMIT 1";
317 let row = {
318 let client = self.get_connection().await?;
319 let statement = client.prepare_cached(q).await?;
320 client.query_opt(&statement, &[&key]).await?
321 };
322 let row = match row {
323 None => return Ok(None),
324 Some(row) => row,
325 };
326
327 let seqno: SeqNo = row.try_get("sequence_number")?;
328
329 let data: Vec<u8> = row.try_get("data")?;
330 Ok(Some(VersionedData {
331 seqno,
332 data: Bytes::from(data),
333 }))
334 }
335
336 async fn compare_and_set(
337 &self,
338 key: &str,
339 expected: Option<SeqNo>,
340 new: VersionedData,
341 ) -> Result<CaSResult, ExternalError> {
342 if let Some(expected) = expected {
343 if new.seqno <= expected {
344 return Err(Error::from(
345 format!("new seqno must be strictly greater than expected. Got new: {:?} expected: {:?}",
346 new.seqno, expected)).into());
347 }
348 }
349
350 let result = if let Some(expected) = expected {
351 let q = r#"
358 INSERT INTO consensus (shard, sequence_number, data)
359 SELECT $1, $2, $3
360 WHERE (SELECT sequence_number FROM consensus
361 WHERE shard = $1
362 ORDER BY sequence_number DESC LIMIT 1) = $4;
363 "#;
364 let client = self.get_connection().await?;
365 let statement = client.prepare_cached(q).await?;
366 client
367 .execute(
368 &statement,
369 &[&key, &new.seqno, &new.data.as_ref(), &expected],
370 )
371 .await?
372 } else {
373 let q = "INSERT INTO consensus SELECT $1, $2, $3 WHERE
375 NOT EXISTS (
376 SELECT * FROM consensus WHERE shard = $1
377 )
378 ON CONFLICT DO NOTHING";
379 let client = self.get_connection().await?;
380 let statement = client.prepare_cached(q).await?;
381 client
382 .execute(&statement, &[&key, &new.seqno, &new.data.as_ref()])
383 .await?
384 };
385
386 if result == 1 {
387 Ok(CaSResult::Committed)
388 } else {
389 Ok(CaSResult::ExpectationMismatch)
390 }
391 }
392
393 async fn scan(
394 &self,
395 key: &str,
396 from: SeqNo,
397 limit: usize,
398 ) -> Result<Vec<VersionedData>, ExternalError> {
399 let q = "SELECT sequence_number, data FROM consensus
400 WHERE shard = $1 AND sequence_number >= $2
401 ORDER BY sequence_number ASC LIMIT $3";
402 let Ok(limit) = i64::try_from(limit) else {
403 return Err(ExternalError::from(anyhow!(
404 "limit must be [0, i64::MAX]. was: {:?}",
405 limit
406 )));
407 };
408 let rows = {
409 let client = self.get_connection().await?;
410 let statement = client.prepare_cached(q).await?;
411 client.query(&statement, &[&key, &from, &limit]).await?
412 };
413 let mut results = Vec::with_capacity(rows.len());
414
415 for row in rows {
416 let seqno: SeqNo = row.try_get("sequence_number")?;
417 let data: Vec<u8> = row.try_get("data")?;
418 results.push(VersionedData {
419 seqno,
420 data: Bytes::from(data),
421 });
422 }
423 Ok(results)
424 }
425
426 async fn truncate(&self, key: &str, seqno: SeqNo) -> Result<usize, ExternalError> {
427 let q = "DELETE FROM consensus
428 WHERE shard = $1 AND sequence_number < $2 AND
429 EXISTS(
430 SELECT * FROM consensus WHERE shard = $1 AND sequence_number >= $2
431 )";
432
433 let result = {
434 let client = self.get_connection().await?;
435 let statement = client.prepare_cached(q).await?;
436 client.execute(&statement, &[&key, &seqno]).await?
437 };
438 if result == 0 {
439 let current = self.head(key).await?;
451 if current.map_or(true, |data| data.seqno < seqno) {
452 return Err(ExternalError::from(anyhow!(
453 "upper bound too high for truncate: {:?}",
454 seqno
455 )));
456 }
457 }
458
459 Ok(usize::cast_from(result))
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use mz_ore::assert_err;
466 use tracing::info;
467 use uuid::Uuid;
468
469 use crate::location::tests::consensus_impl_test;
470
471 use super::*;
472
473 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
474 #[cfg_attr(miri, ignore)] async fn postgres_consensus() -> Result<(), ExternalError> {
476 let config = match PostgresConsensusConfig::new_for_test()? {
477 Some(config) => config,
478 None => {
479 info!(
480 "{} env not set: skipping test that uses external service",
481 PostgresConsensusConfig::EXTERNAL_TESTS_POSTGRES_URL
482 );
483 return Ok(());
484 }
485 };
486
487 consensus_impl_test(|| PostgresConsensus::open(config.clone())).await?;
488
489 let consensus = PostgresConsensus::open(config.clone()).await?;
491 let key = Uuid::new_v4().to_string();
492 let state = VersionedData {
493 seqno: SeqNo(5),
494 data: Bytes::from("abc"),
495 };
496
497 assert_eq!(
498 consensus.compare_and_set(&key, None, state.clone()).await,
499 Ok(CaSResult::Committed),
500 );
501
502 assert_eq!(consensus.head(&key).await, Ok(Some(state.clone())));
503
504 consensus.drop_and_recreate().await?;
505
506 assert_eq!(consensus.head(&key).await, Ok(None));
507
508 let config = match PostgresConsensusConfig::new_for_test()? {
512 Some(config) => config,
513 None => {
514 info!(
515 "{} env not set: skipping test that uses external service",
516 PostgresConsensusConfig::EXTERNAL_TESTS_POSTGRES_URL
517 );
518 return Ok(());
519 }
520 };
521
522 let consensus: PostgresConsensus = PostgresConsensus::open(config.clone()).await?;
523 let _conn1 = consensus.get_connection().await?;
525 let _conn2 = consensus.get_connection().await?;
526
527 let conn3 = consensus.get_connection().await;
529
530 assert_err!(conn3);
531
532 Ok(())
533 }
534}