1use crate::job::{JobFifo, JobRef, StackJob};
2use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LockLatch, SpinLatch};
3use crate::log::Event::*;
4use crate::log::Logger;
5use crate::sleep::Sleep;
6use crate::unwind;
7use crate::{
8 ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
9};
10use crossbeam_deque::{Injector, Steal, Stealer, Worker};
11use std::any::Any;
12use std::cell::Cell;
13use std::collections::hash_map::DefaultHasher;
14use std::fmt;
15use std::hash::Hasher;
16use std::io;
17use std::mem;
18use std::ptr;
19#[allow(deprecated)]
20use std::sync::atomic::ATOMIC_USIZE_INIT;
21use std::sync::atomic::{AtomicUsize, Ordering};
22use std::sync::{Arc, Once};
23use std::thread;
24use std::usize;
25
26pub struct ThreadBuilder {
29 name: Option<String>,
30 stack_size: Option<usize>,
31 worker: Worker<JobRef>,
32 registry: Arc<Registry>,
33 index: usize,
34}
35
36impl ThreadBuilder {
37 pub fn index(&self) -> usize {
39 self.index
40 }
41
42 pub fn name(&self) -> Option<&str> {
44 self.name.as_ref().map(String::as_str)
45 }
46
47 pub fn stack_size(&self) -> Option<usize> {
49 self.stack_size
50 }
51
52 pub fn run(self) {
55 unsafe { main_loop(self.worker, self.registry, self.index) }
56 }
57}
58
59impl fmt::Debug for ThreadBuilder {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("ThreadBuilder")
62 .field("pool", &self.registry.id())
63 .field("index", &self.index)
64 .field("name", &self.name)
65 .field("stack_size", &self.stack_size)
66 .finish()
67 }
68}
69
70pub trait ThreadSpawn {
75 private_decl! {}
76
77 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
80}
81
82#[derive(Debug, Default)]
87pub struct DefaultSpawn;
88
89impl ThreadSpawn for DefaultSpawn {
90 private_impl! {}
91
92 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
93 let mut b = thread::Builder::new();
94 if let Some(name) = thread.name() {
95 b = b.name(name.to_owned());
96 }
97 if let Some(stack_size) = thread.stack_size() {
98 b = b.stack_size(stack_size);
99 }
100 b.spawn(|| thread.run())?;
101 Ok(())
102 }
103}
104
105#[derive(Debug)]
110pub struct CustomSpawn<F>(F);
111
112impl<F> CustomSpawn<F>
113where
114 F: FnMut(ThreadBuilder) -> io::Result<()>,
115{
116 pub(super) fn new(spawn: F) -> Self {
117 CustomSpawn(spawn)
118 }
119}
120
121impl<F> ThreadSpawn for CustomSpawn<F>
122where
123 F: FnMut(ThreadBuilder) -> io::Result<()>,
124{
125 private_impl! {}
126
127 #[inline]
128 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
129 (self.0)(thread)
130 }
131}
132
133pub(super) struct Registry {
134 logger: Logger,
135 thread_infos: Vec<ThreadInfo>,
136 sleep: Sleep,
137 injected_jobs: Injector<JobRef>,
138 panic_handler: Option<Box<PanicHandler>>,
139 start_handler: Option<Box<StartHandler>>,
140 exit_handler: Option<Box<ExitHandler>>,
141
142 terminate_count: AtomicUsize,
156}
157
158static mut THE_REGISTRY: Option<Arc<Registry>> = None;
162static THE_REGISTRY_SET: Once = Once::new();
163
164pub(super) fn global_registry() -> &'static Arc<Registry> {
168 set_global_registry(|| Registry::new(ThreadPoolBuilder::new()))
169 .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
170 .expect("The global thread pool has not been initialized.")
171}
172
173pub(super) fn init_global_registry<S>(
176 builder: ThreadPoolBuilder<S>,
177) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
178where
179 S: ThreadSpawn,
180{
181 set_global_registry(|| Registry::new(builder))
182}
183
184fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
187where
188 F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
189{
190 let mut result = Err(ThreadPoolBuildError::new(
191 ErrorKind::GlobalPoolAlreadyInitialized,
192 ));
193
194 THE_REGISTRY_SET.call_once(|| {
195 result = registry()
196 .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) })
197 });
198
199 result
200}
201
202struct Terminator<'a>(&'a Arc<Registry>);
203
204impl<'a> Drop for Terminator<'a> {
205 fn drop(&mut self) {
206 self.0.terminate()
207 }
208}
209
210impl Registry {
211 pub(super) fn new<S>(
212 mut builder: ThreadPoolBuilder<S>,
213 ) -> Result<Arc<Self>, ThreadPoolBuildError>
214 where
215 S: ThreadSpawn,
216 {
217 let n_threads = builder.get_num_threads();
218 let breadth_first = builder.get_breadth_first();
219
220 let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
221 .map(|_| {
222 let worker = if breadth_first {
223 Worker::new_fifo()
224 } else {
225 Worker::new_lifo()
226 };
227
228 let stealer = worker.stealer();
229 (worker, stealer)
230 })
231 .unzip();
232
233 let logger = Logger::new(n_threads);
234 let registry = Arc::new(Registry {
235 logger: logger.clone(),
236 thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
237 sleep: Sleep::new(logger, n_threads),
238 injected_jobs: Injector::new(),
239 terminate_count: AtomicUsize::new(1),
240 panic_handler: builder.take_panic_handler(),
241 start_handler: builder.take_start_handler(),
242 exit_handler: builder.take_exit_handler(),
243 });
244
245 let t1000 = Terminator(®istry);
247
248 for (index, worker) in workers.into_iter().enumerate() {
249 let thread = ThreadBuilder {
250 name: builder.get_thread_name(index),
251 stack_size: builder.get_stack_size(),
252 registry: registry.clone(),
253 worker,
254 index,
255 };
256 if let Err(e) = builder.get_spawn_handler().spawn(thread) {
257 return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
258 }
259 }
260
261 mem::forget(t1000);
263
264 Ok(registry.clone())
265 }
266
267 pub(super) fn current() -> Arc<Registry> {
268 unsafe {
269 let worker_thread = WorkerThread::current();
270 if worker_thread.is_null() {
271 global_registry().clone()
272 } else {
273 (*worker_thread).registry.clone()
274 }
275 }
276 }
277
278 pub(super) fn current_num_threads() -> usize {
282 unsafe {
283 let worker_thread = WorkerThread::current();
284 if worker_thread.is_null() {
285 global_registry().num_threads()
286 } else {
287 (*worker_thread).registry.num_threads()
288 }
289 }
290 }
291
292 pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
294 unsafe {
295 let worker = WorkerThread::current().as_ref()?;
296 if worker.registry().id() == self.id() {
297 Some(worker)
298 } else {
299 None
300 }
301 }
302 }
303
304 pub(super) fn id(&self) -> RegistryId {
306 RegistryId {
309 addr: self as *const Self as usize,
310 }
311 }
312
313 #[inline]
314 pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
315 self.logger.log(event)
316 }
317
318 pub(super) fn num_threads(&self) -> usize {
319 self.thread_infos.len()
320 }
321
322 pub(super) fn handle_panic(&self, err: Box<dyn Any + Send>) {
323 match self.panic_handler {
324 Some(ref handler) => {
325 let abort_guard = unwind::AbortIfPanic;
328 handler(err);
329 mem::forget(abort_guard);
330 }
331 None => {
332 let _ = unwind::AbortIfPanic; }
335 }
336 }
337
338 pub(super) fn wait_until_primed(&self) {
343 for info in &self.thread_infos {
344 info.primed.wait();
345 }
346 }
347
348 #[cfg(test)]
351 pub(super) fn wait_until_stopped(&self) {
352 for info in &self.thread_infos {
353 info.stopped.wait();
354 }
355 }
356
357 pub(super) fn inject_or_push(&self, job_ref: JobRef) {
367 let worker_thread = WorkerThread::current();
368 unsafe {
369 if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
370 (*worker_thread).push(job_ref);
371 } else {
372 self.inject(&[job_ref]);
373 }
374 }
375 }
376
377 pub(super) fn inject(&self, injected_jobs: &[JobRef]) {
381 self.log(|| JobsInjected {
382 count: injected_jobs.len(),
383 });
384
385 debug_assert_ne!(
391 self.terminate_count.load(Ordering::Acquire),
392 0,
393 "inject() sees state.terminate as true"
394 );
395
396 let queue_was_empty = self.injected_jobs.is_empty();
397
398 for &job_ref in injected_jobs {
399 self.injected_jobs.push(job_ref);
400 }
401
402 self.sleep
403 .new_injected_jobs(usize::MAX, injected_jobs.len() as u32, queue_was_empty);
404 }
405
406 fn has_injected_job(&self) -> bool {
407 !self.injected_jobs.is_empty()
408 }
409
410 fn pop_injected_job(&self, worker_index: usize) -> Option<JobRef> {
411 loop {
412 match self.injected_jobs.steal() {
413 Steal::Success(job) => {
414 self.log(|| JobUninjected {
415 worker: worker_index,
416 });
417 return Some(job);
418 }
419 Steal::Empty => return None,
420 Steal::Retry => {}
421 }
422 }
423 }
424
425 pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
431 where
432 OP: FnOnce(&WorkerThread, bool) -> R + Send,
433 R: Send,
434 {
435 unsafe {
436 let worker_thread = WorkerThread::current();
437 if worker_thread.is_null() {
438 self.in_worker_cold(op)
439 } else if (*worker_thread).registry().id() != self.id() {
440 self.in_worker_cross(&*worker_thread, op)
441 } else {
442 op(&*worker_thread, false)
446 }
447 }
448 }
449
450 #[cold]
451 unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
452 where
453 OP: FnOnce(&WorkerThread, bool) -> R + Send,
454 R: Send,
455 {
456 thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
457
458 LOCK_LATCH.with(|l| {
459 debug_assert!(WorkerThread::current().is_null());
461 let job = StackJob::new(
462 |injected| {
463 let worker_thread = WorkerThread::current();
464 assert!(injected && !worker_thread.is_null());
465 op(&*worker_thread, true)
466 },
467 l,
468 );
469 self.inject(&[job.as_job_ref()]);
470 job.latch.wait_and_reset(); self.logger.log(|| Flush);
474
475 job.into_result()
476 })
477 }
478
479 #[cold]
480 unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
481 where
482 OP: FnOnce(&WorkerThread, bool) -> R + Send,
483 R: Send,
484 {
485 debug_assert!(current_thread.registry().id() != self.id());
488 let latch = SpinLatch::cross(current_thread);
489 let job = StackJob::new(
490 |injected| {
491 let worker_thread = WorkerThread::current();
492 assert!(injected && !worker_thread.is_null());
493 op(&*worker_thread, true)
494 },
495 latch,
496 );
497 self.inject(&[job.as_job_ref()]);
498 current_thread.wait_until(&job.latch);
499 job.into_result()
500 }
501
502 pub(super) fn increment_terminate_count(&self) {
523 let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
524 debug_assert!(previous != 0, "registry ref count incremented from zero");
525 assert!(
526 previous != std::usize::MAX,
527 "overflow in registry ref count"
528 );
529 }
530
531 pub(super) fn terminate(&self) {
535 if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
536 for (i, thread_info) in self.thread_infos.iter().enumerate() {
537 thread_info.terminate.set_and_tickle_one(self, i);
538 }
539 }
540 }
541
542 pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
544 self.sleep.notify_worker_latch_is_set(target_worker_index);
545 }
546}
547
548#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
549pub(super) struct RegistryId {
550 addr: usize,
551}
552
553struct ThreadInfo {
554 primed: LockLatch,
558
559 stopped: LockLatch,
562
563 terminate: CountLatch,
571
572 stealer: Stealer<JobRef>,
574}
575
576impl ThreadInfo {
577 fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
578 ThreadInfo {
579 primed: LockLatch::new(),
580 stopped: LockLatch::new(),
581 terminate: CountLatch::new(),
582 stealer,
583 }
584 }
585}
586
587pub(super) struct WorkerThread {
591 worker: Worker<JobRef>,
593
594 fifo: JobFifo,
596
597 index: usize,
598
599 rng: XorShift64Star,
601
602 registry: Arc<Registry>,
603}
604
605thread_local! {
611 static WORKER_THREAD_STATE: Cell<*const WorkerThread> = Cell::new(ptr::null());
612}
613
614impl Drop for WorkerThread {
615 fn drop(&mut self) {
616 WORKER_THREAD_STATE.with(|t| {
618 assert!(t.get().eq(&(self as *const _)));
619 t.set(ptr::null());
620 });
621 }
622}
623
624impl WorkerThread {
625 #[inline]
629 pub(super) fn current() -> *const WorkerThread {
630 WORKER_THREAD_STATE.with(Cell::get)
631 }
632
633 unsafe fn set_current(thread: *const WorkerThread) {
636 WORKER_THREAD_STATE.with(|t| {
637 assert!(t.get().is_null());
638 t.set(thread);
639 });
640 }
641
642 #[inline]
644 pub(super) fn registry(&self) -> &Arc<Registry> {
645 &self.registry
646 }
647
648 #[inline]
649 pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
650 self.registry.logger.log(event)
651 }
652
653 #[inline]
655 pub(super) fn index(&self) -> usize {
656 self.index
657 }
658
659 #[inline]
660 pub(super) unsafe fn push(&self, job: JobRef) {
661 self.log(|| JobPushed { worker: self.index });
662 let queue_was_empty = self.worker.is_empty();
663 self.worker.push(job);
664 self.registry
665 .sleep
666 .new_internal_jobs(self.index, 1, queue_was_empty);
667 }
668
669 #[inline]
670 pub(super) unsafe fn push_fifo(&self, job: JobRef) {
671 self.push(self.fifo.push(job));
672 }
673
674 #[inline]
675 pub(super) fn local_deque_is_empty(&self) -> bool {
676 self.worker.is_empty()
677 }
678
679 #[inline]
684 pub(super) unsafe fn take_local_job(&self) -> Option<JobRef> {
685 let popped_job = self.worker.pop();
686
687 if popped_job.is_some() {
688 self.log(|| JobPopped { worker: self.index });
689 }
690
691 popped_job
692 }
693
694 #[inline]
697 pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
698 let latch = latch.as_core_latch();
699 if !latch.probe() {
700 self.wait_until_cold(latch);
701 }
702 }
703
704 #[cold]
705 unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
706 let abort_guard = unwind::AbortIfPanic;
712
713 let mut idle_state = self.registry.sleep.start_looking(self.index, latch);
714 while !latch.probe() {
715 if let Some(job) = self
721 .take_local_job()
722 .or_else(|| self.steal())
723 .or_else(|| self.registry.pop_injected_job(self.index))
724 {
725 self.registry.sleep.work_found(idle_state);
726 self.execute(job);
727 idle_state = self.registry.sleep.start_looking(self.index, latch);
728 } else {
729 self.registry
730 .sleep
731 .no_work_found(&mut idle_state, latch, || self.registry.has_injected_job())
732 }
733 }
734
735 self.registry.sleep.work_found(idle_state);
739
740 self.log(|| ThreadSawLatchSet {
741 worker: self.index,
742 latch_addr: latch.addr(),
743 });
744 mem::forget(abort_guard); }
746
747 #[inline]
748 pub(super) unsafe fn execute(&self, job: JobRef) {
749 job.execute();
750 }
751
752 unsafe fn steal(&self) -> Option<JobRef> {
757 debug_assert!(self.local_deque_is_empty());
759
760 let thread_infos = &self.registry.thread_infos.as_slice();
762 let num_threads = thread_infos.len();
763 if num_threads <= 1 {
764 return None;
765 }
766
767 loop {
768 let mut retry = false;
769 let start = self.rng.next_usize(num_threads);
770 let job = (start..num_threads)
771 .chain(0..start)
772 .filter(move |&i| i != self.index)
773 .find_map(|victim_index| {
774 let victim = &thread_infos[victim_index];
775 match victim.stealer.steal() {
776 Steal::Success(job) => {
777 self.log(|| JobStolen {
778 worker: self.index,
779 victim: victim_index,
780 });
781 Some(job)
782 }
783 Steal::Empty => None,
784 Steal::Retry => {
785 retry = true;
786 None
787 }
788 }
789 });
790 if job.is_some() || !retry {
791 return job;
792 }
793 }
794 }
795}
796
797unsafe fn main_loop(worker: Worker<JobRef>, registry: Arc<Registry>, index: usize) {
800 let worker_thread = &WorkerThread {
801 worker,
802 fifo: JobFifo::new(),
803 index,
804 rng: XorShift64Star::new(),
805 registry: registry.clone(),
806 };
807 WorkerThread::set_current(worker_thread);
808
809 registry.thread_infos[index].primed.set();
811
812 let abort_guard = unwind::AbortIfPanic;
816
817 if let Some(ref handler) = registry.start_handler {
819 let registry = registry.clone();
820 match unwind::halt_unwinding(|| handler(index)) {
821 Ok(()) => {}
822 Err(err) => {
823 registry.handle_panic(err);
824 }
825 }
826 }
827
828 let my_terminate_latch = ®istry.thread_infos[index].terminate;
829 worker_thread.log(|| ThreadStart {
830 worker: index,
831 terminate_addr: my_terminate_latch.as_core_latch().addr(),
832 });
833 worker_thread.wait_until(my_terminate_latch);
834
835 debug_assert!(worker_thread.take_local_job().is_none());
837
838 registry.thread_infos[index].stopped.set();
840
841 mem::forget(abort_guard);
843
844 worker_thread.log(|| ThreadTerminate { worker: index });
845
846 if let Some(ref handler) = registry.exit_handler {
848 let registry = registry.clone();
849 match unwind::halt_unwinding(|| handler(index)) {
850 Ok(()) => {}
851 Err(err) => {
852 registry.handle_panic(err);
853 }
854 }
855 }
857}
858
859pub(super) fn in_worker<OP, R>(op: OP) -> R
865where
866 OP: FnOnce(&WorkerThread, bool) -> R + Send,
867 R: Send,
868{
869 unsafe {
870 let owner_thread = WorkerThread::current();
871 if !owner_thread.is_null() {
872 op(&*owner_thread, false)
876 } else {
877 global_registry().in_worker_cold(op)
878 }
879 }
880}
881
882struct XorShift64Star {
887 state: Cell<u64>,
888}
889
890impl XorShift64Star {
891 fn new() -> Self {
892 let mut seed = 0;
894 while seed == 0 {
895 let mut hasher = DefaultHasher::new();
896 #[allow(deprecated)]
897 static COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
898 hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
899 seed = hasher.finish();
900 }
901
902 XorShift64Star {
903 state: Cell::new(seed),
904 }
905 }
906
907 fn next(&self) -> u64 {
908 let mut x = self.state.get();
909 debug_assert_ne!(x, 0);
910 x ^= x >> 12;
911 x ^= x << 25;
912 x ^= x >> 27;
913 self.state.set(x);
914 x.wrapping_mul(0x2545_f491_4f6c_dd1d)
915 }
916
917 fn next_usize(&self, n: usize) -> usize {
919 (self.next() % n as u64) as usize
920 }
921}