1#![deny(missing_docs)]
47
48#[cfg(test)]
49extern crate rand;
50#[cfg(feature = "parallel")]
51extern crate rayon;
52
53mod atomic;
54mod iter;
55mod ops;
56mod util;
57
58pub use atomic::AtomicBitSet;
59pub use iter::{BitIter, DrainBitIter};
60#[cfg(feature = "parallel")]
61pub use iter::{BitParIter, BitProducer};
62pub use ops::{BitSetAll, BitSetAnd, BitSetNot, BitSetOr, BitSetXor};
63
64use util::*;
65
66#[derive(Clone, Debug, Default)]
72pub struct BitSet {
73 layer3: usize,
74 layer2: Vec<usize>,
75 layer1: Vec<usize>,
76 layer0: Vec<usize>,
77}
78
79impl BitSet {
80 pub fn new() -> BitSet {
82 Default::default()
83 }
84
85 #[inline]
86 fn valid_range(max: Index) {
87 if (MAX_EID as u32) < max {
88 panic!("Expected index to be less then {}, found {}", MAX_EID, max);
89 }
90 }
91
92 pub fn with_capacity(max: Index) -> BitSet {
94 Self::valid_range(max);
95 let mut value = BitSet::new();
96 value.extend(max);
97 value
98 }
99
100 #[inline(never)]
101 fn extend(&mut self, id: Index) {
102 Self::valid_range(id);
103 let (p0, p1, p2) = offsets(id);
104
105 Self::fill_up(&mut self.layer2, p2);
106 Self::fill_up(&mut self.layer1, p1);
107 Self::fill_up(&mut self.layer0, p0);
108 }
109
110 fn fill_up(vec: &mut Vec<usize>, upper_index: usize) {
111 if vec.len() <= upper_index {
112 vec.resize(upper_index + 1, 0);
113 }
114 }
115
116 #[inline(never)]
119 fn add_slow(&mut self, id: Index) {
120 let (_, p1, p2) = offsets(id);
121 self.layer1[p1] |= id.mask(SHIFT1);
122 self.layer2[p2] |= id.mask(SHIFT2);
123 self.layer3 |= id.mask(SHIFT3);
124 }
125
126 #[inline]
129 pub fn add(&mut self, id: Index) -> bool {
130 let (p0, mask) = (id.offset(SHIFT1), id.mask(SHIFT0));
131
132 if p0 >= self.layer0.len() {
133 self.extend(id);
134 }
135
136 if self.layer0[p0] & mask != 0 {
137 return true;
138 }
139
140 let old = self.layer0[p0];
143 self.layer0[p0] |= mask;
144 if old == 0 {
145 self.add_slow(id);
146 }
147 false
148 }
149
150 fn layer_mut(&mut self, level: usize, idx: usize) -> &mut usize {
151 match level {
152 0 => {
153 Self::fill_up(&mut self.layer0, idx);
154 &mut self.layer0[idx]
155 }
156 1 => {
157 Self::fill_up(&mut self.layer1, idx);
158 &mut self.layer1[idx]
159 }
160 2 => {
161 Self::fill_up(&mut self.layer2, idx);
162 &mut self.layer2[idx]
163 }
164 3 => &mut self.layer3,
165 _ => panic!("Invalid layer: {}", level),
166 }
167 }
168
169 #[inline]
173 pub fn remove(&mut self, id: Index) -> bool {
174 let (p0, p1, p2) = offsets(id);
175
176 if p0 >= self.layer0.len() {
177 return false;
178 }
179
180 if self.layer0[p0] & id.mask(SHIFT0) == 0 {
181 return false;
182 }
183
184 self.layer0[p0] &= !id.mask(SHIFT0);
189 if self.layer0[p0] != 0 {
190 return true;
191 }
192
193 self.layer1[p1] &= !id.mask(SHIFT1);
194 if self.layer1[p1] != 0 {
195 return true;
196 }
197
198 self.layer2[p2] &= !id.mask(SHIFT2);
199 if self.layer2[p2] != 0 {
200 return true;
201 }
202
203 self.layer3 &= !id.mask(SHIFT3);
204 return true;
205 }
206
207 #[inline]
209 pub fn contains(&self, id: Index) -> bool {
210 let p0 = id.offset(SHIFT1);
211 p0 < self.layer0.len() && (self.layer0[p0] & id.mask(SHIFT0)) != 0
212 }
213
214 #[inline]
216 pub fn contains_set(&self, other: &BitSet) -> bool {
217 for id in other.iter() {
218 if !self.contains(id) {
219 return false;
220 }
221 }
222 true
223 }
224
225 pub fn clear(&mut self) {
227 self.layer0.clear();
228 self.layer1.clear();
229 self.layer2.clear();
230 self.layer3 = 0;
231 }
232
233 #[cfg(target_pointer_width = "32")]
244 pub const BITS_PER_USIZE: usize = 32;
245
246 #[cfg(target_pointer_width = "64")]
257 pub const BITS_PER_USIZE: usize = 64;
258
259 pub fn layer0_as_slice(&self) -> &[usize] {
286 self.layer0.as_slice()
287 }
288
289 pub const LAYER1_GRANULARITY: usize = Self::BITS_PER_USIZE;
302
303 pub fn layer1_as_slice(&self) -> &[usize] {
332 self.layer1.as_slice()
333 }
334
335 pub const LAYER2_GRANULARITY: usize = Self::LAYER1_GRANULARITY * Self::BITS_PER_USIZE;
348
349 pub fn layer2_as_slice(&self) -> &[usize] {
377 self.layer2.as_slice()
378 }
379}
380
381pub trait BitSetLike {
395 fn get_from_layer(&self, layer: usize, idx: usize) -> usize {
399 match layer {
400 0 => self.layer0(idx),
401 1 => self.layer1(idx),
402 2 => self.layer2(idx),
403 3 => self.layer3(),
404 _ => panic!("Invalid layer: {}", layer),
405 }
406 }
407
408 fn is_empty(&self) -> bool {
410 self.layer3() == 0
411 }
412
413 fn layer3(&self) -> usize;
416
417 fn layer2(&self, i: usize) -> usize;
420
421 fn layer1(&self, i: usize) -> usize;
424
425 fn layer0(&self, i: usize) -> usize;
428
429 fn contains(&self, i: Index) -> bool;
431
432 fn iter(self) -> BitIter<Self>
434 where
435 Self: Sized,
436 {
437 let layer3 = self.layer3();
438
439 BitIter::new(self, [0, 0, 0, layer3], [0; LAYERS - 1])
440 }
441
442 #[cfg(feature = "parallel")]
444 fn par_iter(self) -> BitParIter<Self>
445 where
446 Self: Sized,
447 {
448 BitParIter::new(self)
449 }
450}
451
452pub trait DrainableBitSet: BitSetLike {
454 fn remove(&mut self, i: Index) -> bool;
458
459 fn drain<'a>(&'a mut self) -> DrainBitIter<'a, Self>
461 where
462 Self: Sized,
463 {
464 let layer3 = self.layer3();
465
466 DrainBitIter::new(self, [0, 0, 0, layer3], [0; LAYERS - 1])
467 }
468}
469
470impl<'a, T> BitSetLike for &'a T
471where
472 T: BitSetLike + ?Sized,
473{
474 #[inline]
475 fn layer3(&self) -> usize {
476 (*self).layer3()
477 }
478
479 #[inline]
480 fn layer2(&self, i: usize) -> usize {
481 (*self).layer2(i)
482 }
483
484 #[inline]
485 fn layer1(&self, i: usize) -> usize {
486 (*self).layer1(i)
487 }
488
489 #[inline]
490 fn layer0(&self, i: usize) -> usize {
491 (*self).layer0(i)
492 }
493
494 #[inline]
495 fn contains(&self, i: Index) -> bool {
496 (*self).contains(i)
497 }
498}
499
500impl<'a, T> BitSetLike for &'a mut T
501where
502 T: BitSetLike + ?Sized,
503{
504 #[inline]
505 fn layer3(&self) -> usize {
506 (**self).layer3()
507 }
508
509 #[inline]
510 fn layer2(&self, i: usize) -> usize {
511 (**self).layer2(i)
512 }
513
514 #[inline]
515 fn layer1(&self, i: usize) -> usize {
516 (**self).layer1(i)
517 }
518
519 #[inline]
520 fn layer0(&self, i: usize) -> usize {
521 (**self).layer0(i)
522 }
523
524 #[inline]
525 fn contains(&self, i: Index) -> bool {
526 (**self).contains(i)
527 }
528}
529
530impl<'a, T> DrainableBitSet for &'a mut T
531where
532 T: DrainableBitSet,
533{
534 #[inline]
535 fn remove(&mut self, i: Index) -> bool {
536 (**self).remove(i)
537 }
538}
539
540impl BitSetLike for BitSet {
541 #[inline]
542 fn layer3(&self) -> usize {
543 self.layer3
544 }
545
546 #[inline]
547 fn layer2(&self, i: usize) -> usize {
548 self.layer2.get(i).map(|&x| x).unwrap_or(0)
549 }
550
551 #[inline]
552 fn layer1(&self, i: usize) -> usize {
553 self.layer1.get(i).map(|&x| x).unwrap_or(0)
554 }
555
556 #[inline]
557 fn layer0(&self, i: usize) -> usize {
558 self.layer0.get(i).map(|&x| x).unwrap_or(0)
559 }
560
561 #[inline]
562 fn contains(&self, i: Index) -> bool {
563 self.contains(i)
564 }
565}
566
567impl DrainableBitSet for BitSet {
568 #[inline]
569 fn remove(&mut self, i: Index) -> bool {
570 self.remove(i)
571 }
572}
573
574impl PartialEq for BitSet {
575 #[inline]
576 fn eq(&self, rhv: &BitSet) -> bool {
577 if self.layer3 != rhv.layer3 {
578 return false;
579 }
580 if self.layer2.len() != rhv.layer2.len()
581 || self.layer1.len() != rhv.layer1.len()
582 || self.layer0.len() != rhv.layer0.len()
583 {
584 return false;
585 }
586
587 for i in 0..self.layer2.len() {
588 if self.layer2(i) != rhv.layer2(i) {
589 return false;
590 }
591 }
592 for i in 0..self.layer1.len() {
593 if self.layer1(i) != rhv.layer1(i) {
594 return false;
595 }
596 }
597 for i in 0..self.layer0.len() {
598 if self.layer0(i) != rhv.layer0(i) {
599 return false;
600 }
601 }
602
603 true
604 }
605}
606impl Eq for BitSet {}
607
608#[cfg(test)]
609mod tests {
610 use super::{BitSet, BitSetAnd, BitSetLike, BitSetNot};
611
612 #[test]
613 fn insert() {
614 let mut c = BitSet::new();
615 for i in 0..1_000 {
616 assert!(!c.add(i));
617 assert!(c.add(i));
618 }
619
620 for i in 0..1_000 {
621 assert!(c.contains(i));
622 }
623 }
624
625 #[test]
626 fn insert_100k() {
627 let mut c = BitSet::new();
628 for i in 0..100_000 {
629 assert!(!c.add(i));
630 assert!(c.add(i));
631 }
632
633 for i in 0..100_000 {
634 assert!(c.contains(i));
635 }
636 }
637 #[test]
638 fn remove() {
639 let mut c = BitSet::new();
640 for i in 0..1_000 {
641 assert!(!c.add(i));
642 }
643
644 for i in 0..1_000 {
645 assert!(c.contains(i));
646 assert!(c.remove(i));
647 assert!(!c.contains(i));
648 assert!(!c.remove(i));
649 }
650 }
651
652 #[test]
653 fn iter() {
654 let mut c = BitSet::new();
655 for i in 0..100_000 {
656 c.add(i);
657 }
658
659 let mut count = 0;
660 for (idx, i) in c.iter().enumerate() {
661 count += 1;
662 assert_eq!(idx, i as usize);
663 }
664 assert_eq!(count, 100_000);
665 }
666
667 #[test]
668 fn iter_odd_even() {
669 let mut odd = BitSet::new();
670 let mut even = BitSet::new();
671 for i in 0..100_000 {
672 if i % 2 == 1 {
673 odd.add(i);
674 } else {
675 even.add(i);
676 }
677 }
678
679 assert_eq!((&odd).iter().count(), 50_000);
680 assert_eq!((&even).iter().count(), 50_000);
681 assert_eq!(BitSetAnd(&odd, &even).iter().count(), 0);
682 }
683
684 #[test]
685 fn iter_random_add() {
686 use rand::prelude::*;
687
688 let mut set = BitSet::new();
689 let mut rng = thread_rng();
690 let limit = 1_048_576;
691 let mut added = 0;
692 for _ in 0..(limit / 10) {
693 let index = rng.gen_range(0, limit);
694 if !set.add(index) {
695 added += 1;
696 }
697 }
698 assert_eq!(set.iter().count(), added as usize);
699 }
700
701 #[test]
702 fn iter_clusters() {
703 let mut set = BitSet::new();
704 for x in 0..8 {
705 let x = (x * 3) << (::BITS * 2); for y in 0..8 {
707 let y = (y * 3) << (::BITS);
708 for z in 0..8 {
709 let z = z * 2;
710 set.add(x + y + z);
711 }
712 }
713 }
714 assert_eq!(set.iter().count(), 8usize.pow(3));
715 }
716
717 #[test]
718 fn not() {
719 let mut c = BitSet::new();
720 for i in 0..10_000 {
721 if i % 2 == 1 {
722 c.add(i);
723 }
724 }
725 let d = BitSetNot(c);
726 for (idx, i) in d.iter().take(5_000).enumerate() {
727 assert_eq!(idx * 2, i as usize);
728 }
729 }
730}
731
732#[cfg(all(test, feature = "parallel"))]
733mod test_parallel {
734 use super::{BitSet, BitSetAnd, BitSetLike};
735 use rayon::iter::ParallelIterator;
736
737 #[test]
738 fn par_iter_one() {
739 let step = 5000;
740 let tests = 1_048_576 / step;
741 for n in 0..tests {
742 let n = n * step;
743 let mut set = BitSet::new();
744 set.add(n);
745 assert_eq!(set.par_iter().count(), 1);
746 }
747 let mut set = BitSet::new();
748 set.add(1_048_576 - 1);
749 assert_eq!(set.par_iter().count(), 1);
750 }
751
752 #[test]
753 fn par_iter_random_add() {
754 use rand::prelude::*;
755 use std::collections::HashSet;
756 use std::sync::{Arc, Mutex};
757
758 let mut set = BitSet::new();
759 let mut check_set = HashSet::new();
760 let mut rng = thread_rng();
761 let limit = 1_048_576;
762 for _ in 0..(limit / 10) {
763 let index = rng.gen_range(0, limit);
764 set.add(index);
765 check_set.insert(index);
766 }
767 let check_set = Arc::new(Mutex::new(check_set));
768 let missing_set = Arc::new(Mutex::new(HashSet::new()));
769 set.par_iter().for_each(|n| {
770 let check_set = check_set.clone();
771 let missing_set = missing_set.clone();
772 let mut check = check_set.lock().unwrap();
773 if !check.remove(&n) {
774 let mut missing = missing_set.lock().unwrap();
775 missing.insert(n);
776 }
777 });
778 let check_set = check_set.lock().unwrap();
779 let missing_set = missing_set.lock().unwrap();
780 if !check_set.is_empty() && !missing_set.is_empty() {
781 panic!(
782 "There were values that didn't get iterated: {:?}
783 There were values that got iterated, but that shouldn't be: {:?}",
784 *check_set, *missing_set
785 );
786 }
787 if !check_set.is_empty() {
788 panic!(
789 "There were values that didn't get iterated: {:?}",
790 *check_set
791 );
792 }
793 if !missing_set.is_empty() {
794 panic!(
795 "There were values that got iterated, but that shouldn't be: {:?}",
796 *missing_set
797 );
798 }
799 }
800
801 #[test]
802 fn par_iter_odd_even() {
803 let mut odd = BitSet::new();
804 let mut even = BitSet::new();
805 for i in 0..100_000 {
806 if i % 2 == 1 {
807 odd.add(i);
808 } else {
809 even.add(i);
810 }
811 }
812
813 assert_eq!((&odd).par_iter().count(), 50_000);
814 assert_eq!((&even).par_iter().count(), 50_000);
815 assert_eq!(BitSetAnd(&odd, &even).par_iter().count(), 0);
816 }
817
818 #[test]
819 fn par_iter_clusters() {
820 use std::collections::HashSet;
821 use std::sync::{Arc, Mutex};
822 let mut set = BitSet::new();
823 let mut check_set = HashSet::new();
824 for x in 0..8 {
825 let x = (x * 3) << (::BITS * 2); for y in 0..8 {
827 let y = (y * 3) << (::BITS);
828 for z in 0..8 {
829 let z = z * 2;
830 let index = x + y + z;
831 set.add(index);
832 check_set.insert(index);
833 }
834 }
835 }
836 let check_set = Arc::new(Mutex::new(check_set));
837 let missing_set = Arc::new(Mutex::new(HashSet::new()));
838 set.par_iter().for_each(|n| {
839 let check_set = check_set.clone();
840 let missing_set = missing_set.clone();
841 let mut check = check_set.lock().unwrap();
842 if !check.remove(&n) {
843 let mut missing = missing_set.lock().unwrap();
844 missing.insert(n);
845 }
846 });
847 let check_set = check_set.lock().unwrap();
848 let missing_set = missing_set.lock().unwrap();
849 if !check_set.is_empty() && !missing_set.is_empty() {
850 panic!(
851 "There were values that didn't get iterated: {:?}
852 There were values that got iterated, but that shouldn't be: {:?}",
853 *check_set, *missing_set
854 );
855 }
856 if !check_set.is_empty() {
857 panic!(
858 "There were values that didn't get iterated: {:?}",
859 *check_set
860 );
861 }
862 if !missing_set.is_empty() {
863 panic!(
864 "There were values that got iterated, but that shouldn't be: {:?}",
865 *missing_set
866 );
867 }
868 }
869}