1use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::option::OptionExt;
17use mz_ore::task::{self, AbortOnDropHandle};
18use mz_repr::CatalogItemId;
19use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
20use mz_ssh_util::tunnel_manager::SshTunnelManager;
21use tokio::net::TcpStream as TokioTcpStream;
22use tokio_postgres::config::{Host, ReplicationMode};
23use tokio_postgres::tls::MakeTlsConnect;
24use tracing::{info, warn};
25
26use crate::PostgresError;
27
28macro_rules! bail_generic {
29 ($fmt:expr, $($arg:tt)*) => {
30 return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
31 };
32 ($err:expr $(,)?) => {
33 return Err(PostgresError::Generic(anyhow::anyhow!($err)))
34 };
35}
36
37#[derive(Debug, PartialEq, Clone)]
40pub enum TunnelConfig {
41 Direct {
45 resolved_ips: Option<BTreeSet<IpAddr>>,
46 },
47 Ssh { config: SshTunnelConfig },
53 AwsPrivatelink {
56 connection_id: CatalogItemId,
58 },
59}
60
61pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
62
63pub struct Client {
65 inner: tokio_postgres::Client,
66 _connection_handle: AbortOnDropHandle<()>,
70}
71
72impl Deref for Client {
73 type Target = tokio_postgres::Client;
74
75 fn deref(&self) -> &Self::Target {
76 &self.inner
77 }
78}
79
80impl DerefMut for Client {
81 fn deref_mut(&mut self) -> &mut Self::Target {
82 &mut self.inner
83 }
84}
85
86#[derive(Clone, Debug)]
91pub struct Config {
92 inner: tokio_postgres::Config,
93 tunnel: TunnelConfig,
94 in_task: InTask,
95 ssh_timeout_config: SshTimeoutConfig,
96}
97
98impl Config {
99 pub fn new(
100 inner: tokio_postgres::Config,
101 tunnel: TunnelConfig,
102 ssh_timeout_config: SshTimeoutConfig,
103 in_task: InTask,
104 ) -> Result<Self, PostgresError> {
105 let config = Self {
106 inner,
107 tunnel,
108 in_task,
109 ssh_timeout_config,
110 };
111
112 config.address()?;
115
116 Ok(config)
117 }
118
119 pub async fn connect(
121 &self,
122 task_name: &str,
123 ssh_tunnel_manager: &SshTunnelManager,
124 ) -> Result<Client, PostgresError> {
125 self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
126 .await
127 }
128
129 pub async fn connect_replication(
131 &self,
132 ssh_tunnel_manager: &SshTunnelManager,
133 ) -> Result<Client, PostgresError> {
134 self.connect_traced(
135 "postgres_connect_replication",
136 |config| {
137 config.replication_mode(ReplicationMode::Logical);
138 },
139 ssh_tunnel_manager,
140 )
141 .await
142 }
143
144 fn address(&self) -> Result<(&str, u16), PostgresError> {
145 match (self.inner.get_hosts(), self.inner.get_ports()) {
146 ([Host::Tcp(host)], [port]) => Ok((host, *port)),
147 _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
148 }
149 }
150
151 async fn connect_traced<F>(
152 &self,
153 task_name: &str,
154 configure: F,
155 ssh_tunnel_manager: &SshTunnelManager,
156 ) -> Result<Client, PostgresError>
157 where
158 F: FnOnce(&mut tokio_postgres::Config),
159 {
160 let (host, port) = self.address()?;
161 let address = format!(
162 "{}@{}:{}/{}",
163 self.get_user().display_or("<unknown-user>"),
164 host,
165 port,
166 self.get_dbname().display_or("<unknown-dbname>")
167 );
168 info!(%task_name, %address, "connecting");
169 match self
170 .connect_internal(task_name, configure, ssh_tunnel_manager)
171 .await
172 {
173 Ok(t) => {
174 let backend_pid = t.backend_pid();
175 info!(%task_name, %address, %backend_pid, "connected");
176 Ok(t)
177 }
178 Err(e) => {
179 warn!(%task_name, %address, "connection failed: {e:#}");
180 Err(e)
181 }
182 }
183 }
184
185 async fn connect_internal<F>(
186 &self,
187 task_name: &str,
188 configure: F,
189 ssh_tunnel_manager: &SshTunnelManager,
190 ) -> Result<Client, PostgresError>
191 where
192 F: FnOnce(&mut tokio_postgres::Config),
193 {
194 let mut postgres_config = self.inner.clone();
195 configure(&mut postgres_config);
196
197 let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
198 mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
199 mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
200 })?;
201
202 match &self.tunnel {
203 TunnelConfig::Direct { resolved_ips } => {
204 if let Some(ips) = resolved_ips {
205 let host = match postgres_config.get_hosts() {
206 [Host::Tcp(host)] => host,
207 _ => bail_generic!(
208 "only TCP connections to a single PostgreSQL server are supported"
209 ),
210 }
211 .to_owned();
212 for (idx, ip) in ips.iter().enumerate() {
216 if idx != 0 {
217 postgres_config.host(&host);
218 }
219 postgres_config.hostaddr(ip.clone());
220 }
221 };
222
223 let (client, connection) = async move { postgres_config.connect(tls).await }
224 .run_in_task_if(self.in_task, || "pg_connect".to_string())
225 .await?;
226
227 let client = Client {
228 inner: client,
229 _connection_handle: task::spawn(|| task_name, async {
230 if let Err(e) = connection.await {
231 warn!("postgres direct connection failed: {e}");
232 }
233 })
234 .abort_on_drop(),
235 };
236 Ok(client)
237 }
238 TunnelConfig::Ssh { config } => {
239 let (host, port) = self.address()?;
240 let tunnel = ssh_tunnel_manager
241 .connect(
242 config.clone(),
243 host,
244 port,
245 self.ssh_timeout_config,
246 self.in_task,
247 )
248 .await
249 .map_err(PostgresError::Ssh)?;
250
251 let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
252 let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
253 .await
254 .map_err(PostgresError::SshIo)?;
255 let (client, connection) =
263 async move { postgres_config.connect_raw(tcp_stream, tls).await }
264 .run_in_task_if(self.in_task, || "pg_connect".to_string())
265 .await?;
266
267 let client = Client {
268 inner: client,
269 _connection_handle: task::spawn(|| task_name, async {
270 let _tunnel = tunnel; if let Err(e) = connection.await {
272 warn!("postgres via SSH tunnel connection failed: {e}");
273 }
274 })
275 .abort_on_drop(),
276 };
277 Ok(client)
278 }
279 TunnelConfig::AwsPrivatelink { connection_id } => {
280 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
287 let privatelink_addrs = tokio::net::lookup_host((privatelink_host, 0)).await?;
288
289 let host = match postgres_config.get_hosts() {
292 [Host::Tcp(host)] => host,
293 _ => bail_generic!(
294 "only TCP connections to a single PostgreSQL server are supported"
295 ),
296 }
297 .to_owned();
298 for (idx, addr) in privatelink_addrs.enumerate() {
302 if idx != 0 {
303 postgres_config.host(&host);
304 }
305 postgres_config.hostaddr(addr.ip());
306 }
307
308 let (client, connection) = async move { postgres_config.connect(tls).await }
309 .run_in_task_if(self.in_task, || "pg_connect".to_string())
310 .await?;
311
312 let client = Client {
313 inner: client,
314 _connection_handle: task::spawn(|| task_name, async {
315 if let Err(e) = connection.await {
316 warn!("postgres AWS link connection failed: {e}");
317 }
318 })
319 .abort_on_drop(),
320 };
321 Ok(client)
322 }
323 }
324 }
325
326 pub fn get_user(&self) -> Option<&str> {
327 self.inner.get_user()
328 }
329
330 pub fn get_dbname(&self) -> Option<&str> {
331 self.inner.get_dbname()
332 }
333}