mz_postgres_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::option::OptionExt;
17use mz_ore::task;
18use mz_proto::{RustType, TryFromProtoError};
19use mz_repr::CatalogItemId;
20use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
21use mz_ssh_util::tunnel_manager::SshTunnelManager;
22use proptest_derive::Arbitrary;
23use serde::{Deserialize, Serialize};
24use tokio::io::{AsyncRead, AsyncWrite};
25use tokio::net::TcpStream as TokioTcpStream;
26use tokio_postgres::config::{Host, ReplicationMode};
27use tokio_postgres::tls::MakeTlsConnect;
28use tracing::{info, warn};
29
30use crate::PostgresError;
31
32include!(concat!(env!("OUT_DIR"), "/mz_postgres_util.tunnel.rs"));
33
34macro_rules! bail_generic {
35    ($fmt:expr, $($arg:tt)*) => {
36        return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
37    };
38    ($err:expr $(,)?) => {
39        return Err(PostgresError::Generic(anyhow::anyhow!($err)))
40    };
41}
42
43/// Configures an optional tunnel for use when connecting to a PostgreSQL
44/// database.
45#[derive(Debug, PartialEq, Clone)]
46pub enum TunnelConfig {
47    /// Establish a direct TCP connection to the database host.
48    /// If `resolved_ips` is not None, the provided IPs will be used
49    /// rather than resolving the hostname.
50    Direct {
51        resolved_ips: Option<BTreeSet<IpAddr>>,
52    },
53    /// Establish a TCP connection to the database via an SSH tunnel.
54    /// This means first establishing an SSH connection to a bastion host,
55    /// and then opening a separate connection from that host to the database.
56    /// This is commonly referred by vendors as a "direct SSH tunnel", in
57    /// opposition to "reverse SSH tunnel", which is currently unsupported.
58    Ssh { config: SshTunnelConfig },
59    /// Establish a TCP connection to the database via an AWS PrivateLink
60    /// service.
61    AwsPrivatelink {
62        /// The ID of the AWS PrivateLink service.
63        connection_id: CatalogItemId,
64    },
65}
66
67pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
68
69/// A wrapper for [`tokio_postgres::Client`] that can report the server version.
70pub struct Client {
71    inner: tokio_postgres::Client,
72    server_version: Option<String>,
73}
74
75impl Client {
76    fn new<S, T>(
77        client: tokio_postgres::Client,
78        connection: &tokio_postgres::Connection<S, T>,
79    ) -> Client
80    where
81        S: AsyncRead + AsyncWrite + Unpin,
82        T: AsyncRead + AsyncWrite + Unpin,
83    {
84        let server_version = connection
85            .parameter("server_version")
86            .map(|v| v.to_string());
87        Client {
88            inner: client,
89            server_version,
90        }
91    }
92
93    /// Reports the value of the `server_version` parameter reported by the
94    /// server.
95    pub fn server_version(&self) -> Option<&str> {
96        self.server_version.as_deref()
97    }
98
99    /// Reports the postgres flavor as indicated by the server version.
100    pub fn server_flavor(&self) -> PostgresFlavor {
101        match self.server_version.as_ref() {
102            Some(v) if v.contains("-YB-") => PostgresFlavor::Yugabyte,
103            _ => PostgresFlavor::Vanilla,
104        }
105    }
106}
107
108impl Deref for Client {
109    type Target = tokio_postgres::Client;
110
111    fn deref(&self) -> &Self::Target {
112        &self.inner
113    }
114}
115
116impl DerefMut for Client {
117    fn deref_mut(&mut self) -> &mut Self::Target {
118        &mut self.inner
119    }
120}
121
122#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
123pub enum PostgresFlavor {
124    /// A normal PostgreSQL server.
125    Vanilla,
126    /// A Yugabyte server.
127    Yugabyte,
128}
129
130impl RustType<ProtoPostgresFlavor> for PostgresFlavor {
131    fn into_proto(&self) -> ProtoPostgresFlavor {
132        let kind = match self {
133            PostgresFlavor::Vanilla => proto_postgres_flavor::Kind::Vanilla(()),
134            PostgresFlavor::Yugabyte => proto_postgres_flavor::Kind::Yugabyte(()),
135        };
136        ProtoPostgresFlavor { kind: Some(kind) }
137    }
138
139    fn from_proto(proto: ProtoPostgresFlavor) -> Result<Self, TryFromProtoError> {
140        let flavor = proto
141            .kind
142            .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
143        Ok(match flavor {
144            proto_postgres_flavor::Kind::Vanilla(()) => PostgresFlavor::Vanilla,
145            proto_postgres_flavor::Kind::Yugabyte(()) => PostgresFlavor::Yugabyte,
146        })
147    }
148}
149
150/// Configuration for PostgreSQL connections.
151///
152/// This wraps [`tokio_postgres::Config`] to allow the configuration of a
153/// tunnel via a [`TunnelConfig`].
154#[derive(Clone, Debug)]
155pub struct Config {
156    inner: tokio_postgres::Config,
157    tunnel: TunnelConfig,
158    in_task: InTask,
159    ssh_timeout_config: SshTimeoutConfig,
160}
161
162impl Config {
163    pub fn new(
164        inner: tokio_postgres::Config,
165        tunnel: TunnelConfig,
166        ssh_timeout_config: SshTimeoutConfig,
167        in_task: InTask,
168    ) -> Result<Self, PostgresError> {
169        let config = Self {
170            inner,
171            tunnel,
172            in_task,
173            ssh_timeout_config,
174        };
175
176        // Early validate that the configuration contains only a single TCP
177        // server.
178        config.address()?;
179
180        Ok(config)
181    }
182
183    /// Connects to the configured PostgreSQL database.
184    pub async fn connect(
185        &self,
186        task_name: &str,
187        ssh_tunnel_manager: &SshTunnelManager,
188    ) -> Result<Client, PostgresError> {
189        self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
190            .await
191    }
192
193    /// Starts a replication connection to the configured PostgreSQL database.
194    pub async fn connect_replication(
195        &self,
196        ssh_tunnel_manager: &SshTunnelManager,
197    ) -> Result<Client, PostgresError> {
198        self.connect_traced(
199            "postgres_connect_replication",
200            |config| {
201                config.replication_mode(ReplicationMode::Logical);
202            },
203            ssh_tunnel_manager,
204        )
205        .await
206    }
207
208    fn address(&self) -> Result<(&str, u16), PostgresError> {
209        match (self.inner.get_hosts(), self.inner.get_ports()) {
210            ([Host::Tcp(host)], [port]) => Ok((host, *port)),
211            _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
212        }
213    }
214
215    async fn connect_traced<F>(
216        &self,
217        task_name: &str,
218        configure: F,
219        ssh_tunnel_manager: &SshTunnelManager,
220    ) -> Result<Client, PostgresError>
221    where
222        F: FnOnce(&mut tokio_postgres::Config),
223    {
224        let (host, port) = self.address()?;
225        let address = format!(
226            "{}@{}:{}/{}",
227            self.get_user().display_or("<unknown-user>"),
228            host,
229            port,
230            self.get_dbname().display_or("<unknown-dbname>")
231        );
232        info!(%task_name, %address, "connecting");
233        match self
234            .connect_internal(task_name, configure, ssh_tunnel_manager)
235            .await
236        {
237            Ok(t) => {
238                let backend_pid = t.backend_pid();
239                info!(%task_name, %address, %backend_pid, "connected");
240                Ok(t)
241            }
242            Err(e) => {
243                warn!(%task_name, %address, "connection failed: {e:#}");
244                Err(e)
245            }
246        }
247    }
248
249    async fn connect_internal<F>(
250        &self,
251        task_name: &str,
252        configure: F,
253        ssh_tunnel_manager: &SshTunnelManager,
254    ) -> Result<Client, PostgresError>
255    where
256        F: FnOnce(&mut tokio_postgres::Config),
257    {
258        let mut postgres_config = self.inner.clone();
259        configure(&mut postgres_config);
260
261        let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
262            mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
263            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
264        })?;
265
266        match &self.tunnel {
267            TunnelConfig::Direct { resolved_ips } => {
268                if let Some(ips) = resolved_ips {
269                    let host = match postgres_config.get_hosts() {
270                        [Host::Tcp(host)] => host,
271                        _ => bail_generic!(
272                            "only TCP connections to a single PostgreSQL server are supported"
273                        ),
274                    }
275                    .to_owned();
276                    // Associate each resolved ip with the exact same, singular host, for tls
277                    // verification. We are required to do this dance because `tokio-postgres`
278                    // enforces that the number of 'host' and 'hostaddr' values must be the same.
279                    for (idx, ip) in ips.iter().enumerate() {
280                        if idx != 0 {
281                            postgres_config.host(&host);
282                        }
283                        postgres_config.hostaddr(ip.clone());
284                    }
285                };
286
287                let (client, connection) = async move { postgres_config.connect(tls).await }
288                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
289                    .await?;
290                let client = Client::new(client, &connection);
291                task::spawn(|| task_name, connection);
292                Ok(client)
293            }
294            TunnelConfig::Ssh { config } => {
295                let (host, port) = self.address()?;
296                let tunnel = ssh_tunnel_manager
297                    .connect(
298                        config.clone(),
299                        host,
300                        port,
301                        self.ssh_timeout_config,
302                        self.in_task,
303                    )
304                    .await
305                    .map_err(PostgresError::Ssh)?;
306
307                let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
308                let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
309                    .await
310                    .map_err(PostgresError::SshIo)?;
311                // Because we are connecting to a local host/port, we don't configure any TCP
312                // keepalive settings. The connection is entirely local to the machine running the
313                // process and we trust the kernel to keep a local connection alive without keepalives.
314                //
315                // Ideally we'd be able to configure SSH to enable TCP keepalives on the other
316                // end of the tunnel, between the SSH bastion host and the PostgreSQL server,
317                // but SSH does not expose an option for this.
318                let (client, connection) =
319                    async move { postgres_config.connect_raw(tcp_stream, tls).await }
320                        .run_in_task_if(self.in_task, || "pg_connect".to_string())
321                        .await?;
322                let client = Client::new(client, &connection);
323                task::spawn(|| task_name, async {
324                    let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
325
326                    if let Err(e) = connection.await {
327                        warn!("postgres connection failed: {e}");
328                    }
329                });
330                Ok(client)
331            }
332            TunnelConfig::AwsPrivatelink { connection_id } => {
333                // This section of code is somewhat subtle. We are overriding the host
334                // for the actual TCP connection to be the PrivateLink host, but leaving the host
335                // for TLS verification as the original host. Managing the
336                // `tokio_postgres::Config` to do this is somewhat confusing, and requires we edit
337                // the singular host in place.
338
339                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
340                // `net::lookup_host` requires a port to be specified, but the port has no effect
341                // on the lookup so use a dummy one
342                let privatelink_addrs = tokio::net::lookup_host((privatelink_host, 11111)).await?;
343
344                // Override the actual IPs to connect to for the TCP connection, leaving the original host in-place
345                // for TLS verification
346                let host = match postgres_config.get_hosts() {
347                    [Host::Tcp(host)] => host,
348                    _ => bail_generic!(
349                        "only TCP connections to a single PostgreSQL server are supported"
350                    ),
351                }
352                .to_owned();
353                // Associate each resolved ip with the exact same, singular host, for tls
354                // verification. We are required to do this dance because `tokio-postgres`
355                // enforces that the number of 'host' and 'hostaddr' values must be the same.
356                for (idx, addr) in privatelink_addrs.enumerate() {
357                    if idx != 0 {
358                        postgres_config.host(&host);
359                    }
360                    postgres_config.hostaddr(addr.ip());
361                }
362
363                let (client, connection) = async move { postgres_config.connect(tls).await }
364                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
365                    .await?;
366                let client = Client::new(client, &connection);
367                task::spawn(|| task_name, connection);
368                Ok(client)
369            }
370        }
371    }
372
373    pub fn get_user(&self) -> Option<&str> {
374        self.inner.get_user()
375    }
376
377    pub fn get_dbname(&self) -> Option<&str> {
378        self.inner.get_dbname()
379    }
380}