brotli/enc/
multithreading.rs

1#![cfg(feature = "std")]
2
3use alloc::{Allocator, SliceWrapper};
4use core::marker::PhantomData;
5use core::mem;
6use std;
7// in-place thread create
8use std::sync::RwLock;
9use std::thread::JoinHandle;
10
11use crate::enc::backward_references::UnionHasher;
12use crate::enc::threading::{
13    AnyBoxConstructor, BatchSpawnable, BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti,
14    CompressionThreadResult, InternalOwned, InternalSendAlloc, Joinable, Owned, OwnedRetriever,
15    PoisonedThreadError, SendAlloc,
16};
17use crate::enc::{BrotliAlloc, BrotliEncoderParams};
18
19pub struct MultiThreadedJoinable<T: Send + 'static, U: Send + 'static>(
20    JoinHandle<T>,
21    PhantomData<U>,
22);
23
24impl<T: Send + 'static, U: Send + 'static + AnyBoxConstructor> Joinable<T, U>
25    for MultiThreadedJoinable<T, U>
26{
27    fn join(self) -> Result<T, U> {
28        match self.0.join() {
29            Ok(t) => Ok(t),
30            Err(e) => Err(<U as AnyBoxConstructor>::new(e)),
31        }
32    }
33}
34
35pub struct MultiThreadedOwnedRetriever<U: Send + 'static>(RwLock<U>);
36
37impl<U: Send + 'static> OwnedRetriever<U> for MultiThreadedOwnedRetriever<U> {
38    fn view<T, F: FnOnce(&U) -> T>(&self, f: F) -> Result<T, PoisonedThreadError> {
39        match self.0.read() {
40            Ok(u) => Ok(f(&*u)),
41            Err(_) => Err(PoisonedThreadError::default()),
42        }
43    }
44    fn unwrap(self) -> Result<U, PoisonedThreadError> {
45        match self.0.into_inner() {
46            Ok(u) => Ok(u),
47            Err(_) => Err(PoisonedThreadError::default()),
48        }
49    }
50}
51
52#[derive(Default)]
53pub struct MultiThreadedSpawner {}
54
55fn spawn_work<
56    ReturnValue: Send + 'static,
57    ExtraInput: Send + 'static,
58    F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static,
59    Alloc: BrotliAlloc + Send + 'static,
60    U: Send + 'static + Sync,
61>(
62    extra_input: ExtraInput,
63    index: usize,
64    num_threads: usize,
65    locked_input: std::sync::Arc<RwLock<U>>,
66    alloc: Alloc,
67    f: F,
68) -> std::thread::JoinHandle<ReturnValue>
69where
70    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
71{
72    std::thread::spawn(move || {
73        let t: ReturnValue = locked_input
74            .view(move |guard: &U| -> ReturnValue {
75                f(extra_input, index, num_threads, guard, alloc)
76            })
77            .unwrap();
78        t
79    })
80}
81
82impl<
83        ReturnValue: Send + 'static,
84        ExtraInput: Send + 'static,
85        Alloc: BrotliAlloc + Send + 'static,
86        U: Send + 'static + Sync,
87    > BatchSpawnable<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
88where
89    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
90{
91    type JoinHandle = MultiThreadedJoinable<ReturnValue, BrotliEncoderThreadError>;
92    type FinalJoinHandle = std::sync::Arc<RwLock<U>>;
93    fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
94        std::sync::Arc::<RwLock<U>>::new(RwLock::new(
95            mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
96        ))
97    }
98    fn spawn<F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static + Copy>(
99        &mut self,
100        input: &mut Self::FinalJoinHandle,
101        work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
102        index: usize,
103        num_threads: usize,
104        f: F,
105    ) {
106        let (alloc, extra_input) = work.replace_with_default();
107        let ret = spawn_work(extra_input, index, num_threads, input.clone(), alloc, f);
108        *work = SendAlloc(InternalSendAlloc::Join(MultiThreadedJoinable(
109            ret,
110            PhantomData,
111        )));
112    }
113}
114impl<
115        ReturnValue: Send + 'static,
116        ExtraInput: Send + 'static,
117        Alloc: BrotliAlloc + Send + 'static,
118        U: Send + 'static + Sync,
119    > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
120where
121    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
122    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
123    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
124{
125    type JoinHandle =
126        <MultiThreadedSpawner as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::JoinHandle;
127    type FinalJoinHandle = <MultiThreadedSpawner as BatchSpawnable<
128        ReturnValue,
129        ExtraInput,
130        Alloc,
131        U,
132    >>::FinalJoinHandle;
133    fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
134        <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::make_spawner(self, input)
135    }
136    fn spawn(
137        &mut self,
138        handle: &mut Self::FinalJoinHandle,
139        alloc_per_thread: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
140        index: usize,
141        num_threads: usize,
142        f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
143    ) {
144        <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::spawn(
145            self,
146            handle,
147            alloc_per_thread,
148            index,
149            num_threads,
150            f,
151        )
152    }
153}
154
155pub fn compress_multi<
156    Alloc: BrotliAlloc + Send + 'static,
157    SliceW: SliceWrapper<u8> + Send + 'static + Sync,
158>(
159    params: &BrotliEncoderParams,
160    owned_input: &mut Owned<SliceW>,
161    output: &mut [u8],
162    alloc_per_thread: &mut [SendAlloc<
163        CompressionThreadResult<Alloc>,
164        UnionHasher<Alloc>,
165        Alloc,
166        <MultiThreadedSpawner as BatchSpawnable<
167            CompressionThreadResult<Alloc>,
168            UnionHasher<Alloc>,
169            Alloc,
170            SliceW,
171        >>::JoinHandle,
172    >],
173) -> Result<usize, BrotliEncoderThreadError>
174where
175    <Alloc as Allocator<u8>>::AllocatedMemory: Send,
176    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
177    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
178{
179    CompressMulti(
180        params,
181        owned_input,
182        output,
183        alloc_per_thread,
184        &mut MultiThreadedSpawner::default(),
185    )
186}