1use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::netio::DUMMY_DNS_PORT;
17use mz_ore::option::OptionExt;
18use mz_ore::task;
19use mz_repr::CatalogItemId;
20use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
21use mz_ssh_util::tunnel_manager::SshTunnelManager;
22use tokio::io::{AsyncRead, AsyncWrite};
23use tokio::net::TcpStream as TokioTcpStream;
24use tokio_postgres::config::{Host, ReplicationMode};
25use tokio_postgres::tls::MakeTlsConnect;
26use tracing::{info, warn};
27
28use crate::PostgresError;
29
30macro_rules! bail_generic {
31 ($fmt:expr, $($arg:tt)*) => {
32 return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
33 };
34 ($err:expr $(,)?) => {
35 return Err(PostgresError::Generic(anyhow::anyhow!($err)))
36 };
37}
38
39#[derive(Debug, PartialEq, Clone)]
42pub enum TunnelConfig {
43 Direct {
47 resolved_ips: Option<BTreeSet<IpAddr>>,
48 },
49 Ssh { config: SshTunnelConfig },
55 AwsPrivatelink {
58 connection_id: CatalogItemId,
60 },
61}
62
63pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
64
65pub struct Client {
67 inner: tokio_postgres::Client,
68 server_version: Option<String>,
69}
70
71impl Client {
72 fn new<S, T>(
73 client: tokio_postgres::Client,
74 connection: &tokio_postgres::Connection<S, T>,
75 ) -> Client
76 where
77 S: AsyncRead + AsyncWrite + Unpin,
78 T: AsyncRead + AsyncWrite + Unpin,
79 {
80 let server_version = connection
81 .parameter("server_version")
82 .map(|v| v.to_string());
83 Client {
84 inner: client,
85 server_version,
86 }
87 }
88
89 pub fn server_version(&self) -> Option<&str> {
92 self.server_version.as_deref()
93 }
94}
95
96impl Deref for Client {
97 type Target = tokio_postgres::Client;
98
99 fn deref(&self) -> &Self::Target {
100 &self.inner
101 }
102}
103
104impl DerefMut for Client {
105 fn deref_mut(&mut self) -> &mut Self::Target {
106 &mut self.inner
107 }
108}
109
110#[derive(Clone, Debug)]
115pub struct Config {
116 inner: tokio_postgres::Config,
117 tunnel: TunnelConfig,
118 in_task: InTask,
119 ssh_timeout_config: SshTimeoutConfig,
120}
121
122impl Config {
123 pub fn new(
124 inner: tokio_postgres::Config,
125 tunnel: TunnelConfig,
126 ssh_timeout_config: SshTimeoutConfig,
127 in_task: InTask,
128 ) -> Result<Self, PostgresError> {
129 let config = Self {
130 inner,
131 tunnel,
132 in_task,
133 ssh_timeout_config,
134 };
135
136 config.address()?;
139
140 Ok(config)
141 }
142
143 pub async fn connect(
145 &self,
146 task_name: &str,
147 ssh_tunnel_manager: &SshTunnelManager,
148 ) -> Result<Client, PostgresError> {
149 self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
150 .await
151 }
152
153 pub async fn connect_replication(
155 &self,
156 ssh_tunnel_manager: &SshTunnelManager,
157 ) -> Result<Client, PostgresError> {
158 self.connect_traced(
159 "postgres_connect_replication",
160 |config| {
161 config.replication_mode(ReplicationMode::Logical);
162 },
163 ssh_tunnel_manager,
164 )
165 .await
166 }
167
168 fn address(&self) -> Result<(&str, u16), PostgresError> {
169 match (self.inner.get_hosts(), self.inner.get_ports()) {
170 ([Host::Tcp(host)], [port]) => Ok((host, *port)),
171 _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
172 }
173 }
174
175 async fn connect_traced<F>(
176 &self,
177 task_name: &str,
178 configure: F,
179 ssh_tunnel_manager: &SshTunnelManager,
180 ) -> Result<Client, PostgresError>
181 where
182 F: FnOnce(&mut tokio_postgres::Config),
183 {
184 let (host, port) = self.address()?;
185 let address = format!(
186 "{}@{}:{}/{}",
187 self.get_user().display_or("<unknown-user>"),
188 host,
189 port,
190 self.get_dbname().display_or("<unknown-dbname>")
191 );
192 info!(%task_name, %address, "connecting");
193 match self
194 .connect_internal(task_name, configure, ssh_tunnel_manager)
195 .await
196 {
197 Ok(t) => {
198 let backend_pid = t.backend_pid();
199 info!(%task_name, %address, %backend_pid, "connected");
200 Ok(t)
201 }
202 Err(e) => {
203 warn!(%task_name, %address, "connection failed: {e:#}");
204 Err(e)
205 }
206 }
207 }
208
209 async fn connect_internal<F>(
210 &self,
211 task_name: &str,
212 configure: F,
213 ssh_tunnel_manager: &SshTunnelManager,
214 ) -> Result<Client, PostgresError>
215 where
216 F: FnOnce(&mut tokio_postgres::Config),
217 {
218 let mut postgres_config = self.inner.clone();
219 configure(&mut postgres_config);
220
221 let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
222 mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
223 mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
224 })?;
225
226 match &self.tunnel {
227 TunnelConfig::Direct { resolved_ips } => {
228 if let Some(ips) = resolved_ips {
229 let host = match postgres_config.get_hosts() {
230 [Host::Tcp(host)] => host,
231 _ => bail_generic!(
232 "only TCP connections to a single PostgreSQL server are supported"
233 ),
234 }
235 .to_owned();
236 for (idx, ip) in ips.iter().enumerate() {
240 if idx != 0 {
241 postgres_config.host(&host);
242 }
243 postgres_config.hostaddr(ip.clone());
244 }
245 };
246
247 let (client, connection) = async move { postgres_config.connect(tls).await }
248 .run_in_task_if(self.in_task, || "pg_connect".to_string())
249 .await?;
250 let client = Client::new(client, &connection);
251 task::spawn(|| task_name, connection);
252 Ok(client)
253 }
254 TunnelConfig::Ssh { config } => {
255 let (host, port) = self.address()?;
256 let tunnel = ssh_tunnel_manager
257 .connect(
258 config.clone(),
259 host,
260 port,
261 self.ssh_timeout_config,
262 self.in_task,
263 )
264 .await
265 .map_err(PostgresError::Ssh)?;
266
267 let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
268 let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
269 .await
270 .map_err(PostgresError::SshIo)?;
271 let (client, connection) =
279 async move { postgres_config.connect_raw(tcp_stream, tls).await }
280 .run_in_task_if(self.in_task, || "pg_connect".to_string())
281 .await?;
282 let client = Client::new(client, &connection);
283 task::spawn(|| task_name, async {
284 let _tunnel = tunnel; if let Err(e) = connection.await {
287 warn!("postgres connection failed: {e}");
288 }
289 });
290 Ok(client)
291 }
292 TunnelConfig::AwsPrivatelink { connection_id } => {
293 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
300 let privatelink_addrs =
301 tokio::net::lookup_host((privatelink_host, DUMMY_DNS_PORT)).await?;
302
303 let host = match postgres_config.get_hosts() {
306 [Host::Tcp(host)] => host,
307 _ => bail_generic!(
308 "only TCP connections to a single PostgreSQL server are supported"
309 ),
310 }
311 .to_owned();
312 for (idx, addr) in privatelink_addrs.enumerate() {
316 if idx != 0 {
317 postgres_config.host(&host);
318 }
319 postgres_config.hostaddr(addr.ip());
320 }
321
322 let (client, connection) = async move { postgres_config.connect(tls).await }
323 .run_in_task_if(self.in_task, || "pg_connect".to_string())
324 .await?;
325 let client = Client::new(client, &connection);
326 task::spawn(|| task_name, connection);
327 Ok(client)
328 }
329 }
330 }
331
332 pub fn get_user(&self) -> Option<&str> {
333 self.inner.get_user()
334 }
335
336 pub fn get_dbname(&self) -> Option<&str> {
337 self.inner.get_dbname()
338 }
339}