brotli/enc/
weights.rs
1use core::cmp::{max, min};
2pub type Prob = u16;
3
4pub const BLEND_FIXED_POINT_PRECISION: i8 = 15;
5#[allow(dead_code)]
6pub const LOG2_SCALE: i32 = 15;
7#[derive(Debug, Copy, Clone)]
8pub struct Weights {
9 model_weights: [i32; 2],
10 mixing_param: u8,
11 normalized_weight: Prob,
12}
13impl Default for Weights {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18impl Weights {
19 pub fn new() -> Self {
20 Weights {
21 model_weights: [1; 2],
22 mixing_param: 1,
23 normalized_weight: 1 << (BLEND_FIXED_POINT_PRECISION - 1),
24 }
25 }
26 #[allow(unused)]
27 #[inline(always)]
28 pub fn update(&mut self, model_probs: [Prob; 2], weighted_prob: Prob) {
29 debug_assert!(self.mixing_param != 0);
30 normalize_weights(&mut self.model_weights);
31 let w0new = compute_new_weight(
32 model_probs,
33 weighted_prob,
34 self.model_weights,
35 false,
36 self.mixing_param - 1,
37 );
38 let w1new = compute_new_weight(
39 model_probs,
40 weighted_prob,
41 self.model_weights,
42 true,
43 self.mixing_param - 1,
44 );
45 self.model_weights = [w0new, w1new];
46 self.normalized_weight = compute_normalized_weight(self.model_weights);
47 }
48 #[allow(dead_code)]
49 #[inline(always)]
50 pub fn norm_weight(&self) -> Prob {
51 self.normalized_weight
52 }
53}
54
55#[allow(dead_code)]
56#[inline(always)]
57fn compute_normalized_weight(model_weights: [i32; 2]) -> Prob {
58 let total = i64::from(model_weights[0]) + i64::from(model_weights[1]);
59 let leading_zeros = total.leading_zeros();
60 let shift = max(56 - (leading_zeros as i8), 0);
61 let total_8bit = total >> shift;
62 ((((model_weights[0] >> shift) as u16) << 8) / total_8bit as u16)
67 << (BLEND_FIXED_POINT_PRECISION - 8)
68}
69
70#[allow(dead_code)]
71#[cold]
72fn fix_weights(weights: &mut [i32; 2]) {
73 let ilog = 32 - min(weights[0].leading_zeros(), weights[1].leading_zeros());
74 let max_log = 24;
75 if ilog >= max_log {
76 weights[0] >>= ilog - max_log;
77 weights[1] >>= ilog - max_log;
78 }
79}
80
81#[allow(dead_code)]
82#[inline(always)]
83fn normalize_weights(weights: &mut [i32; 2]) {
84 if ((weights[0] | weights[1]) & 0x7f00_0000) != 0 {
85 fix_weights(weights);
86 }
87}
88
89#[allow(dead_code)]
90#[cfg(feature = "floating_point_context_mixing")]
91fn compute_new_weight(
92 probs: [Prob; 2],
93 weighted_prob: Prob,
94 weights: [i32; 2],
95 index_equal_1: bool,
96 _speed: u8,
97) -> i32 {
98 let index = index_equal_1 as usize;
100 let n1i = probs[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
101 let ni = 1.0f64;
103 let s1 = weighted_prob as f64 / ((1i64 << LOG2_SCALE) as f64);
104 let s0 = 1.0f64 - s1;
105 let s = 1.0f64;
106 let p1 = s1;
108 let wi = weights[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
109 let mut wi_new = wi + (1.0 - p1) * (s * n1i - s1 * ni) / (s0 * s1);
110 let eps = 0.00001f64;
111 if !(wi_new > eps) {
112 wi_new = eps;
113 }
114 (wi_new * ((1i64 << LOG2_SCALE) as f64)) as i32
115}
116
117#[allow(dead_code)]
118#[cfg(not(feature = "floating_point_context_mixing"))]
119#[inline(always)]
120fn compute_new_weight(
121 probs: [Prob; 2],
122 weighted_prob: Prob,
123 weights: [i32; 2],
124 index_equal_1: bool,
125 _speed: u8,
126) -> i32 {
127 let index = index_equal_1 as usize;
129 let full_model_sum_p1 = i64::from(weighted_prob);
130 let full_model_total = 1i64 << LOG2_SCALE;
131 let full_model_sum_p0 = full_model_total.wrapping_sub(i64::from(weighted_prob));
132 let n1i = i64::from(probs[index]);
133 let ni = 1i64 << LOG2_SCALE;
134 let error = full_model_total.wrapping_sub(full_model_sum_p1);
135 let wi = i64::from(weights[index]);
136 let efficacy = full_model_total.wrapping_mul(n1i) - full_model_sum_p1.wrapping_mul(ni);
137 let log_geometric_probabilities =
139 64 - (full_model_sum_p1.wrapping_mul(full_model_sum_p0)).leading_zeros();
140 let new_weight_adj = (error.wrapping_mul(efficacy)) >> log_geometric_probabilities;
144 max(1, wi.wrapping_add(new_weight_adj) as i32)
147}