1use crate::client::{Addr, SocketConfig};
2use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::MakeTlsConnect;
6use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
7use futures_util::{future, pin_mut, Future, FutureExt, Stream};
8use rand::seq::SliceRandom;
9use std::task::Poll;
10use std::{cmp, io};
11use tokio::net;
12
13pub async fn connect<T>(
14 mut tls: T,
15 config: &Config,
16) -> Result<(Client, Connection<Socket, T::Stream>), Error>
17where
18 T: MakeTlsConnect<Socket>,
19{
20 if config.host.is_empty() && config.hostaddr.is_empty() {
21 return Err(Error::config("both host and hostaddr are missing".into()));
22 }
23
24 if !config.host.is_empty()
25 && !config.hostaddr.is_empty()
26 && config.host.len() != config.hostaddr.len()
27 {
28 let msg = format!(
29 "number of hosts ({}) is different from number of hostaddrs ({})",
30 config.host.len(),
31 config.hostaddr.len(),
32 );
33 return Err(Error::config(msg.into()));
34 }
35
36 let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
40
41 if config.port.len() > 1 && config.port.len() != num_hosts {
42 return Err(Error::config("invalid number of ports".into()));
43 }
44
45 let mut indices = (0..num_hosts).collect::<Vec<_>>();
46 if config.load_balance_hosts == LoadBalanceHosts::Random {
47 indices.shuffle(&mut rand::thread_rng());
48 }
49
50 let mut error = None;
51 for i in indices {
52 let host = config.host.get(i);
53 let hostaddr = config.hostaddr.get(i);
54 let port = config
55 .port
56 .get(i)
57 .or_else(|| config.port.first())
58 .copied()
59 .unwrap_or(5432);
60
61 let hostname = match host {
63 Some(Host::Tcp(host)) => Some(host.clone()),
64 #[cfg(unix)]
66 Some(Host::Unix(_)) => None,
67 None => None,
68 };
69
70 let addr = match hostaddr {
73 Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
74 None => host.cloned().unwrap(),
75 };
76
77 match connect_host(addr, hostname, port, &mut tls, config).await {
78 Ok((client, connection)) => return Ok((client, connection)),
79 Err(e) => error = Some(e),
80 }
81 }
82
83 Err(error.unwrap())
84}
85
86async fn connect_host<T>(
87 host: Host,
88 hostname: Option<String>,
89 port: u16,
90 tls: &mut T,
91 config: &Config,
92) -> Result<(Client, Connection<Socket, T::Stream>), Error>
93where
94 T: MakeTlsConnect<Socket>,
95{
96 match host {
97 Host::Tcp(host) => {
98 let mut addrs = net::lookup_host((&*host, port))
99 .await
100 .map_err(Error::connect)?
101 .collect::<Vec<_>>();
102
103 if config.load_balance_hosts == LoadBalanceHosts::Random {
104 addrs.shuffle(&mut rand::thread_rng());
105 }
106
107 let mut last_err = None;
108 for addr in addrs {
109 match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
110 .await
111 {
112 Ok(stream) => return Ok(stream),
113 Err(e) => {
114 last_err = Some(e);
115 continue;
116 }
117 };
118 }
119
120 Err(last_err.unwrap_or_else(|| {
121 Error::connect(io::Error::new(
122 io::ErrorKind::InvalidInput,
123 "could not resolve any addresses",
124 ))
125 }))
126 }
127 #[cfg(unix)]
128 Host::Unix(path) => {
129 connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
130 }
131 }
132}
133
134async fn connect_once<T>(
135 addr: Addr,
136 hostname: Option<&str>,
137 port: u16,
138 tls: &mut T,
139 config: &Config,
140) -> Result<(Client, Connection<Socket, T::Stream>), Error>
141where
142 T: MakeTlsConnect<Socket>,
143{
144 let socket = connect_socket(
145 &addr,
146 port,
147 config.connect_timeout,
148 config.tcp_user_timeout,
149 if config.keepalives {
150 Some(&config.keepalive_config)
151 } else {
152 None
153 },
154 )
155 .await?;
156
157 let tls = tls
158 .make_tls_connect(hostname.unwrap_or(""))
159 .map_err(|e| Error::tls(e.into()))?;
160 let has_hostname = hostname.is_some();
161 let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
162
163 if config.target_session_attrs != TargetSessionAttrs::Any {
164 let rows = client.simple_query_raw("SHOW transaction_read_only");
165 pin_mut!(rows);
166
167 let rows = future::poll_fn(|cx| {
168 if connection.poll_unpin(cx)?.is_ready() {
169 return Poll::Ready(Err(Error::closed()));
170 }
171
172 rows.as_mut().poll(cx)
173 })
174 .await?;
175 pin_mut!(rows);
176
177 loop {
178 let next = future::poll_fn(|cx| {
179 if connection.poll_unpin(cx)?.is_ready() {
180 return Poll::Ready(Some(Err(Error::closed())));
181 }
182
183 rows.as_mut().poll_next(cx)
184 });
185
186 match next.await.transpose()? {
187 Some(SimpleQueryMessage::Row(row)) => {
188 let read_only_result = row.try_get(0)?;
189 if read_only_result == Some("on")
190 && config.target_session_attrs == TargetSessionAttrs::ReadWrite
191 {
192 return Err(Error::connect(io::Error::new(
193 io::ErrorKind::PermissionDenied,
194 "database does not allow writes",
195 )));
196 } else if read_only_result == Some("off")
197 && config.target_session_attrs == TargetSessionAttrs::ReadOnly
198 {
199 return Err(Error::connect(io::Error::new(
200 io::ErrorKind::PermissionDenied,
201 "database is not read only",
202 )));
203 } else {
204 break;
205 }
206 }
207 Some(_) => {}
208 None => return Err(Error::unexpected_message()),
209 }
210 }
211 }
212
213 client.set_socket_config(SocketConfig {
214 addr,
215 hostname: hostname.map(|s| s.to_string()),
216 port,
217 connect_timeout: config.connect_timeout,
218 tcp_user_timeout: config.tcp_user_timeout,
219 keepalive: if config.keepalives {
220 Some(config.keepalive_config.clone())
221 } else {
222 None
223 },
224 });
225
226 Ok((client, connection))
227}