use std::fmt::Formatter;
use std::sync::Arc;
use std::time::Duration;
use anyhow::anyhow;
use async_stream::try_stream;
use async_trait::async_trait;
use bytes::Bytes;
use deadpool_postgres::tokio_postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use deadpool_postgres::{Object, PoolError};
use futures_util::StreamExt;
use mz_ore::cast::CastFrom;
use mz_ore::metrics::MetricsRegistry;
use mz_postgres_client::metrics::PostgresClientMetrics;
use mz_postgres_client::{PostgresClient, PostgresClientConfig, PostgresClientKnobs};
use tokio_postgres::error::SqlState;
use tracing::warn;
use crate::error::Error;
use crate::location::{CaSResult, Consensus, ExternalError, ResultStream, SeqNo, VersionedData};
const SCHEMA: &str = "
CREATE TABLE IF NOT EXISTS consensus (
shard text NOT NULL,
sequence_number bigint NOT NULL,
data bytea NOT NULL,
PRIMARY KEY(shard, sequence_number)
) WITH (sql_stats_automatic_collection_enabled = false);
";
impl ToSql for SeqNo {
fn to_sql(
&self,
ty: &Type,
w: &mut bytes::BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
let value = i64::try_from(self.0)?;
<i64 as ToSql>::to_sql(&value, ty, w)
}
fn accepts(ty: &Type) -> bool {
<i64 as ToSql>::accepts(ty)
}
to_sql_checked!();
}
impl<'a> FromSql<'a> for SeqNo {
fn from_sql(
ty: &Type,
raw: &'a [u8],
) -> Result<SeqNo, Box<dyn std::error::Error + Sync + Send>> {
let sequence_number = <i64 as FromSql>::from_sql(ty, raw)?;
let sequence_number = u64::try_from(sequence_number)?;
Ok(SeqNo(sequence_number))
}
fn accepts(ty: &Type) -> bool {
<i64 as FromSql>::accepts(ty)
}
}
#[derive(Clone, Debug)]
pub struct PostgresConsensusConfig {
url: String,
knobs: Arc<dyn PostgresClientKnobs>,
metrics: PostgresClientMetrics,
}
impl From<PostgresConsensusConfig> for PostgresClientConfig {
fn from(config: PostgresConsensusConfig) -> Self {
PostgresClientConfig::new(config.url, config.knobs, config.metrics)
}
}
impl PostgresConsensusConfig {
const EXTERNAL_TESTS_POSTGRES_URL: &'static str =
"MZ_PERSIST_EXTERNAL_STORAGE_TEST_POSTGRES_URL";
pub fn new(
url: &str,
knobs: Box<dyn PostgresClientKnobs>,
metrics: PostgresClientMetrics,
) -> Result<Self, Error> {
Ok(PostgresConsensusConfig {
url: url.to_string(),
knobs: Arc::from(knobs),
metrics,
})
}
pub fn new_for_test() -> Result<Option<Self>, Error> {
let url = match std::env::var(Self::EXTERNAL_TESTS_POSTGRES_URL) {
Ok(url) => url,
Err(_) => {
if mz_ore::env::is_var_truthy("CI") {
panic!("CI is supposed to run this test but something has gone wrong!");
}
return Ok(None);
}
};
struct TestConsensusKnobs;
impl std::fmt::Debug for TestConsensusKnobs {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestConsensusKnobs").finish_non_exhaustive()
}
}
impl PostgresClientKnobs for TestConsensusKnobs {
fn connection_pool_max_size(&self) -> usize {
2
}
fn connection_pool_max_wait(&self) -> Option<Duration> {
Some(Duration::from_secs(1))
}
fn connection_pool_ttl(&self) -> Duration {
Duration::MAX
}
fn connection_pool_ttl_stagger(&self) -> Duration {
Duration::MAX
}
fn connect_timeout(&self) -> Duration {
Duration::MAX
}
fn tcp_user_timeout(&self) -> Duration {
Duration::ZERO
}
}
let config = PostgresConsensusConfig::new(
&url,
Box::new(TestConsensusKnobs),
PostgresClientMetrics::new(&MetricsRegistry::new(), "mz_persist"),
)?;
Ok(Some(config))
}
}
pub struct PostgresConsensus {
postgres_client: PostgresClient,
}
impl std::fmt::Debug for PostgresConsensus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgresConsensus").finish_non_exhaustive()
}
}
impl PostgresConsensus {
pub async fn open(config: PostgresConsensusConfig) -> Result<Self, ExternalError> {
let postgres_client = PostgresClient::open(config.into())?;
let client = postgres_client.get_connection().await?;
match client
.batch_execute(&format!(
"{} {}",
SCHEMA, "ALTER TABLE consensus CONFIGURE ZONE USING gc.ttlseconds = 600;",
))
.await
{
Ok(()) => {}
Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => {
warn!("unable to ALTER TABLE consensus, this is expected and OK when connecting with a read-only user");
}
Err(e) => return Err(e.into()),
}
Ok(PostgresConsensus { postgres_client })
}
pub async fn drop_and_recreate(&self) -> Result<(), ExternalError> {
let client = self.get_connection().await?;
client.execute("DROP TABLE consensus", &[]).await?;
client.execute(SCHEMA, &[]).await?;
Ok(())
}
async fn get_connection(&self) -> Result<Object, PoolError> {
self.postgres_client.get_connection().await
}
}
#[async_trait]
impl Consensus for PostgresConsensus {
fn list_keys(&self) -> ResultStream<String> {
let q = "SELECT DISTINCT shard FROM consensus";
Box::pin(try_stream! {
let client = self.get_connection().await?;
let statement = client.prepare_cached(q).await?;
let params: &[String] = &[];
let mut rows = Box::pin(client.query_raw(&statement, params).await?);
while let Some(row) = rows.next().await {
let shard: String = row?.try_get("shard")?;
yield shard;
}
})
}
async fn head(&self, key: &str) -> Result<Option<VersionedData>, ExternalError> {
let q = "SELECT sequence_number, data FROM consensus
WHERE shard = $1 ORDER BY sequence_number DESC LIMIT 1";
let row = {
let client = self.get_connection().await?;
let statement = client.prepare_cached(q).await?;
client.query_opt(&statement, &[&key]).await?
};
let row = match row {
None => return Ok(None),
Some(row) => row,
};
let seqno: SeqNo = row.try_get("sequence_number")?;
let data: Vec<u8> = row.try_get("data")?;
Ok(Some(VersionedData {
seqno,
data: Bytes::from(data),
}))
}
async fn compare_and_set(
&self,
key: &str,
expected: Option<SeqNo>,
new: VersionedData,
) -> Result<CaSResult, ExternalError> {
if let Some(expected) = expected {
if new.seqno <= expected {
return Err(Error::from(
format!("new seqno must be strictly greater than expected. Got new: {:?} expected: {:?}",
new.seqno, expected)).into());
}
}
let result = if let Some(expected) = expected {
let q = r#"
INSERT INTO consensus (shard, sequence_number, data)
SELECT $1, $2, $3
WHERE (SELECT sequence_number FROM consensus
WHERE shard = $1
ORDER BY sequence_number DESC LIMIT 1) = $4;
"#;
let client = self.get_connection().await?;
let statement = client.prepare_cached(q).await?;
client
.execute(
&statement,
&[&key, &new.seqno, &new.data.as_ref(), &expected],
)
.await?
} else {
let q = "INSERT INTO consensus SELECT $1, $2, $3 WHERE
NOT EXISTS (
SELECT * FROM consensus WHERE shard = $1
)
ON CONFLICT DO NOTHING";
let client = self.get_connection().await?;
let statement = client.prepare_cached(q).await?;
client
.execute(&statement, &[&key, &new.seqno, &new.data.as_ref()])
.await?
};
if result == 1 {
Ok(CaSResult::Committed)
} else {
Ok(CaSResult::ExpectationMismatch)
}
}
async fn scan(
&self,
key: &str,
from: SeqNo,
limit: usize,
) -> Result<Vec<VersionedData>, ExternalError> {
let q = "SELECT sequence_number, data FROM consensus
WHERE shard = $1 AND sequence_number >= $2
ORDER BY sequence_number ASC LIMIT $3";
let Ok(limit) = i64::try_from(limit) else {
return Err(ExternalError::from(anyhow!(
"limit must be [0, i64::MAX]. was: {:?}",
limit
)));
};
let rows = {
let client = self.get_connection().await?;
let statement = client.prepare_cached(q).await?;
client.query(&statement, &[&key, &from, &limit]).await?
};
let mut results = Vec::with_capacity(rows.len());
for row in rows {
let seqno: SeqNo = row.try_get("sequence_number")?;
let data: Vec<u8> = row.try_get("data")?;
results.push(VersionedData {
seqno,
data: Bytes::from(data),
});
}
Ok(results)
}
async fn truncate(&self, key: &str, seqno: SeqNo) -> Result<usize, ExternalError> {
let q = "DELETE FROM consensus
WHERE shard = $1 AND sequence_number < $2 AND
EXISTS(
SELECT * FROM consensus WHERE shard = $1 AND sequence_number >= $2
)";
let result = {
let client = self.get_connection().await?;
let statement = client.prepare_cached(q).await?;
client.execute(&statement, &[&key, &seqno]).await?
};
if result == 0 {
let current = self.head(key).await?;
if current.map_or(true, |data| data.seqno < seqno) {
return Err(ExternalError::from(anyhow!(
"upper bound too high for truncate: {:?}",
seqno
)));
}
}
Ok(usize::cast_from(result))
}
}
#[cfg(test)]
mod tests {
use mz_ore::assert_err;
use tracing::info;
use uuid::Uuid;
use crate::location::tests::consensus_impl_test;
use super::*;
#[mz_ore::test(tokio::test(flavor = "multi_thread"))]
#[cfg_attr(miri, ignore)] async fn postgres_consensus() -> Result<(), ExternalError> {
let config = match PostgresConsensusConfig::new_for_test()? {
Some(config) => config,
None => {
info!(
"{} env not set: skipping test that uses external service",
PostgresConsensusConfig::EXTERNAL_TESTS_POSTGRES_URL
);
return Ok(());
}
};
consensus_impl_test(|| PostgresConsensus::open(config.clone())).await?;
let consensus = PostgresConsensus::open(config.clone()).await?;
let key = Uuid::new_v4().to_string();
let state = VersionedData {
seqno: SeqNo(5),
data: Bytes::from("abc"),
};
assert_eq!(
consensus.compare_and_set(&key, None, state.clone()).await,
Ok(CaSResult::Committed),
);
assert_eq!(consensus.head(&key).await, Ok(Some(state.clone())));
consensus.drop_and_recreate().await?;
assert_eq!(consensus.head(&key).await, Ok(None));
Ok(())
}
#[mz_ore::test(tokio::test(flavor = "multi_thread"))]
#[cfg_attr(miri, ignore)] async fn postgres_consensus_blocking() -> Result<(), ExternalError> {
let config = match PostgresConsensusConfig::new_for_test()? {
Some(config) => config,
None => {
info!(
"{} env not set: skipping test that uses external service",
PostgresConsensusConfig::EXTERNAL_TESTS_POSTGRES_URL
);
return Ok(());
}
};
let consensus: PostgresConsensus = PostgresConsensus::open(config.clone()).await?;
let _conn1 = consensus.get_connection().await?;
let _conn2 = consensus.get_connection().await?;
let conn3 = consensus.get_connection().await;
assert_err!(conn3);
Ok(())
}
}