1use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::netio::DUMMY_DNS_PORT;
17use mz_ore::option::OptionExt;
18use mz_ore::task;
19use mz_proto::{RustType, TryFromProtoError};
20use mz_repr::CatalogItemId;
21use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
22use mz_ssh_util::tunnel_manager::SshTunnelManager;
23use proptest_derive::Arbitrary;
24use serde::{Deserialize, Serialize};
25use tokio::io::{AsyncRead, AsyncWrite};
26use tokio::net::TcpStream as TokioTcpStream;
27use tokio_postgres::config::{Host, ReplicationMode};
28use tokio_postgres::tls::MakeTlsConnect;
29use tracing::{info, warn};
30
31use crate::PostgresError;
32
33include!(concat!(env!("OUT_DIR"), "/mz_postgres_util.tunnel.rs"));
34
35macro_rules! bail_generic {
36 ($fmt:expr, $($arg:tt)*) => {
37 return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
38 };
39 ($err:expr $(,)?) => {
40 return Err(PostgresError::Generic(anyhow::anyhow!($err)))
41 };
42}
43
44#[derive(Debug, PartialEq, Clone)]
47pub enum TunnelConfig {
48 Direct {
52 resolved_ips: Option<BTreeSet<IpAddr>>,
53 },
54 Ssh { config: SshTunnelConfig },
60 AwsPrivatelink {
63 connection_id: CatalogItemId,
65 },
66}
67
68pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
69
70pub struct Client {
72 inner: tokio_postgres::Client,
73 server_version: Option<String>,
74}
75
76impl Client {
77 fn new<S, T>(
78 client: tokio_postgres::Client,
79 connection: &tokio_postgres::Connection<S, T>,
80 ) -> Client
81 where
82 S: AsyncRead + AsyncWrite + Unpin,
83 T: AsyncRead + AsyncWrite + Unpin,
84 {
85 let server_version = connection
86 .parameter("server_version")
87 .map(|v| v.to_string());
88 Client {
89 inner: client,
90 server_version,
91 }
92 }
93
94 pub fn server_version(&self) -> Option<&str> {
97 self.server_version.as_deref()
98 }
99
100 pub fn server_flavor(&self) -> PostgresFlavor {
102 match self.server_version.as_ref() {
103 Some(v) if v.contains("-YB-") => PostgresFlavor::Yugabyte,
104 _ => PostgresFlavor::Vanilla,
105 }
106 }
107}
108
109impl Deref for Client {
110 type Target = tokio_postgres::Client;
111
112 fn deref(&self) -> &Self::Target {
113 &self.inner
114 }
115}
116
117impl DerefMut for Client {
118 fn deref_mut(&mut self) -> &mut Self::Target {
119 &mut self.inner
120 }
121}
122
123#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
124pub enum PostgresFlavor {
125 Vanilla,
127 Yugabyte,
129}
130
131impl RustType<ProtoPostgresFlavor> for PostgresFlavor {
132 fn into_proto(&self) -> ProtoPostgresFlavor {
133 let kind = match self {
134 PostgresFlavor::Vanilla => proto_postgres_flavor::Kind::Vanilla(()),
135 PostgresFlavor::Yugabyte => proto_postgres_flavor::Kind::Yugabyte(()),
136 };
137 ProtoPostgresFlavor { kind: Some(kind) }
138 }
139
140 fn from_proto(proto: ProtoPostgresFlavor) -> Result<Self, TryFromProtoError> {
141 let flavor = proto
142 .kind
143 .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
144 Ok(match flavor {
145 proto_postgres_flavor::Kind::Vanilla(()) => PostgresFlavor::Vanilla,
146 proto_postgres_flavor::Kind::Yugabyte(()) => PostgresFlavor::Yugabyte,
147 })
148 }
149}
150
151#[derive(Clone, Debug)]
156pub struct Config {
157 inner: tokio_postgres::Config,
158 tunnel: TunnelConfig,
159 in_task: InTask,
160 ssh_timeout_config: SshTimeoutConfig,
161}
162
163impl Config {
164 pub fn new(
165 inner: tokio_postgres::Config,
166 tunnel: TunnelConfig,
167 ssh_timeout_config: SshTimeoutConfig,
168 in_task: InTask,
169 ) -> Result<Self, PostgresError> {
170 let config = Self {
171 inner,
172 tunnel,
173 in_task,
174 ssh_timeout_config,
175 };
176
177 config.address()?;
180
181 Ok(config)
182 }
183
184 pub async fn connect(
186 &self,
187 task_name: &str,
188 ssh_tunnel_manager: &SshTunnelManager,
189 ) -> Result<Client, PostgresError> {
190 self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
191 .await
192 }
193
194 pub async fn connect_replication(
196 &self,
197 ssh_tunnel_manager: &SshTunnelManager,
198 ) -> Result<Client, PostgresError> {
199 self.connect_traced(
200 "postgres_connect_replication",
201 |config| {
202 config.replication_mode(ReplicationMode::Logical);
203 },
204 ssh_tunnel_manager,
205 )
206 .await
207 }
208
209 fn address(&self) -> Result<(&str, u16), PostgresError> {
210 match (self.inner.get_hosts(), self.inner.get_ports()) {
211 ([Host::Tcp(host)], [port]) => Ok((host, *port)),
212 _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
213 }
214 }
215
216 async fn connect_traced<F>(
217 &self,
218 task_name: &str,
219 configure: F,
220 ssh_tunnel_manager: &SshTunnelManager,
221 ) -> Result<Client, PostgresError>
222 where
223 F: FnOnce(&mut tokio_postgres::Config),
224 {
225 let (host, port) = self.address()?;
226 let address = format!(
227 "{}@{}:{}/{}",
228 self.get_user().display_or("<unknown-user>"),
229 host,
230 port,
231 self.get_dbname().display_or("<unknown-dbname>")
232 );
233 info!(%task_name, %address, "connecting");
234 match self
235 .connect_internal(task_name, configure, ssh_tunnel_manager)
236 .await
237 {
238 Ok(t) => {
239 let backend_pid = t.backend_pid();
240 info!(%task_name, %address, %backend_pid, "connected");
241 Ok(t)
242 }
243 Err(e) => {
244 warn!(%task_name, %address, "connection failed: {e:#}");
245 Err(e)
246 }
247 }
248 }
249
250 async fn connect_internal<F>(
251 &self,
252 task_name: &str,
253 configure: F,
254 ssh_tunnel_manager: &SshTunnelManager,
255 ) -> Result<Client, PostgresError>
256 where
257 F: FnOnce(&mut tokio_postgres::Config),
258 {
259 let mut postgres_config = self.inner.clone();
260 configure(&mut postgres_config);
261
262 let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
263 mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
264 mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
265 })?;
266
267 match &self.tunnel {
268 TunnelConfig::Direct { resolved_ips } => {
269 if let Some(ips) = resolved_ips {
270 let host = match postgres_config.get_hosts() {
271 [Host::Tcp(host)] => host,
272 _ => bail_generic!(
273 "only TCP connections to a single PostgreSQL server are supported"
274 ),
275 }
276 .to_owned();
277 for (idx, ip) in ips.iter().enumerate() {
281 if idx != 0 {
282 postgres_config.host(&host);
283 }
284 postgres_config.hostaddr(ip.clone());
285 }
286 };
287
288 let (client, connection) = async move { postgres_config.connect(tls).await }
289 .run_in_task_if(self.in_task, || "pg_connect".to_string())
290 .await?;
291 let client = Client::new(client, &connection);
292 task::spawn(|| task_name, connection);
293 Ok(client)
294 }
295 TunnelConfig::Ssh { config } => {
296 let (host, port) = self.address()?;
297 let tunnel = ssh_tunnel_manager
298 .connect(
299 config.clone(),
300 host,
301 port,
302 self.ssh_timeout_config,
303 self.in_task,
304 )
305 .await
306 .map_err(PostgresError::Ssh)?;
307
308 let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
309 let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
310 .await
311 .map_err(PostgresError::SshIo)?;
312 let (client, connection) =
320 async move { postgres_config.connect_raw(tcp_stream, tls).await }
321 .run_in_task_if(self.in_task, || "pg_connect".to_string())
322 .await?;
323 let client = Client::new(client, &connection);
324 task::spawn(|| task_name, async {
325 let _tunnel = tunnel; if let Err(e) = connection.await {
328 warn!("postgres connection failed: {e}");
329 }
330 });
331 Ok(client)
332 }
333 TunnelConfig::AwsPrivatelink { connection_id } => {
334 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
341 let privatelink_addrs =
342 tokio::net::lookup_host((privatelink_host, DUMMY_DNS_PORT)).await?;
343
344 let host = match postgres_config.get_hosts() {
347 [Host::Tcp(host)] => host,
348 _ => bail_generic!(
349 "only TCP connections to a single PostgreSQL server are supported"
350 ),
351 }
352 .to_owned();
353 for (idx, addr) in privatelink_addrs.enumerate() {
357 if idx != 0 {
358 postgres_config.host(&host);
359 }
360 postgres_config.hostaddr(addr.ip());
361 }
362
363 let (client, connection) = async move { postgres_config.connect(tls).await }
364 .run_in_task_if(self.in_task, || "pg_connect".to_string())
365 .await?;
366 let client = Client::new(client, &connection);
367 task::spawn(|| task_name, connection);
368 Ok(client)
369 }
370 }
371 }
372
373 pub fn get_user(&self) -> Option<&str> {
374 self.inner.get_user()
375 }
376
377 pub fn get_dbname(&self) -> Option<&str> {
378 self.inner.get_dbname()
379 }
380}