hibitset/
atomic.rs

1use std::default::Default;
2use std::fmt::{Debug, Error as FormatError, Formatter};
3use std::iter::repeat;
4use std::marker::PhantomData;
5use std::ptr;
6use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
7
8use util::*;
9use {BitSetLike, DrainableBitSet};
10
11/// This is similar to a [`BitSet`] but allows setting of value
12/// without unique ownership of the structure
13///
14/// An `AtomicBitSet` has the ability to add an item to the set
15/// without unique ownership (given that the set is big enough).
16/// Removing elements does require unique ownership as an effect
17/// of the hierarchy it holds. Worst case multiple writers set the
18/// same bit twice (but only is told they set it).
19///
20/// It is possible to atomically remove from the set, but not at the
21/// same time as atomically adding. This is because there is no way
22/// to know if layer 1-3 would be left in a consistent state if they are
23/// being cleared and set at the same time.
24///
25/// `AtromicBitSet` resolves this race by disallowing atomic
26/// clearing of bits.
27///
28/// [`BitSet`]: ../struct.BitSet.html
29#[derive(Debug)]
30pub struct AtomicBitSet {
31    layer3: AtomicUsize,
32    layer2: Vec<AtomicUsize>,
33    layer1: Vec<AtomicBlock>,
34}
35
36impl AtomicBitSet {
37    /// Creates an empty `AtomicBitSet`.
38    pub fn new() -> AtomicBitSet {
39        Default::default()
40    }
41
42    /// Adds `id` to the `AtomicBitSet`. Returns `true` if the value was
43    /// already in the set.
44    ///
45    /// Because we cannot safely extend an AtomicBitSet without unique ownership
46    /// this will panic if the Index is out of range.
47    #[inline]
48    pub fn add_atomic(&self, id: Index) -> bool {
49        let (_, p1, p2) = offsets(id);
50
51        // While it is tempting to check of the bit was set and exit here if it
52        // was, this can result in a data race. If this thread and another
53        // thread both set the same bit it is possible for the second thread
54        // to exit before l3 was set. Resulting in the iterator to be in an
55        // incorrect state. The window is small, but it exists.
56        let set = self.layer1[p1].add(id);
57        self.layer2[p2].fetch_or(id.mask(SHIFT2), Ordering::Relaxed);
58        self.layer3.fetch_or(id.mask(SHIFT3), Ordering::Relaxed);
59        set
60    }
61
62    /// Adds `id` to the `BitSet`. Returns `true` if the value was
63    /// already in the set.
64    #[inline]
65    pub fn add(&mut self, id: Index) -> bool {
66        use std::sync::atomic::Ordering::Relaxed;
67
68        let (_, p1, p2) = offsets(id);
69        if self.layer1[p1].add(id) {
70            return true;
71        }
72
73        self.layer2[p2].store(self.layer2[p2].load(Relaxed) | id.mask(SHIFT2), Relaxed);
74        self.layer3
75            .store(self.layer3.load(Relaxed) | id.mask(SHIFT3), Relaxed);
76        false
77    }
78
79    /// Removes `id` from the set, returns `true` if the value
80    /// was removed, and `false` if the value was not set
81    /// to begin with.
82    #[inline]
83    pub fn remove(&mut self, id: Index) -> bool {
84        use std::sync::atomic::Ordering::Relaxed;
85        let (_, p1, p2) = offsets(id);
86
87        // if the bitmask was set we need to clear
88        // its bit from layer0 to 3. the layers above only
89        // should be cleared if the bit cleared was the last bit
90        // in its set
91        //
92        // These are used over a `fetch_and` because we have a mutable
93        // access to the AtomicBitSet so this is sound (and faster)
94        if !self.layer1[p1].remove(id) {
95            return false;
96        }
97        if self.layer1[p1].mask.load(Ordering::Relaxed) != 0 {
98            return true;
99        }
100
101        let v = self.layer2[p2].load(Relaxed) & !id.mask(SHIFT2);
102        self.layer2[p2].store(v, Relaxed);
103        if v != 0 {
104            return true;
105        }
106
107        let v = self.layer3.load(Relaxed) & !id.mask(SHIFT3);
108        self.layer3.store(v, Relaxed);
109        return true;
110    }
111
112    /// Returns `true` if `id` is in the set.
113    #[inline]
114    pub fn contains(&self, id: Index) -> bool {
115        let i = id.offset(SHIFT2);
116        self.layer1[i].contains(id)
117    }
118
119    /// Clear all bits in the set
120    pub fn clear(&mut self) {
121        // This is the same hierarchical-striding used in the iterators.
122        // Using this technique we can avoid clearing segments of the bitset
123        // that are already clear. In the best case when the set is already cleared,
124        // this will only touch the highest layer.
125
126        let (mut m3, mut m2) = (self.layer3.swap(0, Ordering::Relaxed), 0usize);
127        let mut offset = 0;
128
129        loop {
130            if m2 != 0 {
131                let bit = m2.trailing_zeros() as usize;
132                m2 &= !(1 << bit);
133
134                // layer 1 & 0 are cleared unconditionally. it's only 32-64 words
135                // and the extra logic to select the correct works is slower
136                // then just clearing them all.
137                self.layer1[offset + bit].clear();
138                continue;
139            }
140
141            if m3 != 0 {
142                let bit = m3.trailing_zeros() as usize;
143                m3 &= !(1 << bit);
144                offset = bit << BITS;
145                m2 = self.layer2[bit].swap(0, Ordering::Relaxed);
146                continue;
147            }
148            break;
149        }
150    }
151}
152
153impl BitSetLike for AtomicBitSet {
154    #[inline]
155    fn layer3(&self) -> usize {
156        self.layer3.load(Ordering::Relaxed)
157    }
158    #[inline]
159    fn layer2(&self, i: usize) -> usize {
160        self.layer2[i].load(Ordering::Relaxed)
161    }
162    #[inline]
163    fn layer1(&self, i: usize) -> usize {
164        self.layer1[i].mask.load(Ordering::Relaxed)
165    }
166    #[inline]
167    fn layer0(&self, i: usize) -> usize {
168        let (o1, o0) = (i >> BITS, i & ((1 << BITS) - 1));
169        self.layer1[o1]
170            .atom
171            .get()
172            .map(|layer0| layer0[o0].load(Ordering::Relaxed))
173            .unwrap_or(0)
174    }
175    #[inline]
176    fn contains(&self, i: Index) -> bool {
177        self.contains(i)
178    }
179}
180
181impl DrainableBitSet for AtomicBitSet {
182    #[inline]
183    fn remove(&mut self, i: Index) -> bool {
184        self.remove(i)
185    }
186}
187
188impl Default for AtomicBitSet {
189    fn default() -> Self {
190        AtomicBitSet {
191            layer3: Default::default(),
192            layer2: repeat(0)
193                .map(|_| AtomicUsize::new(0))
194                .take(1 << BITS)
195                .collect(),
196            layer1: repeat(0)
197                .map(|_| AtomicBlock::new())
198                .take(1 << (2 * BITS))
199                .collect(),
200        }
201    }
202}
203
204struct OnceAtom {
205    inner: AtomicPtr<[AtomicUsize; 1 << BITS]>,
206    marker: PhantomData<Option<Box<[AtomicUsize; 1 << BITS]>>>,
207}
208
209impl Drop for OnceAtom {
210    fn drop(&mut self) {
211        let ptr = *self.inner.get_mut();
212        if !ptr.is_null() {
213            // SAFETY: If the pointer is not null, we created it from
214            // `Box::into_raw` in `Self::atom_get_or_init`.
215            drop(unsafe { Box::from_raw(ptr) });
216        }
217    }
218}
219
220impl OnceAtom {
221    fn new() -> Self {
222        Self {
223            inner: AtomicPtr::new(ptr::null_mut()),
224            marker: PhantomData,
225        }
226    }
227
228    fn get_or_init(&self) -> &[AtomicUsize; 1 << BITS] {
229        let current_ptr = self.inner.load(Ordering::Acquire);
230        let ptr = if current_ptr.is_null() {
231            const ZERO: AtomicUsize = AtomicUsize::new(0);
232            let new_ptr = Box::into_raw(Box::new([ZERO; 1 << BITS]));
233            if let Err(existing_ptr) = self.inner.compare_exchange(
234                ptr::null_mut(),
235                new_ptr,
236                // On success, Release matches any Acquire loads of the non-null
237                // pointer, to ensure the new box is visible to other threads.
238                Ordering::Release,
239                Ordering::Acquire,
240            ) {
241                // SAFETY: We obtained this pointer from `Box::into_raw` above
242                // and failed to publish it to the `AtomicPtr`.
243                drop(unsafe { Box::from_raw(new_ptr) });
244                existing_ptr
245            } else {
246                new_ptr
247            }
248        } else {
249            current_ptr
250        };
251
252        // SAFETY: We checked that this pointer is not null (either by
253        // `.is_null()` check, `compare_exhange`, or from `Box::into_raw`). We
254        // created from `Box::into_raw` (at some point) and we only use it to
255        // create immutable references (unless we have exclusive access to self)
256        unsafe { &*ptr }
257    }
258
259    fn get(&self) -> Option<&[AtomicUsize; 1 << BITS]> {
260        let ptr = self.inner.load(Ordering::Acquire);
261        // SAFETY: If it is not null, we created this pointer from
262        // `Box::into_raw` and only use it to create immutable references
263        // (unless we have exclusive access to self)
264        unsafe { ptr.as_ref() }
265    }
266
267    fn get_mut(&mut self) -> Option<&mut [AtomicUsize; 1 << BITS]> {
268        let ptr = self.inner.get_mut();
269        // SAFETY: If this is not null, we created this pointer from
270        // `Box::into_raw` and we have an exclusive borrow of self.
271        unsafe { ptr.as_mut() }
272    }
273}
274
275struct AtomicBlock {
276    mask: AtomicUsize,
277    atom: OnceAtom,
278}
279
280impl AtomicBlock {
281    fn new() -> AtomicBlock {
282        AtomicBlock {
283            mask: AtomicUsize::new(0),
284            atom: OnceAtom::new(),
285        }
286    }
287
288    fn add(&self, id: Index) -> bool {
289        let (i, m) = (id.row(SHIFT1), id.mask(SHIFT0));
290        let old = self.atom.get_or_init()[i].fetch_or(m, Ordering::Relaxed);
291        self.mask.fetch_or(id.mask(SHIFT1), Ordering::Relaxed);
292        old & m != 0
293    }
294
295    fn contains(&self, id: Index) -> bool {
296        self.atom
297            .get()
298            .map(|layer0| layer0[id.row(SHIFT1)].load(Ordering::Relaxed) & id.mask(SHIFT0) != 0)
299            .unwrap_or(false)
300    }
301
302    fn remove(&mut self, id: Index) -> bool {
303        if let Some(layer0) = self.atom.get_mut() {
304            let (i, m) = (id.row(SHIFT1), !id.mask(SHIFT0));
305            let v = layer0[i].get_mut();
306            let was_set = *v & id.mask(SHIFT0) == id.mask(SHIFT0);
307            *v = *v & m;
308            if *v == 0 {
309                // no other bits are set
310                // so unset bit in the next level up
311                *self.mask.get_mut() &= !id.mask(SHIFT1);
312            }
313            was_set
314        } else {
315            false
316        }
317    }
318
319    fn clear(&mut self) {
320        *self.mask.get_mut() = 0;
321        self.atom.get_mut().map(|layer0| {
322            for l in layer0 {
323                *l.get_mut() = 0;
324            }
325        });
326    }
327}
328
329impl Debug for AtomicBlock {
330    fn fmt(&self, f: &mut Formatter) -> Result<(), FormatError> {
331        f.debug_struct("AtomicBlock")
332            .field("mask", &self.mask)
333            .field("atom", &self.atom.get().unwrap().iter())
334            .finish()
335    }
336}
337
338#[cfg(test)]
339mod atomic_set_test {
340    use {AtomicBitSet, BitSetAnd, BitSetLike};
341
342    #[test]
343    fn insert() {
344        let mut c = AtomicBitSet::new();
345        for i in 0..1_000 {
346            assert!(!c.add(i));
347            assert!(c.add(i));
348        }
349
350        for i in 0..1_000 {
351            assert!(c.contains(i));
352        }
353    }
354
355    #[test]
356    fn insert_100k() {
357        let mut c = AtomicBitSet::new();
358        for i in 0..100_000 {
359            assert!(!c.add(i));
360            assert!(c.add(i));
361        }
362
363        for i in 0..100_000 {
364            assert!(c.contains(i));
365        }
366    }
367
368    #[test]
369    fn add_atomic() {
370        let c = AtomicBitSet::new();
371        for i in 0..1_000 {
372            assert!(!c.add_atomic(i));
373            assert!(c.add_atomic(i));
374        }
375
376        for i in 0..1_000 {
377            assert!(c.contains(i));
378        }
379    }
380
381    #[test]
382    fn add_atomic_100k() {
383        let c = AtomicBitSet::new();
384        for i in 0..100_000 {
385            assert!(!c.add_atomic(i));
386            assert!(c.add_atomic(i));
387        }
388
389        for i in 0..100_000 {
390            assert!(c.contains(i));
391        }
392    }
393
394    #[test]
395    fn remove() {
396        let mut c = AtomicBitSet::new();
397        for i in 0..1_000 {
398            assert!(!c.add(i));
399        }
400
401        for i in 0..1_000 {
402            assert!(c.contains(i));
403            assert!(c.remove(i));
404            assert!(!c.contains(i));
405            assert!(!c.remove(i));
406        }
407    }
408
409    #[test]
410    fn iter() {
411        let mut c = AtomicBitSet::new();
412        for i in 0..100_000 {
413            c.add(i);
414        }
415
416        let mut count = 0;
417        for (idx, i) in c.iter().enumerate() {
418            count += 1;
419            assert_eq!(idx, i as usize);
420        }
421        assert_eq!(count, 100_000);
422    }
423
424    #[test]
425    fn iter_odd_even() {
426        let mut odd = AtomicBitSet::new();
427        let mut even = AtomicBitSet::new();
428        for i in 0..100_000 {
429            if i % 2 == 1 {
430                odd.add(i);
431            } else {
432                even.add(i);
433            }
434        }
435
436        assert_eq!((&odd).iter().count(), 50_000);
437        assert_eq!((&even).iter().count(), 50_000);
438        assert_eq!(BitSetAnd(&odd, &even).iter().count(), 0);
439    }
440
441    #[test]
442    fn clear() {
443        let mut set = AtomicBitSet::new();
444        for i in 0..1_000 {
445            set.add(i);
446        }
447
448        assert_eq!((&set).iter().sum::<u32>(), 500_500 - 1_000);
449
450        assert_eq!((&set).iter().count(), 1_000);
451        set.clear();
452        assert_eq!((&set).iter().count(), 0);
453
454        for i in 0..1_000 {
455            set.add(i * 64);
456        }
457
458        assert_eq!((&set).iter().count(), 1_000);
459        set.clear();
460        assert_eq!((&set).iter().count(), 0);
461
462        for i in 0..1_000 {
463            set.add(i * 1_000);
464        }
465
466        assert_eq!((&set).iter().count(), 1_000);
467        set.clear();
468        assert_eq!((&set).iter().count(), 0);
469
470        for i in 0..100 {
471            set.add(i * 10_000);
472        }
473
474        assert_eq!((&set).iter().count(), 100);
475        set.clear();
476        assert_eq!((&set).iter().count(), 0);
477
478        for i in 0..10 {
479            set.add(i * 10_000);
480        }
481
482        assert_eq!((&set).iter().count(), 10);
483        set.clear();
484        assert_eq!((&set).iter().count(), 0);
485    }
486}