use std::collections::BTreeMap;
use mz_pgrepr::Type;
use mz_sql_parser::ast::{Ident, UnresolvedItemName};
use postgres_protocol::escape;
use tokio_postgres::Client;
use crate::destination::{config, FIVETRAN_SYSTEM_COLUMN_DELETE};
use crate::error::{Context, OpError, OpErrorKind};
use crate::fivetran_sdk::{
AlterTableRequest, Column, CreateTableRequest, DataType, DecimalParams, DescribeTableRequest,
Table,
};
use crate::utils;
const PRIMARY_KEY_MAGIC_STRING: &str = "mz_is_primary_key";
pub async fn handle_describe_table(
request: DescribeTableRequest,
) -> Result<Option<Table>, OpError> {
let (dbname, client) = config::connect(request.configuration).await?;
describe_table(&client, &dbname, &request.schema_name, &request.table_name).await
}
pub async fn describe_table(
client: &Client,
database: &str,
schema: &str,
table: &str,
) -> Result<Option<Table>, OpError> {
let table_id = {
let rows = client
.query(
r#"SELECT t.id
FROM mz_tables t
JOIN mz_schemas s ON s.id = t.schema_id
JOIN mz_databases d ON d.id = s.database_id
WHERE d.name = $1 AND s.name = $2 AND t.name = $3
"#,
&[&database, &schema, &table],
)
.await
.context("fetching table ID")?;
match &*rows {
[] => return Ok(None),
[row] => row.get::<_, String>("id"),
_ => {
let err = OpErrorKind::InvariantViolated(
"describe table query returned multiple results".to_string(),
);
return Err(err.into());
}
}
};
let columns = {
let stmt = r#"SELECT
name,
type_oid,
type_mod,
COALESCE(coms.comment, '') = $1 AS primary_key
FROM mz_columns AS cols
LEFT JOIN mz_internal.mz_comments AS coms
ON cols.id = coms.id AND cols.position = coms.object_sub_id
WHERE cols.id = $2
ORDER BY cols.position ASC"#;
let rows = client
.query(stmt, &[&PRIMARY_KEY_MAGIC_STRING, &table_id])
.await
.context("fetching table columns")?;
let mut columns = vec![];
for row in rows {
let name = row.get::<_, String>("name");
let primary_key = row.get::<_, bool>("primary_key");
let ty_oid = row.get::<_, u32>("type_oid");
let ty_mod = row.get::<_, i32>("type_mod");
let ty = Type::from_oid_and_typmod(ty_oid, ty_mod).with_context(|| {
format!("looking up type with OID {ty_oid} and modifier {ty_mod}")
})?;
let (ty, decimal) = utils::to_fivetran_type(ty)?;
columns.push(Column {
name,
r#type: ty.into(),
primary_key,
decimal,
})
}
columns
};
Ok(Some(Table {
name: table.to_string(),
columns,
}))
}
pub async fn handle_create_table(request: CreateTableRequest) -> Result<(), OpError> {
let table = request.table.ok_or(OpErrorKind::FieldMissing("table"))?;
let schema = Ident::new(&request.schema_name)?;
let qualified_table_name =
UnresolvedItemName::qualified(&[schema.clone(), Ident::new(&table.name)?]);
let mut total_columns = table.columns;
let contains_delete = total_columns
.iter()
.any(|col| col.name == FIVETRAN_SYSTEM_COLUMN_DELETE);
if !contains_delete {
let delete_column = Column {
name: FIVETRAN_SYSTEM_COLUMN_DELETE.to_string(),
r#type: DataType::Boolean.into(),
primary_key: false,
decimal: None,
};
total_columns.push(delete_column);
}
let mut defs = vec![];
let mut primary_key_columns = vec![];
for column in total_columns {
let name = escape::escape_identifier(&column.name);
let mut ty = utils::to_materialize_type(column.r#type())?.to_string();
if let Some(d) = column.decimal {
ty += &format!("({}, {})", d.precision, d.scale);
}
defs.push(format!("{name} {ty}"));
if column.primary_key {
primary_key_columns.push(name);
}
}
let sql = format!(
r#"BEGIN; CREATE SCHEMA IF NOT EXISTS {schema}; COMMIT;
BEGIN; CREATE TABLE {qualified_table_name} ({defs}); COMMIT;"#,
defs = defs.join(","),
);
let (_dbname, client) = config::connect(request.configuration).await?;
client.batch_execute(&sql).await?;
for column_name in primary_key_columns {
let stmt = format!(
"COMMENT ON COLUMN {qualified_table_name}.{column_name} IS {magic_comment}",
magic_comment = escape::escape_literal(PRIMARY_KEY_MAGIC_STRING),
);
client
.execute(&stmt, &[])
.await
.context("setting magic primary key comment")?;
}
Ok(())
}
pub async fn handle_alter_table(request: AlterTableRequest) -> Result<(), OpError> {
let (dbname, client) = config::connect(request.configuration).await?;
let Some(request_table) = request.table else {
return Ok(());
};
let current_table = describe_table(&client, &dbname, &request.schema_name, &request_table.name)
.await
.context("alter table")?;
let Some(current_table) = current_table else {
return Err(OpErrorKind::UnknownTable {
database: dbname,
schema: request.schema_name,
table: request_table.name,
}
.into());
};
if columns_match(&request_table, ¤t_table) {
Ok(())
} else {
let error = format!(
"alter_table, request: {:?}, current: {:?}",
request_table, current_table
);
Err(OpErrorKind::Unsupported(error).into())
}
}
fn columns_match(request: &Table, current: &Table) -> bool {
#[derive(Clone, Debug)]
struct ColumnMetadata {
ty: DataType,
primary_key: bool,
decimal: Option<DecimalParams>,
}
impl PartialEq<ColumnMetadata> for ColumnMetadata {
fn eq(&self, other: &ColumnMetadata) -> bool {
self.ty == other.ty
&& self.primary_key == other.primary_key
&& self.decimal.is_some() == other.decimal.is_some()
}
}
let map_columns = |col: &Column| {
let metadata = ColumnMetadata {
ty: col.r#type(),
primary_key: col.primary_key,
decimal: col.decimal.clone(),
};
(col.name.clone(), metadata)
};
let request_cols: BTreeMap<_, _> = request.columns.iter().map(map_columns).collect();
let current_cols: BTreeMap<_, _> = current.columns.iter().map(map_columns).collect();
request_cols == current_cols
}