1use futures::{stream::Fuse, Stream, StreamExt};
4use hashbrown::{hash_map::RawEntryMut, HashMap};
5use pin_project::pin_project;
6use std::{
7 collections::HashSet,
8 hash::Hash,
9 pin::Pin,
10 task::{Context, Poll},
11 time::Duration,
12};
13use tokio::time::Instant;
14use tokio_util::time::delay_queue::{self, DelayQueue};
15
16#[derive(Debug)]
18pub struct ScheduleRequest<T> {
19 pub message: T,
20 pub run_at: Instant,
21}
22
23struct ScheduledEntry {
25 run_at: Instant,
26 queue_key: delay_queue::Key,
27}
28
29#[pin_project(project = SchedulerProj)]
30pub struct Scheduler<T, R> {
31 queue: DelayQueue<T>,
39 scheduled: HashMap<T, ScheduledEntry>,
43 pending: HashSet<T>,
45 #[pin]
47 requests: Fuse<R>,
48 debounce: Duration,
54}
55
56impl<T, R: Stream> Scheduler<T, R> {
57 fn new(requests: R, debounce: Duration) -> Self {
58 Self {
59 queue: DelayQueue::new(),
60 scheduled: HashMap::new(),
61 pending: HashSet::new(),
62 requests: requests.fuse(),
63 debounce,
64 }
65 }
66}
67
68impl<T: Hash + Eq + Clone, R> SchedulerProj<'_, T, R> {
69 fn schedule_message(&mut self, request: ScheduleRequest<T>) {
73 if self.pending.contains(&request.message) {
74 return;
76 }
77 let next_time = request
78 .run_at
79 .checked_add(*self.debounce)
80 .map_or_else(max_schedule_time, |time|
81 time.min(max_schedule_time()));
83 match self.scheduled.raw_entry_mut().from_key(&request.message) {
84 RawEntryMut::Occupied(mut old_entry) if old_entry.get().run_at >= request.run_at => {
88 let entry = old_entry.get_mut();
90 self.queue.reset_at(&entry.queue_key, next_time);
91 entry.run_at = next_time;
92 old_entry.insert_key(request.message);
93 }
94 RawEntryMut::Occupied(_old_entry) => {
95 }
97 RawEntryMut::Vacant(entry) => {
98 let message = request.message.clone();
100 entry.insert(request.message, ScheduledEntry {
101 run_at: next_time,
102 queue_key: self.queue.insert_at(message, next_time),
103 });
104 }
105 }
106 }
107
108 fn poll_pop_queue_message(
110 &mut self,
111 cx: &mut Context<'_>,
112 can_take_message: impl Fn(&T) -> bool,
113 ) -> Poll<T> {
114 if let Some(msg) = self.pending.iter().find(|msg| can_take_message(*msg)).cloned() {
115 return Poll::Ready(self.pending.take(&msg).unwrap());
116 }
117
118 loop {
119 match self.queue.poll_expired(cx) {
120 Poll::Ready(Some(msg)) => {
121 let msg = msg.into_inner();
122 let (msg, _) = self.scheduled.remove_entry(&msg).expect(
123 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
124 );
125 if can_take_message(&msg) {
126 break Poll::Ready(msg);
127 }
128 self.pending.insert(msg);
129 }
130 Poll::Ready(None) | Poll::Pending => break Poll::Pending,
131 }
132 }
133 }
134
135 pub fn pop_queue_message_into_pending(&mut self, cx: &mut Context<'_>) {
137 while let Poll::Ready(Some(msg)) = self.queue.poll_expired(cx) {
138 let msg = msg.into_inner();
139 self.scheduled.remove_entry(&msg).expect(
140 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
141 );
142 self.pending.insert(msg);
143 }
144 }
145}
146
147pub struct Hold<'a, T, R> {
149 scheduler: Pin<&'a mut Scheduler<T, R>>,
150}
151
152impl<T, R> Stream for Hold<'_, T, R>
153where
154 T: Eq + Hash + Clone,
155 R: Stream<Item = ScheduleRequest<T>>,
156{
157 type Item = T;
158
159 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
160 let this = self.get_mut();
161 let mut scheduler = this.scheduler.as_mut().project();
162
163 loop {
164 match scheduler.requests.as_mut().poll_next(cx) {
165 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
166 Poll::Ready(None) => return Poll::Ready(None),
167 Poll::Pending => break,
168 }
169 }
170
171 scheduler.pop_queue_message_into_pending(cx);
172 Poll::Pending
173 }
174}
175
176pub struct HoldUnless<'a, T, R, C> {
178 scheduler: Pin<&'a mut Scheduler<T, R>>,
179 can_take_message: C,
180}
181
182impl<T, R, C> Stream for HoldUnless<'_, T, R, C>
183where
184 T: Eq + Hash + Clone,
185 R: Stream<Item = ScheduleRequest<T>>,
186 C: Fn(&T) -> bool + Unpin,
187{
188 type Item = T;
189
190 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
191 let this = self.get_mut();
192 let can_take_message = &this.can_take_message;
193 let mut scheduler = this.scheduler.as_mut().project();
194
195 loop {
196 match scheduler.requests.as_mut().poll_next(cx) {
197 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
198 Poll::Ready(None) => return Poll::Ready(None),
199 Poll::Pending => break,
200 }
201 }
202
203 match scheduler.poll_pop_queue_message(cx, can_take_message) {
204 Poll::Ready(expired) => Poll::Ready(Some(expired)),
205 Poll::Pending => Poll::Pending,
206 }
207 }
208}
209
210impl<T, R> Scheduler<T, R>
211where
212 T: Eq + Hash + Clone,
213 R: Stream<Item = ScheduleRequest<T>>,
214{
215 pub fn hold_unless<C: Fn(&T) -> bool>(self: Pin<&mut Self>, can_take_message: C) -> HoldUnless<T, R, C> {
226 HoldUnless {
227 scheduler: self,
228 can_take_message,
229 }
230 }
231
232 #[must_use]
236 pub fn hold(self: Pin<&mut Self>) -> Hold<T, R> {
237 Hold { scheduler: self }
238 }
239
240 #[cfg(test)]
242 pub fn contains_pending(&self, msg: &T) -> bool {
243 self.pending.contains(msg)
244 }
245}
246
247impl<T, R> Stream for Scheduler<T, R>
248where
249 T: Eq + Hash + Clone,
250 R: Stream<Item = ScheduleRequest<T>>,
251{
252 type Item = T;
253
254 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255 Pin::new(&mut self.hold_unless(|_| true)).poll_next(cx)
256 }
257}
258
259pub fn scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(requests: S) -> Scheduler<T, S> {
270 Scheduler::new(requests, Duration::ZERO)
271}
272
273#[allow(clippy::module_name_repetitions)]
281pub fn debounced_scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(
282 requests: S,
283 debounce: Duration,
284) -> Scheduler<T, S> {
285 Scheduler::new(requests, debounce)
286}
287
288pub(crate) fn max_schedule_time() -> Instant {
292 Instant::now() + Duration::from_secs(86400 * 30 * 6)
293}
294
295#[cfg(test)]
296mod tests {
297 use crate::utils::KubeRuntimeStreamExt;
298
299 use super::{debounced_scheduler, scheduler, ScheduleRequest};
300 use educe::Educe;
301 use futures::{channel::mpsc, future, poll, stream, FutureExt, SinkExt, StreamExt};
302 use std::{pin::pin, task::Poll};
303 use tokio::time::{advance, pause, sleep, Duration, Instant};
304
305 fn unwrap_poll<T>(poll: Poll<T>) -> T {
306 if let Poll::Ready(x) = poll {
307 x
308 } else {
309 panic!("Tried to unwrap a pending poll!")
310 }
311 }
312
313 #[derive(Educe, Eq, Clone, Debug)]
315 #[educe(PartialEq, Hash)]
316 struct SingletonMessage(#[educe(PartialEq(ignore), Hash(ignore))] u8);
317
318 #[tokio::test]
319 async fn scheduler_should_hold_and_release_items() {
320 pause();
321 let mut scheduler = Box::pin(scheduler(
322 stream::iter(vec![ScheduleRequest {
323 message: 1_u8,
324 run_at: Instant::now(),
325 }])
326 .on_complete(sleep(Duration::from_secs(4))),
327 ));
328 assert!(!scheduler.contains_pending(&1));
329 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
330 assert!(scheduler.contains_pending(&1));
331 assert_eq!(
332 unwrap_poll(poll!(scheduler.as_mut().hold_unless(|_| true).next())).unwrap(),
333 1_u8
334 );
335 assert!(!scheduler.contains_pending(&1));
336 assert!(scheduler.as_mut().hold_unless(|_| true).next().await.is_none());
337 }
338
339 #[tokio::test]
340 async fn scheduler_should_not_reschedule_pending_items() {
341 pause();
342 let (mut tx, rx) = mpsc::unbounded::<ScheduleRequest<u8>>();
343 let mut scheduler = Box::pin(scheduler(rx));
344 tx.send(ScheduleRequest {
345 message: 1,
346 run_at: Instant::now(),
347 })
348 .await
349 .unwrap();
350 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
351 tx.send(ScheduleRequest {
352 message: 1,
353 run_at: Instant::now(),
354 })
355 .await
356 .unwrap();
357 future::join(
358 async {
359 sleep(Duration::from_secs(2)).await;
360 drop(tx);
361 },
362 async {
363 assert_eq!(scheduler.next().await.unwrap(), 1);
364 assert!(scheduler.next().await.is_none())
365 },
366 )
367 .await;
368 }
369
370 #[tokio::test]
371 async fn scheduler_pending_message_should_not_block_head_of_line() {
372 let mut scheduler = Box::pin(scheduler(
373 stream::iter(vec![
374 ScheduleRequest {
375 message: 1,
376 run_at: Instant::now(),
377 },
378 ScheduleRequest {
379 message: 2,
380 run_at: Instant::now(),
381 },
382 ])
383 .on_complete(sleep(Duration::from_secs(2))),
384 ));
385 assert_eq!(
386 scheduler.as_mut().hold_unless(|x| *x != 1).next().await.unwrap(),
387 2
388 );
389 }
390
391 #[tokio::test]
392 async fn scheduler_should_emit_items_as_requested() {
393 pause();
394 let mut scheduler = pin!(scheduler(
395 stream::iter(vec![
396 ScheduleRequest {
397 message: 1_u8,
398 run_at: Instant::now() + Duration::from_secs(1),
399 },
400 ScheduleRequest {
401 message: 2,
402 run_at: Instant::now() + Duration::from_secs(3),
403 },
404 ])
405 .on_complete(sleep(Duration::from_secs(5))),
406 ));
407 assert!(poll!(scheduler.next()).is_pending());
408 advance(Duration::from_secs(2)).await;
409 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 1);
410 assert!(poll!(scheduler.next()).is_pending());
411 advance(Duration::from_secs(2)).await;
412 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 2);
413 assert!(scheduler.next().await.is_none());
415 }
416
417 #[tokio::test]
418 async fn scheduler_dedupe_should_keep_earlier_item() {
419 pause();
420 let mut scheduler = pin!(scheduler(
421 stream::iter(vec![
422 ScheduleRequest {
423 message: (),
424 run_at: Instant::now() + Duration::from_secs(1),
425 },
426 ScheduleRequest {
427 message: (),
428 run_at: Instant::now() + Duration::from_secs(3),
429 },
430 ])
431 .on_complete(sleep(Duration::from_secs(5))),
432 ));
433 assert!(poll!(scheduler.next()).is_pending());
434 advance(Duration::from_secs(2)).await;
435 scheduler.next().now_or_never().unwrap().unwrap();
436 assert!(scheduler.next().await.is_none());
438 }
439
440 #[tokio::test]
441 async fn scheduler_dedupe_should_replace_later_item() {
442 pause();
443 let mut scheduler = pin!(scheduler(
444 stream::iter(vec![
445 ScheduleRequest {
446 message: (),
447 run_at: Instant::now() + Duration::from_secs(3),
448 },
449 ScheduleRequest {
450 message: (),
451 run_at: Instant::now() + Duration::from_secs(1),
452 },
453 ])
454 .on_complete(sleep(Duration::from_secs(5))),
455 ));
456 assert!(poll!(scheduler.next()).is_pending());
457 advance(Duration::from_secs(2)).await;
458 scheduler.next().now_or_never().unwrap().unwrap();
459 assert!(scheduler.next().await.is_none());
461 }
462
463 #[tokio::test]
464 async fn scheduler_dedupe_should_allow_rescheduling_emitted_item() {
465 pause();
466 let (mut schedule_tx, schedule_rx) = mpsc::unbounded();
467 let mut scheduler = scheduler(schedule_rx);
468 schedule_tx
469 .send(ScheduleRequest {
470 message: (),
471 run_at: Instant::now() + Duration::from_secs(1),
472 })
473 .await
474 .unwrap();
475 assert!(poll!(scheduler.next()).is_pending());
476 advance(Duration::from_secs(2)).await;
477 scheduler.next().now_or_never().unwrap().unwrap();
478 assert!(poll!(scheduler.next()).is_pending());
479 schedule_tx
480 .send(ScheduleRequest {
481 message: (),
482 run_at: Instant::now() + Duration::from_secs(1),
483 })
484 .await
485 .unwrap();
486 assert!(poll!(scheduler.next()).is_pending());
487 advance(Duration::from_secs(2)).await;
488 scheduler.next().now_or_never().unwrap().unwrap();
489 assert!(poll!(scheduler.next()).is_pending());
490 }
491
492 #[tokio::test]
493 async fn scheduler_should_overwrite_message_with_soonest_version() {
494 pause();
495
496 let now = Instant::now();
497 let scheduler = scheduler(
498 stream::iter([
499 ScheduleRequest {
500 message: SingletonMessage(1),
501 run_at: now + Duration::from_secs(2),
502 },
503 ScheduleRequest {
504 message: SingletonMessage(2),
505 run_at: now + Duration::from_secs(1),
506 },
507 ])
508 .on_complete(sleep(Duration::from_secs(5))),
509 );
510 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![2]);
511 }
512
513 #[tokio::test]
514 async fn scheduler_should_not_overwrite_message_with_later_version() {
515 pause();
516
517 let now = Instant::now();
518 let scheduler = scheduler(
519 stream::iter([
520 ScheduleRequest {
521 message: SingletonMessage(1),
522 run_at: now + Duration::from_secs(1),
523 },
524 ScheduleRequest {
525 message: SingletonMessage(2),
526 run_at: now + Duration::from_secs(2),
527 },
528 ])
529 .on_complete(sleep(Duration::from_secs(5))),
530 );
531 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![1]);
532 }
533
534 #[tokio::test]
535 async fn scheduler_should_add_debounce_to_a_request() {
536 pause();
537
538 let now = Instant::now();
539 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
540 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(2));
541
542 sched_tx
543 .send(ScheduleRequest {
544 message: SingletonMessage(1),
545 run_at: now,
546 })
547 .await
548 .unwrap();
549 advance(Duration::from_secs(1)).await;
550 assert!(poll!(scheduler.next()).is_pending());
551 advance(Duration::from_secs(3)).await;
552 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 1);
553 }
554
555 #[tokio::test]
556 async fn scheduler_should_dedup_message_within_debounce_period() {
557 pause();
558
559 let mut now = Instant::now();
560 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
561 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(3));
562
563 sched_tx
564 .send(ScheduleRequest {
565 message: SingletonMessage(1),
566 run_at: now,
567 })
568 .await
569 .unwrap();
570 assert!(poll!(scheduler.next()).is_pending());
571 advance(Duration::from_secs(1)).await;
572
573 now = Instant::now();
574 sched_tx
575 .send(ScheduleRequest {
576 message: SingletonMessage(2),
577 run_at: now,
578 })
579 .await
580 .unwrap();
581 advance(Duration::from_millis(2500)).await;
583 assert!(poll!(scheduler.next()).is_pending());
584
585 advance(Duration::from_secs(3)).await;
586 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 2);
587 assert!(poll!(scheduler.next()).is_pending());
588 }
589}