#![warn(missing_docs)]
use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use tokio::time::{sleep_until, Instant, Sleep};
pin_project! {
#[derive(Debug)]
struct TimeoutState {
timeout: Option<Duration>,
#[pin]
cur: Sleep,
active: bool,
}
}
impl TimeoutState {
#[inline]
fn new() -> TimeoutState {
TimeoutState {
timeout: None,
cur: sleep_until(Instant::now()),
active: false,
}
}
#[inline]
fn timeout(&self) -> Option<Duration> {
self.timeout
}
#[inline]
fn set_timeout(&mut self, timeout: Option<Duration>) {
self.timeout = timeout;
}
#[inline]
fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
*self.as_mut().project().timeout = timeout;
self.reset();
}
#[inline]
fn reset(self: Pin<&mut Self>) {
let this = self.project();
if *this.active {
*this.active = false;
this.cur.reset(Instant::now());
}
}
#[inline]
fn poll_check(self: Pin<&mut Self>, cx: &mut Context) -> io::Result<()> {
let mut this = self.project();
let timeout = match this.timeout {
Some(timeout) => *timeout,
None => return Ok(()),
};
if !*this.active {
this.cur.as_mut().reset(Instant::now() + timeout);
*this.active = true;
}
match this.cur.poll(cx) {
Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
Poll::Pending => Ok(()),
}
}
}
pin_project! {
#[derive(Debug)]
pub struct TimeoutReader<R> {
#[pin]
reader: R,
#[pin]
state: TimeoutState,
}
}
impl<R> TimeoutReader<R>
where
R: Read,
{
pub fn new(reader: R) -> TimeoutReader<R> {
TimeoutReader {
reader,
state: TimeoutState::new(),
}
}
pub fn timeout(&self) -> Option<Duration> {
self.state.timeout()
}
pub fn set_timeout(&mut self, timeout: Option<Duration>) {
self.state.set_timeout(timeout);
}
pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
self.project().state.set_timeout_pinned(timeout);
}
pub fn get_ref(&self) -> &R {
&self.reader
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
self.project().reader
}
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R> Read for TimeoutReader<R>
where
R: Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: ReadBufCursor,
) -> Poll<Result<(), io::Error>> {
let this = self.project();
let r = this.reader.poll_read(cx, buf);
match r {
Poll::Pending => this.state.poll_check(cx)?,
_ => this.state.reset(),
}
r
}
}
impl<R> Write for TimeoutReader<R>
where
R: Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().reader.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().reader.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().reader.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[io::IoSlice],
) -> Poll<io::Result<usize>> {
self.project().reader.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.reader.is_write_vectored()
}
}
pin_project! {
#[derive(Debug)]
pub struct TimeoutWriter<W> {
#[pin]
writer: W,
#[pin]
state: TimeoutState,
}
}
impl<W> TimeoutWriter<W>
where
W: Write,
{
pub fn new(writer: W) -> TimeoutWriter<W> {
TimeoutWriter {
writer,
state: TimeoutState::new(),
}
}
pub fn timeout(&self) -> Option<Duration> {
self.state.timeout()
}
pub fn set_timeout(&mut self, timeout: Option<Duration>) {
self.state.set_timeout(timeout);
}
pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
self.project().state.set_timeout_pinned(timeout);
}
pub fn get_ref(&self) -> &W {
&self.writer
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.writer
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
self.project().writer
}
pub fn into_inner(self) -> W {
self.writer
}
}
impl<W> Write for TimeoutWriter<W>
where
W: Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
let r = this.writer.poll_write(cx, buf);
match r {
Poll::Pending => this.state.poll_check(cx)?,
_ => this.state.reset(),
}
r
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
let r = this.writer.poll_flush(cx);
match r {
Poll::Pending => this.state.poll_check(cx)?,
_ => this.state.reset(),
}
r
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
let r = this.writer.poll_shutdown(cx);
match r {
Poll::Pending => this.state.poll_check(cx)?,
_ => this.state.reset(),
}
r
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[io::IoSlice],
) -> Poll<io::Result<usize>> {
let this = self.project();
let r = this.writer.poll_write_vectored(cx, bufs);
match r {
Poll::Pending => this.state.poll_check(cx)?,
_ => this.state.reset(),
}
r
}
fn is_write_vectored(&self) -> bool {
self.writer.is_write_vectored()
}
}
impl<W> Read for TimeoutWriter<W>
where
W: Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: ReadBufCursor,
) -> Poll<Result<(), io::Error>> {
self.project().writer.poll_read(cx, buf)
}
}
pin_project! {
#[derive(Debug)]
pub struct TimeoutStream<S> {
#[pin]
stream: TimeoutReader<TimeoutWriter<S>>
}
}
impl<S> TimeoutStream<S>
where
S: Read + Write,
{
pub fn new(stream: S) -> TimeoutStream<S> {
let writer = TimeoutWriter::new(stream);
let stream = TimeoutReader::new(writer);
TimeoutStream { stream }
}
pub fn read_timeout(&self) -> Option<Duration> {
self.stream.timeout()
}
pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.stream.set_timeout(timeout)
}
pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
self.project().stream.set_timeout_pinned(timeout)
}
pub fn write_timeout(&self) -> Option<Duration> {
self.stream.get_ref().timeout()
}
pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
self.stream.get_mut().set_timeout(timeout)
}
pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
self.project()
.stream
.get_pin_mut()
.set_timeout_pinned(timeout)
}
pub fn get_ref(&self) -> &S {
self.stream.get_ref().get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut().get_mut()
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
self.project().stream.get_pin_mut().get_pin_mut()
}
pub fn into_inner(self) -> S {
self.stream.into_inner().into_inner()
}
}
impl<S> Read for TimeoutStream<S>
where
S: Read + Write,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: ReadBufCursor,
) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_read(cx, buf)
}
}
impl<S> Write for TimeoutStream<S>
where
S: Read + Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().stream.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[io::IoSlice],
) -> Poll<io::Result<usize>> {
self.project().stream.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
}
impl<S> Connection for TimeoutStream<S>
where
S: Read + Write + Connection + Unpin,
{
fn connected(&self) -> Connected {
self.get_ref().connected()
}
}
impl<S> Connection for Pin<Box<TimeoutStream<S>>>
where
S: Read + Write + Connection + Unpin,
{
fn connected(&self) -> Connected {
self.get_ref().connected()
}
}
pin_project! {
struct ReadFut<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut [u8],
}
}
fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R>
where
R: Read + Unpin + ?Sized,
{
ReadFut { reader, buf }
}
impl<R> Future for ReadFut<'_, R>
where
R: Read + Unpin + ?Sized,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let me = self.project();
let mut buf = ReadBuf::new(me.buf);
ready!(Pin::new(me.reader).poll_read(cx, buf.unfilled()))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
trait ReadExt: Read {
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self>
where
Self: Unpin,
{
read(self, buf)
}
}
pin_project! {
struct WriteFut<'a, W: ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
}
}
fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W>
where
W: Write + Unpin + ?Sized,
{
WriteFut { writer, buf }
}
impl<W> Future for WriteFut<'_, W>
where
W: Write + Unpin + ?Sized,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let me = self.project();
Pin::new(&mut *me.writer).poll_write(cx, me.buf)
}
}
trait WriteExt: Write {
fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self>
where
Self: Unpin,
{
write(self, src)
}
}
impl<R> ReadExt for Pin<&mut TimeoutReader<R>>
where
R: Read,
{
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
read(self, buf)
}
}
impl<W> WriteExt for Pin<&mut TimeoutWriter<W>>
where
W: Write,
{
fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
write(self, src)
}
}
impl<S> ReadExt for Pin<&mut TimeoutStream<S>>
where
S: Read + Write,
{
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
read(self, buf)
}
}
impl<S> WriteExt for Pin<&mut TimeoutStream<S>>
where
S: Read + Write,
{
fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
write(self, src)
}
}
#[cfg(test)]
mod test {
use super::*;
use hyper_util::rt::TokioIo;
use std::io::Write;
use std::net::TcpListener;
use std::thread;
use tokio::net::TcpStream;
use tokio::pin;
pin_project! {
struct DelayStream {
#[pin]
sleep: Sleep,
}
}
impl DelayStream {
fn new(until: Instant) -> Self {
DelayStream {
sleep: sleep_until(until),
}
}
}
impl Read for DelayStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
_buf: ReadBufCursor,
) -> Poll<Result<(), io::Error>> {
match self.project().sleep.poll(cx) {
Poll::Ready(()) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}
impl hyper::rt::Write for DelayStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.project().sleep.poll(cx) {
Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn read_timeout() {
let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
let mut reader = TimeoutReader::new(reader);
reader.set_timeout(Some(Duration::from_millis(100)));
pin!(reader);
let r = reader.read(&mut [0, 1, 2]).await;
assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
}
#[tokio::test]
async fn read_ok() {
let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
let mut reader = TimeoutReader::new(reader);
reader.set_timeout(Some(Duration::from_millis(500)));
pin!(reader);
reader.read(&mut [0]).await.unwrap();
}
#[tokio::test]
async fn write_timeout() {
let writer = DelayStream::new(Instant::now() + Duration::from_millis(500));
let mut writer = TimeoutWriter::new(writer);
writer.set_timeout(Some(Duration::from_millis(100)));
pin!(writer);
let r = writer.write(&[0]).await;
assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
}
#[tokio::test]
async fn write_ok() {
let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
let mut writer = TimeoutWriter::new(writer);
writer.set_timeout(Some(Duration::from_millis(500)));
pin!(writer);
writer.write(&[0]).await.unwrap();
}
#[tokio::test]
async fn tcp_read() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || {
let mut socket = listener.accept().unwrap().0;
thread::sleep(Duration::from_millis(10));
socket.write_all(b"f").unwrap();
thread::sleep(Duration::from_millis(500));
let _ = socket.write_all(b"f"); });
let s = TcpStream::connect(&addr).await.unwrap();
let s = TokioIo::new(s);
let mut s = TimeoutStream::new(s);
s.set_read_timeout(Some(Duration::from_millis(100)));
pin!(s);
s.read(&mut [0]).await.unwrap();
let r = s.read(&mut [0]).await;
match r {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
Err(e) => panic!("{:?}", e),
}
}
}