thread_local/
lib.rs

1// Copyright 2017 Amanieu d'Antras
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Per-object thread-local storage
9//!
10//! This library provides the `ThreadLocal` type which allows a separate copy of
11//! an object to be used for each thread. This allows for per-object
12//! thread-local storage, unlike the standard library's `thread_local!` macro
13//! which only allows static thread-local storage.
14//!
15//! Per-thread objects are not destroyed when a thread exits. Instead, objects
16//! are only destroyed when the `ThreadLocal` containing them is destroyed.
17//!
18//! You can also iterate over the thread-local values of all thread in a
19//! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can
20//! only be done if you have mutable access to the `ThreadLocal` object, which
21//! guarantees that you are the only thread currently accessing it.
22//!
23//! Note that since thread IDs are recycled when a thread exits, it is possible
24//! for one thread to retrieve the object of another thread. Since this can only
25//! occur after a thread has exited this does not lead to any race conditions.
26//!
27//! # Examples
28//!
29//! Basic usage of `ThreadLocal`:
30//!
31//! ```rust
32//! use thread_local::ThreadLocal;
33//! let tls: ThreadLocal<u32> = ThreadLocal::new();
34//! assert_eq!(tls.get(), None);
35//! assert_eq!(tls.get_or(|| 5), &5);
36//! assert_eq!(tls.get(), Some(&5));
37//! ```
38//!
39//! Combining thread-local values into a single result:
40//!
41//! ```rust
42//! use thread_local::ThreadLocal;
43//! use std::sync::Arc;
44//! use std::cell::Cell;
45//! use std::thread;
46//!
47//! let tls = Arc::new(ThreadLocal::new());
48//!
49//! // Create a bunch of threads to do stuff
50//! for _ in 0..5 {
51//!     let tls2 = tls.clone();
52//!     thread::spawn(move || {
53//!         // Increment a counter to count some event...
54//!         let cell = tls2.get_or(|| Cell::new(0));
55//!         cell.set(cell.get() + 1);
56//!     }).join().unwrap();
57//! }
58//!
59//! // Once all threads are done, collect the counter values and return the
60//! // sum of all thread-local counter values.
61//! let tls = Arc::try_unwrap(tls).unwrap();
62//! let total = tls.into_iter().fold(0, |x, y| x + y.get());
63//! assert_eq!(total, 5);
64//! ```
65
66#![warn(missing_docs)]
67#![allow(clippy::mutex_atomic)]
68
69mod cached;
70mod thread_id;
71mod unreachable;
72
73#[allow(deprecated)]
74pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
75
76use std::cell::UnsafeCell;
77use std::fmt;
78use std::iter::FusedIterator;
79use std::mem;
80use std::mem::MaybeUninit;
81use std::panic::UnwindSafe;
82use std::ptr;
83use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
84use std::sync::Mutex;
85use thread_id::Thread;
86use unreachable::UncheckedResultExt;
87
88// Use usize::BITS once it has stabilized and the MSRV has been bumped.
89#[cfg(target_pointer_width = "16")]
90const POINTER_WIDTH: u8 = 16;
91#[cfg(target_pointer_width = "32")]
92const POINTER_WIDTH: u8 = 32;
93#[cfg(target_pointer_width = "64")]
94const POINTER_WIDTH: u8 = 64;
95
96/// The total number of buckets stored in each thread local.
97const BUCKETS: usize = (POINTER_WIDTH + 1) as usize;
98
99/// Thread-local variable wrapper
100///
101/// See the [module-level documentation](index.html) for more.
102pub struct ThreadLocal<T: Send> {
103    /// The buckets in the thread local. The nth bucket contains `2^(n-1)`
104    /// elements. Each bucket is lazily allocated.
105    buckets: [AtomicPtr<Entry<T>>; BUCKETS],
106
107    /// The number of values in the thread local. This can be less than the real number of values,
108    /// but is never more.
109    values: AtomicUsize,
110
111    /// Lock used to guard against concurrent modifications. This is taken when
112    /// there is a possibility of allocating a new bucket, which only occurs
113    /// when inserting values.
114    lock: Mutex<()>,
115}
116
117struct Entry<T> {
118    present: AtomicBool,
119    value: UnsafeCell<MaybeUninit<T>>,
120}
121
122impl<T> Drop for Entry<T> {
123    fn drop(&mut self) {
124        unsafe {
125            if *self.present.get_mut() {
126                ptr::drop_in_place((*self.value.get()).as_mut_ptr());
127            }
128        }
129    }
130}
131
132// ThreadLocal is always Sync, even if T isn't
133unsafe impl<T: Send> Sync for ThreadLocal<T> {}
134
135impl<T: Send> Default for ThreadLocal<T> {
136    fn default() -> ThreadLocal<T> {
137        ThreadLocal::new()
138    }
139}
140
141impl<T: Send> Drop for ThreadLocal<T> {
142    fn drop(&mut self) {
143        let mut bucket_size = 1;
144
145        // Free each non-null bucket
146        for (i, bucket) in self.buckets.iter_mut().enumerate() {
147            let bucket_ptr = *bucket.get_mut();
148
149            let this_bucket_size = bucket_size;
150            if i != 0 {
151                bucket_size <<= 1;
152            }
153
154            if bucket_ptr.is_null() {
155                continue;
156            }
157
158            unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket_ptr, this_bucket_size)) };
159        }
160    }
161}
162
163impl<T: Send> ThreadLocal<T> {
164    /// Creates a new empty `ThreadLocal`.
165    pub fn new() -> ThreadLocal<T> {
166        Self::with_capacity(2)
167    }
168
169    /// Creates a new `ThreadLocal` with an initial capacity. If less than the capacity threads
170    /// access the thread local it will never reallocate. The capacity may be rounded up to the
171    /// nearest power of two.
172    pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
173        let allocated_buckets = capacity
174            .checked_sub(1)
175            .map(|c| usize::from(POINTER_WIDTH) - (c.leading_zeros() as usize) + 1)
176            .unwrap_or(0);
177
178        let mut buckets = [ptr::null_mut(); BUCKETS];
179        let mut bucket_size = 1;
180        for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
181            *bucket = allocate_bucket::<T>(bucket_size);
182
183            if i != 0 {
184                bucket_size <<= 1;
185            }
186        }
187
188        ThreadLocal {
189            // Safety: AtomicPtr has the same representation as a pointer and arrays have the same
190            // representation as a sequence of their inner type.
191            buckets: unsafe { mem::transmute(buckets) },
192            values: AtomicUsize::new(0),
193            lock: Mutex::new(()),
194        }
195    }
196
197    /// Returns the element for the current thread, if it exists.
198    pub fn get(&self) -> Option<&T> {
199        let thread = thread_id::get();
200        self.get_inner(thread)
201    }
202
203    /// Returns the element for the current thread, or creates it if it doesn't
204    /// exist.
205    pub fn get_or<F>(&self, create: F) -> &T
206    where
207        F: FnOnce() -> T,
208    {
209        unsafe {
210            self.get_or_try(|| Ok::<T, ()>(create()))
211                .unchecked_unwrap_ok()
212        }
213    }
214
215    /// Returns the element for the current thread, or creates it if it doesn't
216    /// exist. If `create` fails, that error is returned and no element is
217    /// added.
218    pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
219    where
220        F: FnOnce() -> Result<T, E>,
221    {
222        let thread = thread_id::get();
223        match self.get_inner(thread) {
224            Some(x) => Ok(x),
225            None => Ok(self.insert(thread, create()?)),
226        }
227    }
228
229    fn get_inner(&self, thread: Thread) -> Option<&T> {
230        let bucket_ptr =
231            unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
232        if bucket_ptr.is_null() {
233            return None;
234        }
235        unsafe {
236            let entry = &*bucket_ptr.add(thread.index);
237            // Read without atomic operations as only this thread can set the value.
238            if (&entry.present as *const _ as *const bool).read() {
239                Some(&*(&*entry.value.get()).as_ptr())
240            } else {
241                None
242            }
243        }
244    }
245
246    #[cold]
247    fn insert(&self, thread: Thread, data: T) -> &T {
248        // Lock the Mutex to ensure only a single thread is allocating buckets at once
249        let _guard = self.lock.lock().unwrap();
250
251        let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
252
253        let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
254        let bucket_ptr = if bucket_ptr.is_null() {
255            // Allocate a new bucket
256            let bucket_ptr = allocate_bucket(thread.bucket_size);
257            bucket_atomic_ptr.store(bucket_ptr, Ordering::Release);
258            bucket_ptr
259        } else {
260            bucket_ptr
261        };
262
263        drop(_guard);
264
265        // Insert the new element into the bucket
266        let entry = unsafe { &*bucket_ptr.add(thread.index) };
267        let value_ptr = entry.value.get();
268        unsafe { value_ptr.write(MaybeUninit::new(data)) };
269        entry.present.store(true, Ordering::Release);
270
271        self.values.fetch_add(1, Ordering::Release);
272
273        unsafe { &*(&*value_ptr).as_ptr() }
274    }
275
276    /// Returns an iterator over the local values of all threads in unspecified
277    /// order.
278    ///
279    /// This call can be done safely, as `T` is required to implement [`Sync`].
280    pub fn iter(&self) -> Iter<'_, T>
281    where
282        T: Sync,
283    {
284        Iter {
285            thread_local: self,
286            raw: RawIter::new(),
287        }
288    }
289
290    /// Returns a mutable iterator over the local values of all threads in
291    /// unspecified order.
292    ///
293    /// Since this call borrows the `ThreadLocal` mutably, this operation can
294    /// be done safely---the mutable borrow statically guarantees no other
295    /// threads are currently accessing their associated values.
296    pub fn iter_mut(&mut self) -> IterMut<T> {
297        IterMut {
298            thread_local: self,
299            raw: RawIter::new(),
300        }
301    }
302
303    /// Removes all thread-specific values from the `ThreadLocal`, effectively
304    /// reseting it to its original state.
305    ///
306    /// Since this call borrows the `ThreadLocal` mutably, this operation can
307    /// be done safely---the mutable borrow statically guarantees no other
308    /// threads are currently accessing their associated values.
309    pub fn clear(&mut self) {
310        *self = ThreadLocal::new();
311    }
312}
313
314impl<T: Send> IntoIterator for ThreadLocal<T> {
315    type Item = T;
316    type IntoIter = IntoIter<T>;
317
318    fn into_iter(self) -> IntoIter<T> {
319        IntoIter {
320            thread_local: self,
321            raw: RawIter::new(),
322        }
323    }
324}
325
326impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
327    type Item = &'a T;
328    type IntoIter = Iter<'a, T>;
329
330    fn into_iter(self) -> Self::IntoIter {
331        self.iter()
332    }
333}
334
335impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal<T> {
336    type Item = &'a mut T;
337    type IntoIter = IterMut<'a, T>;
338
339    fn into_iter(self) -> IterMut<'a, T> {
340        self.iter_mut()
341    }
342}
343
344impl<T: Send + Default> ThreadLocal<T> {
345    /// Returns the element for the current thread, or creates a default one if
346    /// it doesn't exist.
347    pub fn get_or_default(&self) -> &T {
348        self.get_or(Default::default)
349    }
350}
351
352impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
353    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
354        write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
355    }
356}
357
358impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
359
360#[derive(Debug)]
361struct RawIter {
362    yielded: usize,
363    bucket: usize,
364    bucket_size: usize,
365    index: usize,
366}
367impl RawIter {
368    #[inline]
369    fn new() -> Self {
370        Self {
371            yielded: 0,
372            bucket: 0,
373            bucket_size: 1,
374            index: 0,
375        }
376    }
377
378    fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
379        while self.bucket < BUCKETS {
380            let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
381            let bucket = bucket.load(Ordering::Acquire);
382
383            if !bucket.is_null() {
384                while self.index < self.bucket_size {
385                    let entry = unsafe { &*bucket.add(self.index) };
386                    self.index += 1;
387                    if entry.present.load(Ordering::Acquire) {
388                        self.yielded += 1;
389                        return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
390                    }
391                }
392            }
393
394            self.next_bucket();
395        }
396        None
397    }
398    fn next_mut<'a, T: Send>(
399        &mut self,
400        thread_local: &'a mut ThreadLocal<T>,
401    ) -> Option<&'a mut Entry<T>> {
402        if *thread_local.values.get_mut() == self.yielded {
403            return None;
404        }
405
406        loop {
407            let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
408            let bucket = *bucket.get_mut();
409
410            if !bucket.is_null() {
411                while self.index < self.bucket_size {
412                    let entry = unsafe { &mut *bucket.add(self.index) };
413                    self.index += 1;
414                    if *entry.present.get_mut() {
415                        self.yielded += 1;
416                        return Some(entry);
417                    }
418                }
419            }
420
421            self.next_bucket();
422        }
423    }
424
425    #[inline]
426    fn next_bucket(&mut self) {
427        if self.bucket != 0 {
428            self.bucket_size <<= 1;
429        }
430        self.bucket += 1;
431        self.index = 0;
432    }
433
434    fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
435        let total = thread_local.values.load(Ordering::Acquire);
436        (total - self.yielded, None)
437    }
438    fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
439        let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
440        let remaining = total - self.yielded;
441        (remaining, Some(remaining))
442    }
443}
444
445/// Iterator over the contents of a `ThreadLocal`.
446#[derive(Debug)]
447pub struct Iter<'a, T: Send + Sync> {
448    thread_local: &'a ThreadLocal<T>,
449    raw: RawIter,
450}
451
452impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
453    type Item = &'a T;
454    fn next(&mut self) -> Option<Self::Item> {
455        self.raw.next(self.thread_local)
456    }
457    fn size_hint(&self) -> (usize, Option<usize>) {
458        self.raw.size_hint(self.thread_local)
459    }
460}
461impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
462
463/// Mutable iterator over the contents of a `ThreadLocal`.
464pub struct IterMut<'a, T: Send> {
465    thread_local: &'a mut ThreadLocal<T>,
466    raw: RawIter,
467}
468
469impl<'a, T: Send> Iterator for IterMut<'a, T> {
470    type Item = &'a mut T;
471    fn next(&mut self) -> Option<&'a mut T> {
472        self.raw
473            .next_mut(self.thread_local)
474            .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
475    }
476    fn size_hint(&self) -> (usize, Option<usize>) {
477        self.raw.size_hint_frozen(self.thread_local)
478    }
479}
480
481impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
482impl<T: Send> FusedIterator for IterMut<'_, T> {}
483
484// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
485// this thread's value that potentially aliases with a mutable reference we have given out.
486impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
487    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488        f.debug_struct("IterMut").field("raw", &self.raw).finish()
489    }
490}
491
492/// An iterator that moves out of a `ThreadLocal`.
493#[derive(Debug)]
494pub struct IntoIter<T: Send> {
495    thread_local: ThreadLocal<T>,
496    raw: RawIter,
497}
498
499impl<T: Send> Iterator for IntoIter<T> {
500    type Item = T;
501    fn next(&mut self) -> Option<T> {
502        self.raw.next_mut(&mut self.thread_local).map(|entry| {
503            *entry.present.get_mut() = false;
504            unsafe {
505                std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
506            }
507        })
508    }
509    fn size_hint(&self) -> (usize, Option<usize>) {
510        self.raw.size_hint_frozen(&self.thread_local)
511    }
512}
513
514impl<T: Send> ExactSizeIterator for IntoIter<T> {}
515impl<T: Send> FusedIterator for IntoIter<T> {}
516
517fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
518    Box::into_raw(
519        (0..size)
520            .map(|_| Entry::<T> {
521                present: AtomicBool::new(false),
522                value: UnsafeCell::new(MaybeUninit::uninit()),
523            })
524            .collect(),
525    ) as *mut _
526}
527
528#[cfg(test)]
529mod tests {
530    use super::ThreadLocal;
531    use std::cell::RefCell;
532    use std::sync::atomic::AtomicUsize;
533    use std::sync::atomic::Ordering::Relaxed;
534    use std::sync::Arc;
535    use std::thread;
536
537    fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
538        let count = AtomicUsize::new(0);
539        Arc::new(move || count.fetch_add(1, Relaxed))
540    }
541
542    #[test]
543    fn same_thread() {
544        let create = make_create();
545        let mut tls = ThreadLocal::new();
546        assert_eq!(None, tls.get());
547        assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
548        assert_eq!(0, *tls.get_or(|| create()));
549        assert_eq!(Some(&0), tls.get());
550        assert_eq!(0, *tls.get_or(|| create()));
551        assert_eq!(Some(&0), tls.get());
552        assert_eq!(0, *tls.get_or(|| create()));
553        assert_eq!(Some(&0), tls.get());
554        assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
555        tls.clear();
556        assert_eq!(None, tls.get());
557    }
558
559    #[test]
560    fn different_thread() {
561        let create = make_create();
562        let tls = Arc::new(ThreadLocal::new());
563        assert_eq!(None, tls.get());
564        assert_eq!(0, *tls.get_or(|| create()));
565        assert_eq!(Some(&0), tls.get());
566
567        let tls2 = tls.clone();
568        let create2 = create.clone();
569        thread::spawn(move || {
570            assert_eq!(None, tls2.get());
571            assert_eq!(1, *tls2.get_or(|| create2()));
572            assert_eq!(Some(&1), tls2.get());
573        })
574        .join()
575        .unwrap();
576
577        assert_eq!(Some(&0), tls.get());
578        assert_eq!(0, *tls.get_or(|| create()));
579    }
580
581    #[test]
582    fn iter() {
583        let tls = Arc::new(ThreadLocal::new());
584        tls.get_or(|| Box::new(1));
585
586        let tls2 = tls.clone();
587        thread::spawn(move || {
588            tls2.get_or(|| Box::new(2));
589            let tls3 = tls2.clone();
590            thread::spawn(move || {
591                tls3.get_or(|| Box::new(3));
592            })
593            .join()
594            .unwrap();
595            drop(tls2);
596        })
597        .join()
598        .unwrap();
599
600        let mut tls = Arc::try_unwrap(tls).unwrap();
601
602        let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
603        v.sort_unstable();
604        assert_eq!(vec![1, 2, 3], v);
605
606        let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
607        v.sort_unstable();
608        assert_eq!(vec![1, 2, 3], v);
609
610        let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
611        v.sort_unstable();
612        assert_eq!(vec![1, 2, 3], v);
613    }
614
615    #[test]
616    fn test_drop() {
617        let local = ThreadLocal::new();
618        struct Dropped(Arc<AtomicUsize>);
619        impl Drop for Dropped {
620            fn drop(&mut self) {
621                self.0.fetch_add(1, Relaxed);
622            }
623        }
624
625        let dropped = Arc::new(AtomicUsize::new(0));
626        local.get_or(|| Dropped(dropped.clone()));
627        assert_eq!(dropped.load(Relaxed), 0);
628        drop(local);
629        assert_eq!(dropped.load(Relaxed), 1);
630    }
631
632    #[test]
633    fn is_sync() {
634        fn foo<T: Sync>() {}
635        foo::<ThreadLocal<String>>();
636        foo::<ThreadLocal<RefCell<String>>>();
637    }
638}