mz_ore/netio/
timeout.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! `Async{Read,Write}` wrappers that enforce a configurable timeout on each I/O operation.
17
18use std::io;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use std::time::Duration;
22
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24use tokio::time::Sleep;
25
26/// An [`AsyncRead`] wrapper that enforces a timeout on each read.
27#[derive(Debug)]
28pub struct TimedReader<R> {
29    reader: R,
30    timeout: Duration,
31    sleep: Option<Pin<Box<Sleep>>>,
32}
33
34impl<R> TimedReader<R> {
35    /// Wrap a reader with a timeout.
36    pub fn new(reader: R, timeout: Duration) -> Self {
37        Self {
38            reader,
39            timeout,
40            sleep: None,
41        }
42    }
43
44    /// Poll the sleep future, creating it if necessary.
45    fn poll_sleep(&mut self, cx: &mut Context<'_>) -> Poll<()> {
46        let timeout = self.timeout;
47        let sleep = self.sleep.get_or_insert_with(|| {
48            let sleep = tokio::time::sleep(timeout);
49            Box::pin(sleep)
50        });
51
52        sleep.as_mut().poll(cx)
53    }
54}
55
56impl<R: AsyncRead + Unpin> AsyncRead for TimedReader<R> {
57    fn poll_read(
58        mut self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &mut ReadBuf<'_>,
61    ) -> Poll<io::Result<()>> {
62        let poll = if self.poll_sleep(cx).is_ready() {
63            Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
64        } else if let Poll::Ready(result) = Pin::new(&mut self.reader).poll_read(cx, buf) {
65            Poll::Ready(result)
66        } else {
67            Poll::Pending
68        };
69
70        if poll.is_ready() {
71            self.sleep = None;
72        }
73
74        poll
75    }
76}
77
78/// An [`AsyncWrite`] wrapper that enforces a timeout on each write.
79#[derive(Debug)]
80pub struct TimedWriter<W> {
81    writer: W,
82    timeout: Duration,
83    sleep: Option<Pin<Box<Sleep>>>,
84}
85
86impl<W> TimedWriter<W> {
87    /// Wrap a writer with a timeout.
88    pub fn new(writer: W, timeout: Duration) -> Self {
89        Self {
90            writer,
91            timeout,
92            sleep: None,
93        }
94    }
95
96    /// Poll the sleep future, creating it if necessary.
97    fn poll_sleep(&mut self, cx: &mut Context<'_>) -> Poll<()> {
98        let timeout = self.timeout;
99        let sleep = self.sleep.get_or_insert_with(|| {
100            let sleep = tokio::time::sleep(timeout);
101            Box::pin(sleep)
102        });
103
104        sleep.as_mut().poll(cx)
105    }
106}
107
108impl<W: AsyncWrite + Unpin> AsyncWrite for TimedWriter<W> {
109    fn poll_write(
110        mut self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112        buf: &[u8],
113    ) -> Poll<Result<usize, io::Error>> {
114        let poll = if self.poll_sleep(cx).is_ready() {
115            Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
116        } else if let Poll::Ready(result) = Pin::new(&mut self.writer).poll_write(cx, buf) {
117            Poll::Ready(result)
118        } else {
119            Poll::Pending
120        };
121
122        if poll.is_ready() {
123            self.sleep = None;
124        }
125
126        poll
127    }
128
129    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
130        let poll = if self.poll_sleep(cx).is_ready() {
131            Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
132        } else if let Poll::Ready(result) = Pin::new(&mut self.writer).poll_flush(cx) {
133            Poll::Ready(result)
134        } else {
135            Poll::Pending
136        };
137
138        if poll.is_ready() {
139            self.sleep = None;
140        }
141
142        poll
143    }
144
145    fn poll_shutdown(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148    ) -> Poll<Result<(), io::Error>> {
149        let poll = if self.poll_sleep(cx).is_ready() {
150            Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
151        } else if let Poll::Ready(result) = Pin::new(&mut self.writer).poll_shutdown(cx) {
152            Poll::Ready(result)
153        } else {
154            Poll::Pending
155        };
156
157        if poll.is_ready() {
158            self.sleep = None;
159        }
160
161        poll
162    }
163}