1use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::option::OptionExt;
17use mz_ore::task::{self, AbortOnDropHandle};
18use mz_repr::CatalogItemId;
19use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
20use mz_ssh_util::tunnel_manager::SshTunnelManager;
21use tokio::net::TcpStream as TokioTcpStream;
22use tokio_postgres::config::{Host, ReplicationMode};
23use tokio_postgres::tls::MakeTlsConnect;
24use tracing::{info, warn};
25
26use crate::PostgresError;
27
28macro_rules! bail_generic {
29 ($err:expr $(,)?) => {
30 return Err(PostgresError::Generic(anyhow::anyhow!($err)))
31 };
32}
33
34#[derive(Debug, PartialEq, Clone)]
37pub enum TunnelConfig {
38 Direct {
42 resolved_ips: Option<BTreeSet<IpAddr>>,
43 },
44 Ssh { config: SshTunnelConfig },
50 AwsPrivatelink {
53 connection_id: CatalogItemId,
55 },
56}
57
58pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
59
60pub struct Client {
62 inner: tokio_postgres::Client,
63 _connection_handle: AbortOnDropHandle<()>,
67}
68
69impl Deref for Client {
70 type Target = tokio_postgres::Client;
71
72 fn deref(&self) -> &Self::Target {
73 &self.inner
74 }
75}
76
77impl DerefMut for Client {
78 fn deref_mut(&mut self) -> &mut Self::Target {
79 &mut self.inner
80 }
81}
82
83#[derive(Clone, Debug)]
88pub struct Config {
89 inner: tokio_postgres::Config,
90 tunnel: TunnelConfig,
91 in_task: InTask,
92 ssh_timeout_config: SshTimeoutConfig,
93}
94
95impl Config {
96 pub fn new(
97 inner: tokio_postgres::Config,
98 tunnel: TunnelConfig,
99 ssh_timeout_config: SshTimeoutConfig,
100 in_task: InTask,
101 ) -> Result<Self, PostgresError> {
102 let config = Self {
103 inner,
104 tunnel,
105 in_task,
106 ssh_timeout_config,
107 };
108
109 config.address()?;
112
113 Ok(config)
114 }
115
116 pub async fn connect(
118 &self,
119 task_name: &str,
120 ssh_tunnel_manager: &SshTunnelManager,
121 ) -> Result<Client, PostgresError> {
122 self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
123 .await
124 }
125
126 pub async fn connect_replication(
128 &self,
129 ssh_tunnel_manager: &SshTunnelManager,
130 ) -> Result<Client, PostgresError> {
131 self.connect_traced(
132 "postgres_connect_replication",
133 |config| {
134 config.replication_mode(ReplicationMode::Logical);
135 },
136 ssh_tunnel_manager,
137 )
138 .await
139 }
140
141 fn address(&self) -> Result<(&str, u16), PostgresError> {
142 match (self.inner.get_hosts(), self.inner.get_ports()) {
143 ([Host::Tcp(host)], [port]) => Ok((host, *port)),
144 _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
145 }
146 }
147
148 async fn connect_traced<F>(
149 &self,
150 task_name: &str,
151 configure: F,
152 ssh_tunnel_manager: &SshTunnelManager,
153 ) -> Result<Client, PostgresError>
154 where
155 F: FnOnce(&mut tokio_postgres::Config),
156 {
157 let (host, port) = self.address()?;
158 let address = format!(
159 "{}@{}:{}/{}",
160 self.get_user().display_or("<unknown-user>"),
161 host,
162 port,
163 self.get_dbname().display_or("<unknown-dbname>")
164 );
165 info!(%task_name, %address, "connecting");
166 match self
167 .connect_internal(task_name, configure, ssh_tunnel_manager)
168 .await
169 {
170 Ok(t) => {
171 let backend_pid = t.backend_pid();
172 info!(%task_name, %address, %backend_pid, "connected");
173 Ok(t)
174 }
175 Err(e) => {
176 warn!(%task_name, %address, "connection failed: {e:#}");
177 Err(e)
178 }
179 }
180 }
181
182 async fn connect_internal<F>(
183 &self,
184 task_name: &str,
185 configure: F,
186 ssh_tunnel_manager: &SshTunnelManager,
187 ) -> Result<Client, PostgresError>
188 where
189 F: FnOnce(&mut tokio_postgres::Config),
190 {
191 let mut postgres_config = self.inner.clone();
192 configure(&mut postgres_config);
193
194 let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
195 mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
196 mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
197 })?;
198
199 match &self.tunnel {
200 TunnelConfig::Direct { resolved_ips } => {
201 if let Some(ips) = resolved_ips {
202 let host = match postgres_config.get_hosts() {
203 [Host::Tcp(host)] => host,
204 _ => bail_generic!(
205 "only TCP connections to a single PostgreSQL server are supported"
206 ),
207 }
208 .to_owned();
209 for (idx, ip) in ips.iter().enumerate() {
213 if idx != 0 {
214 postgres_config.host(&host);
215 }
216 postgres_config.hostaddr(ip.clone());
217 }
218 };
219
220 let (client, connection) = async move { postgres_config.connect(tls).await }
221 .run_in_task_if(self.in_task, || "pg_connect".to_string())
222 .await?;
223
224 let client = Client {
225 inner: client,
226 _connection_handle: task::spawn(|| task_name, async {
227 if let Err(e) = connection.await {
228 warn!("postgres direct connection failed: {e}");
229 }
230 })
231 .abort_on_drop(),
232 };
233 Ok(client)
234 }
235 TunnelConfig::Ssh { config } => {
236 let (host, port) = self.address()?;
237 let tunnel = ssh_tunnel_manager
238 .connect(
239 config.clone(),
240 host,
241 port,
242 self.ssh_timeout_config,
243 self.in_task,
244 )
245 .await
246 .map_err(PostgresError::Ssh)?;
247
248 let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
249 let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
250 .await
251 .map_err(PostgresError::SshIo)?;
252 let (client, connection) =
260 async move { postgres_config.connect_raw(tcp_stream, tls).await }
261 .run_in_task_if(self.in_task, || "pg_connect".to_string())
262 .await?;
263
264 let client = Client {
265 inner: client,
266 _connection_handle: task::spawn(|| task_name, async {
267 let _tunnel = tunnel; if let Err(e) = connection.await {
269 warn!("postgres via SSH tunnel connection failed: {e}");
270 }
271 })
272 .abort_on_drop(),
273 };
274 Ok(client)
275 }
276 TunnelConfig::AwsPrivatelink { connection_id } => {
277 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
284 let privatelink_addrs = tokio::net::lookup_host((privatelink_host, 0)).await?;
285
286 let host = match postgres_config.get_hosts() {
289 [Host::Tcp(host)] => host,
290 _ => bail_generic!(
291 "only TCP connections to a single PostgreSQL server are supported"
292 ),
293 }
294 .to_owned();
295 for (idx, addr) in privatelink_addrs.enumerate() {
299 if idx != 0 {
300 postgres_config.host(&host);
301 }
302 postgres_config.hostaddr(addr.ip());
303 }
304
305 let (client, connection) = async move { postgres_config.connect(tls).await }
306 .run_in_task_if(self.in_task, || "pg_connect".to_string())
307 .await?;
308
309 let client = Client {
310 inner: client,
311 _connection_handle: task::spawn(|| task_name, async {
312 if let Err(e) = connection.await {
313 warn!("postgres AWS link connection failed: {e}");
314 }
315 })
316 .abort_on_drop(),
317 };
318 Ok(client)
319 }
320 }
321 }
322
323 pub fn get_user(&self) -> Option<&str> {
324 self.inner.get_user()
325 }
326
327 pub fn get_dbname(&self) -> Option<&str> {
328 self.inner.get_dbname()
329 }
330}