timely_communication/allocator/zero_copy/
spill.rs1use std::collections::VecDeque;
20use std::sync::Arc;
21
22use timely_bytes::arc::Bytes;
23
24use super::bytes_exchange::QueueEntry;
25
26pub type SpillPolicyFn = Arc<dyn Fn() -> (Box<dyn SpillPolicy>, Box<dyn SpillPolicy>) + Send + Sync>;
32
33pub trait SpillPolicy: Send {
35 fn apply(&mut self, queue: &mut VecDeque<QueueEntry>);
42}
43
44pub trait BytesSpill: Send {
46 fn spill(&mut self, chunks: &mut Vec<Bytes>, handles: &mut Vec<Box<dyn BytesFetch>>);
53}
54
55pub trait BytesFetch: Send {
57 fn fetch(self: Box<Self>) -> Result<Vec<Bytes>, Box<dyn BytesFetch>>;
61}
62
63pub mod threshold {
65 use super::*;
66
67 pub struct Threshold {
72 strategy: Box<dyn BytesSpill>,
73 pub threshold_bytes: usize,
76 pub head_reserve_bytes: usize,
79 }
80
81 impl Threshold {
82 pub fn new(strategy: Box<dyn BytesSpill>) -> Self {
85 Threshold {
86 strategy,
87 threshold_bytes: 256 << 20, head_reserve_bytes: 64 << 20, }
90 }
91 }
92
93 impl SpillPolicy for Threshold {
94 fn apply(&mut self, queue: &mut VecDeque<QueueEntry>) {
95 let resident: usize = queue.iter().map(|e| match e {
96 QueueEntry::Bytes(b) => b.len(),
97 QueueEntry::Paged(_) => 0,
98 }).sum();
99 if resident <= self.head_reserve_bytes + self.threshold_bytes {
100 return;
101 }
102
103 let head_reserve = self.head_reserve_bytes;
104
105 let mut cumulative: usize = 0;
106 let last_index = queue.len().saturating_sub(1);
107 let mut target_indices: Vec<usize> = Vec::new();
108 let mut target_bytes: Vec<Bytes> = Vec::new();
109 for (i, entry) in queue.iter().enumerate() {
110 if i == last_index { break; }
111 match entry {
112 QueueEntry::Bytes(b) => {
113 if cumulative >= head_reserve {
114 target_indices.push(i);
115 target_bytes.push(b.clone());
116 }
117 cumulative += b.len();
118 }
119 QueueEntry::Paged(_) => {}
120 }
121 }
122
123 if target_bytes.is_empty() {
124 return;
125 }
126
127 let mut handles: Vec<Box<dyn BytesFetch>> = Vec::new();
128 self.strategy.spill(&mut target_bytes, &mut handles);
129 for (i, handle) in target_indices.into_iter().zip(handles) {
131 queue[i] = QueueEntry::Paged(handle);
132 }
133 }
135 }
136}
137
138pub mod prefetch {
140 use super::*;
141
142 pub struct PrefetchPolicy {
147 pub budget: usize,
149 }
150
151 impl PrefetchPolicy {
152 pub fn new(budget: usize) -> Self {
154 PrefetchPolicy { budget }
155 }
156 }
157
158 impl SpillPolicy for PrefetchPolicy {
159 fn apply(&mut self, queue: &mut VecDeque<QueueEntry>) {
160 let mut resident_head = 0;
161 let mut i = 0;
162 while i < queue.len() && resident_head < self.budget {
163 match &queue[i] {
164 QueueEntry::Bytes(b) => {
165 resident_head += b.len();
166 i += 1;
167 }
168 QueueEntry::Paged(_) => {
169 let entry = queue.remove(i).expect("index valid");
170 if let QueueEntry::Paged(h) = entry {
171 match h.fetch() {
172 Ok(fetched) => {
173 let n = fetched.len();
174 for (j, b) in fetched.into_iter().enumerate() {
175 resident_head += b.len();
176 queue.insert(i + j, QueueEntry::Bytes(b));
177 }
178 i += n;
179 }
180 Err(h) => {
181 queue.insert(i, QueueEntry::Paged(h));
183 break;
184 }
185 }
186 }
187 }
188 }
189 }
190 }
191 }
192}
193
194pub use threshold::Threshold;
196pub use prefetch::PrefetchPolicy;
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 fn bytes_of(data: &[u8]) -> Bytes {
203 timely_bytes::arc::BytesMut::from(data.to_vec()).freeze()
204 }
205
206 struct MockStrategy;
207 struct MockHandle { data: Bytes }
208 impl BytesSpill for MockStrategy {
209 fn spill(&mut self, chunks: &mut Vec<Bytes>, handles: &mut Vec<Box<dyn BytesFetch>>) {
210 handles.extend(chunks.drain(..)
211 .map(|b| Box::new(MockHandle { data: b }) as Box<dyn BytesFetch>));
212 }
213 }
214 impl BytesFetch for MockHandle {
215 fn fetch(self: Box<Self>) -> Result<Vec<Bytes>, Box<dyn BytesFetch>> { Ok(vec![self.data]) }
216 }
217
218 #[test]
219 fn eager_policy_moves_middle_entries() {
220 struct EagerPolicy { strategy: Box<dyn BytesSpill> }
221 impl SpillPolicy for EagerPolicy {
222 fn apply(&mut self, queue: &mut VecDeque<QueueEntry>) {
223 let last = queue.len().saturating_sub(1);
224 let mut indices = Vec::new();
225 let mut bytes = Vec::new();
226 for (i, entry) in queue.iter().enumerate() {
227 if i == last { break; }
228 if let QueueEntry::Bytes(b) = entry {
229 indices.push(i);
230 bytes.push(b.clone());
231 }
232 }
233 if bytes.is_empty() { return; }
234 let mut handles = Vec::new();
235 self.strategy.spill(&mut bytes, &mut handles);
236 for (i, h) in indices.into_iter().zip(handles) {
237 queue[i] = QueueEntry::Paged(h);
238 }
239 }
240 }
241
242 let mut p = EagerPolicy { strategy: Box::new(MockStrategy) };
243 let mut queue: VecDeque<QueueEntry> = VecDeque::new();
244 for i in 0..4 {
245 queue.push_back(QueueEntry::Bytes(bytes_of(&[i as u8; 8])));
246 }
247 p.apply(&mut queue);
248 assert!(matches!(queue[0], QueueEntry::Paged(_)));
249 assert!(matches!(queue[1], QueueEntry::Paged(_)));
250 assert!(matches!(queue[2], QueueEntry::Paged(_)));
251 assert!(matches!(queue[3], QueueEntry::Bytes(_)));
252 }
253
254 #[test]
255 fn merge_queue_spill_roundtrip_mock() {
256 use super::super::bytes_exchange::{MergeQueue, BytesPush, BytesPull};
257
258 let head_reserve = 128;
259 let mut tp = Threshold::new(Box::new(MockStrategy));
260 tp.threshold_bytes = 512;
261 tp.head_reserve_bytes = head_reserve;
262 let writer_policy: Box<dyn SpillPolicy> = Box::new(tp);
263 let reader_policy: Box<dyn SpillPolicy> = Box::new(PrefetchPolicy::new(head_reserve));
264
265 let buzzer = crate::buzzer::Buzzer::default();
266 let (mut writer, mut reader) =
267 MergeQueue::new_pair(buzzer, Some(writer_policy), Some(reader_policy));
268
269 let mut expected: Vec<Vec<u8>> = Vec::new();
270 for i in 0..100 {
271 let data = vec![(i % 251) as u8; 64];
272 expected.push(data.clone());
273 writer.extend(Some(bytes_of(&data)));
274 }
275
276 let mut received: Vec<Bytes> = Vec::new();
277 loop {
278 let before = received.len();
279 reader.drain_into(&mut received);
280 if received.len() == before { break; }
281 }
282
283 let expected_flat: Vec<u8> = expected.into_iter().flatten().collect();
284 let received_flat: Vec<u8> = received.iter().flat_map(|b| b.iter().copied()).collect();
285 assert_eq!(expected_flat, received_flat);
286 }
287}