hibitset/
util.rs

1/// Type used for indexing.
2pub type Index = u32;
3
4/// Base two log of the number of bits in a usize.
5#[cfg(target_pointer_width = "64")]
6pub const BITS: usize = 6;
7#[cfg(target_pointer_width = "32")]
8pub const BITS: usize = 5;
9/// Amount of layers in the hierarchical bitset.
10pub const LAYERS: usize = 4;
11pub const MAX: usize = BITS * LAYERS;
12/// Maximum amount of bits per bitset.
13pub const MAX_EID: usize = 2 << MAX - 1;
14
15/// Layer0 shift (bottom layer, true bitset).
16pub const SHIFT0: usize = 0;
17/// Layer1 shift (third layer).
18pub const SHIFT1: usize = SHIFT0 + BITS;
19/// Layer2 shift (second layer).
20pub const SHIFT2: usize = SHIFT1 + BITS;
21/// Top layer shift.
22pub const SHIFT3: usize = SHIFT2 + BITS;
23
24pub trait Row: Sized + Copy {
25    /// Location of the bit in the row.
26    fn row(self, shift: usize) -> usize;
27
28    /// Index of the row that the bit is in.
29    fn offset(self, shift: usize) -> usize;
30
31    /// Bitmask of the row the bit is in.
32    #[inline(always)]
33    fn mask(self, shift: usize) -> usize {
34        1usize << self.row(shift)
35    }
36}
37
38impl Row for Index {
39    #[inline(always)]
40    fn row(self, shift: usize) -> usize {
41        ((self >> shift) as usize) & ((1 << BITS) - 1)
42    }
43
44    #[inline(always)]
45    fn offset(self, shift: usize) -> usize {
46        self as usize / (1 << shift)
47    }
48}
49
50/// Helper method for getting parent offsets of 3 layers at once.
51///
52/// Returns them in (Layer0, Layer1, Layer2) order.
53#[inline]
54pub fn offsets(bit: Index) -> (usize, usize, usize) {
55    (bit.offset(SHIFT1), bit.offset(SHIFT2), bit.offset(SHIFT3))
56}
57
58/// Finds the highest bit that splits set bits of the `usize`
59/// to half (rounding up).
60///
61/// Returns `None` if the `usize` has only one or zero set bits.
62///
63/// # Examples
64/// ````rust,ignore
65/// use hibitset::util::average_ones;
66///
67/// assert_eq!(Some(4), average_ones(0b10110));
68/// assert_eq!(Some(5), average_ones(0b100010));
69/// assert_eq!(None, average_ones(0));
70/// assert_eq!(None, average_ones(1));
71/// ````
72// TODO: Can 64/32 bit variants be merged to one implementation?
73// Seems that this would need integer generics to do.
74#[cfg(feature = "parallel")]
75pub fn average_ones(n: usize) -> Option<usize> {
76    #[cfg(target_pointer_width = "64")]
77    let average = average_ones_u64(n as u64).map(|n| n as usize);
78
79    #[cfg(target_pointer_width = "32")]
80    let average = average_ones_u32(n as u32).map(|n| n as usize);
81
82    average
83}
84
85#[cfg(all(any(test, target_pointer_width = "32"), feature = "parallel"))]
86fn average_ones_u32(n: u32) -> Option<u32> {
87    // !0 / ((1 << (1 << n)) | 1)
88    const PAR: [u32; 5] = [!0 / 0x3, !0 / 0x5, !0 / 0x11, !0 / 0x101, !0 / 0x10001];
89
90    // Counting set bits in parallel
91    let a = n - ((n >> 1) & PAR[0]);
92    let b = (a & PAR[1]) + ((a >> 2) & PAR[1]);
93    let c = (b + (b >> 4)) & PAR[2];
94    let d = (c + (c >> 8)) & PAR[3];
95    let mut cur = d >> 16;
96    let count = (d + cur) & PAR[4];
97    if count <= 1 {
98        return None;
99    }
100
101    // Amount of set bits that are wanted for both sides
102    let mut target = count / 2;
103
104    // Binary search
105    let mut result = 32;
106    {
107        let mut descend = |child, child_stride, child_mask| {
108            if cur < target {
109                result -= 2 * child_stride;
110                target -= cur;
111            }
112            // Descend to upper half or lower half
113            // depending on are we over or under
114            cur = (child >> (result - child_stride)) & child_mask;
115        };
116        //(!PAR[n] & (PAR[n] + 1)) - 1
117        descend(c, 8, 16 - 1); // PAR[3]
118        descend(b, 4, 8 - 1); // PAR[2]
119        descend(a, 2, 4 - 1); // PAR[1]
120        descend(n, 1, 2 - 1); // PAR[0]
121    }
122    if cur < target {
123        result -= 1;
124    }
125
126    Some(result - 1)
127}
128
129#[cfg(all(any(test, target_pointer_width = "64"), feature = "parallel"))]
130fn average_ones_u64(n: u64) -> Option<u64> {
131    // !0 / ((1 << (1 << n)) | 1)
132    const PAR: [u64; 6] = [
133        !0 / 0x3,
134        !0 / 0x5,
135        !0 / 0x11,
136        !0 / 0x101,
137        !0 / 0x10001,
138        !0 / 0x100000001,
139    ];
140
141    // Counting set bits in parallel
142    let a = n - ((n >> 1) & PAR[0]);
143    let b = (a & PAR[1]) + ((a >> 2) & PAR[1]);
144    let c = (b + (b >> 4)) & PAR[2];
145    let d = (c + (c >> 8)) & PAR[3];
146    let e = (d + (d >> 16)) & PAR[4];
147    let mut cur = e >> 32;
148    let count = (e + cur) & PAR[5];
149    if count <= 1 {
150        return None;
151    }
152
153    // Amount of set bits that are wanted for both sides
154    let mut target = count / 2;
155
156    // Binary search
157    let mut result = 64;
158    {
159        let mut descend = |child, child_stride, child_mask| {
160            if cur < target {
161                result -= 2 * child_stride;
162                target -= cur;
163            }
164            // Descend to upper half or lower half
165            // depending on are we over or under
166            cur = (child >> (result - child_stride)) & child_mask;
167        };
168        //(!PAR[n] & (PAR[n] + 1)) - 1
169        descend(d, 16, 256 - 1); // PAR[4]
170        descend(c, 8, 16 - 1); // PAR[3]
171        descend(b, 4, 8 - 1); // PAR[2]
172        descend(a, 2, 4 - 1); // PAR[1]
173        descend(n, 1, 2 - 1); // PAR[0]
174    }
175    if cur < target {
176        result -= 1;
177    }
178
179    Some(result - 1)
180}
181
182#[cfg(all(test, feature = "parallel"))]
183mod test_average_ones {
184    use super::*;
185    #[test]
186    fn parity_0_average_ones_u32() {
187        struct EvenParity(u32);
188
189        impl Iterator for EvenParity {
190            type Item = u32;
191            fn next(&mut self) -> Option<Self::Item> {
192                if self.0 == u32::max_value() {
193                    return None;
194                }
195                self.0 += 1;
196                while self.0.count_ones() & 1 != 0 {
197                    if self.0 == u32::max_value() {
198                        return None;
199                    }
200                    self.0 += 1;
201                }
202                Some(self.0)
203            }
204        }
205
206        let steps = 1000;
207        for i in 0..steps {
208            let pos = i * (u32::max_value() / steps);
209            for i in EvenParity(pos).take(steps as usize) {
210                let mask = (1 << average_ones_u32(i).unwrap_or(31)) - 1;
211                assert_eq!((i & mask).count_ones(), (i & !mask).count_ones(), "{:x}", i);
212            }
213        }
214    }
215
216    #[test]
217    fn parity_1_average_ones_u32() {
218        struct OddParity(u32);
219
220        impl Iterator for OddParity {
221            type Item = u32;
222            fn next(&mut self) -> Option<Self::Item> {
223                if self.0 == u32::max_value() {
224                    return None;
225                }
226                self.0 += 1;
227                while self.0.count_ones() & 1 == 0 {
228                    if self.0 == u32::max_value() {
229                        return None;
230                    }
231                    self.0 += 1;
232                }
233                Some(self.0)
234            }
235        }
236
237        let steps = 1000;
238        for i in 0..steps {
239            let pos = i * (u32::max_value() / steps);
240            for i in OddParity(pos).take(steps as usize) {
241                let mask = (1 << average_ones_u32(i).unwrap_or(31)) - 1;
242                let a = (i & mask).count_ones();
243                let b = (i & !mask).count_ones();
244                if a < b {
245                    assert_eq!(a + 1, b, "{:x}", i);
246                } else if b < a {
247                    assert_eq!(a, b + 1, "{:x}", i);
248                } else {
249                    panic!("Odd parity shouldn't split in exactly half");
250                }
251            }
252        }
253    }
254
255    #[test]
256    fn empty_average_ones_u32() {
257        assert_eq!(None, average_ones_u32(0));
258    }
259
260    #[test]
261    fn singleton_average_ones_u32() {
262        for i in 0..32 {
263            assert_eq!(None, average_ones_u32(1 << i), "{:x}", i);
264        }
265    }
266
267    #[test]
268    fn parity_0_average_ones_u64() {
269        struct EvenParity(u64);
270
271        impl Iterator for EvenParity {
272            type Item = u64;
273            fn next(&mut self) -> Option<Self::Item> {
274                if self.0 == u64::max_value() {
275                    return None;
276                }
277                self.0 += 1;
278                while self.0.count_ones() & 1 != 0 {
279                    if self.0 == u64::max_value() {
280                        return None;
281                    }
282                    self.0 += 1;
283                }
284                Some(self.0)
285            }
286        }
287
288        let steps = 1000;
289        for i in 0..steps {
290            let pos = i * (u64::max_value() / steps);
291            for i in EvenParity(pos).take(steps as usize) {
292                let mask = (1 << average_ones_u64(i).unwrap_or(63)) - 1;
293                assert_eq!((i & mask).count_ones(), (i & !mask).count_ones(), "{:x}", i);
294            }
295        }
296    }
297
298    #[test]
299    fn parity_1_average_ones_u64() {
300        struct OddParity(u64);
301
302        impl Iterator for OddParity {
303            type Item = u64;
304            fn next(&mut self) -> Option<Self::Item> {
305                if self.0 == u64::max_value() {
306                    return None;
307                }
308                self.0 += 1;
309                while self.0.count_ones() & 1 == 0 {
310                    if self.0 == u64::max_value() {
311                        return None;
312                    }
313                    self.0 += 1;
314                }
315                Some(self.0)
316            }
317        }
318
319        let steps = 1000;
320        for i in 0..steps {
321            let pos = i * (u64::max_value() / steps);
322            for i in OddParity(pos).take(steps as usize) {
323                let mask = (1 << average_ones_u64(i).unwrap_or(63)) - 1;
324                let a = (i & mask).count_ones();
325                let b = (i & !mask).count_ones();
326                if a < b {
327                    assert_eq!(a + 1, b, "{:x}", i);
328                } else if b < a {
329                    assert_eq!(a, b + 1, "{:x}", i);
330                } else {
331                    panic!("Odd parity shouldn't split in exactly half");
332                }
333            }
334        }
335    }
336
337    #[test]
338    fn empty_average_ones_u64() {
339        assert_eq!(None, average_ones_u64(0));
340    }
341
342    #[test]
343    fn singleton_average_ones_u64() {
344        for i in 0..64 {
345            assert_eq!(None, average_ones_u64(1 << i), "{:x}", i);
346        }
347    }
348
349    #[test]
350    fn average_ones_agree_u32_u64() {
351        let steps = 1000;
352        for i in 0..steps {
353            let pos = i * (u32::max_value() / steps);
354            for i in pos..steps {
355                assert_eq!(
356                    average_ones_u32(i),
357                    average_ones_u64(i as u64).map(|n| n as u32),
358                    "{:x}",
359                    i
360                );
361            }
362        }
363    }
364
365    #[test]
366    fn specific_values() {
367        assert_eq!(Some(4), average_ones_u32(0b10110));
368        assert_eq!(Some(5), average_ones_u32(0b100010));
369        assert_eq!(None, average_ones_u32(0));
370        assert_eq!(None, average_ones_u32(1));
371
372        assert_eq!(Some(4), average_ones_u64(0b10110));
373        assert_eq!(Some(5), average_ones_u64(0b100010));
374        assert_eq!(None, average_ones_u64(0));
375        assert_eq!(None, average_ones_u64(1));
376    }
377}