tokio_postgres/
connect.rs

1use crate::client::{Addr, SocketConfig};
2use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::MakeTlsConnect;
6use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
7use futures_util::{future, pin_mut, Future, FutureExt, Stream};
8use rand::seq::SliceRandom;
9use std::task::Poll;
10use std::{cmp, io};
11use tokio::net;
12
13pub async fn connect<T>(
14    mut tls: T,
15    config: &Config,
16) -> Result<(Client, Connection<Socket, T::Stream>), Error>
17where
18    T: MakeTlsConnect<Socket>,
19{
20    if config.host.is_empty() && config.hostaddr.is_empty() {
21        return Err(Error::config("both host and hostaddr are missing".into()));
22    }
23
24    if !config.host.is_empty()
25        && !config.hostaddr.is_empty()
26        && config.host.len() != config.hostaddr.len()
27    {
28        let msg = format!(
29            "number of hosts ({}) is different from number of hostaddrs ({})",
30            config.host.len(),
31            config.hostaddr.len(),
32        );
33        return Err(Error::config(msg.into()));
34    }
35
36    // At this point, either one of the following two scenarios could happen:
37    // (1) either config.host or config.hostaddr must be empty;
38    // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal.
39    let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
40
41    if config.port.len() > 1 && config.port.len() != num_hosts {
42        return Err(Error::config("invalid number of ports".into()));
43    }
44
45    let mut indices = (0..num_hosts).collect::<Vec<_>>();
46    if config.load_balance_hosts == LoadBalanceHosts::Random {
47        indices.shuffle(&mut rand::thread_rng());
48    }
49
50    let mut error = None;
51    for i in indices {
52        let host = config.host.get(i);
53        let hostaddr = config.hostaddr.get(i);
54        let port = config
55            .port
56            .get(i)
57            .or_else(|| config.port.first())
58            .copied()
59            .unwrap_or(5432);
60
61        // The value of host is used as the hostname for TLS validation,
62        let hostname = match host {
63            Some(Host::Tcp(host)) => Some(host.clone()),
64            // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
65            #[cfg(unix)]
66            Some(Host::Unix(_)) => None,
67            None => None,
68        };
69
70        // Try to use the value of hostaddr to establish the TCP connection,
71        // fallback to host if hostaddr is not present.
72        let addr = match hostaddr {
73            Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
74            None => host.cloned().unwrap(),
75        };
76
77        match connect_host(addr, hostname, port, &mut tls, config).await {
78            Ok((client, connection)) => return Ok((client, connection)),
79            Err(e) => error = Some(e),
80        }
81    }
82
83    Err(error.unwrap())
84}
85
86async fn connect_host<T>(
87    host: Host,
88    hostname: Option<String>,
89    port: u16,
90    tls: &mut T,
91    config: &Config,
92) -> Result<(Client, Connection<Socket, T::Stream>), Error>
93where
94    T: MakeTlsConnect<Socket>,
95{
96    match host {
97        Host::Tcp(host) => {
98            let mut addrs = net::lookup_host((&*host, port))
99                .await
100                .map_err(Error::connect)?
101                .collect::<Vec<_>>();
102
103            if config.load_balance_hosts == LoadBalanceHosts::Random {
104                addrs.shuffle(&mut rand::thread_rng());
105            }
106
107            let mut last_err = None;
108            for addr in addrs {
109                match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
110                    .await
111                {
112                    Ok(stream) => return Ok(stream),
113                    Err(e) => {
114                        last_err = Some(e);
115                        continue;
116                    }
117                };
118            }
119
120            Err(last_err.unwrap_or_else(|| {
121                Error::connect(io::Error::new(
122                    io::ErrorKind::InvalidInput,
123                    "could not resolve any addresses",
124                ))
125            }))
126        }
127        #[cfg(unix)]
128        Host::Unix(path) => {
129            connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
130        }
131    }
132}
133
134async fn connect_once<T>(
135    addr: Addr,
136    hostname: Option<&str>,
137    port: u16,
138    tls: &mut T,
139    config: &Config,
140) -> Result<(Client, Connection<Socket, T::Stream>), Error>
141where
142    T: MakeTlsConnect<Socket>,
143{
144    let socket = connect_socket(
145        &addr,
146        port,
147        config.connect_timeout,
148        config.tcp_user_timeout,
149        if config.keepalives {
150            Some(&config.keepalive_config)
151        } else {
152            None
153        },
154    )
155    .await?;
156
157    let tls = tls
158        .make_tls_connect(hostname.unwrap_or(""))
159        .map_err(|e| Error::tls(e.into()))?;
160    let has_hostname = hostname.is_some();
161    let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
162
163    if config.target_session_attrs != TargetSessionAttrs::Any {
164        let rows = client.simple_query_raw("SHOW transaction_read_only");
165        pin_mut!(rows);
166
167        let rows = future::poll_fn(|cx| {
168            if connection.poll_unpin(cx)?.is_ready() {
169                return Poll::Ready(Err(Error::closed()));
170            }
171
172            rows.as_mut().poll(cx)
173        })
174        .await?;
175        pin_mut!(rows);
176
177        loop {
178            let next = future::poll_fn(|cx| {
179                if connection.poll_unpin(cx)?.is_ready() {
180                    return Poll::Ready(Some(Err(Error::closed())));
181                }
182
183                rows.as_mut().poll_next(cx)
184            });
185
186            match next.await.transpose()? {
187                Some(SimpleQueryMessage::Row(row)) => {
188                    let read_only_result = row.try_get(0)?;
189                    if read_only_result == Some("on")
190                        && config.target_session_attrs == TargetSessionAttrs::ReadWrite
191                    {
192                        return Err(Error::connect(io::Error::new(
193                            io::ErrorKind::PermissionDenied,
194                            "database does not allow writes",
195                        )));
196                    } else if read_only_result == Some("off")
197                        && config.target_session_attrs == TargetSessionAttrs::ReadOnly
198                    {
199                        return Err(Error::connect(io::Error::new(
200                            io::ErrorKind::PermissionDenied,
201                            "database is not read only",
202                        )));
203                    } else {
204                        break;
205                    }
206                }
207                Some(_) => {}
208                None => return Err(Error::unexpected_message()),
209            }
210        }
211    }
212
213    client.set_socket_config(SocketConfig {
214        addr,
215        hostname: hostname.map(|s| s.to_string()),
216        port,
217        connect_timeout: config.connect_timeout,
218        tcp_user_timeout: config.tcp_user_timeout,
219        keepalive: if config.keepalives {
220            Some(config.keepalive_config.clone())
221        } else {
222            None
223        },
224    });
225
226    Ok((client, connection))
227}