brotli/enc/
weights.rs

1use core::cmp::{max, min};
2pub type Prob = u16;
3
4pub const BLEND_FIXED_POINT_PRECISION: i8 = 15;
5#[allow(dead_code)]
6pub const LOG2_SCALE: i32 = 15;
7#[derive(Debug, Copy, Clone)]
8pub struct Weights {
9    model_weights: [i32; 2],
10    mixing_param: u8,
11    normalized_weight: Prob,
12}
13impl Default for Weights {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18impl Weights {
19    pub fn new() -> Self {
20        Weights {
21            model_weights: [1; 2],
22            mixing_param: 1,
23            normalized_weight: 1 << (BLEND_FIXED_POINT_PRECISION - 1),
24        }
25    }
26    #[allow(unused)]
27    #[inline(always)]
28    pub fn update(&mut self, model_probs: [Prob; 2], weighted_prob: Prob) {
29        debug_assert!(self.mixing_param != 0);
30        normalize_weights(&mut self.model_weights);
31        let w0new = compute_new_weight(
32            model_probs,
33            weighted_prob,
34            self.model_weights,
35            false,
36            self.mixing_param - 1,
37        );
38        let w1new = compute_new_weight(
39            model_probs,
40            weighted_prob,
41            self.model_weights,
42            true,
43            self.mixing_param - 1,
44        );
45        self.model_weights = [w0new, w1new];
46        self.normalized_weight = compute_normalized_weight(self.model_weights);
47    }
48    #[allow(dead_code)]
49    #[inline(always)]
50    pub fn norm_weight(&self) -> Prob {
51        self.normalized_weight
52    }
53}
54
55#[allow(dead_code)]
56#[inline(always)]
57fn compute_normalized_weight(model_weights: [i32; 2]) -> Prob {
58    let total = i64::from(model_weights[0]) + i64::from(model_weights[1]);
59    let leading_zeros = total.leading_zeros();
60    let shift = max(56 - (leading_zeros as i8), 0);
61    let total_8bit = total >> shift;
62    /*::probability::numeric::fast_divide_16bit_by_8bit(
63    ((model_weights[0] >> shift) as u16)<< 8,
64    ::probability::numeric::lookup_divisor8(total_8bit as u8)) << (BLEND_FIXED_POINT_PRECISION - 8)
65    */
66    ((((model_weights[0] >> shift) as u16) << 8) / total_8bit as u16/*fixme??*/)
67        << (BLEND_FIXED_POINT_PRECISION - 8)
68}
69
70#[allow(dead_code)]
71#[cold]
72fn fix_weights(weights: &mut [i32; 2]) {
73    let ilog = 32 - min(weights[0].leading_zeros(), weights[1].leading_zeros());
74    let max_log = 24;
75    if ilog >= max_log {
76        weights[0] >>= ilog - max_log;
77        weights[1] >>= ilog - max_log;
78    }
79}
80
81#[allow(dead_code)]
82#[inline(always)]
83fn normalize_weights(weights: &mut [i32; 2]) {
84    if ((weights[0] | weights[1]) & 0x7f00_0000) != 0 {
85        fix_weights(weights);
86    }
87}
88
89#[allow(dead_code)]
90#[cfg(feature = "floating_point_context_mixing")]
91fn compute_new_weight(
92    probs: [Prob; 2],
93    weighted_prob: Prob,
94    weights: [i32; 2],
95    index_equal_1: bool,
96    _speed: u8,
97) -> i32 {
98    // speed ranges from 1 to 14 inclusive
99    let index = index_equal_1 as usize;
100    let n1i = probs[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
101    //let n0i = 1.0f64 - n1i;
102    let ni = 1.0f64;
103    let s1 = weighted_prob as f64 / ((1i64 << LOG2_SCALE) as f64);
104    let s0 = 1.0f64 - s1;
105    let s = 1.0f64;
106    //let p0 = s0;
107    let p1 = s1;
108    let wi = weights[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
109    let mut wi_new = wi + (1.0 - p1) * (s * n1i - s1 * ni) / (s0 * s1);
110    let eps = 0.00001f64;
111    if !(wi_new > eps) {
112        wi_new = eps;
113    }
114    (wi_new * ((1i64 << LOG2_SCALE) as f64)) as i32
115}
116
117#[allow(dead_code)]
118#[cfg(not(feature = "floating_point_context_mixing"))]
119#[inline(always)]
120fn compute_new_weight(
121    probs: [Prob; 2],
122    weighted_prob: Prob,
123    weights: [i32; 2],
124    index_equal_1: bool,
125    _speed: u8,
126) -> i32 {
127    // speed ranges from 1 to 14 inclusive
128    let index = index_equal_1 as usize;
129    let full_model_sum_p1 = i64::from(weighted_prob);
130    let full_model_total = 1i64 << LOG2_SCALE;
131    let full_model_sum_p0 = full_model_total.wrapping_sub(i64::from(weighted_prob));
132    let n1i = i64::from(probs[index]);
133    let ni = 1i64 << LOG2_SCALE;
134    let error = full_model_total.wrapping_sub(full_model_sum_p1);
135    let wi = i64::from(weights[index]);
136    let efficacy = full_model_total.wrapping_mul(n1i) - full_model_sum_p1.wrapping_mul(ni);
137    //let geometric_probabilities = full_model_sum_p1 * full_model_sum_p0;
138    let log_geometric_probabilities =
139        64 - (full_model_sum_p1.wrapping_mul(full_model_sum_p0)).leading_zeros();
140    //let scaled_geometric_probabilities = geometric_probabilities * S;
141    //let new_weight_adj = (error * efficacy) >> log_geometric_probabilities;// / geometric_probabilities;
142    //let new_weight_adj = (error * efficacy)/(full_model_sum_p1 * full_model_sum_p0);
143    let new_weight_adj = (error.wrapping_mul(efficacy)) >> log_geometric_probabilities;
144    //    assert!(wi + new_weight_adj < (1i64 << 31));
145    //print!("{} -> {} due to {:?} vs {}\n", wi as f64 / (weights[0] + weights[1]) as f64, (wi + new_weight_adj) as f64 /(weights[0] as i64 + new_weight_adj as i64 + weights[1] as i64) as f64, probs[index], weighted_prob);
146    max(1, wi.wrapping_add(new_weight_adj) as i32)
147}