tokio_postgres/
connect_socket.rs

1use crate::client::Addr;
2use crate::keepalive::KeepaliveConfig;
3use crate::{Error, Socket};
4use socket2::{SockRef, TcpKeepalive};
5use std::future::Future;
6use std::io;
7use std::time::Duration;
8use tokio::net::TcpStream;
9#[cfg(unix)]
10use tokio::net::UnixStream;
11use tokio::time;
12
13pub(crate) async fn connect_socket(
14    addr: &Addr,
15    port: u16,
16    connect_timeout: Option<Duration>,
17    #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option<
18        Duration,
19    >,
20    keepalive_config: Option<&KeepaliveConfig>,
21) -> Result<Socket, Error> {
22    match addr {
23        Addr::Tcp(ip) => {
24            let stream =
25                connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?;
26
27            stream.set_nodelay(true).map_err(Error::connect)?;
28
29            let sock_ref = SockRef::from(&stream);
30            #[cfg(target_os = "linux")]
31            {
32                sock_ref
33                    .set_tcp_user_timeout(tcp_user_timeout)
34                    .map_err(Error::connect)?;
35            }
36
37            if let Some(keepalive_config) = keepalive_config {
38                sock_ref
39                    .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config))
40                    .map_err(Error::connect)?;
41            }
42
43            Ok(Socket::new_tcp(stream))
44        }
45        #[cfg(unix)]
46        Addr::Unix(dir) => {
47            let path = dir.join(format!(".s.PGSQL.{}", port));
48            let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?;
49            Ok(Socket::new_unix(socket))
50        }
51    }
52}
53
54async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
55where
56    F: Future<Output = io::Result<T>>,
57{
58    match timeout {
59        Some(timeout) => match time::timeout(timeout, connect).await {
60            Ok(Ok(socket)) => Ok(socket),
61            Ok(Err(e)) => Err(Error::connect(e)),
62            Err(_) => Err(Error::connect(io::Error::new(
63                io::ErrorKind::TimedOut,
64                "connection timed out",
65            ))),
66        },
67        None => match connect.await {
68            Ok(socket) => Ok(socket),
69            Err(e) => Err(Error::connect(e)),
70        },
71    }
72}