1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/// Macro that generates a packing function taking the number of bits as a const generic
macro_rules! pack_impl {
    ($t:ty, $bytes:literal, $bits:tt) => {
        pub fn pack<const NUM_BITS: usize>(input: &[$t; $bits], output: &mut [u8]) {
            if NUM_BITS == 0 {
                for out in output {
                    *out = 0;
                }
                return;
            }
            assert!(NUM_BITS <= $bytes * 8);
            assert!(output.len() >= NUM_BITS * $bytes);

            let mask = match NUM_BITS {
                $bits => <$t>::MAX,
                _ => ((1 << NUM_BITS) - 1),
            };

            for i in 0..$bits {
                let start_bit = i * NUM_BITS;
                let end_bit = start_bit + NUM_BITS;

                let start_bit_offset = start_bit % $bits;
                let end_bit_offset = end_bit % $bits;
                let start_byte = start_bit / $bits;
                let end_byte = end_bit / $bits;
                if start_byte != end_byte && end_bit_offset != 0 {
                    let a = input[i] << start_bit_offset;
                    let val_a = <$t>::to_le_bytes(a);
                    for i in 0..$bytes {
                        output[start_byte * $bytes + i] |= val_a[i]
                    }

                    let b = (input[i] >> (NUM_BITS - end_bit_offset)) & mask;
                    let val_b = <$t>::to_le_bytes(b);
                    for i in 0..$bytes {
                        output[end_byte * $bytes + i] |= val_b[i]
                    }
                } else {
                    let val = (input[i] & mask) << start_bit_offset;
                    let val = <$t>::to_le_bytes(val);

                    for i in 0..$bytes {
                        output[start_byte * $bytes + i] |= val[i]
                    }
                }
            }
        }
    };
}

/// Macro that generates pack functions that accept num_bits as a parameter
macro_rules! pack {
    ($name:ident, $t:ty, $bytes:literal, $bits:tt) => {
        mod $name {
            pack_impl!($t, $bytes, $bits);
        }

        /// Pack unpacked `input` into `output` with a bit width of `num_bits`
        pub fn $name(input: &[$t; $bits], output: &mut [u8], num_bits: usize) {
            // This will get optimised into a jump table
            seq_macro::seq!(i in 0..=$bits {
                if i == num_bits {
                    return $name::pack::<i>(input, output);
                }
            });
            unreachable!("invalid num_bits {}", num_bits);
        }
    };
}

pack!(pack8, u8, 1, 8);
pack!(pack16, u16, 2, 16);
pack!(pack32, u32, 4, 32);
pack!(pack64, u64, 8, 64);

#[cfg(test)]
mod tests {
    use super::super::unpack::*;
    use super::*;

    #[test]
    fn test_basic() {
        let input = [0u16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
        for num_bits in 4..16 {
            let mut output = [0u8; 16 * 2];
            pack16(&input, &mut output, num_bits);
            let mut other = [0u16; 16];
            unpack16(&output, &mut other, num_bits);
            assert_eq!(other, input);
        }
    }

    #[test]
    fn test_u32() {
        let input = [
            0u32, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0u32, 1, 2, 3, 4, 5, 6, 7, 8,
            9, 10, 11, 12, 13, 14, 15,
        ];
        for num_bits in 4..32 {
            let mut output = [0u8; 32 * 4];
            pack32(&input, &mut output, num_bits);
            let mut other = [0u32; 32];
            unpack32(&output, &mut other, num_bits);
            assert_eq!(other, input);
        }
    }
}