im/nodes/
hamt.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5use std::borrow::Borrow;
6use std::fmt;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::iter::FusedIterator;
9use std::slice::{Iter as SliceIter, IterMut as SliceIterMut};
10use std::{mem, ptr};
11
12use bitmaps::Bits;
13use sized_chunks::sparse_chunk::{Iter as ChunkIter, IterMut as ChunkIterMut, SparseChunk};
14use typenum::{Pow, Unsigned, U2};
15
16use crate::config::HashLevelSize;
17use crate::util::{clone_ref, Pool, PoolClone, PoolDefault, PoolRef, Ref};
18
19pub(crate) type HashWidth = <U2 as Pow<HashLevelSize>>::Output;
20pub(crate) type HashBits = <HashWidth as Bits>::Store; // a uint of HASH_SIZE bits
21pub(crate) const HASH_SHIFT: usize = HashLevelSize::USIZE;
22pub(crate) const HASH_WIDTH: usize = HashWidth::USIZE;
23pub(crate) const HASH_MASK: HashBits = (HASH_WIDTH - 1) as HashBits;
24
25pub(crate) fn hash_key<K: Hash + ?Sized, S: BuildHasher>(bh: &S, key: &K) -> HashBits {
26    let mut hasher = bh.build_hasher();
27    key.hash(&mut hasher);
28    hasher.finish() as HashBits
29}
30
31#[inline]
32fn mask(hash: HashBits, shift: usize) -> HashBits {
33    hash >> shift & HASH_MASK
34}
35
36pub trait HashValue {
37    type Key: Eq;
38
39    fn extract_key(&self) -> &Self::Key;
40    fn ptr_eq(&self, other: &Self) -> bool;
41}
42
43#[derive(Clone)]
44pub(crate) struct Node<A> {
45    data: SparseChunk<Entry<A>, HashWidth>,
46}
47
48#[allow(unsafe_code)]
49impl<A> PoolDefault for Node<A> {
50    #[cfg(feature = "pool")]
51    unsafe fn default_uninit(target: &mut mem::MaybeUninit<Self>) {
52        SparseChunk::default_uninit(
53            target
54                .as_mut_ptr()
55                .cast::<mem::MaybeUninit<SparseChunk<Entry<A>, HashWidth>>>()
56                .as_mut()
57                .unwrap(),
58        )
59    }
60}
61
62#[allow(unsafe_code)]
63impl<A> PoolClone for Node<A>
64where
65    A: Clone,
66{
67    #[cfg(feature = "pool")]
68    unsafe fn clone_uninit(&self, target: &mut mem::MaybeUninit<Self>) {
69        self.data.clone_uninit(
70            target
71                .as_mut_ptr()
72                .cast::<mem::MaybeUninit<SparseChunk<Entry<A>, HashWidth>>>()
73                .as_mut()
74                .unwrap(),
75        )
76    }
77}
78
79#[derive(Clone)]
80pub(crate) struct CollisionNode<A> {
81    hash: HashBits,
82    data: Vec<A>,
83}
84
85pub(crate) enum Entry<A> {
86    Value(A, HashBits),
87    Collision(Ref<CollisionNode<A>>),
88    Node(PoolRef<Node<A>>),
89}
90
91impl<A: Clone> Clone for Entry<A> {
92    fn clone(&self) -> Self {
93        match self {
94            Entry::Value(value, hash) => Entry::Value(value.clone(), *hash),
95            Entry::Collision(coll) => Entry::Collision(coll.clone()),
96            Entry::Node(node) => Entry::Node(node.clone()),
97        }
98    }
99}
100
101impl<A> Entry<A> {
102    fn is_value(&self) -> bool {
103        matches!(self, Entry::Value(_, _))
104    }
105
106    fn unwrap_value(self) -> A {
107        match self {
108            Entry::Value(a, _) => a,
109            _ => panic!("nodes::hamt::Entry::unwrap_value: unwrapped a non-value"),
110        }
111    }
112
113    fn from_node(pool: &Pool<Node<A>>, node: Node<A>) -> Self {
114        Entry::Node(PoolRef::new(pool, node))
115    }
116}
117
118impl<A> From<CollisionNode<A>> for Entry<A> {
119    fn from(node: CollisionNode<A>) -> Self {
120        Entry::Collision(Ref::new(node))
121    }
122}
123
124impl<A> Default for Node<A> {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl<A> Node<A> {
131    #[inline]
132    pub(crate) fn new() -> Self {
133        Node {
134            data: SparseChunk::new(),
135        }
136    }
137
138    #[inline]
139    fn len(&self) -> usize {
140        self.data.len()
141    }
142
143    #[inline]
144    pub(crate) fn unit(index: usize, value: Entry<A>) -> Self {
145        Node {
146            data: SparseChunk::unit(index, value),
147        }
148    }
149
150    #[inline]
151    pub(crate) fn pair(index1: usize, value1: Entry<A>, index2: usize, value2: Entry<A>) -> Self {
152        Node {
153            data: SparseChunk::pair(index1, value1, index2, value2),
154        }
155    }
156
157    #[inline]
158    pub(crate) fn single_child(pool: &Pool<Node<A>>, index: usize, node: Self) -> Self {
159        Node {
160            data: SparseChunk::unit(index, Entry::from_node(pool, node)),
161        }
162    }
163
164    fn pop(&mut self) -> Entry<A> {
165        self.data.pop().unwrap()
166    }
167}
168
169impl<A: HashValue> Node<A> {
170    fn merge_values(
171        pool: &Pool<Node<A>>,
172        value1: A,
173        hash1: HashBits,
174        value2: A,
175        hash2: HashBits,
176        shift: usize,
177    ) -> Self {
178        let index1 = mask(hash1, shift) as usize;
179        let index2 = mask(hash2, shift) as usize;
180        if index1 != index2 {
181            // Both values fit on the same level.
182            Node::pair(
183                index1,
184                Entry::Value(value1, hash1),
185                index2,
186                Entry::Value(value2, hash2),
187            )
188        } else if shift + HASH_SHIFT >= HASH_WIDTH {
189            // If we're at the bottom, we've got a collision.
190            Node::unit(
191                index1,
192                Entry::from(CollisionNode::new(hash1, value1, value2)),
193            )
194        } else {
195            // Pass the values down a level.
196            let node = Node::merge_values(pool, value1, hash1, value2, hash2, shift + HASH_SHIFT);
197            Node::single_child(pool, index1, node)
198        }
199    }
200
201    pub(crate) fn get<BK>(&self, hash: HashBits, shift: usize, key: &BK) -> Option<&A>
202    where
203        BK: Eq + ?Sized,
204        A::Key: Borrow<BK>,
205    {
206        let index = mask(hash, shift) as usize;
207        if let Some(entry) = self.data.get(index) {
208            match entry {
209                Entry::Value(ref value, _) => {
210                    if key == value.extract_key().borrow() {
211                        Some(value)
212                    } else {
213                        None
214                    }
215                }
216                Entry::Collision(ref coll) => coll.get(key),
217                Entry::Node(ref child) => child.get(hash, shift + HASH_SHIFT, key),
218            }
219        } else {
220            None
221        }
222    }
223
224    pub(crate) fn get_mut<BK>(
225        &mut self,
226        pool: &Pool<Node<A>>,
227        hash: HashBits,
228        shift: usize,
229        key: &BK,
230    ) -> Option<&mut A>
231    where
232        A: Clone,
233        BK: Eq + ?Sized,
234        A::Key: Borrow<BK>,
235    {
236        let index = mask(hash, shift) as usize;
237        if let Some(entry) = self.data.get_mut(index) {
238            match entry {
239                Entry::Value(ref mut value, _) => {
240                    if key == value.extract_key().borrow() {
241                        Some(value)
242                    } else {
243                        None
244                    }
245                }
246                Entry::Collision(ref mut coll_ref) => {
247                    let coll = Ref::make_mut(coll_ref);
248                    coll.get_mut(key)
249                }
250                Entry::Node(ref mut child_ref) => {
251                    let child = PoolRef::make_mut(pool, child_ref);
252                    child.get_mut(pool, hash, shift + HASH_SHIFT, key)
253                }
254            }
255        } else {
256            None
257        }
258    }
259
260    pub(crate) fn insert(
261        &mut self,
262        pool: &Pool<Node<A>>,
263        hash: HashBits,
264        shift: usize,
265        value: A,
266    ) -> Option<A>
267    where
268        A: Clone,
269    {
270        let index = mask(hash, shift) as usize;
271        if let Some(entry) = self.data.get_mut(index) {
272            let mut fallthrough = false;
273            // Value is here
274            match entry {
275                // Update value or create a subtree
276                Entry::Value(ref current, _) => {
277                    if current.extract_key() == value.extract_key() {
278                        // If we have a key match, fall through to the outer
279                        // level where we replace the current value. If we
280                        // don't, fall through to the inner level where we merge
281                        // some nodes.
282                        fallthrough = true;
283                    }
284                }
285                // There's already a collision here.
286                Entry::Collision(ref mut collision) => {
287                    let coll = Ref::make_mut(collision);
288                    return coll.insert(value);
289                }
290                Entry::Node(ref mut child_ref) => {
291                    // Child node
292                    let child = PoolRef::make_mut(pool, child_ref);
293                    return child.insert(pool, hash, shift + HASH_SHIFT, value);
294                }
295            }
296            if !fallthrough {
297                // If we get here, we're looking at a value entry that needs a merge.
298                // We're going to be unsafe and pry it out of the reference, trusting
299                // that we overwrite it with the merged node.
300                #[allow(unsafe_code)]
301                let old_entry = unsafe { ptr::read(entry) };
302                if shift + HASH_SHIFT >= HASH_WIDTH {
303                    // We're at the lowest level, need to set up a collision node.
304                    let coll = CollisionNode::new(hash, old_entry.unwrap_value(), value);
305                    #[allow(unsafe_code)]
306                    unsafe {
307                        ptr::write(entry, Entry::from(coll))
308                    };
309                } else if let Entry::Value(old_value, old_hash) = old_entry {
310                    let node = Node::merge_values(
311                        pool,
312                        old_value,
313                        old_hash,
314                        value,
315                        hash,
316                        shift + HASH_SHIFT,
317                    );
318                    #[allow(unsafe_code)]
319                    unsafe {
320                        ptr::write(entry, Entry::from_node(pool, node))
321                    };
322                } else {
323                    unreachable!()
324                }
325                return None;
326            }
327        }
328        // If we get here, either we found nothing at this index, in which case
329        // we insert a new entry, or we hit a value entry with the same key, in
330        // which case we replace it.
331        self.data
332            .insert(index, Entry::Value(value, hash))
333            .map(Entry::unwrap_value)
334    }
335
336    pub(crate) fn remove<BK>(
337        &mut self,
338        pool: &Pool<Node<A>>,
339        hash: HashBits,
340        shift: usize,
341        key: &BK,
342    ) -> Option<A>
343    where
344        A: Clone,
345        BK: Eq + ?Sized,
346        A::Key: Borrow<BK>,
347    {
348        let index = mask(hash, shift) as usize;
349        let mut new_node = None;
350        let mut removed = None;
351        if let Some(entry) = self.data.get_mut(index) {
352            match entry {
353                Entry::Value(ref value, _) => {
354                    if key != value.extract_key().borrow() {
355                        // Key wasn't in the map.
356                        return None;
357                    } // Otherwise, fall through to the removal.
358                }
359                Entry::Collision(ref mut coll_ref) => {
360                    let coll = Ref::make_mut(coll_ref);
361                    removed = coll.remove(key);
362                    if coll.len() == 1 {
363                        new_node = Some(coll.pop());
364                    } else {
365                        return removed;
366                    }
367                }
368                Entry::Node(ref mut child_ref) => {
369                    let child = PoolRef::make_mut(pool, child_ref);
370                    match child.remove(pool, hash, shift + HASH_SHIFT, key) {
371                        None => {
372                            return None;
373                        }
374                        Some(value) => {
375                            if child.len() == 1
376                                && child.data[child.data.first_index().unwrap()].is_value()
377                            {
378                                // If the child now contains only a single value node,
379                                // pull it up one level and discard the child.
380                                removed = Some(value);
381                                new_node = Some(child.pop());
382                            } else {
383                                return Some(value);
384                            }
385                        }
386                    }
387                }
388            }
389        }
390        if let Some(node) = new_node {
391            self.data.insert(index, node);
392            return removed;
393        }
394        self.data.remove(index).map(Entry::unwrap_value)
395    }
396}
397
398impl<A: HashValue> CollisionNode<A> {
399    fn new(hash: HashBits, value1: A, value2: A) -> Self {
400        CollisionNode {
401            hash,
402            data: vec![value1, value2],
403        }
404    }
405
406    #[inline]
407    fn len(&self) -> usize {
408        self.data.len()
409    }
410
411    fn get<BK>(&self, key: &BK) -> Option<&A>
412    where
413        BK: Eq + ?Sized,
414        A::Key: Borrow<BK>,
415    {
416        for entry in &self.data {
417            if key == entry.extract_key().borrow() {
418                return Some(entry);
419            }
420        }
421        None
422    }
423
424    fn get_mut<BK>(&mut self, key: &BK) -> Option<&mut A>
425    where
426        BK: Eq + ?Sized,
427        A::Key: Borrow<BK>,
428    {
429        for entry in &mut self.data {
430            if key == entry.extract_key().borrow() {
431                return Some(entry);
432            }
433        }
434        None
435    }
436
437    fn insert(&mut self, value: A) -> Option<A> {
438        for item in &mut self.data {
439            if value.extract_key() == item.extract_key() {
440                return Some(mem::replace(item, value));
441            }
442        }
443        self.data.push(value);
444        None
445    }
446
447    fn remove<BK>(&mut self, key: &BK) -> Option<A>
448    where
449        BK: Eq + ?Sized,
450        A::Key: Borrow<BK>,
451    {
452        let mut loc = None;
453        for (index, item) in self.data.iter().enumerate() {
454            if key == item.extract_key().borrow() {
455                loc = Some(index);
456            }
457        }
458        if let Some(index) = loc {
459            Some(self.data.remove(index))
460        } else {
461            None
462        }
463    }
464
465    fn pop(&mut self) -> Entry<A> {
466        Entry::Value(self.data.pop().unwrap(), self.hash)
467    }
468}
469
470// Ref iterator
471
472pub(crate) struct Iter<'a, A> {
473    count: usize,
474    stack: Vec<ChunkIter<'a, Entry<A>, HashWidth>>,
475    current: ChunkIter<'a, Entry<A>, HashWidth>,
476    collision: Option<(HashBits, SliceIter<'a, A>)>,
477}
478
479impl<'a, A> Iter<'a, A>
480where
481    A: 'a,
482{
483    pub(crate) fn new(root: &'a Node<A>, size: usize) -> Self {
484        Iter {
485            count: size,
486            stack: Vec::with_capacity((HASH_WIDTH / HASH_SHIFT) + 1),
487            current: root.data.iter(),
488            collision: None,
489        }
490    }
491}
492
493impl<'a, A> Iterator for Iter<'a, A>
494where
495    A: 'a,
496{
497    type Item = (&'a A, HashBits);
498
499    fn next(&mut self) -> Option<Self::Item> {
500        if self.count == 0 {
501            return None;
502        }
503        if self.collision.is_some() {
504            if let Some((hash, ref mut coll)) = self.collision {
505                match coll.next() {
506                    None => {}
507                    Some(value) => {
508                        self.count -= 1;
509                        return Some((value, hash));
510                    }
511                }
512            }
513            self.collision = None;
514            return self.next();
515        }
516        match self.current.next() {
517            Some(Entry::Value(value, hash)) => {
518                self.count -= 1;
519                Some((value, *hash))
520            }
521            Some(Entry::Node(child)) => {
522                let current = mem::replace(&mut self.current, child.data.iter());
523                self.stack.push(current);
524                self.next()
525            }
526            Some(Entry::Collision(coll)) => {
527                self.collision = Some((coll.hash, coll.data.iter()));
528                self.next()
529            }
530            None => match self.stack.pop() {
531                None => None,
532                Some(iter) => {
533                    self.current = iter;
534                    self.next()
535                }
536            },
537        }
538    }
539
540    fn size_hint(&self) -> (usize, Option<usize>) {
541        (self.count, Some(self.count))
542    }
543}
544
545impl<'a, A> ExactSizeIterator for Iter<'a, A> where A: 'a {}
546
547impl<'a, A> FusedIterator for Iter<'a, A> where A: 'a {}
548
549// Mut ref iterator
550
551pub(crate) struct IterMut<'a, A> {
552    count: usize,
553    pool: Pool<Node<A>>,
554    stack: Vec<ChunkIterMut<'a, Entry<A>, HashWidth>>,
555    current: ChunkIterMut<'a, Entry<A>, HashWidth>,
556    collision: Option<(HashBits, SliceIterMut<'a, A>)>,
557}
558
559impl<'a, A> IterMut<'a, A>
560where
561    A: 'a,
562{
563    pub(crate) fn new(pool: &Pool<Node<A>>, root: &'a mut Node<A>, size: usize) -> Self {
564        IterMut {
565            count: size,
566            pool: pool.clone(),
567            stack: Vec::with_capacity((HASH_WIDTH / HASH_SHIFT) + 1),
568            current: root.data.iter_mut(),
569            collision: None,
570        }
571    }
572}
573
574impl<'a, A> Iterator for IterMut<'a, A>
575where
576    A: Clone + 'a,
577{
578    type Item = (&'a mut A, HashBits);
579
580    fn next(&mut self) -> Option<Self::Item> {
581        if self.count == 0 {
582            return None;
583        }
584        if self.collision.is_some() {
585            if let Some((hash, ref mut coll)) = self.collision {
586                match coll.next() {
587                    None => {}
588                    Some(value) => {
589                        self.count -= 1;
590                        return Some((value, hash));
591                    }
592                }
593            }
594            self.collision = None;
595            return self.next();
596        }
597        match self.current.next() {
598            Some(Entry::Value(value, hash)) => {
599                self.count -= 1;
600                Some((value, *hash))
601            }
602            Some(Entry::Node(child_ref)) => {
603                let child = PoolRef::make_mut(&self.pool, child_ref);
604                let current = mem::replace(&mut self.current, child.data.iter_mut());
605                self.stack.push(current);
606                self.next()
607            }
608            Some(Entry::Collision(coll_ref)) => {
609                let coll = Ref::make_mut(coll_ref);
610                self.collision = Some((coll.hash, coll.data.iter_mut()));
611                self.next()
612            }
613            None => match self.stack.pop() {
614                None => None,
615                Some(iter) => {
616                    self.current = iter;
617                    self.next()
618                }
619            },
620        }
621    }
622
623    fn size_hint(&self) -> (usize, Option<usize>) {
624        (self.count, Some(self.count))
625    }
626}
627
628impl<'a, A> ExactSizeIterator for IterMut<'a, A> where A: Clone + 'a {}
629
630impl<'a, A> FusedIterator for IterMut<'a, A> where A: Clone + 'a {}
631
632// Consuming iterator
633
634pub(crate) struct Drain<A>
635where
636    A: HashValue,
637{
638    count: usize,
639    pool: Pool<Node<A>>,
640    stack: Vec<PoolRef<Node<A>>>,
641    current: PoolRef<Node<A>>,
642    collision: Option<CollisionNode<A>>,
643}
644
645impl<A> Drain<A>
646where
647    A: HashValue,
648{
649    pub(crate) fn new(pool: &Pool<Node<A>>, root: PoolRef<Node<A>>, size: usize) -> Self {
650        Drain {
651            count: size,
652            pool: pool.clone(),
653            stack: vec![],
654            current: root,
655            collision: None,
656        }
657    }
658}
659
660impl<A> Iterator for Drain<A>
661where
662    A: HashValue + Clone,
663{
664    type Item = (A, HashBits);
665
666    fn next(&mut self) -> Option<Self::Item> {
667        if self.count == 0 {
668            return None;
669        }
670        if self.collision.is_some() {
671            if let Some(ref mut coll) = self.collision {
672                if let Some(value) = coll.data.pop() {
673                    self.count -= 1;
674                    return Some((value, coll.hash));
675                }
676            }
677            self.collision = None;
678            return self.next();
679        }
680        match PoolRef::make_mut(&self.pool, &mut self.current).data.pop() {
681            Some(Entry::Value(value, hash)) => {
682                self.count -= 1;
683                Some((value, hash))
684            }
685            Some(Entry::Collision(coll_ref)) => {
686                self.collision = Some(clone_ref(coll_ref));
687                self.next()
688            }
689            Some(Entry::Node(child)) => {
690                let parent = mem::replace(&mut self.current, child);
691                self.stack.push(parent);
692                self.next()
693            }
694            None => match self.stack.pop() {
695                None => None,
696                Some(parent) => {
697                    self.current = parent;
698                    self.next()
699                }
700            },
701        }
702    }
703
704    fn size_hint(&self) -> (usize, Option<usize>) {
705        (self.count, Some(self.count))
706    }
707}
708
709impl<A: HashValue> ExactSizeIterator for Drain<A> where A: Clone {}
710
711impl<A: HashValue> FusedIterator for Drain<A> where A: Clone {}
712
713impl<A: HashValue + fmt::Debug> fmt::Debug for Node<A> {
714    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
715        write!(f, "Node[ ")?;
716        for i in self.data.indices() {
717            write!(f, "{}: ", i)?;
718            match &self.data[i] {
719                Entry::Value(v, h) => write!(f, "{:?} :: {}, ", v, h)?,
720                Entry::Collision(c) => write!(f, "Coll{:?} :: {}", c.data, c.hash)?,
721                Entry::Node(n) => write!(f, "{:?}, ", n)?,
722            }
723        }
724        write!(f, " ]")
725    }
726}