mz_postgres_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::netio::DUMMY_DNS_PORT;
17use mz_ore::option::OptionExt;
18use mz_ore::task::{self, AbortOnDropHandle};
19use mz_repr::CatalogItemId;
20use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
21use mz_ssh_util::tunnel_manager::SshTunnelManager;
22use tokio::net::TcpStream as TokioTcpStream;
23use tokio_postgres::config::{Host, ReplicationMode};
24use tokio_postgres::tls::MakeTlsConnect;
25use tracing::{info, warn};
26
27use crate::PostgresError;
28
29macro_rules! bail_generic {
30    ($fmt:expr, $($arg:tt)*) => {
31        return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
32    };
33    ($err:expr $(,)?) => {
34        return Err(PostgresError::Generic(anyhow::anyhow!($err)))
35    };
36}
37
38/// Configures an optional tunnel for use when connecting to a PostgreSQL
39/// database.
40#[derive(Debug, PartialEq, Clone)]
41pub enum TunnelConfig {
42    /// Establish a direct TCP connection to the database host.
43    /// If `resolved_ips` is not None, the provided IPs will be used
44    /// rather than resolving the hostname.
45    Direct {
46        resolved_ips: Option<BTreeSet<IpAddr>>,
47    },
48    /// Establish a TCP connection to the database via an SSH tunnel.
49    /// This means first establishing an SSH connection to a bastion host,
50    /// and then opening a separate connection from that host to the database.
51    /// This is commonly referred by vendors as a "direct SSH tunnel", in
52    /// opposition to "reverse SSH tunnel", which is currently unsupported.
53    Ssh { config: SshTunnelConfig },
54    /// Establish a TCP connection to the database via an AWS PrivateLink
55    /// service.
56    AwsPrivatelink {
57        /// The ID of the AWS PrivateLink service.
58        connection_id: CatalogItemId,
59    },
60}
61
62pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
63
64/// A wrapper for [`tokio_postgres::Client`] that can report the server version.
65pub struct Client {
66    inner: tokio_postgres::Client,
67    // Holds a handle to the task with the connection to ensure that when
68    // the client is dropped, the task can be aborted to close the connection.
69    // This is also useful for maintaining the lifetimes of dependent object (e.g. ssh tunnel).
70    _connection_handle: AbortOnDropHandle<()>,
71}
72
73impl Deref for Client {
74    type Target = tokio_postgres::Client;
75
76    fn deref(&self) -> &Self::Target {
77        &self.inner
78    }
79}
80
81impl DerefMut for Client {
82    fn deref_mut(&mut self) -> &mut Self::Target {
83        &mut self.inner
84    }
85}
86
87/// Configuration for PostgreSQL connections.
88///
89/// This wraps [`tokio_postgres::Config`] to allow the configuration of a
90/// tunnel via a [`TunnelConfig`].
91#[derive(Clone, Debug)]
92pub struct Config {
93    inner: tokio_postgres::Config,
94    tunnel: TunnelConfig,
95    in_task: InTask,
96    ssh_timeout_config: SshTimeoutConfig,
97}
98
99impl Config {
100    pub fn new(
101        inner: tokio_postgres::Config,
102        tunnel: TunnelConfig,
103        ssh_timeout_config: SshTimeoutConfig,
104        in_task: InTask,
105    ) -> Result<Self, PostgresError> {
106        let config = Self {
107            inner,
108            tunnel,
109            in_task,
110            ssh_timeout_config,
111        };
112
113        // Early validate that the configuration contains only a single TCP
114        // server.
115        config.address()?;
116
117        Ok(config)
118    }
119
120    /// Connects to the configured PostgreSQL database.
121    pub async fn connect(
122        &self,
123        task_name: &str,
124        ssh_tunnel_manager: &SshTunnelManager,
125    ) -> Result<Client, PostgresError> {
126        self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
127            .await
128    }
129
130    /// Starts a replication connection to the configured PostgreSQL database.
131    pub async fn connect_replication(
132        &self,
133        ssh_tunnel_manager: &SshTunnelManager,
134    ) -> Result<Client, PostgresError> {
135        self.connect_traced(
136            "postgres_connect_replication",
137            |config| {
138                config.replication_mode(ReplicationMode::Logical);
139            },
140            ssh_tunnel_manager,
141        )
142        .await
143    }
144
145    fn address(&self) -> Result<(&str, u16), PostgresError> {
146        match (self.inner.get_hosts(), self.inner.get_ports()) {
147            ([Host::Tcp(host)], [port]) => Ok((host, *port)),
148            _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
149        }
150    }
151
152    async fn connect_traced<F>(
153        &self,
154        task_name: &str,
155        configure: F,
156        ssh_tunnel_manager: &SshTunnelManager,
157    ) -> Result<Client, PostgresError>
158    where
159        F: FnOnce(&mut tokio_postgres::Config),
160    {
161        let (host, port) = self.address()?;
162        let address = format!(
163            "{}@{}:{}/{}",
164            self.get_user().display_or("<unknown-user>"),
165            host,
166            port,
167            self.get_dbname().display_or("<unknown-dbname>")
168        );
169        info!(%task_name, %address, "connecting");
170        match self
171            .connect_internal(task_name, configure, ssh_tunnel_manager)
172            .await
173        {
174            Ok(t) => {
175                let backend_pid = t.backend_pid();
176                info!(%task_name, %address, %backend_pid, "connected");
177                Ok(t)
178            }
179            Err(e) => {
180                warn!(%task_name, %address, "connection failed: {e:#}");
181                Err(e)
182            }
183        }
184    }
185
186    async fn connect_internal<F>(
187        &self,
188        task_name: &str,
189        configure: F,
190        ssh_tunnel_manager: &SshTunnelManager,
191    ) -> Result<Client, PostgresError>
192    where
193        F: FnOnce(&mut tokio_postgres::Config),
194    {
195        let mut postgres_config = self.inner.clone();
196        configure(&mut postgres_config);
197
198        let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
199            mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
200            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
201        })?;
202
203        match &self.tunnel {
204            TunnelConfig::Direct { resolved_ips } => {
205                if let Some(ips) = resolved_ips {
206                    let host = match postgres_config.get_hosts() {
207                        [Host::Tcp(host)] => host,
208                        _ => bail_generic!(
209                            "only TCP connections to a single PostgreSQL server are supported"
210                        ),
211                    }
212                    .to_owned();
213                    // Associate each resolved ip with the exact same, singular host, for tls
214                    // verification. We are required to do this dance because `tokio-postgres`
215                    // enforces that the number of 'host' and 'hostaddr' values must be the same.
216                    for (idx, ip) in ips.iter().enumerate() {
217                        if idx != 0 {
218                            postgres_config.host(&host);
219                        }
220                        postgres_config.hostaddr(ip.clone());
221                    }
222                };
223
224                let (client, connection) = async move { postgres_config.connect(tls).await }
225                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
226                    .await?;
227
228                let client = Client {
229                    inner: client,
230                    _connection_handle: task::spawn(|| task_name, async {
231                        if let Err(e) = connection.await {
232                            warn!("postgres direct connection failed: {e}");
233                        }
234                    })
235                    .abort_on_drop(),
236                };
237                Ok(client)
238            }
239            TunnelConfig::Ssh { config } => {
240                let (host, port) = self.address()?;
241                let tunnel = ssh_tunnel_manager
242                    .connect(
243                        config.clone(),
244                        host,
245                        port,
246                        self.ssh_timeout_config,
247                        self.in_task,
248                    )
249                    .await
250                    .map_err(PostgresError::Ssh)?;
251
252                let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
253                let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
254                    .await
255                    .map_err(PostgresError::SshIo)?;
256                // Because we are connecting to a local host/port, we don't configure any TCP
257                // keepalive settings. The connection is entirely local to the machine running the
258                // process and we trust the kernel to keep a local connection alive without keepalives.
259                //
260                // Ideally we'd be able to configure SSH to enable TCP keepalives on the other
261                // end of the tunnel, between the SSH bastion host and the PostgreSQL server,
262                // but SSH does not expose an option for this.
263                let (client, connection) =
264                    async move { postgres_config.connect_raw(tcp_stream, tls).await }
265                        .run_in_task_if(self.in_task, || "pg_connect".to_string())
266                        .await?;
267
268                let client = Client {
269                    inner: client,
270                    _connection_handle: task::spawn(|| task_name, async {
271                        let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
272                        if let Err(e) = connection.await {
273                            warn!("postgres via SSH tunnel connection failed: {e}");
274                        }
275                    })
276                    .abort_on_drop(),
277                };
278                Ok(client)
279            }
280            TunnelConfig::AwsPrivatelink { connection_id } => {
281                // This section of code is somewhat subtle. We are overriding the host
282                // for the actual TCP connection to be the PrivateLink host, but leaving the host
283                // for TLS verification as the original host. Managing the
284                // `tokio_postgres::Config` to do this is somewhat confusing, and requires we edit
285                // the singular host in place.
286
287                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
288                let privatelink_addrs =
289                    tokio::net::lookup_host((privatelink_host, DUMMY_DNS_PORT)).await?;
290
291                // Override the actual IPs to connect to for the TCP connection, leaving the original host in-place
292                // for TLS verification
293                let host = match postgres_config.get_hosts() {
294                    [Host::Tcp(host)] => host,
295                    _ => bail_generic!(
296                        "only TCP connections to a single PostgreSQL server are supported"
297                    ),
298                }
299                .to_owned();
300                // Associate each resolved ip with the exact same, singular host, for tls
301                // verification. We are required to do this dance because `tokio-postgres`
302                // enforces that the number of 'host' and 'hostaddr' values must be the same.
303                for (idx, addr) in privatelink_addrs.enumerate() {
304                    if idx != 0 {
305                        postgres_config.host(&host);
306                    }
307                    postgres_config.hostaddr(addr.ip());
308                }
309
310                let (client, connection) = async move { postgres_config.connect(tls).await }
311                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
312                    .await?;
313
314                let client = Client {
315                    inner: client,
316                    _connection_handle: task::spawn(|| task_name, async {
317                        if let Err(e) = connection.await {
318                            warn!("postgres AWS link connection failed: {e}");
319                        }
320                    })
321                    .abort_on_drop(),
322                };
323                Ok(client)
324            }
325        }
326    }
327
328    pub fn get_user(&self) -> Option<&str> {
329        self.inner.get_user()
330    }
331
332    pub fn get_dbname(&self) -> Option<&str> {
333        self.inner.get_dbname()
334    }
335}