tonic/transport/server/service/
io.rs

1use crate::transport::server::Connected;
2use std::io;
3use std::io::IoSlice;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7#[cfg(feature = "tls")]
8use tokio_rustls::server::TlsStream;
9
10pub(crate) enum ServerIo<IO> {
11    Io(IO),
12    #[cfg(feature = "tls")]
13    TlsIo(Box<TlsStream<IO>>),
14}
15
16use tower::util::Either;
17
18#[cfg(feature = "tls")]
19type ServerIoConnectInfo<IO> =
20    Either<<IO as Connected>::ConnectInfo, <TlsStream<IO> as Connected>::ConnectInfo>;
21
22#[cfg(not(feature = "tls"))]
23type ServerIoConnectInfo<IO> = Either<<IO as Connected>::ConnectInfo, ()>;
24
25impl<IO> ServerIo<IO> {
26    pub(in crate::transport) fn new_io(io: IO) -> Self {
27        Self::Io(io)
28    }
29
30    #[cfg(feature = "tls")]
31    pub(in crate::transport) fn new_tls_io(io: TlsStream<IO>) -> Self {
32        Self::TlsIo(Box::new(io))
33    }
34
35    pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo<IO>
36    where
37        IO: Connected,
38    {
39        match self {
40            Self::Io(io) => Either::A(io.connect_info()),
41            #[cfg(feature = "tls")]
42            Self::TlsIo(io) => Either::B(io.connect_info()),
43        }
44    }
45}
46
47impl<IO> AsyncRead for ServerIo<IO>
48where
49    IO: AsyncWrite + AsyncRead + Unpin,
50{
51    fn poll_read(
52        mut self: Pin<&mut Self>,
53        cx: &mut Context<'_>,
54        buf: &mut ReadBuf<'_>,
55    ) -> Poll<io::Result<()>> {
56        match &mut *self {
57            Self::Io(io) => Pin::new(io).poll_read(cx, buf),
58            #[cfg(feature = "tls")]
59            Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
60        }
61    }
62}
63
64impl<IO> AsyncWrite for ServerIo<IO>
65where
66    IO: AsyncWrite + AsyncRead + Unpin,
67{
68    fn poll_write(
69        mut self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &[u8],
72    ) -> Poll<io::Result<usize>> {
73        match &mut *self {
74            Self::Io(io) => Pin::new(io).poll_write(cx, buf),
75            #[cfg(feature = "tls")]
76            Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
77        }
78    }
79
80    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81        match &mut *self {
82            Self::Io(io) => Pin::new(io).poll_flush(cx),
83            #[cfg(feature = "tls")]
84            Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
85        }
86    }
87
88    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89        match &mut *self {
90            Self::Io(io) => Pin::new(io).poll_shutdown(cx),
91            #[cfg(feature = "tls")]
92            Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
93        }
94    }
95
96    fn poll_write_vectored(
97        mut self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99        bufs: &[IoSlice<'_>],
100    ) -> Poll<Result<usize, io::Error>> {
101        match &mut *self {
102            Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
103            #[cfg(feature = "tls")]
104            Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
105        }
106    }
107
108    fn is_write_vectored(&self) -> bool {
109        match self {
110            Self::Io(io) => io.is_write_vectored(),
111            #[cfg(feature = "tls")]
112            Self::TlsIo(io) => io.is_write_vectored(),
113        }
114    }
115}