1use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::option::OptionExt;
17use mz_ore::task;
18use mz_proto::{RustType, TryFromProtoError};
19use mz_repr::CatalogItemId;
20use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
21use mz_ssh_util::tunnel_manager::SshTunnelManager;
22use proptest_derive::Arbitrary;
23use serde::{Deserialize, Serialize};
24use tokio::io::{AsyncRead, AsyncWrite};
25use tokio::net::TcpStream as TokioTcpStream;
26use tokio_postgres::config::{Host, ReplicationMode};
27use tokio_postgres::tls::MakeTlsConnect;
28use tracing::{info, warn};
29
30use crate::PostgresError;
31
32include!(concat!(env!("OUT_DIR"), "/mz_postgres_util.tunnel.rs"));
33
34macro_rules! bail_generic {
35 ($fmt:expr, $($arg:tt)*) => {
36 return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
37 };
38 ($err:expr $(,)?) => {
39 return Err(PostgresError::Generic(anyhow::anyhow!($err)))
40 };
41}
42
43#[derive(Debug, PartialEq, Clone)]
46pub enum TunnelConfig {
47 Direct {
51 resolved_ips: Option<BTreeSet<IpAddr>>,
52 },
53 Ssh { config: SshTunnelConfig },
59 AwsPrivatelink {
62 connection_id: CatalogItemId,
64 },
65}
66
67pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
68
69pub struct Client {
71 inner: tokio_postgres::Client,
72 server_version: Option<String>,
73}
74
75impl Client {
76 fn new<S, T>(
77 client: tokio_postgres::Client,
78 connection: &tokio_postgres::Connection<S, T>,
79 ) -> Client
80 where
81 S: AsyncRead + AsyncWrite + Unpin,
82 T: AsyncRead + AsyncWrite + Unpin,
83 {
84 let server_version = connection
85 .parameter("server_version")
86 .map(|v| v.to_string());
87 Client {
88 inner: client,
89 server_version,
90 }
91 }
92
93 pub fn server_version(&self) -> Option<&str> {
96 self.server_version.as_deref()
97 }
98
99 pub fn server_flavor(&self) -> PostgresFlavor {
101 match self.server_version.as_ref() {
102 Some(v) if v.contains("-YB-") => PostgresFlavor::Yugabyte,
103 _ => PostgresFlavor::Vanilla,
104 }
105 }
106}
107
108impl Deref for Client {
109 type Target = tokio_postgres::Client;
110
111 fn deref(&self) -> &Self::Target {
112 &self.inner
113 }
114}
115
116impl DerefMut for Client {
117 fn deref_mut(&mut self) -> &mut Self::Target {
118 &mut self.inner
119 }
120}
121
122#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
123pub enum PostgresFlavor {
124 Vanilla,
126 Yugabyte,
128}
129
130impl RustType<ProtoPostgresFlavor> for PostgresFlavor {
131 fn into_proto(&self) -> ProtoPostgresFlavor {
132 let kind = match self {
133 PostgresFlavor::Vanilla => proto_postgres_flavor::Kind::Vanilla(()),
134 PostgresFlavor::Yugabyte => proto_postgres_flavor::Kind::Yugabyte(()),
135 };
136 ProtoPostgresFlavor { kind: Some(kind) }
137 }
138
139 fn from_proto(proto: ProtoPostgresFlavor) -> Result<Self, TryFromProtoError> {
140 let flavor = proto
141 .kind
142 .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
143 Ok(match flavor {
144 proto_postgres_flavor::Kind::Vanilla(()) => PostgresFlavor::Vanilla,
145 proto_postgres_flavor::Kind::Yugabyte(()) => PostgresFlavor::Yugabyte,
146 })
147 }
148}
149
150#[derive(Clone, Debug)]
155pub struct Config {
156 inner: tokio_postgres::Config,
157 tunnel: TunnelConfig,
158 in_task: InTask,
159 ssh_timeout_config: SshTimeoutConfig,
160}
161
162impl Config {
163 pub fn new(
164 inner: tokio_postgres::Config,
165 tunnel: TunnelConfig,
166 ssh_timeout_config: SshTimeoutConfig,
167 in_task: InTask,
168 ) -> Result<Self, PostgresError> {
169 let config = Self {
170 inner,
171 tunnel,
172 in_task,
173 ssh_timeout_config,
174 };
175
176 config.address()?;
179
180 Ok(config)
181 }
182
183 pub async fn connect(
185 &self,
186 task_name: &str,
187 ssh_tunnel_manager: &SshTunnelManager,
188 ) -> Result<Client, PostgresError> {
189 self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
190 .await
191 }
192
193 pub async fn connect_replication(
195 &self,
196 ssh_tunnel_manager: &SshTunnelManager,
197 ) -> Result<Client, PostgresError> {
198 self.connect_traced(
199 "postgres_connect_replication",
200 |config| {
201 config.replication_mode(ReplicationMode::Logical);
202 },
203 ssh_tunnel_manager,
204 )
205 .await
206 }
207
208 fn address(&self) -> Result<(&str, u16), PostgresError> {
209 match (self.inner.get_hosts(), self.inner.get_ports()) {
210 ([Host::Tcp(host)], [port]) => Ok((host, *port)),
211 _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
212 }
213 }
214
215 async fn connect_traced<F>(
216 &self,
217 task_name: &str,
218 configure: F,
219 ssh_tunnel_manager: &SshTunnelManager,
220 ) -> Result<Client, PostgresError>
221 where
222 F: FnOnce(&mut tokio_postgres::Config),
223 {
224 let (host, port) = self.address()?;
225 let address = format!(
226 "{}@{}:{}/{}",
227 self.get_user().display_or("<unknown-user>"),
228 host,
229 port,
230 self.get_dbname().display_or("<unknown-dbname>")
231 );
232 info!(%task_name, %address, "connecting");
233 match self
234 .connect_internal(task_name, configure, ssh_tunnel_manager)
235 .await
236 {
237 Ok(t) => {
238 let backend_pid = t.backend_pid();
239 info!(%task_name, %address, %backend_pid, "connected");
240 Ok(t)
241 }
242 Err(e) => {
243 warn!(%task_name, %address, "connection failed: {e:#}");
244 Err(e)
245 }
246 }
247 }
248
249 async fn connect_internal<F>(
250 &self,
251 task_name: &str,
252 configure: F,
253 ssh_tunnel_manager: &SshTunnelManager,
254 ) -> Result<Client, PostgresError>
255 where
256 F: FnOnce(&mut tokio_postgres::Config),
257 {
258 let mut postgres_config = self.inner.clone();
259 configure(&mut postgres_config);
260
261 let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
262 mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
263 mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
264 })?;
265
266 match &self.tunnel {
267 TunnelConfig::Direct { resolved_ips } => {
268 if let Some(ips) = resolved_ips {
269 let host = match postgres_config.get_hosts() {
270 [Host::Tcp(host)] => host,
271 _ => bail_generic!(
272 "only TCP connections to a single PostgreSQL server are supported"
273 ),
274 }
275 .to_owned();
276 for (idx, ip) in ips.iter().enumerate() {
280 if idx != 0 {
281 postgres_config.host(&host);
282 }
283 postgres_config.hostaddr(ip.clone());
284 }
285 };
286
287 let (client, connection) = async move { postgres_config.connect(tls).await }
288 .run_in_task_if(self.in_task, || "pg_connect".to_string())
289 .await?;
290 let client = Client::new(client, &connection);
291 task::spawn(|| task_name, connection);
292 Ok(client)
293 }
294 TunnelConfig::Ssh { config } => {
295 let (host, port) = self.address()?;
296 let tunnel = ssh_tunnel_manager
297 .connect(
298 config.clone(),
299 host,
300 port,
301 self.ssh_timeout_config,
302 self.in_task,
303 )
304 .await
305 .map_err(PostgresError::Ssh)?;
306
307 let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
308 let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
309 .await
310 .map_err(PostgresError::SshIo)?;
311 let (client, connection) =
319 async move { postgres_config.connect_raw(tcp_stream, tls).await }
320 .run_in_task_if(self.in_task, || "pg_connect".to_string())
321 .await?;
322 let client = Client::new(client, &connection);
323 task::spawn(|| task_name, async {
324 let _tunnel = tunnel; if let Err(e) = connection.await {
327 warn!("postgres connection failed: {e}");
328 }
329 });
330 Ok(client)
331 }
332 TunnelConfig::AwsPrivatelink { connection_id } => {
333 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
340 let privatelink_addrs = tokio::net::lookup_host((privatelink_host, 11111)).await?;
343
344 let host = match postgres_config.get_hosts() {
347 [Host::Tcp(host)] => host,
348 _ => bail_generic!(
349 "only TCP connections to a single PostgreSQL server are supported"
350 ),
351 }
352 .to_owned();
353 for (idx, addr) in privatelink_addrs.enumerate() {
357 if idx != 0 {
358 postgres_config.host(&host);
359 }
360 postgres_config.hostaddr(addr.ip());
361 }
362
363 let (client, connection) = async move { postgres_config.connect(tls).await }
364 .run_in_task_if(self.in_task, || "pg_connect".to_string())
365 .await?;
366 let client = Client::new(client, &connection);
367 task::spawn(|| task_name, connection);
368 Ok(client)
369 }
370 }
371 }
372
373 pub fn get_user(&self) -> Option<&str> {
374 self.inner.get_user()
375 }
376
377 pub fn get_dbname(&self) -> Option<&str> {
378 self.inner.get_dbname()
379 }
380}