1#![cfg(feature = "std")]
2
3use alloc::{Allocator, SliceWrapper};
4use core::mem;
5use std;
6use std::sync::RwLock;
8use std::sync::{Arc, Condvar, Mutex};
9
10use crate::enc::backward_references::UnionHasher;
11use crate::enc::fixed_queue::{FixedQueue, MAX_THREADS};
12use crate::enc::threading::{
13 BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti, CompressionThreadResult,
14 InternalOwned, InternalSendAlloc, Joinable, Owned, SendAlloc,
15};
16use crate::enc::{BrotliAlloc, BrotliEncoderParams};
17
18struct JobReply<T: Send + 'static> {
19 result: T,
20 work_id: u64,
21}
22
23struct JobRequest<
24 ReturnValue: Send + 'static,
25 ExtraInput: Send + 'static,
26 Alloc: BrotliAlloc + Send + 'static,
27 U: Send + 'static + Sync,
28> {
29 func: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
30 extra_input: ExtraInput,
31 index: usize,
32 thread_size: usize,
33 data: Arc<RwLock<U>>,
34 alloc: Alloc,
35 work_id: u64,
36}
37
38struct WorkQueue<
39 ReturnValue: Send + 'static,
40 ExtraInput: Send + 'static,
41 Alloc: BrotliAlloc + Send + 'static,
42 U: Send + 'static + Sync,
43> {
44 jobs: FixedQueue<JobRequest<ReturnValue, ExtraInput, Alloc, U>>,
45 results: FixedQueue<JobReply<ReturnValue>>,
46 shutdown: bool,
47 immediate_shutdown: bool,
48 num_in_progress: usize,
49 cur_work_id: u64,
50}
51impl<
52 ReturnValue: Send + 'static,
53 ExtraInput: Send + 'static,
54 Alloc: BrotliAlloc + Send + 'static,
55 U: Send + 'static + Sync,
56 > Default for WorkQueue<ReturnValue, ExtraInput, Alloc, U>
57{
58 fn default() -> Self {
59 WorkQueue {
60 jobs: FixedQueue::default(),
61 results: FixedQueue::default(),
62 num_in_progress: 0,
63 immediate_shutdown: false,
64 shutdown: false,
65 cur_work_id: 0,
66 }
67 }
68}
69
70pub struct GuardedQueue<
71 ReturnValue: Send + 'static,
72 ExtraInput: Send + 'static,
73 Alloc: BrotliAlloc + Send + 'static,
74 U: Send + 'static + Sync,
75>(Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>);
76pub struct WorkerPool<
77 ReturnValue: Send + 'static,
78 ExtraInput: Send + 'static,
79 Alloc: BrotliAlloc + Send + 'static,
80 U: Send + 'static + Sync,
81> {
82 queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
83 join: [Option<std::thread::JoinHandle<()>>; MAX_THREADS],
84}
85
86impl<
87 ReturnValue: Send + 'static,
88 ExtraInput: Send + 'static,
89 Alloc: BrotliAlloc + Send + 'static,
90 U: Send + 'static + Sync,
91 > Drop for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
92{
93 fn drop(&mut self) {
94 {
95 let (lock, cvar) = &*self.queue.0;
96 let mut local_queue = lock.lock().unwrap();
97 local_queue.immediate_shutdown = true;
98 cvar.notify_all();
99 }
100 for thread_handle in self.join.iter_mut() {
101 if let Some(th) = thread_handle.take() {
102 th.join().unwrap();
103 }
104 }
105 }
106}
107impl<
108 ReturnValue: Send + 'static,
109 ExtraInput: Send + 'static,
110 Alloc: BrotliAlloc + Send + 'static,
111 U: Send + 'static + Sync,
112 > WorkerPool<ReturnValue, ExtraInput, Alloc, U>
113{
114 fn do_work(queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>) {
115 loop {
116 let ret;
117 {
118 let possible_job;
123 {
124 let (lock, cvar) = &*queue;
125 let mut local_queue = lock.lock().unwrap();
126 if local_queue.immediate_shutdown {
127 break;
128 }
129 possible_job = if let Some(res) = local_queue.jobs.pop() {
130 cvar.notify_all();
131 local_queue.num_in_progress += 1;
132 res
133 } else if local_queue.shutdown {
134 break;
135 } else {
136 let _lock = cvar.wait(local_queue); continue;
138 };
139 }
140 ret = if let Ok(job_data) = possible_job.data.read() {
141 JobReply {
142 result: (possible_job.func)(
143 possible_job.extra_input,
144 possible_job.index,
145 possible_job.thread_size,
146 &*job_data,
147 possible_job.alloc,
148 ),
149 work_id: possible_job.work_id,
150 }
151 } else {
152 break; };
154 }
155 {
156 let (lock, cvar) = &*queue;
157 let mut local_queue = lock.lock().unwrap();
158 local_queue.num_in_progress -= 1;
159 local_queue.results.push(ret).unwrap();
160 cvar.notify_all();
161 }
162 }
163 }
164 fn _push_job(&mut self, job: JobRequest<ReturnValue, ExtraInput, Alloc, U>) {
165 let (lock, cvar) = &*self.queue.0;
166 let mut local_queue = lock.lock().unwrap();
167 loop {
168 if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
169 < MAX_THREADS
170 {
171 local_queue.jobs.push(job).unwrap();
172 cvar.notify_all();
173 break;
174 }
175 local_queue = cvar.wait(local_queue).unwrap();
176 }
177 }
178 fn _try_push_job(
179 &mut self,
180 job: JobRequest<ReturnValue, ExtraInput, Alloc, U>,
181 ) -> Result<(), JobRequest<ReturnValue, ExtraInput, Alloc, U>> {
182 let (lock, cvar) = &*self.queue.0;
183 let mut local_queue = lock.lock().unwrap();
184 if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
185 < MAX_THREADS
186 {
187 local_queue.jobs.push(job).unwrap();
188 cvar.notify_all();
189 Ok(())
190 } else {
191 Err(job)
192 }
193 }
194 fn start(
195 queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>,
196 ) -> std::thread::JoinHandle<()> {
197 std::thread::spawn(move || Self::do_work(queue))
198 }
199 pub fn new(num_threads: usize) -> Self {
200 let queue = Arc::new((Mutex::new(WorkQueue::default()), Condvar::new()));
201 WorkerPool {
202 queue: GuardedQueue(queue.clone()),
203 join: [
204 Some(Self::start(queue.clone())),
205 if 1 < num_threads {
206 Some(Self::start(queue.clone()))
207 } else {
208 None
209 },
210 if 2 < num_threads {
211 Some(Self::start(queue.clone()))
212 } else {
213 None
214 },
215 if 3 < num_threads {
216 Some(Self::start(queue.clone()))
217 } else {
218 None
219 },
220 if 4 < num_threads {
221 Some(Self::start(queue.clone()))
222 } else {
223 None
224 },
225 if 5 < num_threads {
226 Some(Self::start(queue.clone()))
227 } else {
228 None
229 },
230 if 6 < num_threads {
231 Some(Self::start(queue.clone()))
232 } else {
233 None
234 },
235 if 7 < num_threads {
236 Some(Self::start(queue.clone()))
237 } else {
238 None
239 },
240 if 8 < num_threads {
241 Some(Self::start(queue.clone()))
242 } else {
243 None
244 },
245 if 9 < num_threads {
246 Some(Self::start(queue.clone()))
247 } else {
248 None
249 },
250 if 10 < num_threads {
251 Some(Self::start(queue.clone()))
252 } else {
253 None
254 },
255 if 11 < num_threads {
256 Some(Self::start(queue.clone()))
257 } else {
258 None
259 },
260 if 12 < num_threads {
261 Some(Self::start(queue.clone()))
262 } else {
263 None
264 },
265 if 13 < num_threads {
266 Some(Self::start(queue.clone()))
267 } else {
268 None
269 },
270 if 14 < num_threads {
271 Some(Self::start(queue.clone()))
272 } else {
273 None
274 },
275 if 15 < num_threads {
276 Some(Self::start(queue.clone()))
277 } else {
278 None
279 },
280 ],
281 }
282 }
283}
284
285pub fn new_work_pool<
286 Alloc: BrotliAlloc + Send + 'static,
287 SliceW: SliceWrapper<u8> + Send + 'static + Sync,
288>(
289 num_threads: usize,
290) -> WorkerPool<
291 CompressionThreadResult<Alloc>,
292 UnionHasher<Alloc>,
293 Alloc,
294 (SliceW, BrotliEncoderParams),
295>
296where
297 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
298 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
299 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
300{
301 WorkerPool::new(num_threads)
302}
303
304pub struct WorkerJoinable<
305 ReturnValue: Send + 'static,
306 ExtraInput: Send + 'static,
307 Alloc: BrotliAlloc + Send + 'static,
308 U: Send + 'static + Sync,
309> {
310 queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
311 work_id: u64,
312}
313impl<
314 ReturnValue: Send + 'static,
315 ExtraInput: Send + 'static,
316 Alloc: BrotliAlloc + Send + 'static,
317 U: Send + 'static + Sync,
318 > Joinable<ReturnValue, BrotliEncoderThreadError>
319 for WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>
320{
321 fn join(self) -> Result<ReturnValue, BrotliEncoderThreadError> {
322 let (lock, cvar) = &*self.queue.0;
323 let mut local_queue = lock.lock().unwrap();
324 loop {
325 match local_queue
326 .results
327 .remove(|data: &Option<JobReply<ReturnValue>>| {
328 if let Some(ref item) = *data {
329 item.work_id == self.work_id
330 } else {
331 false
332 }
333 }) {
334 Some(matched) => return Ok(matched.result),
335 None => local_queue = cvar.wait(local_queue).unwrap(),
336 };
337 }
338 }
339}
340
341impl<
342 ReturnValue: Send + 'static,
343 ExtraInput: Send + 'static,
344 Alloc: BrotliAlloc + Send + 'static,
345 U: Send + 'static + Sync,
346 > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U>
347 for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
348where
349 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
350 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
351 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
352{
353 type FinalJoinHandle = Arc<RwLock<U>>;
354 type JoinHandle = WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>;
355
356 fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
357 std::sync::Arc::<RwLock<U>>::new(RwLock::new(
358 mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
359 ))
360 }
361 fn spawn(
362 &mut self,
363 locked_input: &mut Self::FinalJoinHandle,
364 work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
365 index: usize,
366 num_threads: usize,
367 f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
368 ) {
369 assert!(num_threads <= MAX_THREADS);
370 let (lock, cvar) = &*self.queue.0;
371 let mut local_queue = lock.lock().unwrap();
372 loop {
373 if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
374 <= MAX_THREADS
375 {
376 let work_id = local_queue.cur_work_id;
377 local_queue.cur_work_id += 1;
378 let (local_alloc, local_extra) = work.replace_with_default();
379 local_queue
380 .jobs
381 .push(JobRequest {
382 func: f,
383 extra_input: local_extra,
384 index,
385 thread_size: num_threads,
386 data: locked_input.clone(),
387 alloc: local_alloc,
388 work_id,
389 })
390 .unwrap();
391 *work = SendAlloc(InternalSendAlloc::Join(WorkerJoinable {
392 queue: GuardedQueue(self.queue.0.clone()),
393 work_id,
394 }));
395 cvar.notify_all();
396 break;
397 } else {
398 local_queue = cvar.wait(local_queue).unwrap(); }
400 }
401 }
402}
403
404pub fn compress_worker_pool<
405 Alloc: BrotliAlloc + Send + 'static,
406 SliceW: SliceWrapper<u8> + Send + 'static + Sync,
407>(
408 params: &BrotliEncoderParams,
409 owned_input: &mut Owned<SliceW>,
410 output: &mut [u8],
411 alloc_per_thread: &mut [SendAlloc<
412 CompressionThreadResult<Alloc>,
413 UnionHasher<Alloc>,
414 Alloc,
415 <WorkerPool<
416 CompressionThreadResult<Alloc>,
417 UnionHasher<Alloc>,
418 Alloc,
419 (SliceW, BrotliEncoderParams),
420 > as BatchSpawnableLite<
421 CompressionThreadResult<Alloc>,
422 UnionHasher<Alloc>,
423 Alloc,
424 (SliceW, BrotliEncoderParams),
425 >>::JoinHandle,
426 >],
427 work_pool: &mut WorkerPool<
428 CompressionThreadResult<Alloc>,
429 UnionHasher<Alloc>,
430 Alloc,
431 (SliceW, BrotliEncoderParams),
432 >,
433) -> Result<usize, BrotliEncoderThreadError>
434where
435 <Alloc as Allocator<u8>>::AllocatedMemory: Send,
436 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
437 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
438{
439 CompressMulti(params, owned_input, output, alloc_per_thread, work_pool)
440}
441
442