1#![cfg(feature = "std")]
2
3use alloc::{Allocator, SliceWrapper};
4use core::marker::PhantomData;
5use core::mem;
6use std;
7use std::sync::RwLock;
9use std::thread::JoinHandle;
10
11use crate::enc::backward_references::UnionHasher;
12use crate::enc::threading::{
13 AnyBoxConstructor, BatchSpawnable, BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti,
14 CompressionThreadResult, InternalOwned, InternalSendAlloc, Joinable, Owned, OwnedRetriever,
15 PoisonedThreadError, SendAlloc,
16};
17use crate::enc::{BrotliAlloc, BrotliEncoderParams};
18
19pub struct MultiThreadedJoinable<T: Send + 'static, U: Send + 'static>(
20 JoinHandle<T>,
21 PhantomData<U>,
22);
23
24impl<T: Send + 'static, U: Send + 'static + AnyBoxConstructor> Joinable<T, U>
25 for MultiThreadedJoinable<T, U>
26{
27 fn join(self) -> Result<T, U> {
28 match self.0.join() {
29 Ok(t) => Ok(t),
30 Err(e) => Err(<U as AnyBoxConstructor>::new(e)),
31 }
32 }
33}
34
35pub struct MultiThreadedOwnedRetriever<U: Send + 'static>(RwLock<U>);
36
37impl<U: Send + 'static> OwnedRetriever<U> for MultiThreadedOwnedRetriever<U> {
38 fn view<T, F: FnOnce(&U) -> T>(&self, f: F) -> Result<T, PoisonedThreadError> {
39 match self.0.read() {
40 Ok(u) => Ok(f(&*u)),
41 Err(_) => Err(PoisonedThreadError::default()),
42 }
43 }
44 fn unwrap(self) -> Result<U, PoisonedThreadError> {
45 match self.0.into_inner() {
46 Ok(u) => Ok(u),
47 Err(_) => Err(PoisonedThreadError::default()),
48 }
49 }
50}
51
52#[derive(Default)]
53pub struct MultiThreadedSpawner {}
54
55fn spawn_work<
56 ReturnValue: Send + 'static,
57 ExtraInput: Send + 'static,
58 F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static,
59 Alloc: BrotliAlloc + Send + 'static,
60 U: Send + 'static + Sync,
61>(
62 extra_input: ExtraInput,
63 index: usize,
64 num_threads: usize,
65 locked_input: std::sync::Arc<RwLock<U>>,
66 alloc: Alloc,
67 f: F,
68) -> std::thread::JoinHandle<ReturnValue>
69where
70 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
71{
72 std::thread::spawn(move || {
73 let t: ReturnValue = locked_input
74 .view(move |guard: &U| -> ReturnValue {
75 f(extra_input, index, num_threads, guard, alloc)
76 })
77 .unwrap();
78 t
79 })
80}
81
82impl<
83 ReturnValue: Send + 'static,
84 ExtraInput: Send + 'static,
85 Alloc: BrotliAlloc + Send + 'static,
86 U: Send + 'static + Sync,
87 > BatchSpawnable<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
88where
89 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
90{
91 type JoinHandle = MultiThreadedJoinable<ReturnValue, BrotliEncoderThreadError>;
92 type FinalJoinHandle = std::sync::Arc<RwLock<U>>;
93 fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
94 std::sync::Arc::<RwLock<U>>::new(RwLock::new(
95 mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
96 ))
97 }
98 fn spawn<F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static + Copy>(
99 &mut self,
100 input: &mut Self::FinalJoinHandle,
101 work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
102 index: usize,
103 num_threads: usize,
104 f: F,
105 ) {
106 let (alloc, extra_input) = work.replace_with_default();
107 let ret = spawn_work(extra_input, index, num_threads, input.clone(), alloc, f);
108 *work = SendAlloc(InternalSendAlloc::Join(MultiThreadedJoinable(
109 ret,
110 PhantomData,
111 )));
112 }
113}
114impl<
115 ReturnValue: Send + 'static,
116 ExtraInput: Send + 'static,
117 Alloc: BrotliAlloc + Send + 'static,
118 U: Send + 'static + Sync,
119 > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
120where
121 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
122 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
123 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
124{
125 type JoinHandle =
126 <MultiThreadedSpawner as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::JoinHandle;
127 type FinalJoinHandle = <MultiThreadedSpawner as BatchSpawnable<
128 ReturnValue,
129 ExtraInput,
130 Alloc,
131 U,
132 >>::FinalJoinHandle;
133 fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
134 <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::make_spawner(self, input)
135 }
136 fn spawn(
137 &mut self,
138 handle: &mut Self::FinalJoinHandle,
139 alloc_per_thread: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
140 index: usize,
141 num_threads: usize,
142 f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
143 ) {
144 <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::spawn(
145 self,
146 handle,
147 alloc_per_thread,
148 index,
149 num_threads,
150 f,
151 )
152 }
153}
154
155pub fn compress_multi<
156 Alloc: BrotliAlloc + Send + 'static,
157 SliceW: SliceWrapper<u8> + Send + 'static + Sync,
158>(
159 params: &BrotliEncoderParams,
160 owned_input: &mut Owned<SliceW>,
161 output: &mut [u8],
162 alloc_per_thread: &mut [SendAlloc<
163 CompressionThreadResult<Alloc>,
164 UnionHasher<Alloc>,
165 Alloc,
166 <MultiThreadedSpawner as BatchSpawnable<
167 CompressionThreadResult<Alloc>,
168 UnionHasher<Alloc>,
169 Alloc,
170 SliceW,
171 >>::JoinHandle,
172 >],
173) -> Result<usize, BrotliEncoderThreadError>
174where
175 <Alloc as Allocator<u8>>::AllocatedMemory: Send,
176 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
177 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
178{
179 CompressMulti(
180 params,
181 owned_input,
182 output,
183 alloc_per_thread,
184 &mut MultiThreadedSpawner::default(),
185 )
186}