mz_expr/scalar/func/impls/
float64.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use chrono::{DateTime, Utc};
13use mz_lowertest::MzReflect;
14use mz_ore::cast::TryCastFrom;
15use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
16use mz_repr::adt::timestamp::CheckedTimestamp;
17use mz_repr::{ColumnType, ScalarType, strconv};
18use serde::{Deserialize, Serialize};
19
20use crate::EvalError;
21use crate::scalar::DomainLimit;
22use crate::scalar::func::EagerUnaryFunc;
23
24sqlfunc!(
25    #[sqlname = "-"]
26    #[preserves_uniqueness = false]
27    #[inverse = to_unary!(NegFloat64)]
28    #[is_monotone = true]
29    fn neg_float64(a: f64) -> f64 {
30        -a
31    }
32);
33
34sqlfunc!(
35    #[sqlname = "abs"]
36    fn abs_float64(a: f64) -> f64 {
37        a.abs()
38    }
39);
40
41sqlfunc!(
42    #[sqlname = "roundf64"]
43    fn round_float64(a: f64) -> f64 {
44        a.round_ties_even()
45    }
46);
47
48sqlfunc!(
49    #[sqlname = "truncf64"]
50    fn trunc_float64(a: f64) -> f64 {
51        a.trunc()
52    }
53);
54
55sqlfunc!(
56    #[sqlname = "ceilf64"]
57    fn ceil_float64(a: f64) -> f64 {
58        a.ceil()
59    }
60);
61
62sqlfunc!(
63    #[sqlname = "floorf64"]
64    fn floor_float64(a: f64) -> f64 {
65        a.floor()
66    }
67);
68
69sqlfunc!(
70    #[sqlname = "double_to_smallint"]
71    #[preserves_uniqueness = false]
72    #[inverse = to_unary!(super::CastInt16ToFloat64)]
73    #[is_monotone = true]
74    fn cast_float64_to_int16(a: f64) -> Result<i16, EvalError> {
75        let f = round_float64(a);
76        // TODO(benesch): remove potentially dangerous usage of `as`.
77        #[allow(clippy::as_conversions)]
78        if (f >= (i16::MIN as f64)) && (f < -(i16::MIN as f64)) {
79            Ok(f as i16)
80        } else {
81            Err(EvalError::Int16OutOfRange(f.to_string().into()))
82        }
83    }
84);
85
86sqlfunc!(
87    #[sqlname = "double_to_integer"]
88    #[preserves_uniqueness = false]
89    #[inverse = to_unary!(super::CastInt32ToFloat64)]
90    #[is_monotone = true]
91    fn cast_float64_to_int32(a: f64) -> Result<i32, EvalError> {
92        let f = round_float64(a);
93        // This condition is delicate because i32::MIN can be represented exactly by
94        // an f64 but not i32::MAX. We follow PostgreSQL's approach here.
95        //
96        // See: https://github.com/postgres/postgres/blob/ca3b37487/src/include/c.h#L1074-L1096
97        // TODO(benesch): remove potentially dangerous usage of `as`.
98        #[allow(clippy::as_conversions)]
99        if (f >= (i32::MIN as f64)) && (f < -(i32::MIN as f64)) {
100            Ok(f as i32)
101        } else {
102            Err(EvalError::Int32OutOfRange(f.to_string().into()))
103        }
104    }
105);
106
107sqlfunc!(
108    #[sqlname = "f64toi64"]
109    #[preserves_uniqueness = false]
110    #[inverse = to_unary!(super::CastInt64ToFloat64)]
111    #[is_monotone = true]
112    fn cast_float64_to_int64(a: f64) -> Result<i64, EvalError> {
113        let f = round_float64(a);
114        // This condition is delicate because i64::MIN can be represented exactly by
115        // an f64 but not i64::MAX. We follow PostgreSQL's approach here.
116        //
117        // See: https://github.com/postgres/postgres/blob/ca3b37487/src/include/c.h#L1074-L1096
118        // TODO(benesch): remove potentially dangerous usage of `as`.
119        #[allow(clippy::as_conversions)]
120        if (f >= (i64::MIN as f64)) && (f < -(i64::MIN as f64)) {
121            Ok(f as i64)
122        } else {
123            Err(EvalError::Int64OutOfRange(f.to_string().into()))
124        }
125    }
126);
127
128sqlfunc!(
129    #[sqlname = "double_to_real"]
130    #[preserves_uniqueness = false]
131    #[inverse = to_unary!(super::CastFloat32ToFloat64)]
132    #[is_monotone = true]
133    fn cast_float64_to_float32(a: f64) -> Result<f32, EvalError> {
134        // TODO(benesch): remove potentially dangerous usage of `as`.
135        #[allow(clippy::as_conversions)]
136        let result = a as f32;
137        if result.is_infinite() && !a.is_infinite() {
138            Err(EvalError::FloatOverflow)
139        } else if result == 0.0 && a != 0.0 {
140            Err(EvalError::FloatUnderflow)
141        } else {
142            Ok(result)
143        }
144    }
145);
146
147sqlfunc!(
148    #[sqlname = "double_to_text"]
149    #[preserves_uniqueness = false]
150    #[inverse = to_unary!(super::CastStringToFloat64)]
151    fn cast_float64_to_string(a: f64) -> String {
152        let mut s = String::new();
153        strconv::format_float64(&mut s, a);
154        s
155    }
156);
157
158sqlfunc!(
159    #[sqlname = "double_to_uint2"]
160    #[preserves_uniqueness = false]
161    #[inverse = to_unary!(super::CastUint16ToFloat64)]
162    #[is_monotone = true]
163    fn cast_float64_to_uint16(a: f64) -> Result<u16, EvalError> {
164        let f = round_float64(a);
165        // TODO(benesch): remove potentially dangerous usage of `as`.
166        #[allow(clippy::as_conversions)]
167        if (f >= 0.0) && (f <= (u16::MAX as f64)) {
168            Ok(f as u16)
169        } else {
170            Err(EvalError::UInt16OutOfRange(f.to_string().into()))
171        }
172    }
173);
174
175sqlfunc!(
176    #[sqlname = "double_to_uint4"]
177    #[preserves_uniqueness = false]
178    #[inverse = to_unary!(super::CastUint32ToFloat64)]
179    #[is_monotone = true]
180    fn cast_float64_to_uint32(a: f64) -> Result<u32, EvalError> {
181        let f = round_float64(a);
182        // TODO(benesch): remove potentially dangerous usage of `as`.
183        #[allow(clippy::as_conversions)]
184        if (f >= 0.0) && (f <= (u32::MAX as f64)) {
185            Ok(f as u32)
186        } else {
187            Err(EvalError::UInt32OutOfRange(f.to_string().into()))
188        }
189    }
190);
191
192sqlfunc!(
193    #[sqlname = "double_to_uint8"]
194    #[preserves_uniqueness = false]
195    #[inverse = to_unary!(super::CastUint64ToFloat64)]
196    #[is_monotone = true]
197    fn cast_float64_to_uint64(a: f64) -> Result<u64, EvalError> {
198        let f = round_float64(a);
199        // TODO(benesch): remove potentially dangerous usage of `as`.
200        #[allow(clippy::as_conversions)]
201        if (f >= 0.0) && (f <= (u64::MAX as f64)) {
202            Ok(f as u64)
203        } else {
204            Err(EvalError::UInt64OutOfRange(f.to_string().into()))
205        }
206    }
207);
208
209#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
210pub struct CastFloat64ToNumeric(pub Option<NumericMaxScale>);
211
212impl<'a> EagerUnaryFunc<'a> for CastFloat64ToNumeric {
213    type Input = f64;
214    type Output = Result<Numeric, EvalError>;
215
216    fn call(&self, a: f64) -> Result<Numeric, EvalError> {
217        if a.is_infinite() {
218            return Err(EvalError::InfinityOutOfDomain(
219                "casting double precision to numeric".into(),
220            ));
221        }
222        let mut a = Numeric::from(a);
223        if let Some(scale) = self.0 {
224            if numeric::rescale(&mut a, scale.into_u8()).is_err() {
225                return Err(EvalError::NumericFieldOverflow);
226            }
227        }
228        match numeric::munge_numeric(&mut a) {
229            Ok(_) => Ok(a),
230            Err(_) => Err(EvalError::NumericFieldOverflow),
231        }
232    }
233
234    fn output_type(&self, input: ColumnType) -> ColumnType {
235        ScalarType::Numeric { max_scale: self.0 }.nullable(input.nullable)
236    }
237
238    fn inverse(&self) -> Option<crate::UnaryFunc> {
239        to_unary!(super::CastNumericToFloat64)
240    }
241
242    fn is_monotone(&self) -> bool {
243        true
244    }
245}
246
247impl fmt::Display for CastFloat64ToNumeric {
248    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
249        f.write_str("double_to_numeric")
250    }
251}
252
253sqlfunc!(
254    #[sqlname = "sqrtf64"]
255    fn sqrt_float64(a: f64) -> Result<f64, EvalError> {
256        if a < 0.0 {
257            return Err(EvalError::NegSqrt);
258        }
259        Ok(a.sqrt())
260    }
261);
262
263sqlfunc!(
264    #[sqlname = "cbrtf64"]
265    fn cbrt_float64(a: f64) -> f64 {
266        a.cbrt()
267    }
268);
269
270sqlfunc!(
271    fn cos(a: f64) -> Result<f64, EvalError> {
272        if a.is_infinite() {
273            return Err(EvalError::InfinityOutOfDomain("cos".into()));
274        }
275        Ok(a.cos())
276    }
277);
278
279sqlfunc!(
280    fn acos(a: f64) -> Result<f64, EvalError> {
281        if a < -1.0 || 1.0 < a {
282            return Err(EvalError::OutOfDomain(
283                DomainLimit::Inclusive(-1),
284                DomainLimit::Inclusive(1),
285                "acos".into(),
286            ));
287        }
288        Ok(a.acos())
289    }
290);
291
292sqlfunc!(
293    fn cosh(a: f64) -> f64 {
294        a.cosh()
295    }
296);
297
298sqlfunc!(
299    fn acosh(a: f64) -> Result<f64, EvalError> {
300        if a < 1.0 {
301            return Err(EvalError::OutOfDomain(
302                DomainLimit::Inclusive(1),
303                DomainLimit::None,
304                "acosh".into(),
305            ));
306        }
307        Ok(a.acosh())
308    }
309);
310
311sqlfunc!(
312    fn sin(a: f64) -> Result<f64, EvalError> {
313        if a.is_infinite() {
314            return Err(EvalError::InfinityOutOfDomain("sin".into()));
315        }
316        Ok(a.sin())
317    }
318);
319
320sqlfunc!(
321    fn asin(a: f64) -> Result<f64, EvalError> {
322        if a < -1.0 || 1.0 < a {
323            return Err(EvalError::OutOfDomain(
324                DomainLimit::Inclusive(-1),
325                DomainLimit::Inclusive(1),
326                "asin".into(),
327            ));
328        }
329        Ok(a.asin())
330    }
331);
332
333sqlfunc!(
334    fn sinh(a: f64) -> f64 {
335        a.sinh()
336    }
337);
338
339sqlfunc!(
340    fn asinh(a: f64) -> f64 {
341        a.asinh()
342    }
343);
344
345sqlfunc!(
346    fn tan(a: f64) -> Result<f64, EvalError> {
347        if a.is_infinite() {
348            return Err(EvalError::InfinityOutOfDomain("tan".into()));
349        }
350        Ok(a.tan())
351    }
352);
353
354sqlfunc!(
355    fn atan(a: f64) -> f64 {
356        a.atan()
357    }
358);
359
360sqlfunc!(
361    fn tanh(a: f64) -> f64 {
362        a.tanh()
363    }
364);
365
366sqlfunc!(
367    fn atanh(a: f64) -> Result<f64, EvalError> {
368        if a < -1.0 || 1.0 < a {
369            return Err(EvalError::OutOfDomain(
370                DomainLimit::Inclusive(-1),
371                DomainLimit::Inclusive(1),
372                "atanh".into(),
373            ));
374        }
375        Ok(a.atanh())
376    }
377);
378
379sqlfunc!(
380    fn cot(a: f64) -> Result<f64, EvalError> {
381        if a.is_infinite() {
382            return Err(EvalError::InfinityOutOfDomain("cot".into()));
383        }
384        Ok(1.0 / a.tan())
385    }
386);
387
388sqlfunc!(
389    fn radians(a: f64) -> f64 {
390        a.to_radians()
391    }
392);
393
394sqlfunc!(
395    fn degrees(a: f64) -> f64 {
396        a.to_degrees()
397    }
398);
399
400sqlfunc!(
401    #[sqlname = "log10f64"]
402    fn log10(a: f64) -> Result<f64, EvalError> {
403        if a.is_sign_negative() {
404            return Err(EvalError::NegativeOutOfDomain("log10".into()));
405        }
406        if a == 0.0 {
407            return Err(EvalError::ZeroOutOfDomain("log10".into()));
408        }
409        Ok(a.log10())
410    }
411);
412
413sqlfunc!(
414    #[sqlname = "lnf64"]
415    fn ln(a: f64) -> Result<f64, EvalError> {
416        if a.is_sign_negative() {
417            return Err(EvalError::NegativeOutOfDomain("ln".into()));
418        }
419        if a == 0.0 {
420            return Err(EvalError::ZeroOutOfDomain("ln".into()));
421        }
422        Ok(a.ln())
423    }
424);
425
426sqlfunc!(
427    #[sqlname = "expf64"]
428    fn exp(a: f64) -> Result<f64, EvalError> {
429        let r = a.exp();
430        if r.is_infinite() {
431            return Err(EvalError::FloatOverflow);
432        }
433        if r == 0.0 {
434            return Err(EvalError::FloatUnderflow);
435        }
436        Ok(r)
437    }
438);
439
440sqlfunc!(
441    #[sqlname = "mz_sleep"]
442    fn sleep(a: f64) -> Option<CheckedTimestamp<DateTime<Utc>>> {
443        let duration = std::time::Duration::from_secs_f64(a);
444        std::thread::sleep(duration);
445        None
446    }
447);
448
449sqlfunc!(
450    #[sqlname = "tots"]
451    fn to_timestamp(f: f64) -> Result<CheckedTimestamp<DateTime<Utc>>, EvalError> {
452        const NANO_SECONDS_PER_SECOND: i64 = 1_000_000_000;
453        if f.is_nan() {
454            Err(EvalError::TimestampCannotBeNan)
455        } else if f.is_infinite() {
456            // TODO(jkosh44) implement infinite timestamps
457            Err(EvalError::TimestampOutOfRange)
458        } else {
459            let mut secs = i64::try_cast_from(f.trunc()).ok_or(EvalError::TimestampOutOfRange)?;
460            // NOTE(benesch): PostgreSQL has microsecond precision in its timestamps,
461            // while chrono has nanosecond precision. While we normally accept
462            // nanosecond precision, here we round to the nearest microsecond because
463            // f64s lose quite a bit of accuracy in the nanosecond digits when dealing
464            // with common Unix timestamp values (> 1 billion).
465            let microsecs = (f.fract() * 1_000_000.0).round();
466            let mut nanosecs =
467                i64::try_cast_from(microsecs * 1_000.0).ok_or(EvalError::TimestampOutOfRange)?;
468            if nanosecs < 0 {
469                secs = secs.checked_sub(1).ok_or(EvalError::TimestampOutOfRange)?;
470                nanosecs = NANO_SECONDS_PER_SECOND
471                    .checked_add(nanosecs)
472                    .ok_or(EvalError::TimestampOutOfRange)?;
473            }
474            // Ensure `nanosecs` is less than 1 second.
475            secs = secs
476                .checked_add(nanosecs / NANO_SECONDS_PER_SECOND)
477                .ok_or(EvalError::TimestampOutOfRange)?;
478            nanosecs %= NANO_SECONDS_PER_SECOND;
479            let nanosecs = u32::try_from(nanosecs).map_err(|_| EvalError::TimestampOutOfRange)?;
480            match DateTime::from_timestamp(secs, nanosecs) {
481                Some(dt) => CheckedTimestamp::from_timestamplike(dt)
482                    .map_err(|_| EvalError::TimestampOutOfRange),
483                None => Err(EvalError::TimestampOutOfRange),
484            }
485        }
486    }
487);