criterion/stats/univariate/kde/
mod.rs

1//! Kernel density estimation
2
3pub mod kernel;
4
5use self::kernel::Kernel;
6use crate::stats::float::Float;
7use crate::stats::univariate::Sample;
8#[cfg(feature = "rayon")]
9use rayon::prelude::*;
10
11/// Univariate kernel density estimator
12pub struct Kde<'a, A, K>
13where
14    A: Float,
15    K: Kernel<A>,
16{
17    bandwidth: A,
18    kernel: K,
19    sample: &'a Sample<A>,
20}
21
22impl<'a, A, K> Kde<'a, A, K>
23where
24    A: 'a + Float,
25    K: Kernel<A>,
26{
27    /// Creates a new kernel density estimator from the `sample`, using a kernel and estimating
28    /// the bandwidth using the method `bw`
29    pub fn new(sample: &'a Sample<A>, kernel: K, bw: Bandwidth) -> Kde<'a, A, K> {
30        Kde {
31            bandwidth: bw.estimate(sample),
32            kernel,
33            sample,
34        }
35    }
36
37    /// Returns the bandwidth used by the estimator
38    pub fn bandwidth(&self) -> A {
39        self.bandwidth
40    }
41
42    /// Maps the KDE over `xs`
43    ///
44    /// - Multihreaded
45    pub fn map(&self, xs: &[A]) -> Box<[A]> {
46        #[cfg(feature = "rayon")]
47        let iter = xs.par_iter();
48
49        #[cfg(not(feature = "rayon"))]
50        let iter = xs.iter();
51
52        iter.map(|&x| self.estimate(x))
53            .collect::<Vec<_>>()
54            .into_boxed_slice()
55    }
56
57    /// Estimates the probability density of `x`
58    pub fn estimate(&self, x: A) -> A {
59        let _0 = A::cast(0);
60        let slice = self.sample;
61        let h = self.bandwidth;
62        let n = A::cast(slice.len());
63
64        let sum = slice
65            .iter()
66            .fold(_0, |acc, &x_i| acc + self.kernel.evaluate((x - x_i) / h));
67
68        sum / (h * n)
69    }
70}
71
72/// Method to estimate the bandwidth
73pub enum Bandwidth {
74    /// Use Silverman's rule of thumb to estimate the bandwidth from the sample
75    Silverman,
76}
77
78impl Bandwidth {
79    fn estimate<A: Float>(self, sample: &Sample<A>) -> A {
80        match self {
81            Bandwidth::Silverman => {
82                let factor = A::cast(4. / 3.);
83                let exponent = A::cast(1. / 5.);
84                let n = A::cast(sample.len());
85                let sigma = sample.std_dev(None);
86
87                sigma * (factor / n).powf(exponent)
88            }
89        }
90    }
91}
92
93#[cfg(test)]
94macro_rules! test {
95    ($ty:ident) => {
96        mod $ty {
97            use approx::relative_eq;
98            use quickcheck::quickcheck;
99            use quickcheck::TestResult;
100
101            use crate::stats::univariate::kde::kernel::Gaussian;
102            use crate::stats::univariate::kde::{Bandwidth, Kde};
103            use crate::stats::univariate::Sample;
104
105            // The [-inf inf] integral of the estimated PDF should be one
106            quickcheck! {
107                fn integral(size: u8, start: u8) -> TestResult {
108                    let size = size as usize;
109                    let start = start as usize;
110                    const DX: $ty = 1e-3;
111
112                    if let Some(v) = crate::stats::test::vec::<$ty>(size, start) {
113                        let slice = &v[start..];
114                        let data = Sample::new(slice);
115                        let kde = Kde::new(data, Gaussian, Bandwidth::Silverman);
116                        let h = kde.bandwidth();
117                        // NB Obviously a [-inf inf] integral is not feasible, but this range works
118                        // quite well
119                        let (a, b) = (data.min() - 5. * h, data.max() + 5. * h);
120
121                        let mut acc = 0.;
122                        let mut x = a;
123                        let mut y = kde.estimate(a);
124
125                        while x < b {
126                            acc += DX * y / 2.;
127
128                            x += DX;
129                            y = kde.estimate(x);
130
131                            acc += DX * y / 2.;
132                        }
133
134                        TestResult::from_bool(relative_eq!(acc, 1., epsilon = 2e-5))
135                    } else {
136                        TestResult::discard()
137                    }
138                }
139            }
140        }
141    };
142}
143
144#[cfg(test)]
145mod test {
146    test!(f32);
147    test!(f64);
148}