use std::fmt;
use chrono::{DateTime, Utc};
use mz_lowertest::MzReflect;
use mz_ore::cast::TryCastFrom;
use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
use mz_repr::adt::timestamp::CheckedTimestamp;
use mz_repr::{strconv, ColumnType, ScalarType};
use serde::{Deserialize, Serialize};
use crate::scalar::func::EagerUnaryFunc;
use crate::scalar::DomainLimit;
use crate::EvalError;
sqlfunc!(
#[sqlname = "-"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(NegFloat64)]
#[is_monotone = true]
fn neg_float64(a: f64) -> f64 {
-a
}
);
sqlfunc!(
#[sqlname = "abs"]
fn abs_float64(a: f64) -> f64 {
a.abs()
}
);
sqlfunc!(
#[sqlname = "roundf64"]
fn round_float64(a: f64) -> f64 {
a.round_ties_even()
}
);
sqlfunc!(
#[sqlname = "truncf64"]
fn trunc_float64(a: f64) -> f64 {
a.trunc()
}
);
sqlfunc!(
#[sqlname = "ceilf64"]
fn ceil_float64(a: f64) -> f64 {
a.ceil()
}
);
sqlfunc!(
#[sqlname = "floorf64"]
fn floor_float64(a: f64) -> f64 {
a.floor()
}
);
sqlfunc!(
#[sqlname = "double_to_smallint"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastInt16ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_int16(a: f64) -> Result<i16, EvalError> {
let f = round_float64(a);
#[allow(clippy::as_conversions)]
if (f >= (i16::MIN as f64)) && (f < -(i16::MIN as f64)) {
Ok(f as i16)
} else {
Err(EvalError::Int16OutOfRange(f.to_string().into()))
}
}
);
sqlfunc!(
#[sqlname = "double_to_integer"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastInt32ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_int32(a: f64) -> Result<i32, EvalError> {
let f = round_float64(a);
#[allow(clippy::as_conversions)]
if (f >= (i32::MIN as f64)) && (f < -(i32::MIN as f64)) {
Ok(f as i32)
} else {
Err(EvalError::Int32OutOfRange(f.to_string().into()))
}
}
);
sqlfunc!(
#[sqlname = "f64toi64"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastInt64ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_int64(a: f64) -> Result<i64, EvalError> {
let f = round_float64(a);
#[allow(clippy::as_conversions)]
if (f >= (i64::MIN as f64)) && (f < -(i64::MIN as f64)) {
Ok(f as i64)
} else {
Err(EvalError::Int64OutOfRange(f.to_string().into()))
}
}
);
sqlfunc!(
#[sqlname = "double_to_real"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastFloat32ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_float32(a: f64) -> Result<f32, EvalError> {
#[allow(clippy::as_conversions)]
let result = a as f32;
if result.is_infinite() && !a.is_infinite() {
Err(EvalError::FloatOverflow)
} else if result == 0.0 && a != 0.0 {
Err(EvalError::FloatUnderflow)
} else {
Ok(result)
}
}
);
sqlfunc!(
#[sqlname = "double_to_text"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastStringToFloat64)]
fn cast_float64_to_string(a: f64) -> String {
let mut s = String::new();
strconv::format_float64(&mut s, a);
s
}
);
sqlfunc!(
#[sqlname = "double_to_uint2"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastUint16ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_uint16(a: f64) -> Result<u16, EvalError> {
let f = round_float64(a);
#[allow(clippy::as_conversions)]
if (f >= 0.0) && (f <= (u16::MAX as f64)) {
Ok(f as u16)
} else {
Err(EvalError::UInt16OutOfRange(f.to_string().into()))
}
}
);
sqlfunc!(
#[sqlname = "double_to_uint4"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastUint32ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_uint32(a: f64) -> Result<u32, EvalError> {
let f = round_float64(a);
#[allow(clippy::as_conversions)]
if (f >= 0.0) && (f <= (u32::MAX as f64)) {
Ok(f as u32)
} else {
Err(EvalError::UInt32OutOfRange(f.to_string().into()))
}
}
);
sqlfunc!(
#[sqlname = "double_to_uint8"]
#[preserves_uniqueness = false]
#[inverse = to_unary!(super::CastUint64ToFloat64)]
#[is_monotone = true]
fn cast_float64_to_uint64(a: f64) -> Result<u64, EvalError> {
let f = round_float64(a);
#[allow(clippy::as_conversions)]
if (f >= 0.0) && (f <= (u64::MAX as f64)) {
Ok(f as u64)
} else {
Err(EvalError::UInt64OutOfRange(f.to_string().into()))
}
}
);
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
pub struct CastFloat64ToNumeric(pub Option<NumericMaxScale>);
impl<'a> EagerUnaryFunc<'a> for CastFloat64ToNumeric {
type Input = f64;
type Output = Result<Numeric, EvalError>;
fn call(&self, a: f64) -> Result<Numeric, EvalError> {
if a.is_infinite() {
return Err(EvalError::InfinityOutOfDomain(
"casting double precision to numeric".into(),
));
}
let mut a = Numeric::from(a);
if let Some(scale) = self.0 {
if numeric::rescale(&mut a, scale.into_u8()).is_err() {
return Err(EvalError::NumericFieldOverflow);
}
}
match numeric::munge_numeric(&mut a) {
Ok(_) => Ok(a),
Err(_) => Err(EvalError::NumericFieldOverflow),
}
}
fn output_type(&self, input: ColumnType) -> ColumnType {
ScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable)
}
fn inverse(&self) -> Option<crate::UnaryFunc> {
to_unary!(super::CastNumericToFloat64)
}
fn is_monotone(&self) -> bool {
true
}
}
impl fmt::Display for CastFloat64ToNumeric {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("double_to_numeric")
}
}
sqlfunc!(
#[sqlname = "sqrtf64"]
fn sqrt_float64(a: f64) -> Result<f64, EvalError> {
if a < 0.0 {
return Err(EvalError::NegSqrt);
}
Ok(a.sqrt())
}
);
sqlfunc!(
#[sqlname = "cbrtf64"]
fn cbrt_float64(a: f64) -> f64 {
a.cbrt()
}
);
sqlfunc!(
fn cos(a: f64) -> Result<f64, EvalError> {
if a.is_infinite() {
return Err(EvalError::InfinityOutOfDomain("cos".into()));
}
Ok(a.cos())
}
);
sqlfunc!(
fn acos(a: f64) -> Result<f64, EvalError> {
if a < -1.0 || 1.0 < a {
return Err(EvalError::OutOfDomain(
DomainLimit::Inclusive(-1),
DomainLimit::Inclusive(1),
"acos".into(),
));
}
Ok(a.acos())
}
);
sqlfunc!(
fn cosh(a: f64) -> f64 {
a.cosh()
}
);
sqlfunc!(
fn acosh(a: f64) -> Result<f64, EvalError> {
if a < 1.0 {
return Err(EvalError::OutOfDomain(
DomainLimit::Inclusive(1),
DomainLimit::None,
"acosh".into(),
));
}
Ok(a.acosh())
}
);
sqlfunc!(
fn sin(a: f64) -> Result<f64, EvalError> {
if a.is_infinite() {
return Err(EvalError::InfinityOutOfDomain("sin".into()));
}
Ok(a.sin())
}
);
sqlfunc!(
fn asin(a: f64) -> Result<f64, EvalError> {
if a < -1.0 || 1.0 < a {
return Err(EvalError::OutOfDomain(
DomainLimit::Inclusive(-1),
DomainLimit::Inclusive(1),
"asin".into(),
));
}
Ok(a.asin())
}
);
sqlfunc!(
fn sinh(a: f64) -> f64 {
a.sinh()
}
);
sqlfunc!(
fn asinh(a: f64) -> f64 {
a.asinh()
}
);
sqlfunc!(
fn tan(a: f64) -> Result<f64, EvalError> {
if a.is_infinite() {
return Err(EvalError::InfinityOutOfDomain("tan".into()));
}
Ok(a.tan())
}
);
sqlfunc!(
fn atan(a: f64) -> f64 {
a.atan()
}
);
sqlfunc!(
fn tanh(a: f64) -> f64 {
a.tanh()
}
);
sqlfunc!(
fn atanh(a: f64) -> Result<f64, EvalError> {
if a < -1.0 || 1.0 < a {
return Err(EvalError::OutOfDomain(
DomainLimit::Inclusive(-1),
DomainLimit::Inclusive(1),
"atanh".into(),
));
}
Ok(a.atanh())
}
);
sqlfunc!(
fn cot(a: f64) -> Result<f64, EvalError> {
if a.is_infinite() {
return Err(EvalError::InfinityOutOfDomain("cot".into()));
}
Ok(1.0 / a.tan())
}
);
sqlfunc!(
fn radians(a: f64) -> f64 {
a.to_radians()
}
);
sqlfunc!(
fn degrees(a: f64) -> f64 {
a.to_degrees()
}
);
sqlfunc!(
#[sqlname = "log10f64"]
fn log10(a: f64) -> Result<f64, EvalError> {
if a.is_sign_negative() {
return Err(EvalError::NegativeOutOfDomain("log10".into()));
}
if a == 0.0 {
return Err(EvalError::ZeroOutOfDomain("log10".into()));
}
Ok(a.log10())
}
);
sqlfunc!(
#[sqlname = "lnf64"]
fn ln(a: f64) -> Result<f64, EvalError> {
if a.is_sign_negative() {
return Err(EvalError::NegativeOutOfDomain("ln".into()));
}
if a == 0.0 {
return Err(EvalError::ZeroOutOfDomain("ln".into()));
}
Ok(a.ln())
}
);
sqlfunc!(
#[sqlname = "expf64"]
fn exp(a: f64) -> Result<f64, EvalError> {
let r = a.exp();
if r.is_infinite() {
return Err(EvalError::FloatOverflow);
}
if r == 0.0 {
return Err(EvalError::FloatUnderflow);
}
Ok(r)
}
);
sqlfunc!(
#[sqlname = "mz_sleep"]
fn sleep(a: f64) -> Option<CheckedTimestamp<DateTime<Utc>>> {
let duration = std::time::Duration::from_secs_f64(a);
std::thread::sleep(duration);
None
}
);
sqlfunc!(
#[sqlname = "tots"]
fn to_timestamp(f: f64) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
const NANO_SECONDS_PER_SECOND: i64 = 1_000_000_000;
if f.is_nan() {
Err(EvalError::TimestampCannotBeNan)
} else if f.is_infinite() {
Err(EvalError::TimestampOutOfRange)
} else {
let mut secs = i64::try_cast_from(f.trunc()).ok_or(EvalError::TimestampOutOfRange)?;
let microsecs = (f.fract() * 1_000_000.0).round();
let mut nanosecs =
i64::try_cast_from(microsecs * 1_000.0).ok_or(EvalError::TimestampOutOfRange)?;
if nanosecs < 0 {
secs = secs.checked_sub(1).ok_or(EvalError::TimestampOutOfRange)?;
nanosecs = NANO_SECONDS_PER_SECOND
.checked_add(nanosecs)
.ok_or(EvalError::TimestampOutOfRange)?;
}
secs = secs
.checked_add(nanosecs / NANO_SECONDS_PER_SECOND)
.ok_or(EvalError::TimestampOutOfRange)?;
nanosecs %= NANO_SECONDS_PER_SECOND;
let nanosecs = u32::try_from(nanosecs).map_err(|_| EvalError::TimestampOutOfRange)?;
match DateTime::from_timestamp(secs, nanosecs) {
Some(dt) => CheckedTimestamp::from_timestamplike(dt)
.map_err(|_| EvalError::TimestampOutOfRange),
None => Err(EvalError::TimestampOutOfRange),
}
}
}
);