mz_ore/
channel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Channel utilities and extensions.
17
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use futures::{Future, FutureExt};
22use prometheus::core::Atomic;
23use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, error, unbounded_channel};
24use tokio::sync::oneshot;
25
26use crate::metrics::PromLabelsExt;
27
28pub mod trigger;
29
30/// A trait describing a metric that can be used with an `instrumented_unbounded_channel`.
31pub trait InstrumentedChannelMetric {
32    /// Bump the metric, increasing the count of operators (send or receives) that occurred.
33    fn bump(&self);
34}
35
36impl<P, L> InstrumentedChannelMetric for crate::metrics::DeleteOnDropCounter<P, L>
37where
38    P: Atomic,
39    L: PromLabelsExt,
40{
41    fn bump(&self) {
42        self.inc()
43    }
44}
45
46/// A wrapper around tokio's mpsc unbounded channels that connects
47/// metrics that are incremented when sends or receives happen.
48pub fn instrumented_unbounded_channel<T, M>(
49    sender_metric: M,
50    receiver_metric: M,
51) -> (
52    InstrumentedUnboundedSender<T, M>,
53    InstrumentedUnboundedReceiver<T, M>,
54)
55where
56    M: InstrumentedChannelMetric,
57{
58    let (tx, rx) = unbounded_channel();
59
60    (
61        InstrumentedUnboundedSender {
62            tx,
63            metric: sender_metric,
64        },
65        InstrumentedUnboundedReceiver {
66            rx,
67            metric: receiver_metric,
68        },
69    )
70}
71
72/// A wrapper around tokio's `UnboundedSender` that increments a metric when a send occurs.
73///
74/// The metric is not dropped until this sender is dropped.
75#[derive(Debug, Clone)]
76pub struct InstrumentedUnboundedSender<T, M> {
77    tx: UnboundedSender<T>,
78    metric: M,
79}
80
81impl<T, M> InstrumentedUnboundedSender<T, M>
82where
83    M: InstrumentedChannelMetric,
84{
85    /// The same as `UnboundedSender::send`.
86    pub fn send(&self, message: T) -> Result<(), error::SendError<T>> {
87        let res = self.tx.send(message);
88        self.metric.bump();
89        res
90    }
91}
92
93/// A wrapper around tokio's `UnboundedReceiver` that increments a metric when a recv _finishes_.
94///
95/// The metric is not dropped until this receiver is dropped.
96#[derive(Debug)]
97pub struct InstrumentedUnboundedReceiver<T, M> {
98    rx: UnboundedReceiver<T>,
99    metric: M,
100}
101
102impl<T, M> InstrumentedUnboundedReceiver<T, M>
103where
104    M: InstrumentedChannelMetric,
105{
106    /// The same as `UnboundedSender::recv`.
107    pub async fn recv(&mut self) -> Option<T> {
108        let res = self.rx.recv().await;
109        self.metric.bump();
110        res
111    }
112
113    /// The same as `UnboundedSender::try_recv`.
114    pub fn try_recv(&mut self) -> Result<T, error::TryRecvError> {
115        let res = self.rx.try_recv();
116
117        if res.is_ok() {
118            self.metric.bump();
119        }
120        res
121    }
122}
123
124/// Extensions for oneshot channel types.
125pub trait OneshotReceiverExt<T> {
126    /// If the receiver is dropped without the value being observed, the provided closure will be
127    /// called with the value that was left in the channel.
128    ///
129    /// This is useful in cases where you want to cleanup resources if the receiver of this value
130    /// has gone away. If the sender and receiver are running on separate threads, it's possible
131    /// for the sender to succeed, and for the receiver to be concurrently dropped, never realizing
132    /// that it received a value.
133    fn with_guard<F>(self, guard: F) -> GuardedReceiver<F, T>
134    where
135        F: FnMut(T);
136}
137
138impl<T> OneshotReceiverExt<T> for oneshot::Receiver<T> {
139    fn with_guard<F>(self, guard: F) -> GuardedReceiver<F, T>
140    where
141        F: FnMut(T),
142    {
143        GuardedReceiver { guard, inner: self }
144    }
145}
146
147/// A wrapper around [`oneshot::Receiver`] that will call the provided closure if there is a value
148/// in the receiver when it's dropped.
149#[derive(Debug)]
150pub struct GuardedReceiver<F: FnMut(T), T> {
151    guard: F,
152    inner: oneshot::Receiver<T>,
153}
154
155// Note(parkmycar): If this Unpin requirement becomes too restrictive, we can refactor
156// GuardedReceiver to use `pin_project`.
157impl<F: FnMut(T) + Unpin, T> Future for GuardedReceiver<F, T> {
158    type Output = Result<T, oneshot::error::RecvError>;
159
160    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
161        self.inner.poll_unpin(cx)
162    }
163}
164
165impl<F: FnMut(T), T> Drop for GuardedReceiver<F, T> {
166    fn drop(&mut self) {
167        // Close the channel so the sender is guaranteed to fail.
168        self.inner.close();
169
170        // If there was some value waiting in the channel call the guard with the value.
171        if let Ok(x) = self.inner.try_recv() {
172            (self.guard)(x)
173        }
174    }
175}