brotli/enc/
worker_pool.rs

1#![cfg(feature = "std")]
2
3use alloc::{Allocator, SliceWrapper};
4use core::mem;
5use std;
6// in-place thread create
7use std::sync::RwLock;
8use std::sync::{Arc, Condvar, Mutex};
9
10use crate::enc::backward_references::UnionHasher;
11use crate::enc::fixed_queue::{FixedQueue, MAX_THREADS};
12use crate::enc::threading::{
13    BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti, CompressionThreadResult,
14    InternalOwned, InternalSendAlloc, Joinable, Owned, SendAlloc,
15};
16use crate::enc::{BrotliAlloc, BrotliEncoderParams};
17
18struct JobReply<T: Send + 'static> {
19    result: T,
20    work_id: u64,
21}
22
23struct JobRequest<
24    ReturnValue: Send + 'static,
25    ExtraInput: Send + 'static,
26    Alloc: BrotliAlloc + Send + 'static,
27    U: Send + 'static + Sync,
28> {
29    func: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
30    extra_input: ExtraInput,
31    index: usize,
32    thread_size: usize,
33    data: Arc<RwLock<U>>,
34    alloc: Alloc,
35    work_id: u64,
36}
37
38struct WorkQueue<
39    ReturnValue: Send + 'static,
40    ExtraInput: Send + 'static,
41    Alloc: BrotliAlloc + Send + 'static,
42    U: Send + 'static + Sync,
43> {
44    jobs: FixedQueue<JobRequest<ReturnValue, ExtraInput, Alloc, U>>,
45    results: FixedQueue<JobReply<ReturnValue>>,
46    shutdown: bool,
47    immediate_shutdown: bool,
48    num_in_progress: usize,
49    cur_work_id: u64,
50}
51impl<
52        ReturnValue: Send + 'static,
53        ExtraInput: Send + 'static,
54        Alloc: BrotliAlloc + Send + 'static,
55        U: Send + 'static + Sync,
56    > Default for WorkQueue<ReturnValue, ExtraInput, Alloc, U>
57{
58    fn default() -> Self {
59        WorkQueue {
60            jobs: FixedQueue::default(),
61            results: FixedQueue::default(),
62            num_in_progress: 0,
63            immediate_shutdown: false,
64            shutdown: false,
65            cur_work_id: 0,
66        }
67    }
68}
69
70pub struct GuardedQueue<
71    ReturnValue: Send + 'static,
72    ExtraInput: Send + 'static,
73    Alloc: BrotliAlloc + Send + 'static,
74    U: Send + 'static + Sync,
75>(Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>);
76pub struct WorkerPool<
77    ReturnValue: Send + 'static,
78    ExtraInput: Send + 'static,
79    Alloc: BrotliAlloc + Send + 'static,
80    U: Send + 'static + Sync,
81> {
82    queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
83    join: [Option<std::thread::JoinHandle<()>>; MAX_THREADS],
84}
85
86impl<
87        ReturnValue: Send + 'static,
88        ExtraInput: Send + 'static,
89        Alloc: BrotliAlloc + Send + 'static,
90        U: Send + 'static + Sync,
91    > Drop for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
92{
93    fn drop(&mut self) {
94        {
95            let (lock, cvar) = &*self.queue.0;
96            let mut local_queue = lock.lock().unwrap();
97            local_queue.immediate_shutdown = true;
98            cvar.notify_all();
99        }
100        for thread_handle in self.join.iter_mut() {
101            if let Some(th) = thread_handle.take() {
102                th.join().unwrap();
103            }
104        }
105    }
106}
107impl<
108        ReturnValue: Send + 'static,
109        ExtraInput: Send + 'static,
110        Alloc: BrotliAlloc + Send + 'static,
111        U: Send + 'static + Sync,
112    > WorkerPool<ReturnValue, ExtraInput, Alloc, U>
113{
114    fn do_work(queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>) {
115        loop {
116            let ret;
117            {
118                // need to drop possible job before the final lock is taken,
119                // so refcount of possible_job Arc is 0 by the time the job is delivered
120                // to the caller. We basically need a barrier (the lock) to happen
121                // after the destructor that decrefs possible_job
122                let possible_job;
123                {
124                    let (lock, cvar) = &*queue;
125                    let mut local_queue = lock.lock().unwrap();
126                    if local_queue.immediate_shutdown {
127                        break;
128                    }
129                    possible_job = if let Some(res) = local_queue.jobs.pop() {
130                        cvar.notify_all();
131                        local_queue.num_in_progress += 1;
132                        res
133                    } else if local_queue.shutdown {
134                        break;
135                    } else {
136                        let _lock = cvar.wait(local_queue); // unlock immediately, unfortunately
137                        continue;
138                    };
139                }
140                ret = if let Ok(job_data) = possible_job.data.read() {
141                    JobReply {
142                        result: (possible_job.func)(
143                            possible_job.extra_input,
144                            possible_job.index,
145                            possible_job.thread_size,
146                            &*job_data,
147                            possible_job.alloc,
148                        ),
149                        work_id: possible_job.work_id,
150                    }
151                } else {
152                    break; // poisoned lock
153                };
154            }
155            {
156                let (lock, cvar) = &*queue;
157                let mut local_queue = lock.lock().unwrap();
158                local_queue.num_in_progress -= 1;
159                local_queue.results.push(ret).unwrap();
160                cvar.notify_all();
161            }
162        }
163    }
164    fn _push_job(&mut self, job: JobRequest<ReturnValue, ExtraInput, Alloc, U>) {
165        let (lock, cvar) = &*self.queue.0;
166        let mut local_queue = lock.lock().unwrap();
167        loop {
168            if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
169                < MAX_THREADS
170            {
171                local_queue.jobs.push(job).unwrap();
172                cvar.notify_all();
173                break;
174            }
175            local_queue = cvar.wait(local_queue).unwrap();
176        }
177    }
178    fn _try_push_job(
179        &mut self,
180        job: JobRequest<ReturnValue, ExtraInput, Alloc, U>,
181    ) -> Result<(), JobRequest<ReturnValue, ExtraInput, Alloc, U>> {
182        let (lock, cvar) = &*self.queue.0;
183        let mut local_queue = lock.lock().unwrap();
184        if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
185            < MAX_THREADS
186        {
187            local_queue.jobs.push(job).unwrap();
188            cvar.notify_all();
189            Ok(())
190        } else {
191            Err(job)
192        }
193    }
194    fn start(
195        queue: Arc<(Mutex<WorkQueue<ReturnValue, ExtraInput, Alloc, U>>, Condvar)>,
196    ) -> std::thread::JoinHandle<()> {
197        std::thread::spawn(move || Self::do_work(queue))
198    }
199    pub fn new(num_threads: usize) -> Self {
200        let queue = Arc::new((Mutex::new(WorkQueue::default()), Condvar::new()));
201        WorkerPool {
202            queue: GuardedQueue(queue.clone()),
203            join: [
204                Some(Self::start(queue.clone())),
205                if 1 < num_threads {
206                    Some(Self::start(queue.clone()))
207                } else {
208                    None
209                },
210                if 2 < num_threads {
211                    Some(Self::start(queue.clone()))
212                } else {
213                    None
214                },
215                if 3 < num_threads {
216                    Some(Self::start(queue.clone()))
217                } else {
218                    None
219                },
220                if 4 < num_threads {
221                    Some(Self::start(queue.clone()))
222                } else {
223                    None
224                },
225                if 5 < num_threads {
226                    Some(Self::start(queue.clone()))
227                } else {
228                    None
229                },
230                if 6 < num_threads {
231                    Some(Self::start(queue.clone()))
232                } else {
233                    None
234                },
235                if 7 < num_threads {
236                    Some(Self::start(queue.clone()))
237                } else {
238                    None
239                },
240                if 8 < num_threads {
241                    Some(Self::start(queue.clone()))
242                } else {
243                    None
244                },
245                if 9 < num_threads {
246                    Some(Self::start(queue.clone()))
247                } else {
248                    None
249                },
250                if 10 < num_threads {
251                    Some(Self::start(queue.clone()))
252                } else {
253                    None
254                },
255                if 11 < num_threads {
256                    Some(Self::start(queue.clone()))
257                } else {
258                    None
259                },
260                if 12 < num_threads {
261                    Some(Self::start(queue.clone()))
262                } else {
263                    None
264                },
265                if 13 < num_threads {
266                    Some(Self::start(queue.clone()))
267                } else {
268                    None
269                },
270                if 14 < num_threads {
271                    Some(Self::start(queue.clone()))
272                } else {
273                    None
274                },
275                if 15 < num_threads {
276                    Some(Self::start(queue.clone()))
277                } else {
278                    None
279                },
280            ],
281        }
282    }
283}
284
285pub fn new_work_pool<
286    Alloc: BrotliAlloc + Send + 'static,
287    SliceW: SliceWrapper<u8> + Send + 'static + Sync,
288>(
289    num_threads: usize,
290) -> WorkerPool<
291    CompressionThreadResult<Alloc>,
292    UnionHasher<Alloc>,
293    Alloc,
294    (SliceW, BrotliEncoderParams),
295>
296where
297    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
298    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
299    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
300{
301    WorkerPool::new(num_threads)
302}
303
304pub struct WorkerJoinable<
305    ReturnValue: Send + 'static,
306    ExtraInput: Send + 'static,
307    Alloc: BrotliAlloc + Send + 'static,
308    U: Send + 'static + Sync,
309> {
310    queue: GuardedQueue<ReturnValue, ExtraInput, Alloc, U>,
311    work_id: u64,
312}
313impl<
314        ReturnValue: Send + 'static,
315        ExtraInput: Send + 'static,
316        Alloc: BrotliAlloc + Send + 'static,
317        U: Send + 'static + Sync,
318    > Joinable<ReturnValue, BrotliEncoderThreadError>
319    for WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>
320{
321    fn join(self) -> Result<ReturnValue, BrotliEncoderThreadError> {
322        let (lock, cvar) = &*self.queue.0;
323        let mut local_queue = lock.lock().unwrap();
324        loop {
325            match local_queue
326                .results
327                .remove(|data: &Option<JobReply<ReturnValue>>| {
328                    if let Some(ref item) = *data {
329                        item.work_id == self.work_id
330                    } else {
331                        false
332                    }
333                }) {
334                Some(matched) => return Ok(matched.result),
335                None => local_queue = cvar.wait(local_queue).unwrap(),
336            };
337        }
338    }
339}
340
341impl<
342        ReturnValue: Send + 'static,
343        ExtraInput: Send + 'static,
344        Alloc: BrotliAlloc + Send + 'static,
345        U: Send + 'static + Sync,
346    > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U>
347    for WorkerPool<ReturnValue, ExtraInput, Alloc, U>
348where
349    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
350    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
351    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
352{
353    type FinalJoinHandle = Arc<RwLock<U>>;
354    type JoinHandle = WorkerJoinable<ReturnValue, ExtraInput, Alloc, U>;
355
356    fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
357        std::sync::Arc::<RwLock<U>>::new(RwLock::new(
358            mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
359        ))
360    }
361    fn spawn(
362        &mut self,
363        locked_input: &mut Self::FinalJoinHandle,
364        work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
365        index: usize,
366        num_threads: usize,
367        f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
368    ) {
369        assert!(num_threads <= MAX_THREADS);
370        let (lock, cvar) = &*self.queue.0;
371        let mut local_queue = lock.lock().unwrap();
372        loop {
373            if local_queue.jobs.size() + local_queue.num_in_progress + local_queue.results.size()
374                <= MAX_THREADS
375            {
376                let work_id = local_queue.cur_work_id;
377                local_queue.cur_work_id += 1;
378                let (local_alloc, local_extra) = work.replace_with_default();
379                local_queue
380                    .jobs
381                    .push(JobRequest {
382                        func: f,
383                        extra_input: local_extra,
384                        index,
385                        thread_size: num_threads,
386                        data: locked_input.clone(),
387                        alloc: local_alloc,
388                        work_id,
389                    })
390                    .unwrap();
391                *work = SendAlloc(InternalSendAlloc::Join(WorkerJoinable {
392                    queue: GuardedQueue(self.queue.0.clone()),
393                    work_id,
394                }));
395                cvar.notify_all();
396                break;
397            } else {
398                local_queue = cvar.wait(local_queue).unwrap(); // hope room frees up
399            }
400        }
401    }
402}
403
404pub fn compress_worker_pool<
405    Alloc: BrotliAlloc + Send + 'static,
406    SliceW: SliceWrapper<u8> + Send + 'static + Sync,
407>(
408    params: &BrotliEncoderParams,
409    owned_input: &mut Owned<SliceW>,
410    output: &mut [u8],
411    alloc_per_thread: &mut [SendAlloc<
412        CompressionThreadResult<Alloc>,
413        UnionHasher<Alloc>,
414        Alloc,
415        <WorkerPool<
416            CompressionThreadResult<Alloc>,
417            UnionHasher<Alloc>,
418            Alloc,
419            (SliceW, BrotliEncoderParams),
420        > as BatchSpawnableLite<
421            CompressionThreadResult<Alloc>,
422            UnionHasher<Alloc>,
423            Alloc,
424            (SliceW, BrotliEncoderParams),
425        >>::JoinHandle,
426    >],
427    work_pool: &mut WorkerPool<
428        CompressionThreadResult<Alloc>,
429        UnionHasher<Alloc>,
430        Alloc,
431        (SliceW, BrotliEncoderParams),
432    >,
433) -> Result<usize, BrotliEncoderThreadError>
434where
435    <Alloc as Allocator<u8>>::AllocatedMemory: Send,
436    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
437    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
438{
439    CompressMulti(params, owned_input, output, alloc_per_thread, work_pool)
440}
441
442// out of place thread create