1use core;
2
3use super::super::alloc;
4use super::super::alloc::{Allocator, SliceWrapper, SliceWrapperMut};
5use super::input_pair::{InputPair, InputReference, InputReferenceMut};
6pub use super::ir_interpret::{push_base, Context, IRInterpreter};
7use super::util::{floatX, FastLog2u16};
8use super::weights::{Weights, BLEND_FIXED_POINT_PRECISION};
9use super::{find_stride, interface};
10use crate::enc::combined_alloc::alloc_if;
11
12const DEFAULT_CM_SPEED_INDEX: usize = 8;
13const NUM_SPEEDS_TO_TRY: usize = 16;
14const SPEEDS_TO_SEARCH: [u16; NUM_SPEEDS_TO_TRY] = [
15 0, 1, 1, 1, 2, 4, 8, 16, 16, 32, 64, 128, 128, 512, 1664, 1664,
16];
17const MAXES_TO_SEARCH: [u16; NUM_SPEEDS_TO_TRY] = [
18 32, 32, 128, 16384, 1024, 1024, 8192, 48, 8192, 4096, 16384, 256, 16384, 16384, 16384, 16384,
19];
20const NIBBLE_PRIOR_SIZE: usize = 16 * NUM_SPEEDS_TO_TRY;
21const CONTEXT_MAP_PRIOR_SIZE: usize = 256 * NIBBLE_PRIOR_SIZE * 17;
23const STRIDE_PRIOR_SIZE: usize = 256 * 256 * NIBBLE_PRIOR_SIZE * 2;
24#[derive(Clone, Copy, Debug)]
25pub struct SpeedAndMax(pub u16, pub u16);
26
27pub fn speed_to_tuple(inp: [SpeedAndMax; 2]) -> [(u16, u16); 2] {
28 [(inp[0].0, inp[0].1), (inp[1].0, inp[1].1)]
29}
30
31fn get_stride_cdf_low(
32 data: &mut [u16],
33 stride_prior: u8,
34 cm_prior: usize,
35 high_nibble: u8,
36) -> &mut [u16] {
37 let index: usize =
38 1 + 2 * (cm_prior | ((stride_prior as usize & 0xf) << 8) | ((high_nibble as usize) << 12));
39 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
40 .1
41 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
42 .0
43}
44
45fn get_stride_cdf_high(data: &mut [u16], stride_prior: u8, cm_prior: usize) -> &mut [u16] {
46 let index: usize = 2 * (cm_prior | ((stride_prior as usize) << 8));
47 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
48 .1
49 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
50 .0
51}
52
53fn get_cm_cdf_low(data: &mut [u16], cm_prior: usize, high_nibble: u8) -> &mut [u16] {
54 let index: usize = (high_nibble as usize + 1) + 17 * cm_prior;
55 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
56 .1
57 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
58 .0
59}
60
61fn get_cm_cdf_high(data: &mut [u16], cm_prior: usize) -> &mut [u16] {
62 let index: usize = 17 * cm_prior;
63 data.split_at_mut((NUM_SPEEDS_TO_TRY * index) << 4)
64 .1
65 .split_at_mut(16 * NUM_SPEEDS_TO_TRY)
66 .0
67}
68fn init_cdfs(cdfs: &mut [u16]) {
69 assert_eq!(cdfs.len() % (16 * NUM_SPEEDS_TO_TRY), 0);
70 let mut total_index = 0usize;
71 let len = cdfs.len();
72 loop {
73 for cdf_index in 0..16 {
74 let vec = cdfs
75 .split_at_mut(total_index)
76 .1
77 .split_at_mut(NUM_SPEEDS_TO_TRY)
78 .0;
79 for item in vec {
80 *item = 4 + 4 * cdf_index as u16;
81 }
82 total_index += NUM_SPEEDS_TO_TRY;
83 }
84 if total_index == len {
85 break;
86 }
87 }
88}
89fn compute_combined_cost(
90 singleton_cost: &mut [floatX; NUM_SPEEDS_TO_TRY],
91 cdfs: &[u16],
92 mixing_cdf: [u16; 16],
93 nibble_u8: u8,
94 _weights: &mut [Weights; NUM_SPEEDS_TO_TRY],
95) {
96 assert_eq!(cdfs.len(), 16 * NUM_SPEEDS_TO_TRY);
97 let nibble = nibble_u8 as usize & 0xf;
98 let mut stride_pdf = [0u16; NUM_SPEEDS_TO_TRY];
99 stride_pdf.clone_from_slice(
100 cdfs.split_at(NUM_SPEEDS_TO_TRY * nibble)
101 .1
102 .split_at(NUM_SPEEDS_TO_TRY)
103 .0,
104 );
105 let mut cm_pdf: u16 = mixing_cdf[nibble];
106 if nibble_u8 != 0 {
107 let mut tmp = [0u16; NUM_SPEEDS_TO_TRY];
108 tmp.clone_from_slice(
109 cdfs.split_at(NUM_SPEEDS_TO_TRY * (nibble - 1))
110 .1
111 .split_at(NUM_SPEEDS_TO_TRY)
112 .0,
113 );
114 for i in 0..NUM_SPEEDS_TO_TRY {
115 stride_pdf[i] -= tmp[i];
116 }
117 cm_pdf -= mixing_cdf[nibble - 1]
118 }
119 let mut stride_max = [0u16; NUM_SPEEDS_TO_TRY];
120 stride_max.clone_from_slice(cdfs.split_at(NUM_SPEEDS_TO_TRY * 15).1);
121 let cm_max = mixing_cdf[15];
122 for i in 0..NUM_SPEEDS_TO_TRY {
123 if stride_pdf[i] == 0 {
124 assert_ne!(stride_pdf[i], 0);
125 }
126 if stride_max[i] == 0 {
127 assert_ne!(stride_max[i], 0);
128 }
129
130 let w = (1 << (BLEND_FIXED_POINT_PRECISION - 2)); let combined_pdf = w * u32::from(stride_pdf[i])
132 + ((1 << BLEND_FIXED_POINT_PRECISION) - w) * u32::from(cm_pdf);
133 let combined_max = w * u32::from(stride_max[i])
134 + ((1 << BLEND_FIXED_POINT_PRECISION) - w) * u32::from(cm_max);
135 let del = FastLog2u16((combined_pdf >> BLEND_FIXED_POINT_PRECISION) as u16)
136 - FastLog2u16((combined_max >> BLEND_FIXED_POINT_PRECISION) as u16);
137 singleton_cost[i] -= del;
138 }
139}
140fn compute_cost(singleton_cost: &mut [floatX; NUM_SPEEDS_TO_TRY], cdfs: &[u16], nibble_u8: u8) {
141 assert_eq!(cdfs.len(), 16 * NUM_SPEEDS_TO_TRY);
142 let nibble = nibble_u8 as usize & 0xf;
143 let mut pdf = [0u16; NUM_SPEEDS_TO_TRY];
144 pdf.clone_from_slice(
145 cdfs.split_at(NUM_SPEEDS_TO_TRY * nibble)
146 .1
147 .split_at(NUM_SPEEDS_TO_TRY)
148 .0,
149 );
150 if nibble_u8 != 0 {
151 let mut tmp = [0u16; NUM_SPEEDS_TO_TRY];
152 tmp.clone_from_slice(
153 cdfs.split_at(NUM_SPEEDS_TO_TRY * (nibble - 1))
154 .1
155 .split_at(NUM_SPEEDS_TO_TRY)
156 .0,
157 );
158 for i in 0..NUM_SPEEDS_TO_TRY {
159 pdf[i] -= tmp[i];
160 }
161 }
162 let mut max = [0u16; NUM_SPEEDS_TO_TRY];
163 max.clone_from_slice(cdfs.split_at(NUM_SPEEDS_TO_TRY * 15).1);
164 for i in 0..NUM_SPEEDS_TO_TRY {
165 if pdf[i] == 0 {
166 assert_ne!(pdf[i], 0);
167 }
168 if max[i] == 0 {
169 assert_ne!(max[i], 0);
170 }
171 let del = FastLog2u16(pdf[i]) - FastLog2u16(max[i]);
172 singleton_cost[i] -= del;
173 }
174}
175fn update_cdf(cdfs: &mut [u16], nibble_u8: u8) {
176 assert_eq!(cdfs.len(), 16 * NUM_SPEEDS_TO_TRY);
177 let mut overall_index = nibble_u8 as usize * NUM_SPEEDS_TO_TRY;
178 for _nibble in (nibble_u8 as usize & 0xf)..16 {
179 for speed_index in 0..NUM_SPEEDS_TO_TRY {
180 cdfs[overall_index + speed_index] += SPEEDS_TO_SEARCH[speed_index];
181 }
182 overall_index += NUM_SPEEDS_TO_TRY;
183 }
184 overall_index = 0;
185 for nibble in 0..16 {
186 for speed_index in 0..NUM_SPEEDS_TO_TRY {
187 if nibble == 0 {
188 assert_ne!(cdfs[overall_index + speed_index], 0);
189 } else {
190 assert_ne!(
191 cdfs[overall_index + speed_index]
192 - cdfs[overall_index + speed_index - NUM_SPEEDS_TO_TRY],
193 0
194 );
195 }
196 }
197 overall_index += NUM_SPEEDS_TO_TRY;
198 }
199 for max_index in 0..NUM_SPEEDS_TO_TRY {
200 if cdfs[15 * NUM_SPEEDS_TO_TRY + max_index] >= MAXES_TO_SEARCH[max_index] {
201 const CDF_BIAS: [u16; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
202 for nibble_index in 0..16 {
203 let tmp = &mut cdfs[nibble_index * NUM_SPEEDS_TO_TRY + max_index];
204 *tmp = (tmp.wrapping_add(CDF_BIAS[nibble_index]))
205 .wrapping_sub(tmp.wrapping_add(CDF_BIAS[nibble_index]) >> 2);
206 }
207 }
208 }
209 overall_index = 0;
210 for nibble in 0..16 {
211 for speed_index in 0..NUM_SPEEDS_TO_TRY {
212 if nibble == 0 {
213 assert_ne!(cdfs[overall_index + speed_index], 0);
214 } else {
215 assert_ne!(
216 cdfs[overall_index + speed_index]
217 - cdfs[overall_index + speed_index - NUM_SPEEDS_TO_TRY],
218 0
219 );
220 }
221 }
222 overall_index += NUM_SPEEDS_TO_TRY;
223 }
224}
225
226fn extract_single_cdf(cdf_bundle: &[u16], index: usize) -> [u16; 16] {
227 assert_eq!(cdf_bundle.len(), 16 * NUM_SPEEDS_TO_TRY);
228 assert!(index < NUM_SPEEDS_TO_TRY);
229
230 #[allow(clippy::identity_op)]
231 [
232 cdf_bundle[index + 0 * NUM_SPEEDS_TO_TRY],
233 cdf_bundle[index + 1 * NUM_SPEEDS_TO_TRY],
234 cdf_bundle[index + 2 * NUM_SPEEDS_TO_TRY],
235 cdf_bundle[index + 3 * NUM_SPEEDS_TO_TRY],
236 cdf_bundle[index + 4 * NUM_SPEEDS_TO_TRY],
237 cdf_bundle[index + 5 * NUM_SPEEDS_TO_TRY],
238 cdf_bundle[index + 6 * NUM_SPEEDS_TO_TRY],
239 cdf_bundle[index + 7 * NUM_SPEEDS_TO_TRY],
240 cdf_bundle[index + 8 * NUM_SPEEDS_TO_TRY],
241 cdf_bundle[index + 9 * NUM_SPEEDS_TO_TRY],
242 cdf_bundle[index + 10 * NUM_SPEEDS_TO_TRY],
243 cdf_bundle[index + 11 * NUM_SPEEDS_TO_TRY],
244 cdf_bundle[index + 12 * NUM_SPEEDS_TO_TRY],
245 cdf_bundle[index + 13 * NUM_SPEEDS_TO_TRY],
246 cdf_bundle[index + 14 * NUM_SPEEDS_TO_TRY],
247 cdf_bundle[index + 15 * NUM_SPEEDS_TO_TRY],
248 ]
249}
250
251fn min_cost_index_for_speed(cost: &[floatX]) -> usize {
252 assert_eq!(cost.len(), NUM_SPEEDS_TO_TRY);
253 let mut min_cost = cost[0];
254 let mut best_choice = 0;
255 for i in 1..NUM_SPEEDS_TO_TRY {
256 if cost[i] < min_cost {
257 best_choice = i;
258 min_cost = cost[i];
259 }
260 }
261 best_choice
262}
263fn min_cost_speed_max(cost: &[floatX]) -> SpeedAndMax {
264 let best_choice = min_cost_index_for_speed(cost);
265 SpeedAndMax(SPEEDS_TO_SEARCH[best_choice], MAXES_TO_SEARCH[best_choice])
266}
267
268fn min_cost_value(cost: &[floatX]) -> floatX {
269 let best_choice = min_cost_index_for_speed(cost);
270 cost[best_choice]
271}
272
273const SINGLETON_COMBINED_STRATEGY: usize = 2;
274const SINGLETON_STRIDE_STRATEGY: usize = 1;
275const SINGLETON_CM_STRATEGY: usize = 0;
276
277pub struct ContextMapEntropy<
278 'a,
279 Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>,
280> {
281 input: InputPair<'a>,
282 context_map: interface::PredictionModeContextMap<InputReferenceMut<'a>>,
283 block_type: u8,
284 cur_stride: u8,
285 local_byte_offset: usize,
286 weight: [[Weights; NUM_SPEEDS_TO_TRY]; 2],
287
288 cm_priors: <Alloc as Allocator<u16>>::AllocatedMemory,
289 stride_priors: <Alloc as Allocator<u16>>::AllocatedMemory,
290 _stride_pyramid_leaves: [u8; find_stride::NUM_LEAF_NODES],
291 singleton_costs: [[[floatX; NUM_SPEEDS_TO_TRY]; 2]; 3],
292}
293impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
294 ContextMapEntropy<'a, Alloc>
295{
296 pub fn new(
297 m16: &mut Alloc,
298 input: InputPair<'a>,
299 stride: [u8; find_stride::NUM_LEAF_NODES],
300 prediction_mode: interface::PredictionModeContextMap<InputReferenceMut<'a>>,
301 cdf_detection_quality: u8,
302 ) -> Self {
303 let cdf_detect = cdf_detection_quality != 0;
304 let mut ret = ContextMapEntropy::<Alloc> {
305 input,
306 context_map: prediction_mode,
307 block_type: 0,
308 cur_stride: 1,
309 local_byte_offset: 0,
310 cm_priors: alloc_if::<u16, _>(cdf_detect, m16, CONTEXT_MAP_PRIOR_SIZE),
311 stride_priors: alloc_if::<u16, _>(cdf_detect, m16, STRIDE_PRIOR_SIZE),
312 _stride_pyramid_leaves: stride,
313 weight: [
314 [Weights::new(); NUM_SPEEDS_TO_TRY],
315 [Weights::new(); NUM_SPEEDS_TO_TRY],
316 ],
317 singleton_costs: [[[0.0; NUM_SPEEDS_TO_TRY]; 2]; 3],
318 };
319 if cdf_detect {
320 init_cdfs(ret.cm_priors.slice_mut());
321 init_cdfs(ret.stride_priors.slice_mut());
322 }
323 ret
324 }
325 pub fn take_prediction_mode(
326 &mut self,
327 ) -> interface::PredictionModeContextMap<InputReferenceMut<'a>> {
328 core::mem::replace(
329 &mut self.context_map,
330 interface::PredictionModeContextMap::<InputReferenceMut<'a>> {
331 literal_context_map: InputReferenceMut::default(),
332 predmode_speed_and_distance_context_map: InputReferenceMut::default(),
333 },
334 )
335 }
336 pub fn prediction_mode_mut(
337 &mut self,
338 ) -> &mut interface::PredictionModeContextMap<InputReferenceMut<'a>> {
339 &mut self.context_map
340 }
341 pub fn best_singleton_speeds(
342 &self,
343 cm: bool,
344 combined: bool,
345 ) -> ([SpeedAndMax; 2], [floatX; 2]) {
346 let cost_type_index = if combined {
347 2usize
348 } else if cm {
349 0usize
350 } else {
351 1
352 };
353 let mut ret_cost = [
354 self.singleton_costs[cost_type_index][0][0],
355 self.singleton_costs[cost_type_index][1][0],
356 ];
357 let mut best_indexes = [0, 0];
358 for speed_index in 1..NUM_SPEEDS_TO_TRY {
359 for highness in 0..2 {
360 let cur_cost = self.singleton_costs[cost_type_index][highness][speed_index];
361 if cur_cost < ret_cost[highness] {
362 best_indexes[highness] = speed_index;
363 ret_cost[highness] = cur_cost;
364 }
365 }
366 }
367 let ret_speed = [
368 SpeedAndMax(
369 SPEEDS_TO_SEARCH[best_indexes[0]],
370 MAXES_TO_SEARCH[best_indexes[0]],
371 ),
372 SpeedAndMax(
373 SPEEDS_TO_SEARCH[best_indexes[1]],
374 MAXES_TO_SEARCH[best_indexes[1]],
375 ),
376 ];
377 (ret_speed, ret_cost)
378 }
379 pub fn best_speeds(
380 &mut self, cm: bool,
382 combined: bool,
383 ) -> [SpeedAndMax; 2] {
384 let mut ret = [SpeedAndMax(SPEEDS_TO_SEARCH[0], MAXES_TO_SEARCH[0]); 2];
385 let cost_type_index = if combined {
386 2usize
387 } else if cm {
388 0usize
389 } else {
390 1
391 };
392 for high in 0..2 {
393 ret[high] = min_cost_speed_max(&self.singleton_costs[cost_type_index][high][..]);
398 }
399 ret
400 }
401 pub fn best_speeds_costs(
402 &mut self, cm: bool,
404 combined: bool,
405 ) -> [floatX; 2] {
406 let cost_type_index = if combined {
407 2usize
408 } else if cm {
409 0usize
410 } else {
411 1
412 };
413 let mut ret = [0.0; 2];
414 for high in 0..2 {
415 ret[high] = min_cost_value(&self.singleton_costs[cost_type_index][high][..]);
416 }
417 ret
418 }
419 pub fn free(&mut self, alloc: &mut Alloc) {
420 <Alloc as Allocator<u16>>::free_cell(alloc, core::mem::take(&mut self.cm_priors));
421 <Alloc as Allocator<u16>>::free_cell(alloc, core::mem::take(&mut self.stride_priors));
422 }
423 fn update_cost_base(
424 &mut self,
425 stride_prior: u8,
426 _selected_bits: u8,
427 cm_prior: usize,
428 literal: u8,
429 ) {
430 let upper_nibble = (literal >> 4);
431 let lower_nibble = literal & 0xf;
432 let provisional_cm_high_cdf: [u16; 16];
433 let provisional_cm_low_cdf: [u16; 16];
434 {
435 let cm_cdf_high = get_cm_cdf_high(self.cm_priors.slice_mut(), cm_prior);
436 compute_cost(
437 &mut self.singleton_costs[SINGLETON_CM_STRATEGY][1],
438 cm_cdf_high,
439 upper_nibble,
440 );
441 let best_cm_index = DEFAULT_CM_SPEED_INDEX; provisional_cm_high_cdf = extract_single_cdf(cm_cdf_high, best_cm_index);
444 }
445 {
446 let cm_cdf_low = get_cm_cdf_low(self.cm_priors.slice_mut(), cm_prior, upper_nibble);
447 compute_cost(
448 &mut self.singleton_costs[SINGLETON_CM_STRATEGY][0],
449 cm_cdf_low,
450 lower_nibble,
451 );
452 let best_cm_index = DEFAULT_CM_SPEED_INDEX; provisional_cm_low_cdf = extract_single_cdf(cm_cdf_low, best_cm_index);
455 }
456 {
457 let stride_cdf_high =
458 get_stride_cdf_high(self.stride_priors.slice_mut(), stride_prior, cm_prior);
459 compute_combined_cost(
460 &mut self.singleton_costs[SINGLETON_COMBINED_STRATEGY][1],
461 stride_cdf_high,
462 provisional_cm_high_cdf,
463 upper_nibble,
464 &mut self.weight[1],
465 );
466 compute_cost(
467 &mut self.singleton_costs[SINGLETON_STRIDE_STRATEGY][1],
468 stride_cdf_high,
469 upper_nibble,
470 );
471 update_cdf(stride_cdf_high, upper_nibble);
472 }
473 {
474 let stride_cdf_low = get_stride_cdf_low(
475 self.stride_priors.slice_mut(),
476 stride_prior,
477 cm_prior,
478 upper_nibble,
479 );
480 compute_combined_cost(
481 &mut self.singleton_costs[SINGLETON_COMBINED_STRATEGY][0],
482 stride_cdf_low,
483 provisional_cm_low_cdf,
484 lower_nibble,
485 &mut self.weight[0],
486 );
487 compute_cost(
488 &mut self.singleton_costs[SINGLETON_STRIDE_STRATEGY][0],
489 stride_cdf_low,
490 lower_nibble,
491 );
492 update_cdf(stride_cdf_low, lower_nibble);
493 }
494 {
495 let cm_cdf_high = get_cm_cdf_high(self.cm_priors.slice_mut(), cm_prior);
496 update_cdf(cm_cdf_high, upper_nibble);
497 }
498 {
499 let cm_cdf_low = get_cm_cdf_low(self.cm_priors.slice_mut(), cm_prior, upper_nibble);
500 update_cdf(cm_cdf_low, lower_nibble);
501 }
502 }
503}
504
505impl<'a, 'b, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
506 interface::CommandProcessor<'b> for ContextMapEntropy<'a, Alloc>
507{
508 fn push(&mut self, val: interface::Command<InputReference<'b>>) {
509 push_base(self, val)
510 }
511}
512
513impl<'a, Alloc: alloc::Allocator<u16> + alloc::Allocator<u32> + alloc::Allocator<floatX>>
514 IRInterpreter for ContextMapEntropy<'a, Alloc>
515{
516 fn inc_local_byte_offset(&mut self, inc: usize) {
517 self.local_byte_offset += inc;
518 }
519 fn local_byte_offset(&self) -> usize {
520 self.local_byte_offset
521 }
522 fn update_block_type(&mut self, new_type: u8, stride: u8) {
523 self.block_type = new_type;
524 self.cur_stride = stride;
525 }
526 fn block_type(&self) -> u8 {
527 self.block_type
528 }
529 fn literal_data_at_offset(&self, index: usize) -> u8 {
530 self.input[index]
531 }
532 fn literal_context_map(&self) -> &[u8] {
533 self.context_map.literal_context_map.slice()
534 }
535 fn prediction_mode(&self) -> crate::interface::LiteralPredictionModeNibble {
536 self.context_map.literal_prediction_mode()
537 }
538 fn update_cost(
539 &mut self,
540 stride_prior: [u8; 8],
541 stride_prior_offset: usize,
542 selected_bits: u8,
543 cm_prior: usize,
544 literal: u8,
545 ) {
546 let stride = self.cur_stride as usize;
547 self.update_cost_base(
548 stride_prior[stride_prior_offset.wrapping_sub(stride) & 7],
549 selected_bits,
550 cm_prior,
551 literal,
552 )
553 }
554}