brotli/enc/
stride_eval.rs

1use core;
2
3use super::super::alloc;
4use super::super::alloc::{Allocator, SliceWrapper, SliceWrapperMut};
5use super::backward_references::BrotliEncoderParams;
6use super::input_pair::{InputPair, InputReference, InputReferenceMut};
7use super::interface;
8use super::ir_interpret::{push_base, IRInterpreter};
9use super::prior_eval::DEFAULT_SPEED;
10use super::util::{floatX, FastLog2u16};
11use crate::enc::combined_alloc::{alloc_default, allocate};
12const NIBBLE_PRIOR_SIZE: usize = 16;
13pub const STRIDE_PRIOR_SIZE: usize = 256 * 256 * NIBBLE_PRIOR_SIZE * 2;
14
15pub fn local_init_cdfs(cdfs: &mut [u16]) {
16    for (index, item) in cdfs.iter_mut().enumerate() {
17        *item = 4 + 4 * (index as u16 & 0x0f);
18    }
19}
20#[allow(unused_variables)]
21fn stride_lookup_lin(
22    stride_byte: u8,
23    selected_context: u8,
24    actual_context: usize,
25    high_nibble: Option<u8>,
26) -> usize {
27    if let Some(nibble) = high_nibble {
28        1 + 2 * (actual_context | ((stride_byte as usize & 0xf) << 8) | ((nibble as usize) << 12))
29    } else {
30        2 * (actual_context | ((stride_byte as usize) << 8))
31    }
32}
33
34struct CDF<'a> {
35    cdf: &'a mut [u16],
36}
37struct Stride1Prior {}
38impl Stride1Prior {
39    fn lookup_lin(
40        stride_byte: u8,
41        selected_context: u8,
42        actual_context: usize,
43        high_nibble: Option<u8>,
44    ) -> usize {
45        stride_lookup_lin(stride_byte, selected_context, actual_context, high_nibble)
46    }
47    fn lookup_mut(
48        data: &mut [u16],
49        stride_byte: u8,
50        selected_context: u8,
51        actual_context: usize,
52        high_nibble: Option<u8>,
53    ) -> CDF {
54        let index = Self::lookup_lin(stride_byte, selected_context, actual_context, high_nibble)
55            * NIBBLE_PRIOR_SIZE;
56        CDF::from(data.split_at_mut(index).1.split_at_mut(16).0)
57    }
58}
59
60impl<'a> CDF<'a> {
61    pub fn cost(&self, nibble_u8: u8) -> floatX {
62        assert_eq!(self.cdf.len(), 16);
63        let nibble = nibble_u8 as usize & 0xf;
64        let mut pdf = self.cdf[nibble];
65        if nibble_u8 != 0 {
66            pdf -= self.cdf[nibble - 1];
67        }
68        FastLog2u16(self.cdf[15]) - FastLog2u16(pdf)
69    }
70    pub fn update(&mut self, nibble_u8: u8, speed: (u16, u16)) {
71        assert_eq!(self.cdf.len(), 16);
72        for nib_range in (nibble_u8 as usize & 0xf)..16 {
73            self.cdf[nib_range] += speed.0;
74        }
75        if self.cdf[15] >= speed.1 {
76            const CDF_BIAS: [u16; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
77            for nibble_index in 0..16 {
78                let tmp = &mut self.cdf[nibble_index];
79                *tmp = (tmp.wrapping_add(CDF_BIAS[nibble_index]))
80                    .wrapping_sub(tmp.wrapping_add(CDF_BIAS[nibble_index]) >> 2);
81            }
82        }
83    }
84}
85
86impl<'a> From<&'a mut [u16]> for CDF<'a> {
87    fn from(cdf: &'a mut [u16]) -> CDF<'a> {
88        assert_eq!(cdf.len(), 16);
89        CDF { cdf }
90    }
91}
92
93pub struct StrideEval<
94    'a,
95    Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX> + 'a,
96> {
97    input: InputPair<'a>,
98    alloc: &'a mut Alloc,
99    context_map: &'a interface::PredictionModeContextMap<InputReferenceMut<'a>>,
100    block_type: u8,
101    local_byte_offset: usize,
102    stride_priors: [<Alloc as Allocator<u16>>::AllocatedMemory; 8],
103    score: <Alloc as Allocator<floatX>>::AllocatedMemory,
104    cur_score_epoch: usize,
105    stride_speed: [(u16, u16); 2],
106    cur_stride: u8,
107}
108
109impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX> + 'a>
110    StrideEval<'a, Alloc>
111{
112    pub fn new(
113        alloc: &'a mut Alloc,
114        input: InputPair<'a>,
115        prediction_mode: &'a interface::PredictionModeContextMap<InputReferenceMut<'a>>,
116        params: &BrotliEncoderParams,
117    ) -> Self {
118        let do_alloc = true;
119        let mut stride_speed = prediction_mode.stride_context_speed();
120        if stride_speed[0] == (0, 0) {
121            stride_speed[0] = params.literal_adaptation[0]
122        }
123        if stride_speed[0] == (0, 0) {
124            stride_speed[0] = DEFAULT_SPEED;
125        }
126        if stride_speed[1] == (0, 0) {
127            stride_speed[1] = params.literal_adaptation[1]
128        }
129        if stride_speed[1] == (0, 0) {
130            stride_speed[1] = stride_speed[0];
131        }
132        let score = if do_alloc {
133            allocate::<floatX, _>(alloc, 8 * 4) // FIXME make this bigger than just 4
134        } else {
135            alloc_default::<floatX, Alloc>()
136        };
137        let stride_priors = if do_alloc {
138            [
139                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
140                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
141                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
142                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
143                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
144                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
145                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
146                allocate::<u16, _>(alloc, STRIDE_PRIOR_SIZE),
147            ]
148        } else {
149            [
150                alloc_default::<u16, Alloc>(),
151                alloc_default::<u16, Alloc>(),
152                alloc_default::<u16, Alloc>(),
153                alloc_default::<u16, Alloc>(),
154                alloc_default::<u16, Alloc>(),
155                alloc_default::<u16, Alloc>(),
156                alloc_default::<u16, Alloc>(),
157                alloc_default::<u16, Alloc>(),
158            ]
159        };
160        let mut ret = StrideEval::<Alloc> {
161            input,
162            context_map: prediction_mode,
163            block_type: 0,
164            alloc,
165            cur_stride: 1,
166            cur_score_epoch: 0,
167            local_byte_offset: 0,
168            stride_priors,
169            score,
170            stride_speed,
171        };
172        for stride_prior in ret.stride_priors.iter_mut() {
173            local_init_cdfs(stride_prior.slice_mut());
174        }
175        ret
176    }
177    pub fn alloc(&mut self) -> &mut Alloc {
178        self.alloc
179    }
180    pub fn choose_stride(&self, stride_data: &mut [u8]) {
181        assert_eq!(stride_data.len(), self.cur_score_epoch);
182        assert!(self.score.slice().len() > stride_data.len());
183        assert!(self.score.slice().len() > (stride_data.len() << 3) + 7 + 8);
184        for (index, choice) in stride_data.iter_mut().enumerate() {
185            let choices = self
186                .score
187                .slice()
188                .split_at((1 + index) << 3)
189                .1
190                .split_at(8)
191                .0;
192            let mut best_choice: u8 = 0;
193            let mut best_score = choices[0];
194            for (cur_index, cur_score) in choices.iter().enumerate() {
195                if *cur_score + 2.0 < best_score {
196                    // needs to be 2 bits better to be worth the type switch
197                    best_score = *cur_score;
198                    best_choice = cur_index as u8;
199                }
200            }
201            *choice = best_choice;
202        }
203    }
204    pub fn num_types(&self) -> usize {
205        self.cur_score_epoch
206    }
207    fn update_cost_base(
208        &mut self,
209        stride_prior: [u8; 8],
210        selected_bits: u8,
211        cm_prior: usize,
212        literal: u8,
213    ) {
214        type CurPrior = Stride1Prior;
215        {
216            for i in 0..8 {
217                let mut cdf = CurPrior::lookup_mut(
218                    self.stride_priors[i].slice_mut(),
219                    stride_prior[i],
220                    selected_bits,
221                    cm_prior,
222                    None,
223                );
224                self.score.slice_mut()[self.cur_score_epoch * 8 + i] += cdf.cost(literal >> 4);
225                cdf.update(literal >> 4, self.stride_speed[1]);
226            }
227        }
228        {
229            for i in 0..8 {
230                let mut cdf = CurPrior::lookup_mut(
231                    self.stride_priors[i].slice_mut(),
232                    stride_prior[i],
233                    selected_bits,
234                    cm_prior,
235                    Some(literal >> 4),
236                );
237                self.score.slice_mut()[self.cur_score_epoch * 8 + i] += cdf.cost(literal & 0xf);
238                cdf.update(literal & 0xf, self.stride_speed[0]);
239            }
240        }
241    }
242}
243impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>> Drop
244    for StrideEval<'a, Alloc>
245{
246    fn drop(&mut self) {
247        <Alloc as Allocator<floatX>>::free_cell(self.alloc, core::mem::take(&mut self.score));
248        for i in 0..8 {
249            <Alloc as Allocator<u16>>::free_cell(
250                self.alloc,
251                core::mem::take(&mut self.stride_priors[i]),
252            );
253        }
254    }
255}
256
257impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
258    IRInterpreter for StrideEval<'a, Alloc>
259{
260    fn inc_local_byte_offset(&mut self, inc: usize) {
261        self.local_byte_offset += inc;
262    }
263    fn local_byte_offset(&self) -> usize {
264        self.local_byte_offset
265    }
266    fn update_block_type(&mut self, new_type: u8, stride: u8) {
267        self.block_type = new_type;
268        self.cur_stride = stride;
269        self.cur_score_epoch += 1;
270        if self.cur_score_epoch * 8 + 7 >= self.score.slice().len() {
271            let new_len = self.score.slice().len() * 2;
272            let mut new_score = allocate::<floatX, _>(self.alloc, new_len);
273            for (src, dst) in self.score.slice().iter().zip(
274                new_score
275                    .slice_mut()
276                    .split_at_mut(self.score.slice().len())
277                    .0
278                    .iter_mut(),
279            ) {
280                *dst = *src;
281            }
282            <Alloc as Allocator<floatX>>::free_cell(
283                self.alloc,
284                core::mem::replace(&mut self.score, new_score),
285            );
286        }
287    }
288    fn block_type(&self) -> u8 {
289        self.block_type
290    }
291    fn literal_data_at_offset(&self, index: usize) -> u8 {
292        self.input[index]
293    }
294    fn literal_context_map(&self) -> &[u8] {
295        self.context_map.literal_context_map.slice()
296    }
297    fn prediction_mode(&self) -> crate::interface::LiteralPredictionModeNibble {
298        self.context_map.literal_prediction_mode()
299    }
300    fn update_cost(
301        &mut self,
302        stride_prior: [u8; 8],
303        stride_prior_offset: usize,
304        selected_bits: u8,
305        cm_prior: usize,
306        literal: u8,
307    ) {
308        let reversed_stride_priors = [
309            stride_prior[stride_prior_offset & 7],
310            stride_prior[stride_prior_offset.wrapping_sub(1) & 7],
311            stride_prior[stride_prior_offset.wrapping_sub(2) & 7],
312            stride_prior[stride_prior_offset.wrapping_sub(3) & 7],
313            stride_prior[stride_prior_offset.wrapping_sub(4) & 7],
314            stride_prior[stride_prior_offset.wrapping_sub(5) & 7],
315            stride_prior[stride_prior_offset.wrapping_sub(6) & 7],
316            stride_prior[stride_prior_offset.wrapping_sub(7) & 7],
317        ];
318        self.update_cost_base(reversed_stride_priors, selected_bits, cm_prior, literal)
319    }
320}
321
322impl<'a, 'b, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
323    interface::CommandProcessor<'b> for StrideEval<'a, Alloc>
324{
325    fn push(&mut self, val: interface::Command<InputReference<'b>>) {
326        push_base(self, val)
327    }
328}