1#![warn(missing_docs)]
9
10use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use pin_project_lite::pin_project;
13use std::future::Future;
14use std::io;
15use std::pin::Pin;
16use std::task::{ready, Context, Poll};
17use std::time::Duration;
18use tokio::time::{sleep_until, Instant, Sleep};
19
20pin_project! {
21 #[derive(Debug)]
22 struct TimeoutState {
23 timeout: Option<Duration>,
24 #[pin]
25 cur: Sleep,
26 active: bool,
27 }
28}
29
30impl TimeoutState {
31 #[inline]
32 fn new() -> TimeoutState {
33 TimeoutState {
34 timeout: None,
35 cur: sleep_until(Instant::now()),
36 active: false,
37 }
38 }
39
40 #[inline]
41 fn timeout(&self) -> Option<Duration> {
42 self.timeout
43 }
44
45 #[inline]
46 fn set_timeout(&mut self, timeout: Option<Duration>) {
47 self.timeout = timeout;
49 }
50
51 #[inline]
52 fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
53 *self.as_mut().project().timeout = timeout;
54 self.reset();
55 }
56
57 #[inline]
58 fn reset(self: Pin<&mut Self>) {
59 let this = self.project();
60
61 if *this.active {
62 *this.active = false;
63 this.cur.reset(Instant::now());
64 }
65 }
66
67 #[inline]
68 fn poll_check(self: Pin<&mut Self>, cx: &mut Context) -> io::Result<()> {
69 let mut this = self.project();
70
71 let timeout = match this.timeout {
72 Some(timeout) => *timeout,
73 None => return Ok(()),
74 };
75
76 if !*this.active {
77 this.cur.as_mut().reset(Instant::now() + timeout);
78 *this.active = true;
79 }
80
81 match this.cur.poll(cx) {
82 Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
83 Poll::Pending => Ok(()),
84 }
85 }
86}
87
88pin_project! {
89 #[derive(Debug)]
91 pub struct TimeoutReader<R> {
92 #[pin]
93 reader: R,
94 #[pin]
95 state: TimeoutState,
96 }
97}
98
99impl<R> TimeoutReader<R>
100where
101 R: Read,
102{
103 pub fn new(reader: R) -> TimeoutReader<R> {
107 TimeoutReader {
108 reader,
109 state: TimeoutState::new(),
110 }
111 }
112
113 pub fn timeout(&self) -> Option<Duration> {
115 self.state.timeout()
116 }
117
118 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
123 self.state.set_timeout(timeout);
124 }
125
126 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
131 self.project().state.set_timeout_pinned(timeout);
132 }
133
134 pub fn get_ref(&self) -> &R {
136 &self.reader
137 }
138
139 pub fn get_mut(&mut self) -> &mut R {
141 &mut self.reader
142 }
143
144 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
146 self.project().reader
147 }
148
149 pub fn into_inner(self) -> R {
151 self.reader
152 }
153}
154
155impl<R> Read for TimeoutReader<R>
156where
157 R: Read,
158{
159 fn poll_read(
160 self: Pin<&mut Self>,
161 cx: &mut Context,
162 buf: ReadBufCursor,
163 ) -> Poll<Result<(), io::Error>> {
164 let this = self.project();
165 let r = this.reader.poll_read(cx, buf);
166 match r {
167 Poll::Pending => this.state.poll_check(cx)?,
168 _ => this.state.reset(),
169 }
170 r
171 }
172}
173
174impl<R> Write for TimeoutReader<R>
175where
176 R: Write,
177{
178 fn poll_write(
179 self: Pin<&mut Self>,
180 cx: &mut Context,
181 buf: &[u8],
182 ) -> Poll<Result<usize, io::Error>> {
183 self.project().reader.poll_write(cx, buf)
184 }
185
186 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
187 self.project().reader.poll_flush(cx)
188 }
189
190 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
191 self.project().reader.poll_shutdown(cx)
192 }
193
194 fn poll_write_vectored(
195 self: Pin<&mut Self>,
196 cx: &mut Context,
197 bufs: &[io::IoSlice],
198 ) -> Poll<io::Result<usize>> {
199 self.project().reader.poll_write_vectored(cx, bufs)
200 }
201
202 fn is_write_vectored(&self) -> bool {
203 self.reader.is_write_vectored()
204 }
205}
206
207pin_project! {
208 #[derive(Debug)]
210 pub struct TimeoutWriter<W> {
211 #[pin]
212 writer: W,
213 #[pin]
214 state: TimeoutState,
215 }
216}
217
218impl<W> TimeoutWriter<W>
219where
220 W: Write,
221{
222 pub fn new(writer: W) -> TimeoutWriter<W> {
226 TimeoutWriter {
227 writer,
228 state: TimeoutState::new(),
229 }
230 }
231
232 pub fn timeout(&self) -> Option<Duration> {
234 self.state.timeout()
235 }
236
237 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
242 self.state.set_timeout(timeout);
243 }
244
245 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
250 self.project().state.set_timeout_pinned(timeout);
251 }
252
253 pub fn get_ref(&self) -> &W {
255 &self.writer
256 }
257
258 pub fn get_mut(&mut self) -> &mut W {
260 &mut self.writer
261 }
262
263 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
265 self.project().writer
266 }
267
268 pub fn into_inner(self) -> W {
270 self.writer
271 }
272}
273
274impl<W> Write for TimeoutWriter<W>
275where
276 W: Write,
277{
278 fn poll_write(
279 self: Pin<&mut Self>,
280 cx: &mut Context,
281 buf: &[u8],
282 ) -> Poll<Result<usize, io::Error>> {
283 let this = self.project();
284 let r = this.writer.poll_write(cx, buf);
285 match r {
286 Poll::Pending => this.state.poll_check(cx)?,
287 _ => this.state.reset(),
288 }
289 r
290 }
291
292 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
293 let this = self.project();
294 let r = this.writer.poll_flush(cx);
295 match r {
296 Poll::Pending => this.state.poll_check(cx)?,
297 _ => this.state.reset(),
298 }
299 r
300 }
301
302 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
303 let this = self.project();
304 let r = this.writer.poll_shutdown(cx);
305 match r {
306 Poll::Pending => this.state.poll_check(cx)?,
307 _ => this.state.reset(),
308 }
309 r
310 }
311
312 fn poll_write_vectored(
313 self: Pin<&mut Self>,
314 cx: &mut Context,
315 bufs: &[io::IoSlice],
316 ) -> Poll<io::Result<usize>> {
317 let this = self.project();
318 let r = this.writer.poll_write_vectored(cx, bufs);
319 match r {
320 Poll::Pending => this.state.poll_check(cx)?,
321 _ => this.state.reset(),
322 }
323 r
324 }
325
326 fn is_write_vectored(&self) -> bool {
327 self.writer.is_write_vectored()
328 }
329}
330
331impl<W> Read for TimeoutWriter<W>
332where
333 W: Read,
334{
335 fn poll_read(
336 self: Pin<&mut Self>,
337 cx: &mut Context,
338 buf: ReadBufCursor,
339 ) -> Poll<Result<(), io::Error>> {
340 self.project().writer.poll_read(cx, buf)
341 }
342}
343
344pin_project! {
345 #[derive(Debug)]
347 pub struct TimeoutStream<S> {
348 #[pin]
349 stream: TimeoutReader<TimeoutWriter<S>>
350 }
351}
352
353impl<S> TimeoutStream<S>
354where
355 S: Read + Write,
356{
357 pub fn new(stream: S) -> TimeoutStream<S> {
361 let writer = TimeoutWriter::new(stream);
362 let stream = TimeoutReader::new(writer);
363 TimeoutStream { stream }
364 }
365
366 pub fn read_timeout(&self) -> Option<Duration> {
368 self.stream.timeout()
369 }
370
371 pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
376 self.stream.set_timeout(timeout)
377 }
378
379 pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
384 self.project().stream.set_timeout_pinned(timeout)
385 }
386
387 pub fn write_timeout(&self) -> Option<Duration> {
389 self.stream.get_ref().timeout()
390 }
391
392 pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
397 self.stream.get_mut().set_timeout(timeout)
398 }
399
400 pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
405 self.project()
406 .stream
407 .get_pin_mut()
408 .set_timeout_pinned(timeout)
409 }
410
411 pub fn get_ref(&self) -> &S {
413 self.stream.get_ref().get_ref()
414 }
415
416 pub fn get_mut(&mut self) -> &mut S {
418 self.stream.get_mut().get_mut()
419 }
420
421 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
423 self.project().stream.get_pin_mut().get_pin_mut()
424 }
425
426 pub fn into_inner(self) -> S {
428 self.stream.into_inner().into_inner()
429 }
430}
431
432impl<S> Read for TimeoutStream<S>
433where
434 S: Read + Write,
435{
436 fn poll_read(
437 self: Pin<&mut Self>,
438 cx: &mut Context,
439 buf: ReadBufCursor,
440 ) -> Poll<Result<(), io::Error>> {
441 self.project().stream.poll_read(cx, buf)
442 }
443}
444
445impl<S> Write for TimeoutStream<S>
446where
447 S: Read + Write,
448{
449 fn poll_write(
450 self: Pin<&mut Self>,
451 cx: &mut Context,
452 buf: &[u8],
453 ) -> Poll<Result<usize, io::Error>> {
454 self.project().stream.poll_write(cx, buf)
455 }
456
457 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
458 self.project().stream.poll_flush(cx)
459 }
460
461 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
462 self.project().stream.poll_shutdown(cx)
463 }
464
465 fn poll_write_vectored(
466 self: Pin<&mut Self>,
467 cx: &mut Context,
468 bufs: &[io::IoSlice],
469 ) -> Poll<io::Result<usize>> {
470 self.project().stream.poll_write_vectored(cx, bufs)
471 }
472
473 fn is_write_vectored(&self) -> bool {
474 self.stream.is_write_vectored()
475 }
476}
477
478impl<S> Connection for TimeoutStream<S>
479where
480 S: Read + Write + Connection + Unpin,
481{
482 fn connected(&self) -> Connected {
483 self.get_ref().connected()
484 }
485}
486
487impl<S> Connection for Pin<Box<TimeoutStream<S>>>
488where
489 S: Read + Write + Connection + Unpin,
490{
491 fn connected(&self) -> Connected {
492 self.get_ref().connected()
493 }
494}
495
496pin_project! {
497 struct ReadFut<'a, R: ?Sized> {
500 reader: &'a mut R,
501 buf: &'a mut [u8],
502 }
503}
504
505fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R>
511where
512 R: Read + Unpin + ?Sized,
513{
514 ReadFut { reader, buf }
515}
516
517impl<R> Future for ReadFut<'_, R>
518where
519 R: Read + Unpin + ?Sized,
520{
521 type Output = io::Result<usize>;
522
523 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
524 let me = self.project();
525 let mut buf = ReadBuf::new(me.buf);
526 ready!(Pin::new(me.reader).poll_read(cx, buf.unfilled()))?;
527 Poll::Ready(Ok(buf.filled().len()))
528 }
529}
530
531trait ReadExt: Read {
532 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self>
535 where
536 Self: Unpin,
537 {
538 read(self, buf)
539 }
540}
541
542pin_project! {
543 struct WriteFut<'a, W: ?Sized> {
545 writer: &'a mut W,
546 buf: &'a [u8],
547 }
548}
549
550fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W>
553where
554 W: Write + Unpin + ?Sized,
555{
556 WriteFut { writer, buf }
557}
558
559impl<W> Future for WriteFut<'_, W>
560where
561 W: Write + Unpin + ?Sized,
562{
563 type Output = io::Result<usize>;
564
565 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
566 let me = self.project();
567 Pin::new(&mut *me.writer).poll_write(cx, me.buf)
568 }
569}
570
571trait WriteExt: Write {
572 fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self>
575 where
576 Self: Unpin,
577 {
578 write(self, src)
579 }
580}
581
582impl<R> ReadExt for Pin<&mut TimeoutReader<R>>
583where
584 R: Read,
585{
586 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
587 read(self, buf)
588 }
589}
590
591impl<W> WriteExt for Pin<&mut TimeoutWriter<W>>
592where
593 W: Write,
594{
595 fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
596 write(self, src)
597 }
598}
599
600impl<S> ReadExt for Pin<&mut TimeoutStream<S>>
601where
602 S: Read + Write,
603{
604 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
605 read(self, buf)
606 }
607}
608
609impl<S> WriteExt for Pin<&mut TimeoutStream<S>>
610where
611 S: Read + Write,
612{
613 fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
614 write(self, src)
615 }
616}
617
618#[cfg(test)]
619mod test {
620 use super::*;
621 use hyper_util::rt::TokioIo;
622 use std::io::Write;
623 use std::net::TcpListener;
624 use std::thread;
625 use tokio::net::TcpStream;
626 use tokio::pin;
627
628 pin_project! {
629 struct DelayStream {
630 #[pin]
631 sleep: Sleep,
632 }
633 }
634
635 impl DelayStream {
636 fn new(until: Instant) -> Self {
637 DelayStream {
638 sleep: sleep_until(until),
639 }
640 }
641 }
642
643 impl Read for DelayStream {
644 fn poll_read(
645 self: Pin<&mut Self>,
646 cx: &mut Context,
647 _buf: ReadBufCursor,
648 ) -> Poll<Result<(), io::Error>> {
649 match self.project().sleep.poll(cx) {
650 Poll::Ready(()) => Poll::Ready(Ok(())),
651 Poll::Pending => Poll::Pending,
652 }
653 }
654 }
655
656 impl hyper::rt::Write for DelayStream {
657 fn poll_write(
658 self: Pin<&mut Self>,
659 cx: &mut Context,
660 buf: &[u8],
661 ) -> Poll<Result<usize, io::Error>> {
662 match self.project().sleep.poll(cx) {
663 Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
664 Poll::Pending => Poll::Pending,
665 }
666 }
667
668 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
669 Poll::Ready(Ok(()))
670 }
671
672 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
673 Poll::Ready(Ok(()))
674 }
675 }
676
677 #[tokio::test]
678 async fn read_timeout() {
679 let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
680 let mut reader = TimeoutReader::new(reader);
681 reader.set_timeout(Some(Duration::from_millis(100)));
682 pin!(reader);
683
684 let r = reader.read(&mut [0, 1, 2]).await;
685 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
686 }
687
688 #[tokio::test]
689 async fn read_ok() {
690 let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
691 let mut reader = TimeoutReader::new(reader);
692 reader.set_timeout(Some(Duration::from_millis(500)));
693 pin!(reader);
694
695 reader.read(&mut [0]).await.unwrap();
696 }
697
698 #[tokio::test]
699 async fn write_timeout() {
700 let writer = DelayStream::new(Instant::now() + Duration::from_millis(500));
701 let mut writer = TimeoutWriter::new(writer);
702 writer.set_timeout(Some(Duration::from_millis(100)));
703 pin!(writer);
704
705 let r = writer.write(&[0]).await;
706 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
707 }
708
709 #[tokio::test]
710 async fn write_ok() {
711 let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
712 let mut writer = TimeoutWriter::new(writer);
713 writer.set_timeout(Some(Duration::from_millis(500)));
714 pin!(writer);
715
716 writer.write(&[0]).await.unwrap();
717 }
718
719 #[tokio::test]
720 async fn tcp_read() {
721 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
722 let addr = listener.local_addr().unwrap();
723
724 thread::spawn(move || {
725 let mut socket = listener.accept().unwrap().0;
726 thread::sleep(Duration::from_millis(10));
727 socket.write_all(b"f").unwrap();
728 thread::sleep(Duration::from_millis(500));
729 let _ = socket.write_all(b"f"); });
731
732 let s = TcpStream::connect(&addr).await.unwrap();
733 let s = TokioIo::new(s);
734 let mut s = TimeoutStream::new(s);
735 s.set_read_timeout(Some(Duration::from_millis(100)));
736 pin!(s);
737 s.read(&mut [0]).await.unwrap();
738 let r = s.read(&mut [0]).await;
739
740 match r {
741 Ok(_) => panic!("unexpected success"),
742 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
743 Err(e) => panic!("{:?}", e),
744 }
745 }
746}