domain/net/client/
protocol.rs

1//! Underlying transport protocols.
2
3use core::future::Future;
4use core::pin::Pin;
5use std::boxed::Box;
6use std::io;
7use std::net::SocketAddr;
8use std::task::{Context, Poll};
9use tokio::io::ReadBuf;
10use tokio::net::{TcpStream, UdpSocket};
11
12/// How many times do we try a new random port if we get ‘address in use.’
13const RETRY_RANDOM_PORT: usize = 10;
14
15//------------ AsyncConnect --------------------------------------------------
16
17/// Establish a connection asynchronously.
18///
19///
20pub trait AsyncConnect {
21    /// The type of an established connection.
22    type Connection;
23
24    /// The future establishing the connection.
25    type Fut: Future<Output = Result<Self::Connection, io::Error>>
26        + Send
27        + Sync;
28
29    /// Returns a future that establishing a connection.
30    fn connect(&self) -> Self::Fut;
31}
32
33//------------ TcpConnect --------------------------------------------------
34
35/// Create new TCP connections.
36#[derive(Clone, Copy, Debug)]
37pub struct TcpConnect {
38    /// Remote address to connect to.
39    addr: SocketAddr,
40}
41
42impl TcpConnect {
43    /// Create new TCP connections.
44    ///
45    /// addr is the destination address to connect to.
46    pub fn new(addr: SocketAddr) -> Self {
47        Self { addr }
48    }
49}
50
51impl AsyncConnect for TcpConnect {
52    type Connection = TcpStream;
53    type Fut = Pin<
54        Box<
55            dyn Future<Output = Result<Self::Connection, std::io::Error>>
56                + Send
57                + Sync,
58        >,
59    >;
60
61    fn connect(&self) -> Self::Fut {
62        Box::pin(TcpStream::connect(self.addr))
63    }
64}
65
66//------------ TlsConnect -----------------------------------------------------
67
68/// Create new TLS connections
69#[cfg(feature = "tokio-rustls")]
70#[derive(Clone, Debug)]
71pub struct TlsConnect {
72    /// Configuration for setting up a TLS connection.
73    client_config: std::sync::Arc<tokio_rustls::rustls::ClientConfig>,
74
75    /// Server name for certificate verification.
76    server_name: tokio_rustls::rustls::pki_types::ServerName<'static>,
77
78    /// Remote address to connect to.
79    addr: SocketAddr,
80}
81
82#[cfg(feature = "tokio-rustls")]
83impl TlsConnect {
84    /// Function to create a new TLS connection stream
85    pub fn new<Conf>(
86        client_config: Conf,
87        server_name: tokio_rustls::rustls::pki_types::ServerName<'static>,
88        addr: SocketAddr,
89    ) -> Self
90    where
91        Conf: Into<std::sync::Arc<tokio_rustls::rustls::ClientConfig>>,
92    {
93        Self {
94            client_config: client_config.into(),
95            server_name,
96            addr,
97        }
98    }
99}
100
101#[cfg(feature = "tokio-rustls")]
102impl AsyncConnect for TlsConnect {
103    type Connection = tokio_rustls::client::TlsStream<TcpStream>;
104    type Fut = Pin<
105        Box<
106            dyn Future<Output = Result<Self::Connection, std::io::Error>>
107                + Send
108                + Sync,
109        >,
110    >;
111
112    fn connect(&self) -> Self::Fut {
113        let tls_connection =
114            tokio_rustls::TlsConnector::from(self.client_config.clone());
115        let server_name = self.server_name.clone();
116        let addr = self.addr;
117        Box::pin(async move {
118            let box_connection = Box::new(tls_connection);
119            let tcp = TcpStream::connect(addr).await?;
120            box_connection.connect(server_name, tcp).await
121        })
122    }
123}
124
125//------------ UdpConnect --------------------------------------------------
126
127/// Create new UDP connections.
128#[derive(Clone, Copy, Debug)]
129pub struct UdpConnect {
130    /// Remote address to connect to.
131    addr: SocketAddr,
132}
133
134impl UdpConnect {
135    /// Create new UDP connections.
136    ///
137    /// addr is the destination address to connect to.
138    pub fn new(addr: SocketAddr) -> Self {
139        Self { addr }
140    }
141
142    /// Bind to a random local UDP port.
143    async fn bind_and_connect(self) -> Result<UdpSocket, io::Error> {
144        let mut i = 0;
145        let sock = loop {
146            let local: SocketAddr = if self.addr.is_ipv4() {
147                ([0u8; 4], 0).into()
148            } else {
149                ([0u16; 8], 0).into()
150            };
151            match UdpSocket::bind(&local).await {
152                Ok(sock) => break sock,
153                Err(err) => {
154                    if i == RETRY_RANDOM_PORT {
155                        return Err(err);
156                    } else {
157                        i += 1
158                    }
159                }
160            }
161        };
162        sock.connect(self.addr).await?;
163        Ok(sock)
164    }
165}
166
167impl AsyncConnect for UdpConnect {
168    type Connection = UdpSocket;
169    type Fut = Pin<
170        Box<
171            dyn Future<Output = Result<Self::Connection, std::io::Error>>
172                + Send
173                + Sync,
174        >,
175    >;
176
177    fn connect(&self) -> Self::Fut {
178        Box::pin(self.bind_and_connect())
179    }
180}
181
182//------------ AsyncDgramRecv -------------------------------------------------
183
184/// Receive a datagram packets asynchronously.
185pub trait AsyncDgramRecv {
186    /// Polled receive.
187    fn poll_recv(
188        &self,
189        cx: &mut Context<'_>,
190        buf: &mut ReadBuf<'_>,
191    ) -> Poll<Result<(), io::Error>>;
192}
193
194impl AsyncDgramRecv for UdpSocket {
195    fn poll_recv(
196        &self,
197        cx: &mut Context<'_>,
198        buf: &mut ReadBuf<'_>,
199    ) -> Poll<Result<(), io::Error>> {
200        UdpSocket::poll_recv(self, cx, buf)
201    }
202}
203
204//------------ AsyncDgramRecvEx -----------------------------------------------
205
206/// Convenvience trait to turn poll_recv into an asynchronous function.
207pub trait AsyncDgramRecvEx: AsyncDgramRecv {
208    /// Asynchronous receive function.
209    fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> DgramRecv<'a, Self>
210    where
211        Self: Unpin,
212    {
213        DgramRecv {
214            receiver: self,
215            buf,
216        }
217    }
218}
219
220impl<R: AsyncDgramRecv> AsyncDgramRecvEx for R {}
221
222//------------ DgramRecv -----------------------------------------------------
223
224/// Return value of recv. This captures the future for recv.
225pub struct DgramRecv<'a, R: ?Sized> {
226    /// The receiver of the datagram.
227    receiver: &'a R,
228
229    /// Buffer to store the datagram.
230    buf: &'a mut [u8],
231}
232
233impl<R: AsyncDgramRecv + Unpin> Future for DgramRecv<'_, R> {
234    type Output = io::Result<usize>;
235
236    fn poll(
237        mut self: Pin<&mut Self>,
238        cx: &mut Context<'_>,
239    ) -> Poll<io::Result<usize>> {
240        let receiver = self.receiver;
241        let mut buf = ReadBuf::new(self.buf);
242        match Pin::new(receiver).poll_recv(cx, &mut buf) {
243            Poll::Pending => return Poll::Pending,
244            Poll::Ready(res) => {
245                if let Err(err) = res {
246                    return Poll::Ready(Err(err));
247                }
248            }
249        }
250        Poll::Ready(Ok(buf.filled().len()))
251    }
252}
253
254//------------ AsyncDgramSend -------------------------------------------------
255
256/// Send a datagram packet asynchronously.
257///
258///
259pub trait AsyncDgramSend {
260    /// Polled send function.
261    fn poll_send(
262        &self,
263        cx: &mut Context<'_>,
264        buf: &[u8],
265    ) -> Poll<Result<usize, io::Error>>;
266}
267
268impl AsyncDgramSend for UdpSocket {
269    fn poll_send(
270        &self,
271        cx: &mut Context<'_>,
272        buf: &[u8],
273    ) -> Poll<Result<usize, io::Error>> {
274        UdpSocket::poll_send(self, cx, buf)
275    }
276}
277
278//------------ AsyncDgramSendEx ----------------------------------------------
279
280/// Convenience trait that turns poll_send into an asynchronous function.
281pub trait AsyncDgramSendEx: AsyncDgramSend {
282    /// Asynchronous function to send a packet.
283    fn send<'a>(&'a self, buf: &'a [u8]) -> DgramSend<'a, Self>
284    where
285        Self: Unpin,
286    {
287        DgramSend { sender: self, buf }
288    }
289}
290
291impl<S: AsyncDgramSend> AsyncDgramSendEx for S {}
292
293//------------ DgramSend -----------------------------------------------------
294
295/// This is the return value of send. It captures the future for send.
296pub struct DgramSend<'a, S: ?Sized> {
297    /// The datagram send object.
298    sender: &'a S,
299
300    /// The buffer that needs to be sent.
301    buf: &'a [u8],
302}
303
304impl<S: AsyncDgramSend + Unpin> Future for DgramSend<'_, S> {
305    type Output = io::Result<usize>;
306
307    fn poll(
308        self: Pin<&mut Self>,
309        cx: &mut Context<'_>,
310    ) -> Poll<io::Result<usize>> {
311        Pin::new(self.sender).poll_send(cx, self.buf)
312    }
313}