use std::collections::BTreeMap;
use std::sync::Arc;
use bytes::BytesMut;
use mz_controller_types::ClusterId;
use mz_ore::now::{to_datetime, NowFn};
use mz_ore::task::spawn;
use mz_ore::{cast::CastFrom, cast::CastInto, now::EpochMillis};
use mz_repr::adt::array::ArrayDimension;
use mz_repr::adt::timestamp::TimestampLike;
use mz_repr::{Datum, Diff, GlobalId, Row, RowPacker, Timestamp};
use mz_sql::ast::display::AstDisplay;
use mz_sql::ast::{AstInfo, Statement};
use mz_sql::plan::Params;
use mz_sql::session::metadata::SessionMetadata;
use mz_sql_parser::ast::{statement_kind_label_value, StatementKind};
use mz_storage_client::controller::IntrospectionType;
use qcell::QCell;
use rand::SeedableRng;
use rand::{distributions::Bernoulli, prelude::Distribution, thread_rng};
use sha2::{Digest, Sha256};
use tokio::time::MissedTickBehavior;
use tracing::debug;
use uuid::Uuid;
use crate::coord::{ConnMeta, Coordinator};
use crate::session::Session;
use crate::statement_logging::{
SessionHistoryEvent, StatementBeganExecutionRecord, StatementEndedExecutionReason,
StatementEndedExecutionRecord, StatementLifecycleEvent, StatementPreparedRecord,
};
use super::Message;
#[derive(Debug)]
pub enum PreparedStatementLoggingInfo {
AlreadyLogged { uuid: Uuid },
StillToLog {
sql: String,
redacted_sql: String,
prepared_at: EpochMillis,
name: String,
session_id: Uuid,
accounted: bool,
kind: Option<StatementKind>,
_sealed: sealed::Private,
},
}
impl PreparedStatementLoggingInfo {
pub fn still_to_log<A: AstInfo>(
raw_sql: String,
stmt: Option<&Statement<A>>,
prepared_at: EpochMillis,
name: String,
session_id: Uuid,
accounted: bool,
) -> Self {
let kind = stmt.map(StatementKind::from);
let sql = match kind {
Some(StatementKind::CreateSecret | StatementKind::AlterSecret) => {
stmt.map(|s| s.to_ast_string_redacted()).unwrap_or_default()
}
_ => raw_sql,
};
PreparedStatementLoggingInfo::StillToLog {
sql,
redacted_sql: stmt.map(|s| s.to_ast_string_redacted()).unwrap_or_default(),
prepared_at,
name,
session_id,
accounted,
kind,
_sealed: sealed::Private,
}
}
}
#[derive(Copy, Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
pub struct StatementLoggingId(Uuid);
#[derive(Debug)]
pub(crate) struct PreparedStatementEvent {
prepared_statement: Row,
sql_text: Row,
}
#[derive(Debug)]
pub(crate) struct StatementLogging {
executions_begun: BTreeMap<Uuid, StatementBeganExecutionRecord>,
unlogged_sessions: BTreeMap<Uuid, SessionHistoryEvent>,
reproducible_rng: rand_chacha::ChaCha8Rng,
pending_statement_execution_events: Vec<(Row, Diff)>,
pending_prepared_statement_events: Vec<PreparedStatementEvent>,
pending_session_events: Vec<Row>,
pending_statement_lifecycle_events: Vec<Row>,
now: NowFn,
tokens: u64,
last_logged_ts_seconds: u64,
throttled_count: usize,
}
impl StatementLogging {
pub(crate) fn new(now: NowFn) -> Self {
let last_logged_ts_seconds = (now)() / 1000;
Self {
executions_begun: BTreeMap::new(),
unlogged_sessions: BTreeMap::new(),
reproducible_rng: rand_chacha::ChaCha8Rng::seed_from_u64(42),
pending_statement_execution_events: Vec::new(),
pending_prepared_statement_events: Vec::new(),
pending_session_events: Vec::new(),
pending_statement_lifecycle_events: Vec::new(),
tokens: 0,
last_logged_ts_seconds,
now: now.clone(),
throttled_count: 0,
}
}
fn throttling_check(
&mut self,
cost: u64,
target_data_rate: u64,
max_data_credit: Option<u64>,
) -> Option<usize> {
let ts = (self.now)() / 1000;
let elapsed = ts - self.last_logged_ts_seconds;
self.last_logged_ts_seconds = ts;
self.tokens = self
.tokens
.saturating_add(target_data_rate.saturating_mul(elapsed));
if let Some(max_data_credit) = max_data_credit {
self.tokens = self.tokens.min(max_data_credit);
}
if let Some(remaining) = self.tokens.checked_sub(cost) {
debug!("throttling check passed. tokens remaining: {remaining}; cost: {cost}");
self.tokens = remaining;
Some(std::mem::take(&mut self.throttled_count))
} else {
debug!(
"throttling check failed. tokens available: {}; cost: {cost}",
self.tokens
);
self.throttled_count += 1;
None
}
}
}
impl Coordinator {
pub(crate) fn spawn_statement_logging_task(&self) {
let internal_cmd_tx = self.internal_cmd_tx.clone();
spawn(|| "statement_logging", async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
interval.tick().await;
let _ = internal_cmd_tx.send(Message::DrainStatementLog);
}
});
}
#[mz_ore::instrument(level = "debug")]
pub(crate) fn drain_statement_log(&mut self) {
let session_updates = std::mem::take(&mut self.statement_logging.pending_session_events)
.into_iter()
.map(|update| (update, 1))
.collect();
let (prepared_statement_updates, sql_text_updates) =
std::mem::take(&mut self.statement_logging.pending_prepared_statement_events)
.into_iter()
.map(
|PreparedStatementEvent {
prepared_statement,
sql_text,
}| ((prepared_statement, 1), (sql_text, 1)),
)
.unzip::<_, _, Vec<_>, Vec<_>>();
let statement_execution_updates =
std::mem::take(&mut self.statement_logging.pending_statement_execution_events);
let statement_lifecycle_updates =
std::mem::take(&mut self.statement_logging.pending_statement_lifecycle_events)
.into_iter()
.map(|update| (update, 1))
.collect();
use IntrospectionType::*;
for (type_, updates) in [
(SessionHistory, session_updates),
(PreparedStatementHistory, prepared_statement_updates),
(StatementExecutionHistory, statement_execution_updates),
(StatementLifecycleHistory, statement_lifecycle_updates),
(SqlText, sql_text_updates),
] {
if !updates.is_empty() && !self.controller.read_only() {
self.controller
.storage
.append_introspection_updates(type_, updates);
}
}
}
fn statement_logging_throttling_check(&mut self, cost: usize) -> Option<usize> {
let Some(target_data_rate) = self
.catalog
.system_config()
.statement_logging_target_data_rate()
else {
return Some(std::mem::take(&mut self.statement_logging.throttled_count));
};
let max_data_credit = self
.catalog
.system_config()
.statement_logging_max_data_credit();
self.statement_logging.throttling_check(
cost.cast_into(),
target_data_rate.cast_into(),
max_data_credit.map(CastInto::cast_into),
)
}
pub(crate) fn log_prepared_statement(
&mut self,
session: &mut Session,
logging: &Arc<QCell<PreparedStatementLoggingInfo>>,
) -> Option<(
Option<(StatementPreparedRecord, PreparedStatementEvent)>,
Uuid,
)> {
let logging = session.qcell_rw(&*logging);
let mut out = None;
let uuid = match logging {
PreparedStatementLoggingInfo::AlreadyLogged { uuid } => *uuid,
PreparedStatementLoggingInfo::StillToLog {
sql,
redacted_sql,
prepared_at,
name,
session_id,
accounted,
kind,
_sealed: _,
} => {
assert!(
*accounted,
"accounting for logging should be done in `begin_statement_execution`"
);
let uuid = Uuid::new_v4();
let sql = std::mem::take(sql);
let redacted_sql = std::mem::take(redacted_sql);
let sql_hash: [u8; 32] = Sha256::digest(sql.as_bytes()).into();
let record = StatementPreparedRecord {
id: uuid,
sql_hash,
name: std::mem::take(name),
session_id: *session_id,
prepared_at: *prepared_at,
kind: *kind,
};
let mut mpsh_row = Row::default();
let mut mpsh_packer = mpsh_row.packer();
Self::pack_statement_prepared_update(&record, &mut mpsh_packer);
let sql_row = Row::pack([
Datum::TimestampTz(
to_datetime(*prepared_at)
.truncate_day()
.try_into()
.expect("must fit"),
),
Datum::Bytes(sql_hash.as_slice()),
Datum::String(sql.as_str()),
Datum::String(redacted_sql.as_str()),
]);
let cost = mpsh_packer.byte_len() + sql_row.byte_len();
let throttled_count = self.statement_logging_throttling_check(cost)?;
mpsh_packer.push(Datum::UInt64(throttled_count.try_into().expect("must fit")));
out = Some((
record,
PreparedStatementEvent {
prepared_statement: mpsh_row,
sql_text: sql_row,
},
));
*logging = PreparedStatementLoggingInfo::AlreadyLogged { uuid };
uuid
}
};
Some((out, uuid))
}
pub fn statement_execution_sample_rate(&self, session: &Session) -> f64 {
let system: f64 = self
.catalog()
.system_config()
.statement_logging_max_sample_rate()
.try_into()
.expect("value constrained to be convertible to f64");
let user: f64 = session
.vars()
.get_statement_logging_sample_rate()
.try_into()
.expect("value constrained to be convertible to f64");
f64::min(system, user)
}
pub fn end_statement_execution(
&mut self,
id: StatementLoggingId,
reason: StatementEndedExecutionReason,
) {
let StatementLoggingId(uuid) = id;
let now = self.now();
let ended_record = StatementEndedExecutionRecord {
id: uuid,
reason,
ended_at: now,
};
let began_record = self
.statement_logging
.executions_begun
.remove(&uuid)
.expect(
"matched `begin_statement_execution` and `end_statement_execution` invocations",
);
for (row, diff) in
Self::pack_statement_ended_execution_updates(&began_record, &ended_record)
{
self.statement_logging
.pending_statement_execution_events
.push((row, diff));
}
self.record_statement_lifecycle_event(
&id,
&StatementLifecycleEvent::ExecutionFinished,
now,
);
}
fn pack_statement_execution_inner(
record: &StatementBeganExecutionRecord,
packer: &mut RowPacker,
) {
let StatementBeganExecutionRecord {
id,
prepared_statement_id,
sample_rate,
params,
began_at,
cluster_id,
cluster_name,
database_name,
search_path,
application_name,
transaction_isolation,
execution_timestamp,
transaction_id,
transient_index_id,
mz_version,
} = record;
let cluster = cluster_id.map(|id| id.to_string());
let transient_index_id = transient_index_id.map(|id| id.to_string());
packer.extend([
Datum::Uuid(*id),
Datum::Uuid(*prepared_statement_id),
Datum::Float64((*sample_rate).into()),
match &cluster {
None => Datum::Null,
Some(cluster_id) => Datum::String(cluster_id),
},
Datum::String(&*application_name),
cluster_name.as_ref().map(String::as_str).into(),
Datum::String(database_name),
]);
packer.push_list(search_path.iter().map(|s| Datum::String(s)));
packer.extend([
Datum::String(&*transaction_isolation),
(*execution_timestamp).into(),
Datum::UInt64(*transaction_id),
match &transient_index_id {
None => Datum::Null,
Some(transient_index_id) => Datum::String(transient_index_id),
},
]);
packer
.push_array(
&[ArrayDimension {
lower_bound: 1,
length: params.len(),
}],
params
.iter()
.map(|p| Datum::from(p.as_ref().map(String::as_str))),
)
.expect("correct array dimensions");
packer.push(Datum::from(mz_version.as_str()));
packer.push(Datum::TimestampTz(
to_datetime(*began_at).try_into().expect("Sane system time"),
));
}
fn pack_statement_began_execution_update(record: &StatementBeganExecutionRecord) -> Row {
let mut row = Row::default();
let mut packer = row.packer();
Self::pack_statement_execution_inner(record, &mut packer);
packer.extend([
Datum::Null,
Datum::Null,
Datum::Null,
Datum::Null,
Datum::Null,
]);
row
}
fn pack_statement_prepared_update(record: &StatementPreparedRecord, packer: &mut RowPacker) {
let StatementPreparedRecord {
id,
session_id,
name,
sql_hash,
prepared_at,
kind,
} = record;
packer.extend([
Datum::Uuid(*id),
Datum::Uuid(*session_id),
Datum::String(name.as_str()),
Datum::Bytes(sql_hash.as_slice()),
Datum::TimestampTz(to_datetime(*prepared_at).try_into().expect("must fit")),
kind.map(statement_kind_label_value).into(),
]);
}
fn pack_session_history_update(event: &SessionHistoryEvent) -> Row {
let SessionHistoryEvent {
id,
connected_at,
application_name,
authenticated_user,
} = event;
Row::pack_slice(&[
Datum::Uuid(*id),
Datum::TimestampTz(
mz_ore::now::to_datetime(*connected_at)
.try_into()
.expect("must fit"),
),
Datum::String(&*application_name),
Datum::String(&*authenticated_user),
])
}
fn pack_statement_lifecycle_event(
StatementLoggingId(uuid): &StatementLoggingId,
event: &StatementLifecycleEvent,
when: EpochMillis,
) -> Row {
Row::pack_slice(&[
Datum::Uuid(*uuid),
Datum::String(event.as_str()),
Datum::TimestampTz(mz_ore::now::to_datetime(when).try_into().expect("must fit")),
])
}
pub fn pack_full_statement_execution_update(
began_record: &StatementBeganExecutionRecord,
ended_record: &StatementEndedExecutionRecord,
) -> Row {
let mut row = Row::default();
let mut packer = row.packer();
Self::pack_statement_execution_inner(began_record, &mut packer);
let (status, error_message, rows_returned, execution_strategy) = match &ended_record.reason
{
StatementEndedExecutionReason::Success {
rows_returned,
execution_strategy,
} => (
"success",
None,
rows_returned.map(|rr| i64::try_from(rr).expect("must fit")),
execution_strategy.map(|es| es.name()),
),
StatementEndedExecutionReason::Canceled => ("canceled", None, None, None),
StatementEndedExecutionReason::Errored { error } => {
("error", Some(error.as_str()), None, None)
}
StatementEndedExecutionReason::Aborted => ("aborted", None, None, None),
};
packer.extend([
Datum::TimestampTz(
to_datetime(ended_record.ended_at)
.try_into()
.expect("Sane system time"),
),
status.into(),
error_message.into(),
rows_returned.into(),
execution_strategy.into(),
]);
row
}
pub fn pack_statement_ended_execution_updates(
began_record: &StatementBeganExecutionRecord,
ended_record: &StatementEndedExecutionRecord,
) -> [(Row, Diff); 2] {
let retraction = Self::pack_statement_began_execution_update(began_record);
let new = Self::pack_full_statement_execution_update(began_record, ended_record);
[(retraction, -1), (new, 1)]
}
fn mutate_record<F: FnOnce(&mut StatementBeganExecutionRecord)>(
&mut self,
StatementLoggingId(id): StatementLoggingId,
f: F,
) {
let record = self
.statement_logging
.executions_begun
.get_mut(&id)
.expect("mutate_record must not be called after execution ends");
let retraction = Self::pack_statement_began_execution_update(record);
self.statement_logging
.pending_statement_execution_events
.push((retraction, -1));
f(record);
let update = Self::pack_statement_began_execution_update(record);
self.statement_logging
.pending_statement_execution_events
.push((update, 1));
}
pub fn set_statement_execution_cluster(
&mut self,
id: StatementLoggingId,
cluster_id: ClusterId,
) {
let cluster_name = self.catalog().get_cluster(cluster_id).name.clone();
self.mutate_record(id, |record| {
record.cluster_name = Some(cluster_name);
record.cluster_id = Some(cluster_id);
});
}
pub fn set_statement_execution_timestamp(
&mut self,
id: StatementLoggingId,
timestamp: Timestamp,
) {
self.mutate_record(id, |record| {
record.execution_timestamp = Some(u64::from(timestamp));
});
}
pub fn set_transient_index_id(&mut self, id: StatementLoggingId, transient_index_id: GlobalId) {
self.mutate_record(id, |record| {
record.transient_index_id = Some(transient_index_id)
});
}
pub fn begin_statement_execution(
&mut self,
session: &mut Session,
params: &Params,
logging: &Arc<QCell<PreparedStatementLoggingInfo>>,
) -> Option<StatementLoggingId> {
let enable_internal_statement_logging = self
.catalog()
.system_config()
.enable_internal_statement_logging();
if session.user().is_internal() && !enable_internal_statement_logging {
return None;
}
let sample_rate = self.statement_execution_sample_rate(session);
let distribution = Bernoulli::new(sample_rate).expect("rate must be in range [0, 1]");
let sample = if self
.catalog()
.system_config()
.statement_logging_use_reproducible_rng()
{
distribution.sample(&mut self.statement_logging.reproducible_rng)
} else {
distribution.sample(&mut thread_rng())
};
let sampled_label = sample.then_some("true").unwrap_or("false");
self.metrics
.statement_logging_records
.with_label_values(&[sampled_label])
.inc_by(1);
if let Some((sql, accounted)) = match session.qcell_rw(logging) {
PreparedStatementLoggingInfo::AlreadyLogged { .. } => None,
PreparedStatementLoggingInfo::StillToLog { sql, accounted, .. } => {
Some((sql, accounted))
}
} {
if !*accounted {
self.metrics
.statement_logging_unsampled_bytes
.with_label_values(&[])
.inc_by(u64::cast_from(sql.len()));
if sample {
self.metrics
.statement_logging_actual_bytes
.with_label_values(&[])
.inc_by(u64::cast_from(sql.len()));
}
*accounted = true;
}
}
if !sample {
return None;
}
let (ps_record, ps_uuid) = self.log_prepared_statement(session, logging)?;
let ev_id = Uuid::new_v4();
let now = self.now();
self.record_statement_lifecycle_event(
&StatementLoggingId(ev_id),
&StatementLifecycleEvent::ExecutionBegan,
now,
);
let params = std::iter::zip(params.types.iter(), params.datums.iter())
.map(|(r#type, datum)| {
mz_pgrepr::Value::from_datum(datum, r#type).map(|val| {
let mut buf = BytesMut::new();
val.encode_text(&mut buf);
String::from_utf8(Into::<Vec<u8>>::into(buf))
.expect("Serialization shouldn't produce non-UTF-8 strings.")
})
})
.collect();
let record = StatementBeganExecutionRecord {
id: ev_id,
prepared_statement_id: ps_uuid,
sample_rate,
params,
began_at: self.now(),
application_name: session.application_name().to_string(),
transaction_isolation: session.vars().transaction_isolation().to_string(),
transaction_id: session
.transaction()
.inner()
.expect("Every statement runs in an explicit or implicit transaction")
.id,
mz_version: self
.catalog()
.state()
.config()
.build_info
.human_version(None),
cluster_id: None,
cluster_name: None,
execution_timestamp: None,
transient_index_id: None,
database_name: session.vars().database().into(),
search_path: session
.vars()
.search_path()
.iter()
.map(|s| s.as_str().to_string())
.collect(),
};
let mseh_update = Self::pack_statement_began_execution_update(&record);
self.statement_logging
.pending_statement_execution_events
.push((mseh_update, 1));
self.statement_logging
.executions_begun
.insert(ev_id, record);
if let Some((ps_record, ps_update)) = ps_record {
self.statement_logging
.pending_prepared_statement_events
.push(ps_update);
if let Some(sh) = self
.statement_logging
.unlogged_sessions
.remove(&ps_record.session_id)
{
let sh_update = Self::pack_session_history_update(&sh);
self.statement_logging
.pending_session_events
.push(sh_update);
}
}
Some(StatementLoggingId(ev_id))
}
pub fn begin_session_for_statement_logging(&mut self, session: &ConnMeta) {
let id = session.uuid();
let session_role = session.authenticated_role_id();
let event = SessionHistoryEvent {
id,
connected_at: session.connected_at(),
application_name: session.application_name().to_owned(),
authenticated_user: self.catalog.get_role(session_role).name.clone(),
};
self.statement_logging.unlogged_sessions.insert(id, event);
}
pub fn end_session_for_statement_logging(&mut self, uuid: Uuid) {
self.statement_logging.unlogged_sessions.remove(&uuid);
}
pub fn record_statement_lifecycle_event(
&mut self,
id: &StatementLoggingId,
event: &StatementLifecycleEvent,
when: EpochMillis,
) {
if mz_adapter_types::dyncfgs::ENABLE_STATEMENT_LIFECYCLE_LOGGING
.get(self.catalog().system_config().dyncfgs())
{
let row = Self::pack_statement_lifecycle_event(id, event, when);
self.statement_logging
.pending_statement_lifecycle_events
.push(row);
}
}
}
mod sealed {
#[derive(Debug, Copy, Clone)]
pub struct Private;
}