hyper_timeout/
stream.rs

1//! Wrappers for applying timeouts to IO operations.
2//!
3//! This used to depend on [tokio-io-timeout]. After Hyper 1.0 introduced hyper-specific IO traits, this was rewritten to use hyper IO traits instead of tokio IO traits.
4//!
5//! These timeouts are analogous to the read and write timeouts on traditional blocking sockets. A timeout countdown is
6//! initiated when a read/write operation returns [`Poll::Pending`]. If a read/write does not return successfully before
7//! the countdown expires, an [`io::Error`] with a kind of [`TimedOut`](io::ErrorKind::TimedOut) is returned.
8#![warn(missing_docs)]
9
10use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use pin_project_lite::pin_project;
13use std::future::Future;
14use std::io;
15use std::pin::Pin;
16use std::task::{ready, Context, Poll};
17use std::time::Duration;
18use tokio::time::{sleep_until, Instant, Sleep};
19
20pin_project! {
21    #[derive(Debug)]
22    struct TimeoutState {
23        timeout: Option<Duration>,
24        #[pin]
25        cur: Sleep,
26        active: bool,
27    }
28}
29
30impl TimeoutState {
31    #[inline]
32    fn new() -> TimeoutState {
33        TimeoutState {
34            timeout: None,
35            cur: sleep_until(Instant::now()),
36            active: false,
37        }
38    }
39
40    #[inline]
41    fn timeout(&self) -> Option<Duration> {
42        self.timeout
43    }
44
45    #[inline]
46    fn set_timeout(&mut self, timeout: Option<Duration>) {
47        // since this takes &mut self, we can't yet be active
48        self.timeout = timeout;
49    }
50
51    #[inline]
52    fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
53        *self.as_mut().project().timeout = timeout;
54        self.reset();
55    }
56
57    #[inline]
58    fn reset(self: Pin<&mut Self>) {
59        let this = self.project();
60
61        if *this.active {
62            *this.active = false;
63            this.cur.reset(Instant::now());
64        }
65    }
66
67    #[inline]
68    fn poll_check(self: Pin<&mut Self>, cx: &mut Context) -> io::Result<()> {
69        let mut this = self.project();
70
71        let timeout = match this.timeout {
72            Some(timeout) => *timeout,
73            None => return Ok(()),
74        };
75
76        if !*this.active {
77            this.cur.as_mut().reset(Instant::now() + timeout);
78            *this.active = true;
79        }
80
81        match this.cur.poll(cx) {
82            Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
83            Poll::Pending => Ok(()),
84        }
85    }
86}
87
88pin_project! {
89    /// An `hyper::rt::Read`er which applies a timeout to read operations.
90    #[derive(Debug)]
91    pub struct TimeoutReader<R> {
92        #[pin]
93        reader: R,
94        #[pin]
95        state: TimeoutState,
96    }
97}
98
99impl<R> TimeoutReader<R>
100where
101    R: Read,
102{
103    /// Returns a new `TimeoutReader` wrapping the specified reader.
104    ///
105    /// There is initially no timeout.
106    pub fn new(reader: R) -> TimeoutReader<R> {
107        TimeoutReader {
108            reader,
109            state: TimeoutState::new(),
110        }
111    }
112
113    /// Returns the current read timeout.
114    pub fn timeout(&self) -> Option<Duration> {
115        self.state.timeout()
116    }
117
118    /// Sets the read timeout.
119    ///
120    /// This can only be used before the reader is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned)
121    /// otherwise.
122    pub fn set_timeout(&mut self, timeout: Option<Duration>) {
123        self.state.set_timeout(timeout);
124    }
125
126    /// Sets the read timeout.
127    ///
128    /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet
129    /// pinned.
130    pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
131        self.project().state.set_timeout_pinned(timeout);
132    }
133
134    /// Returns a shared reference to the inner reader.
135    pub fn get_ref(&self) -> &R {
136        &self.reader
137    }
138
139    /// Returns a mutable reference to the inner reader.
140    pub fn get_mut(&mut self) -> &mut R {
141        &mut self.reader
142    }
143
144    /// Returns a pinned mutable reference to the inner reader.
145    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
146        self.project().reader
147    }
148
149    /// Consumes the `TimeoutReader`, returning the inner reader.
150    pub fn into_inner(self) -> R {
151        self.reader
152    }
153}
154
155impl<R> Read for TimeoutReader<R>
156where
157    R: Read,
158{
159    fn poll_read(
160        self: Pin<&mut Self>,
161        cx: &mut Context,
162        buf: ReadBufCursor,
163    ) -> Poll<Result<(), io::Error>> {
164        let this = self.project();
165        let r = this.reader.poll_read(cx, buf);
166        match r {
167            Poll::Pending => this.state.poll_check(cx)?,
168            _ => this.state.reset(),
169        }
170        r
171    }
172}
173
174impl<R> Write for TimeoutReader<R>
175where
176    R: Write,
177{
178    fn poll_write(
179        self: Pin<&mut Self>,
180        cx: &mut Context,
181        buf: &[u8],
182    ) -> Poll<Result<usize, io::Error>> {
183        self.project().reader.poll_write(cx, buf)
184    }
185
186    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
187        self.project().reader.poll_flush(cx)
188    }
189
190    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
191        self.project().reader.poll_shutdown(cx)
192    }
193
194    fn poll_write_vectored(
195        self: Pin<&mut Self>,
196        cx: &mut Context,
197        bufs: &[io::IoSlice],
198    ) -> Poll<io::Result<usize>> {
199        self.project().reader.poll_write_vectored(cx, bufs)
200    }
201
202    fn is_write_vectored(&self) -> bool {
203        self.reader.is_write_vectored()
204    }
205}
206
207pin_project! {
208    /// An `hyper::rt::Write`er which applies a timeout to write operations.
209    #[derive(Debug)]
210    pub struct TimeoutWriter<W> {
211        #[pin]
212        writer: W,
213        #[pin]
214        state: TimeoutState,
215    }
216}
217
218impl<W> TimeoutWriter<W>
219where
220    W: Write,
221{
222    /// Returns a new `TimeoutReader` wrapping the specified reader.
223    ///
224    /// There is initially no timeout.
225    pub fn new(writer: W) -> TimeoutWriter<W> {
226        TimeoutWriter {
227            writer,
228            state: TimeoutState::new(),
229        }
230    }
231
232    /// Returns the current write timeout.
233    pub fn timeout(&self) -> Option<Duration> {
234        self.state.timeout()
235    }
236
237    /// Sets the write timeout.
238    ///
239    /// This can only be used before the writer is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned)
240    /// otherwise.
241    pub fn set_timeout(&mut self, timeout: Option<Duration>) {
242        self.state.set_timeout(timeout);
243    }
244
245    /// Sets the write timeout.
246    ///
247    /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet
248    /// pinned.
249    pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
250        self.project().state.set_timeout_pinned(timeout);
251    }
252
253    /// Returns a shared reference to the inner writer.
254    pub fn get_ref(&self) -> &W {
255        &self.writer
256    }
257
258    /// Returns a mutable reference to the inner writer.
259    pub fn get_mut(&mut self) -> &mut W {
260        &mut self.writer
261    }
262
263    /// Returns a pinned mutable reference to the inner writer.
264    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
265        self.project().writer
266    }
267
268    /// Consumes the `TimeoutWriter`, returning the inner writer.
269    pub fn into_inner(self) -> W {
270        self.writer
271    }
272}
273
274impl<W> Write for TimeoutWriter<W>
275where
276    W: Write,
277{
278    fn poll_write(
279        self: Pin<&mut Self>,
280        cx: &mut Context,
281        buf: &[u8],
282    ) -> Poll<Result<usize, io::Error>> {
283        let this = self.project();
284        let r = this.writer.poll_write(cx, buf);
285        match r {
286            Poll::Pending => this.state.poll_check(cx)?,
287            _ => this.state.reset(),
288        }
289        r
290    }
291
292    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
293        let this = self.project();
294        let r = this.writer.poll_flush(cx);
295        match r {
296            Poll::Pending => this.state.poll_check(cx)?,
297            _ => this.state.reset(),
298        }
299        r
300    }
301
302    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
303        let this = self.project();
304        let r = this.writer.poll_shutdown(cx);
305        match r {
306            Poll::Pending => this.state.poll_check(cx)?,
307            _ => this.state.reset(),
308        }
309        r
310    }
311
312    fn poll_write_vectored(
313        self: Pin<&mut Self>,
314        cx: &mut Context,
315        bufs: &[io::IoSlice],
316    ) -> Poll<io::Result<usize>> {
317        let this = self.project();
318        let r = this.writer.poll_write_vectored(cx, bufs);
319        match r {
320            Poll::Pending => this.state.poll_check(cx)?,
321            _ => this.state.reset(),
322        }
323        r
324    }
325
326    fn is_write_vectored(&self) -> bool {
327        self.writer.is_write_vectored()
328    }
329}
330
331impl<W> Read for TimeoutWriter<W>
332where
333    W: Read,
334{
335    fn poll_read(
336        self: Pin<&mut Self>,
337        cx: &mut Context,
338        buf: ReadBufCursor,
339    ) -> Poll<Result<(), io::Error>> {
340        self.project().writer.poll_read(cx, buf)
341    }
342}
343
344pin_project! {
345    /// A stream which applies read and write timeouts to an inner stream.
346    #[derive(Debug)]
347    pub struct TimeoutStream<S> {
348        #[pin]
349        stream: TimeoutReader<TimeoutWriter<S>>
350    }
351}
352
353impl<S> TimeoutStream<S>
354where
355    S: Read + Write,
356{
357    /// Returns a new `TimeoutStream` wrapping the specified stream.
358    ///
359    /// There is initially no read or write timeout.
360    pub fn new(stream: S) -> TimeoutStream<S> {
361        let writer = TimeoutWriter::new(stream);
362        let stream = TimeoutReader::new(writer);
363        TimeoutStream { stream }
364    }
365
366    /// Returns the current read timeout.
367    pub fn read_timeout(&self) -> Option<Duration> {
368        self.stream.timeout()
369    }
370
371    /// Sets the read timeout.
372    ///
373    /// This can only be used before the stream is pinned; use
374    /// [`set_read_timeout_pinned`](Self::set_read_timeout_pinned) otherwise.
375    pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
376        self.stream.set_timeout(timeout)
377    }
378
379    /// Sets the read timeout.
380    ///
381    /// This will reset any pending read timeout. Use [`set_read_timeout`](Self::set_read_timeout) instead if the stream
382    /// has not yet been pinned.
383    pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
384        self.project().stream.set_timeout_pinned(timeout)
385    }
386
387    /// Returns the current write timeout.
388    pub fn write_timeout(&self) -> Option<Duration> {
389        self.stream.get_ref().timeout()
390    }
391
392    /// Sets the write timeout.
393    ///
394    /// This can only be used before the stream is pinned; use
395    /// [`set_write_timeout_pinned`](Self::set_write_timeout_pinned) otherwise.
396    pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
397        self.stream.get_mut().set_timeout(timeout)
398    }
399
400    /// Sets the write timeout.
401    ///
402    /// This will reset any pending write timeout. Use [`set_write_timeout`](Self::set_write_timeout) instead if the
403    /// stream has not yet been pinned.
404    pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
405        self.project()
406            .stream
407            .get_pin_mut()
408            .set_timeout_pinned(timeout)
409    }
410
411    /// Returns a shared reference to the inner stream.
412    pub fn get_ref(&self) -> &S {
413        self.stream.get_ref().get_ref()
414    }
415
416    /// Returns a mutable reference to the inner stream.
417    pub fn get_mut(&mut self) -> &mut S {
418        self.stream.get_mut().get_mut()
419    }
420
421    /// Returns a pinned mutable reference to the inner stream.
422    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
423        self.project().stream.get_pin_mut().get_pin_mut()
424    }
425
426    /// Consumes the stream, returning the inner stream.
427    pub fn into_inner(self) -> S {
428        self.stream.into_inner().into_inner()
429    }
430}
431
432impl<S> Read for TimeoutStream<S>
433where
434    S: Read + Write,
435{
436    fn poll_read(
437        self: Pin<&mut Self>,
438        cx: &mut Context,
439        buf: ReadBufCursor,
440    ) -> Poll<Result<(), io::Error>> {
441        self.project().stream.poll_read(cx, buf)
442    }
443}
444
445impl<S> Write for TimeoutStream<S>
446where
447    S: Read + Write,
448{
449    fn poll_write(
450        self: Pin<&mut Self>,
451        cx: &mut Context,
452        buf: &[u8],
453    ) -> Poll<Result<usize, io::Error>> {
454        self.project().stream.poll_write(cx, buf)
455    }
456
457    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
458        self.project().stream.poll_flush(cx)
459    }
460
461    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
462        self.project().stream.poll_shutdown(cx)
463    }
464
465    fn poll_write_vectored(
466        self: Pin<&mut Self>,
467        cx: &mut Context,
468        bufs: &[io::IoSlice],
469    ) -> Poll<io::Result<usize>> {
470        self.project().stream.poll_write_vectored(cx, bufs)
471    }
472
473    fn is_write_vectored(&self) -> bool {
474        self.stream.is_write_vectored()
475    }
476}
477
478impl<S> Connection for TimeoutStream<S>
479where
480    S: Read + Write + Connection + Unpin,
481{
482    fn connected(&self) -> Connected {
483        self.get_ref().connected()
484    }
485}
486
487impl<S> Connection for Pin<Box<TimeoutStream<S>>>
488where
489    S: Read + Write + Connection + Unpin,
490{
491    fn connected(&self) -> Connected {
492        self.get_ref().connected()
493    }
494}
495
496pin_project! {
497    /// A future which can be used to easily read available number of bytes to fill
498    /// a buffer. Based on the internal [tokio::io::util::read::Read]
499    struct ReadFut<'a, R: ?Sized> {
500        reader: &'a mut R,
501        buf: &'a mut [u8],
502    }
503}
504
505/// Tries to read some bytes directly into the given `buf` in asynchronous
506/// manner, returning a future type.
507///
508/// The returned future will resolve to both the I/O stream and the buffer
509/// as well as the number of bytes read once the read operation is completed.
510fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R>
511where
512    R: Read + Unpin + ?Sized,
513{
514    ReadFut { reader, buf }
515}
516
517impl<R> Future for ReadFut<'_, R>
518where
519    R: Read + Unpin + ?Sized,
520{
521    type Output = io::Result<usize>;
522
523    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
524        let me = self.project();
525        let mut buf = ReadBuf::new(me.buf);
526        ready!(Pin::new(me.reader).poll_read(cx, buf.unfilled()))?;
527        Poll::Ready(Ok(buf.filled().len()))
528    }
529}
530
531trait ReadExt: Read {
532    /// Pulls some bytes from this source into the specified buffer,
533    /// returning how many bytes were read.
534    fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self>
535    where
536        Self: Unpin,
537    {
538        read(self, buf)
539    }
540}
541
542pin_project! {
543    /// A future to write some of the buffer to an `AsyncWrite`.-
544    struct WriteFut<'a, W: ?Sized> {
545        writer: &'a mut W,
546        buf: &'a [u8],
547    }
548}
549
550/// Tries to write some bytes from the given `buf` to the writer in an
551/// asynchronous manner, returning a future.
552fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W>
553where
554    W: Write + Unpin + ?Sized,
555{
556    WriteFut { writer, buf }
557}
558
559impl<W> Future for WriteFut<'_, W>
560where
561    W: Write + Unpin + ?Sized,
562{
563    type Output = io::Result<usize>;
564
565    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
566        let me = self.project();
567        Pin::new(&mut *me.writer).poll_write(cx, me.buf)
568    }
569}
570
571trait WriteExt: Write {
572    /// Writes a buffer into this writer, returning how many bytes were
573    /// written.
574    fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self>
575    where
576        Self: Unpin,
577    {
578        write(self, src)
579    }
580}
581
582impl<R> ReadExt for Pin<&mut TimeoutReader<R>>
583where
584    R: Read,
585{
586    fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
587        read(self, buf)
588    }
589}
590
591impl<W> WriteExt for Pin<&mut TimeoutWriter<W>>
592where
593    W: Write,
594{
595    fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
596        write(self, src)
597    }
598}
599
600impl<S> ReadExt for Pin<&mut TimeoutStream<S>>
601where
602    S: Read + Write,
603{
604    fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
605        read(self, buf)
606    }
607}
608
609impl<S> WriteExt for Pin<&mut TimeoutStream<S>>
610where
611    S: Read + Write,
612{
613    fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
614        write(self, src)
615    }
616}
617
618#[cfg(test)]
619mod test {
620    use super::*;
621    use hyper_util::rt::TokioIo;
622    use std::io::Write;
623    use std::net::TcpListener;
624    use std::thread;
625    use tokio::net::TcpStream;
626    use tokio::pin;
627
628    pin_project! {
629        struct DelayStream {
630            #[pin]
631            sleep: Sleep,
632        }
633    }
634
635    impl DelayStream {
636        fn new(until: Instant) -> Self {
637            DelayStream {
638                sleep: sleep_until(until),
639            }
640        }
641    }
642
643    impl Read for DelayStream {
644        fn poll_read(
645            self: Pin<&mut Self>,
646            cx: &mut Context,
647            _buf: ReadBufCursor,
648        ) -> Poll<Result<(), io::Error>> {
649            match self.project().sleep.poll(cx) {
650                Poll::Ready(()) => Poll::Ready(Ok(())),
651                Poll::Pending => Poll::Pending,
652            }
653        }
654    }
655
656    impl hyper::rt::Write for DelayStream {
657        fn poll_write(
658            self: Pin<&mut Self>,
659            cx: &mut Context,
660            buf: &[u8],
661        ) -> Poll<Result<usize, io::Error>> {
662            match self.project().sleep.poll(cx) {
663                Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
664                Poll::Pending => Poll::Pending,
665            }
666        }
667
668        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
669            Poll::Ready(Ok(()))
670        }
671
672        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
673            Poll::Ready(Ok(()))
674        }
675    }
676
677    #[tokio::test]
678    async fn read_timeout() {
679        let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
680        let mut reader = TimeoutReader::new(reader);
681        reader.set_timeout(Some(Duration::from_millis(100)));
682        pin!(reader);
683
684        let r = reader.read(&mut [0, 1, 2]).await;
685        assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
686    }
687
688    #[tokio::test]
689    async fn read_ok() {
690        let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
691        let mut reader = TimeoutReader::new(reader);
692        reader.set_timeout(Some(Duration::from_millis(500)));
693        pin!(reader);
694
695        reader.read(&mut [0]).await.unwrap();
696    }
697
698    #[tokio::test]
699    async fn write_timeout() {
700        let writer = DelayStream::new(Instant::now() + Duration::from_millis(500));
701        let mut writer = TimeoutWriter::new(writer);
702        writer.set_timeout(Some(Duration::from_millis(100)));
703        pin!(writer);
704
705        let r = writer.write(&[0]).await;
706        assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
707    }
708
709    #[tokio::test]
710    async fn write_ok() {
711        let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
712        let mut writer = TimeoutWriter::new(writer);
713        writer.set_timeout(Some(Duration::from_millis(500)));
714        pin!(writer);
715
716        writer.write(&[0]).await.unwrap();
717    }
718
719    #[tokio::test]
720    async fn tcp_read() {
721        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
722        let addr = listener.local_addr().unwrap();
723
724        thread::spawn(move || {
725            let mut socket = listener.accept().unwrap().0;
726            thread::sleep(Duration::from_millis(10));
727            socket.write_all(b"f").unwrap();
728            thread::sleep(Duration::from_millis(500));
729            let _ = socket.write_all(b"f"); // this may hit an eof
730        });
731
732        let s = TcpStream::connect(&addr).await.unwrap();
733        let s = TokioIo::new(s);
734        let mut s = TimeoutStream::new(s);
735        s.set_read_timeout(Some(Duration::from_millis(100)));
736        pin!(s);
737        s.read(&mut [0]).await.unwrap();
738        let r = s.read(&mut [0]).await;
739
740        match r {
741            Ok(_) => panic!("unexpected success"),
742            Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
743            Err(e) => panic!("{:?}", e),
744        }
745    }
746}