tonic/transport/server/service/
io.rs
1use crate::transport::server::Connected;
2use std::io;
3use std::io::IoSlice;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7#[cfg(feature = "tls")]
8use tokio_rustls::server::TlsStream;
9
10pub(crate) enum ServerIo<IO> {
11 Io(IO),
12 #[cfg(feature = "tls")]
13 TlsIo(Box<TlsStream<IO>>),
14}
15
16use tower::util::Either;
17
18#[cfg(feature = "tls")]
19type ServerIoConnectInfo<IO> =
20 Either<<IO as Connected>::ConnectInfo, <TlsStream<IO> as Connected>::ConnectInfo>;
21
22#[cfg(not(feature = "tls"))]
23type ServerIoConnectInfo<IO> = Either<<IO as Connected>::ConnectInfo, ()>;
24
25impl<IO> ServerIo<IO> {
26 pub(in crate::transport) fn new_io(io: IO) -> Self {
27 Self::Io(io)
28 }
29
30 #[cfg(feature = "tls")]
31 pub(in crate::transport) fn new_tls_io(io: TlsStream<IO>) -> Self {
32 Self::TlsIo(Box::new(io))
33 }
34
35 pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo<IO>
36 where
37 IO: Connected,
38 {
39 match self {
40 Self::Io(io) => Either::A(io.connect_info()),
41 #[cfg(feature = "tls")]
42 Self::TlsIo(io) => Either::B(io.connect_info()),
43 }
44 }
45}
46
47impl<IO> AsyncRead for ServerIo<IO>
48where
49 IO: AsyncWrite + AsyncRead + Unpin,
50{
51 fn poll_read(
52 mut self: Pin<&mut Self>,
53 cx: &mut Context<'_>,
54 buf: &mut ReadBuf<'_>,
55 ) -> Poll<io::Result<()>> {
56 match &mut *self {
57 Self::Io(io) => Pin::new(io).poll_read(cx, buf),
58 #[cfg(feature = "tls")]
59 Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
60 }
61 }
62}
63
64impl<IO> AsyncWrite for ServerIo<IO>
65where
66 IO: AsyncWrite + AsyncRead + Unpin,
67{
68 fn poll_write(
69 mut self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &[u8],
72 ) -> Poll<io::Result<usize>> {
73 match &mut *self {
74 Self::Io(io) => Pin::new(io).poll_write(cx, buf),
75 #[cfg(feature = "tls")]
76 Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
77 }
78 }
79
80 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 match &mut *self {
82 Self::Io(io) => Pin::new(io).poll_flush(cx),
83 #[cfg(feature = "tls")]
84 Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
85 }
86 }
87
88 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89 match &mut *self {
90 Self::Io(io) => Pin::new(io).poll_shutdown(cx),
91 #[cfg(feature = "tls")]
92 Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
93 }
94 }
95
96 fn poll_write_vectored(
97 mut self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 bufs: &[IoSlice<'_>],
100 ) -> Poll<Result<usize, io::Error>> {
101 match &mut *self {
102 Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
103 #[cfg(feature = "tls")]
104 Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
105 }
106 }
107
108 fn is_write_vectored(&self) -> bool {
109 match self {
110 Self::Io(io) => io.is_write_vectored(),
111 #[cfg(feature = "tls")]
112 Self::TlsIo(io) => io.is_write_vectored(),
113 }
114 }
115}