1use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::netio::DUMMY_DNS_PORT;
17use mz_ore::option::OptionExt;
18use mz_ore::task::{self, AbortOnDropHandle};
19use mz_repr::CatalogItemId;
20use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
21use mz_ssh_util::tunnel_manager::SshTunnelManager;
22use tokio::net::TcpStream as TokioTcpStream;
23use tokio_postgres::config::{Host, ReplicationMode};
24use tokio_postgres::tls::MakeTlsConnect;
25use tracing::{info, warn};
26
27use crate::PostgresError;
28
29macro_rules! bail_generic {
30 ($fmt:expr, $($arg:tt)*) => {
31 return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
32 };
33 ($err:expr $(,)?) => {
34 return Err(PostgresError::Generic(anyhow::anyhow!($err)))
35 };
36}
37
38#[derive(Debug, PartialEq, Clone)]
41pub enum TunnelConfig {
42 Direct {
46 resolved_ips: Option<BTreeSet<IpAddr>>,
47 },
48 Ssh { config: SshTunnelConfig },
54 AwsPrivatelink {
57 connection_id: CatalogItemId,
59 },
60}
61
62pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
63
64pub struct Client {
66 inner: tokio_postgres::Client,
67 _connection_handle: AbortOnDropHandle<()>,
71}
72
73impl Deref for Client {
74 type Target = tokio_postgres::Client;
75
76 fn deref(&self) -> &Self::Target {
77 &self.inner
78 }
79}
80
81impl DerefMut for Client {
82 fn deref_mut(&mut self) -> &mut Self::Target {
83 &mut self.inner
84 }
85}
86
87#[derive(Clone, Debug)]
92pub struct Config {
93 inner: tokio_postgres::Config,
94 tunnel: TunnelConfig,
95 in_task: InTask,
96 ssh_timeout_config: SshTimeoutConfig,
97}
98
99impl Config {
100 pub fn new(
101 inner: tokio_postgres::Config,
102 tunnel: TunnelConfig,
103 ssh_timeout_config: SshTimeoutConfig,
104 in_task: InTask,
105 ) -> Result<Self, PostgresError> {
106 let config = Self {
107 inner,
108 tunnel,
109 in_task,
110 ssh_timeout_config,
111 };
112
113 config.address()?;
116
117 Ok(config)
118 }
119
120 pub async fn connect(
122 &self,
123 task_name: &str,
124 ssh_tunnel_manager: &SshTunnelManager,
125 ) -> Result<Client, PostgresError> {
126 self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
127 .await
128 }
129
130 pub async fn connect_replication(
132 &self,
133 ssh_tunnel_manager: &SshTunnelManager,
134 ) -> Result<Client, PostgresError> {
135 self.connect_traced(
136 "postgres_connect_replication",
137 |config| {
138 config.replication_mode(ReplicationMode::Logical);
139 },
140 ssh_tunnel_manager,
141 )
142 .await
143 }
144
145 fn address(&self) -> Result<(&str, u16), PostgresError> {
146 match (self.inner.get_hosts(), self.inner.get_ports()) {
147 ([Host::Tcp(host)], [port]) => Ok((host, *port)),
148 _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
149 }
150 }
151
152 async fn connect_traced<F>(
153 &self,
154 task_name: &str,
155 configure: F,
156 ssh_tunnel_manager: &SshTunnelManager,
157 ) -> Result<Client, PostgresError>
158 where
159 F: FnOnce(&mut tokio_postgres::Config),
160 {
161 let (host, port) = self.address()?;
162 let address = format!(
163 "{}@{}:{}/{}",
164 self.get_user().display_or("<unknown-user>"),
165 host,
166 port,
167 self.get_dbname().display_or("<unknown-dbname>")
168 );
169 info!(%task_name, %address, "connecting");
170 match self
171 .connect_internal(task_name, configure, ssh_tunnel_manager)
172 .await
173 {
174 Ok(t) => {
175 let backend_pid = t.backend_pid();
176 info!(%task_name, %address, %backend_pid, "connected");
177 Ok(t)
178 }
179 Err(e) => {
180 warn!(%task_name, %address, "connection failed: {e:#}");
181 Err(e)
182 }
183 }
184 }
185
186 async fn connect_internal<F>(
187 &self,
188 task_name: &str,
189 configure: F,
190 ssh_tunnel_manager: &SshTunnelManager,
191 ) -> Result<Client, PostgresError>
192 where
193 F: FnOnce(&mut tokio_postgres::Config),
194 {
195 let mut postgres_config = self.inner.clone();
196 configure(&mut postgres_config);
197
198 let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
199 mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
200 mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
201 })?;
202
203 match &self.tunnel {
204 TunnelConfig::Direct { resolved_ips } => {
205 if let Some(ips) = resolved_ips {
206 let host = match postgres_config.get_hosts() {
207 [Host::Tcp(host)] => host,
208 _ => bail_generic!(
209 "only TCP connections to a single PostgreSQL server are supported"
210 ),
211 }
212 .to_owned();
213 for (idx, ip) in ips.iter().enumerate() {
217 if idx != 0 {
218 postgres_config.host(&host);
219 }
220 postgres_config.hostaddr(ip.clone());
221 }
222 };
223
224 let (client, connection) = async move { postgres_config.connect(tls).await }
225 .run_in_task_if(self.in_task, || "pg_connect".to_string())
226 .await?;
227
228 let client = Client {
229 inner: client,
230 _connection_handle: task::spawn(|| task_name, async {
231 if let Err(e) = connection.await {
232 warn!("postgres direct connection failed: {e}");
233 }
234 })
235 .abort_on_drop(),
236 };
237 Ok(client)
238 }
239 TunnelConfig::Ssh { config } => {
240 let (host, port) = self.address()?;
241 let tunnel = ssh_tunnel_manager
242 .connect(
243 config.clone(),
244 host,
245 port,
246 self.ssh_timeout_config,
247 self.in_task,
248 )
249 .await
250 .map_err(PostgresError::Ssh)?;
251
252 let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
253 let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
254 .await
255 .map_err(PostgresError::SshIo)?;
256 let (client, connection) =
264 async move { postgres_config.connect_raw(tcp_stream, tls).await }
265 .run_in_task_if(self.in_task, || "pg_connect".to_string())
266 .await?;
267
268 let client = Client {
269 inner: client,
270 _connection_handle: task::spawn(|| task_name, async {
271 let _tunnel = tunnel; if let Err(e) = connection.await {
273 warn!("postgres via SSH tunnel connection failed: {e}");
274 }
275 })
276 .abort_on_drop(),
277 };
278 Ok(client)
279 }
280 TunnelConfig::AwsPrivatelink { connection_id } => {
281 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
288 let privatelink_addrs =
289 tokio::net::lookup_host((privatelink_host, DUMMY_DNS_PORT)).await?;
290
291 let host = match postgres_config.get_hosts() {
294 [Host::Tcp(host)] => host,
295 _ => bail_generic!(
296 "only TCP connections to a single PostgreSQL server are supported"
297 ),
298 }
299 .to_owned();
300 for (idx, addr) in privatelink_addrs.enumerate() {
304 if idx != 0 {
305 postgres_config.host(&host);
306 }
307 postgres_config.hostaddr(addr.ip());
308 }
309
310 let (client, connection) = async move { postgres_config.connect(tls).await }
311 .run_in_task_if(self.in_task, || "pg_connect".to_string())
312 .await?;
313
314 let client = Client {
315 inner: client,
316 _connection_handle: task::spawn(|| task_name, async {
317 if let Err(e) = connection.await {
318 warn!("postgres AWS link connection failed: {e}");
319 }
320 })
321 .abort_on_drop(),
322 };
323 Ok(client)
324 }
325 }
326 }
327
328 pub fn get_user(&self) -> Option<&str> {
329 self.inner.get_user()
330 }
331
332 pub fn get_dbname(&self) -> Option<&str> {
333 self.inner.get_dbname()
334 }
335}