mz_postgres_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::netio::DUMMY_DNS_PORT;
17use mz_ore::option::OptionExt;
18use mz_ore::task;
19use mz_proto::{RustType, TryFromProtoError};
20use mz_repr::CatalogItemId;
21use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
22use mz_ssh_util::tunnel_manager::SshTunnelManager;
23use proptest_derive::Arbitrary;
24use serde::{Deserialize, Serialize};
25use tokio::io::{AsyncRead, AsyncWrite};
26use tokio::net::TcpStream as TokioTcpStream;
27use tokio_postgres::config::{Host, ReplicationMode};
28use tokio_postgres::tls::MakeTlsConnect;
29use tracing::{info, warn};
30
31use crate::PostgresError;
32
33include!(concat!(env!("OUT_DIR"), "/mz_postgres_util.tunnel.rs"));
34
35macro_rules! bail_generic {
36    ($fmt:expr, $($arg:tt)*) => {
37        return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
38    };
39    ($err:expr $(,)?) => {
40        return Err(PostgresError::Generic(anyhow::anyhow!($err)))
41    };
42}
43
44/// Configures an optional tunnel for use when connecting to a PostgreSQL
45/// database.
46#[derive(Debug, PartialEq, Clone)]
47pub enum TunnelConfig {
48    /// Establish a direct TCP connection to the database host.
49    /// If `resolved_ips` is not None, the provided IPs will be used
50    /// rather than resolving the hostname.
51    Direct {
52        resolved_ips: Option<BTreeSet<IpAddr>>,
53    },
54    /// Establish a TCP connection to the database via an SSH tunnel.
55    /// This means first establishing an SSH connection to a bastion host,
56    /// and then opening a separate connection from that host to the database.
57    /// This is commonly referred by vendors as a "direct SSH tunnel", in
58    /// opposition to "reverse SSH tunnel", which is currently unsupported.
59    Ssh { config: SshTunnelConfig },
60    /// Establish a TCP connection to the database via an AWS PrivateLink
61    /// service.
62    AwsPrivatelink {
63        /// The ID of the AWS PrivateLink service.
64        connection_id: CatalogItemId,
65    },
66}
67
68pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
69
70/// A wrapper for [`tokio_postgres::Client`] that can report the server version.
71pub struct Client {
72    inner: tokio_postgres::Client,
73    server_version: Option<String>,
74}
75
76impl Client {
77    fn new<S, T>(
78        client: tokio_postgres::Client,
79        connection: &tokio_postgres::Connection<S, T>,
80    ) -> Client
81    where
82        S: AsyncRead + AsyncWrite + Unpin,
83        T: AsyncRead + AsyncWrite + Unpin,
84    {
85        let server_version = connection
86            .parameter("server_version")
87            .map(|v| v.to_string());
88        Client {
89            inner: client,
90            server_version,
91        }
92    }
93
94    /// Reports the value of the `server_version` parameter reported by the
95    /// server.
96    pub fn server_version(&self) -> Option<&str> {
97        self.server_version.as_deref()
98    }
99
100    /// Reports the postgres flavor as indicated by the server version.
101    pub fn server_flavor(&self) -> PostgresFlavor {
102        match self.server_version.as_ref() {
103            Some(v) if v.contains("-YB-") => PostgresFlavor::Yugabyte,
104            _ => PostgresFlavor::Vanilla,
105        }
106    }
107}
108
109impl Deref for Client {
110    type Target = tokio_postgres::Client;
111
112    fn deref(&self) -> &Self::Target {
113        &self.inner
114    }
115}
116
117impl DerefMut for Client {
118    fn deref_mut(&mut self) -> &mut Self::Target {
119        &mut self.inner
120    }
121}
122
123#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
124pub enum PostgresFlavor {
125    /// A normal PostgreSQL server.
126    Vanilla,
127    /// A Yugabyte server.
128    Yugabyte,
129}
130
131impl RustType<ProtoPostgresFlavor> for PostgresFlavor {
132    fn into_proto(&self) -> ProtoPostgresFlavor {
133        let kind = match self {
134            PostgresFlavor::Vanilla => proto_postgres_flavor::Kind::Vanilla(()),
135            PostgresFlavor::Yugabyte => proto_postgres_flavor::Kind::Yugabyte(()),
136        };
137        ProtoPostgresFlavor { kind: Some(kind) }
138    }
139
140    fn from_proto(proto: ProtoPostgresFlavor) -> Result<Self, TryFromProtoError> {
141        let flavor = proto
142            .kind
143            .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
144        Ok(match flavor {
145            proto_postgres_flavor::Kind::Vanilla(()) => PostgresFlavor::Vanilla,
146            proto_postgres_flavor::Kind::Yugabyte(()) => PostgresFlavor::Yugabyte,
147        })
148    }
149}
150
151/// Configuration for PostgreSQL connections.
152///
153/// This wraps [`tokio_postgres::Config`] to allow the configuration of a
154/// tunnel via a [`TunnelConfig`].
155#[derive(Clone, Debug)]
156pub struct Config {
157    inner: tokio_postgres::Config,
158    tunnel: TunnelConfig,
159    in_task: InTask,
160    ssh_timeout_config: SshTimeoutConfig,
161}
162
163impl Config {
164    pub fn new(
165        inner: tokio_postgres::Config,
166        tunnel: TunnelConfig,
167        ssh_timeout_config: SshTimeoutConfig,
168        in_task: InTask,
169    ) -> Result<Self, PostgresError> {
170        let config = Self {
171            inner,
172            tunnel,
173            in_task,
174            ssh_timeout_config,
175        };
176
177        // Early validate that the configuration contains only a single TCP
178        // server.
179        config.address()?;
180
181        Ok(config)
182    }
183
184    /// Connects to the configured PostgreSQL database.
185    pub async fn connect(
186        &self,
187        task_name: &str,
188        ssh_tunnel_manager: &SshTunnelManager,
189    ) -> Result<Client, PostgresError> {
190        self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
191            .await
192    }
193
194    /// Starts a replication connection to the configured PostgreSQL database.
195    pub async fn connect_replication(
196        &self,
197        ssh_tunnel_manager: &SshTunnelManager,
198    ) -> Result<Client, PostgresError> {
199        self.connect_traced(
200            "postgres_connect_replication",
201            |config| {
202                config.replication_mode(ReplicationMode::Logical);
203            },
204            ssh_tunnel_manager,
205        )
206        .await
207    }
208
209    fn address(&self) -> Result<(&str, u16), PostgresError> {
210        match (self.inner.get_hosts(), self.inner.get_ports()) {
211            ([Host::Tcp(host)], [port]) => Ok((host, *port)),
212            _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
213        }
214    }
215
216    async fn connect_traced<F>(
217        &self,
218        task_name: &str,
219        configure: F,
220        ssh_tunnel_manager: &SshTunnelManager,
221    ) -> Result<Client, PostgresError>
222    where
223        F: FnOnce(&mut tokio_postgres::Config),
224    {
225        let (host, port) = self.address()?;
226        let address = format!(
227            "{}@{}:{}/{}",
228            self.get_user().display_or("<unknown-user>"),
229            host,
230            port,
231            self.get_dbname().display_or("<unknown-dbname>")
232        );
233        info!(%task_name, %address, "connecting");
234        match self
235            .connect_internal(task_name, configure, ssh_tunnel_manager)
236            .await
237        {
238            Ok(t) => {
239                let backend_pid = t.backend_pid();
240                info!(%task_name, %address, %backend_pid, "connected");
241                Ok(t)
242            }
243            Err(e) => {
244                warn!(%task_name, %address, "connection failed: {e:#}");
245                Err(e)
246            }
247        }
248    }
249
250    async fn connect_internal<F>(
251        &self,
252        task_name: &str,
253        configure: F,
254        ssh_tunnel_manager: &SshTunnelManager,
255    ) -> Result<Client, PostgresError>
256    where
257        F: FnOnce(&mut tokio_postgres::Config),
258    {
259        let mut postgres_config = self.inner.clone();
260        configure(&mut postgres_config);
261
262        let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
263            mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
264            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
265        })?;
266
267        match &self.tunnel {
268            TunnelConfig::Direct { resolved_ips } => {
269                if let Some(ips) = resolved_ips {
270                    let host = match postgres_config.get_hosts() {
271                        [Host::Tcp(host)] => host,
272                        _ => bail_generic!(
273                            "only TCP connections to a single PostgreSQL server are supported"
274                        ),
275                    }
276                    .to_owned();
277                    // Associate each resolved ip with the exact same, singular host, for tls
278                    // verification. We are required to do this dance because `tokio-postgres`
279                    // enforces that the number of 'host' and 'hostaddr' values must be the same.
280                    for (idx, ip) in ips.iter().enumerate() {
281                        if idx != 0 {
282                            postgres_config.host(&host);
283                        }
284                        postgres_config.hostaddr(ip.clone());
285                    }
286                };
287
288                let (client, connection) = async move { postgres_config.connect(tls).await }
289                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
290                    .await?;
291                let client = Client::new(client, &connection);
292                task::spawn(|| task_name, connection);
293                Ok(client)
294            }
295            TunnelConfig::Ssh { config } => {
296                let (host, port) = self.address()?;
297                let tunnel = ssh_tunnel_manager
298                    .connect(
299                        config.clone(),
300                        host,
301                        port,
302                        self.ssh_timeout_config,
303                        self.in_task,
304                    )
305                    .await
306                    .map_err(PostgresError::Ssh)?;
307
308                let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
309                let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
310                    .await
311                    .map_err(PostgresError::SshIo)?;
312                // Because we are connecting to a local host/port, we don't configure any TCP
313                // keepalive settings. The connection is entirely local to the machine running the
314                // process and we trust the kernel to keep a local connection alive without keepalives.
315                //
316                // Ideally we'd be able to configure SSH to enable TCP keepalives on the other
317                // end of the tunnel, between the SSH bastion host and the PostgreSQL server,
318                // but SSH does not expose an option for this.
319                let (client, connection) =
320                    async move { postgres_config.connect_raw(tcp_stream, tls).await }
321                        .run_in_task_if(self.in_task, || "pg_connect".to_string())
322                        .await?;
323                let client = Client::new(client, &connection);
324                task::spawn(|| task_name, async {
325                    let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
326
327                    if let Err(e) = connection.await {
328                        warn!("postgres connection failed: {e}");
329                    }
330                });
331                Ok(client)
332            }
333            TunnelConfig::AwsPrivatelink { connection_id } => {
334                // This section of code is somewhat subtle. We are overriding the host
335                // for the actual TCP connection to be the PrivateLink host, but leaving the host
336                // for TLS verification as the original host. Managing the
337                // `tokio_postgres::Config` to do this is somewhat confusing, and requires we edit
338                // the singular host in place.
339
340                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
341                let privatelink_addrs =
342                    tokio::net::lookup_host((privatelink_host, DUMMY_DNS_PORT)).await?;
343
344                // Override the actual IPs to connect to for the TCP connection, leaving the original host in-place
345                // for TLS verification
346                let host = match postgres_config.get_hosts() {
347                    [Host::Tcp(host)] => host,
348                    _ => bail_generic!(
349                        "only TCP connections to a single PostgreSQL server are supported"
350                    ),
351                }
352                .to_owned();
353                // Associate each resolved ip with the exact same, singular host, for tls
354                // verification. We are required to do this dance because `tokio-postgres`
355                // enforces that the number of 'host' and 'hostaddr' values must be the same.
356                for (idx, addr) in privatelink_addrs.enumerate() {
357                    if idx != 0 {
358                        postgres_config.host(&host);
359                    }
360                    postgres_config.hostaddr(addr.ip());
361                }
362
363                let (client, connection) = async move { postgres_config.connect(tls).await }
364                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
365                    .await?;
366                let client = Client::new(client, &connection);
367                task::spawn(|| task_name, connection);
368                Ok(client)
369            }
370        }
371    }
372
373    pub fn get_user(&self) -> Option<&str> {
374        self.inner.get_user()
375    }
376
377    pub fn get_dbname(&self) -> Option<&str> {
378        self.inner.get_dbname()
379    }
380}