mz_postgres_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::option::OptionExt;
17use mz_ore::task::{self, AbortOnDropHandle};
18use mz_repr::CatalogItemId;
19use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
20use mz_ssh_util::tunnel_manager::SshTunnelManager;
21use tokio::net::TcpStream as TokioTcpStream;
22use tokio_postgres::config::{Host, ReplicationMode};
23use tokio_postgres::tls::MakeTlsConnect;
24use tracing::{info, warn};
25
26use crate::PostgresError;
27
28macro_rules! bail_generic {
29    ($fmt:expr, $($arg:tt)*) => {
30        return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
31    };
32    ($err:expr $(,)?) => {
33        return Err(PostgresError::Generic(anyhow::anyhow!($err)))
34    };
35}
36
37/// Configures an optional tunnel for use when connecting to a PostgreSQL
38/// database.
39#[derive(Debug, PartialEq, Clone)]
40pub enum TunnelConfig {
41    /// Establish a direct TCP connection to the database host.
42    /// If `resolved_ips` is not None, the provided IPs will be used
43    /// rather than resolving the hostname.
44    Direct {
45        resolved_ips: Option<BTreeSet<IpAddr>>,
46    },
47    /// Establish a TCP connection to the database via an SSH tunnel.
48    /// This means first establishing an SSH connection to a bastion host,
49    /// and then opening a separate connection from that host to the database.
50    /// This is commonly referred by vendors as a "direct SSH tunnel", in
51    /// opposition to "reverse SSH tunnel", which is currently unsupported.
52    Ssh { config: SshTunnelConfig },
53    /// Establish a TCP connection to the database via an AWS PrivateLink
54    /// service.
55    AwsPrivatelink {
56        /// The ID of the AWS PrivateLink service.
57        connection_id: CatalogItemId,
58    },
59}
60
61pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
62
63/// A wrapper for [`tokio_postgres::Client`] that can report the server version.
64pub struct Client {
65    inner: tokio_postgres::Client,
66    // Holds a handle to the task with the connection to ensure that when
67    // the client is dropped, the task can be aborted to close the connection.
68    // This is also useful for maintaining the lifetimes of dependent object (e.g. ssh tunnel).
69    _connection_handle: AbortOnDropHandle<()>,
70}
71
72impl Deref for Client {
73    type Target = tokio_postgres::Client;
74
75    fn deref(&self) -> &Self::Target {
76        &self.inner
77    }
78}
79
80impl DerefMut for Client {
81    fn deref_mut(&mut self) -> &mut Self::Target {
82        &mut self.inner
83    }
84}
85
86/// Configuration for PostgreSQL connections.
87///
88/// This wraps [`tokio_postgres::Config`] to allow the configuration of a
89/// tunnel via a [`TunnelConfig`].
90#[derive(Clone, Debug)]
91pub struct Config {
92    inner: tokio_postgres::Config,
93    tunnel: TunnelConfig,
94    in_task: InTask,
95    ssh_timeout_config: SshTimeoutConfig,
96}
97
98impl Config {
99    pub fn new(
100        inner: tokio_postgres::Config,
101        tunnel: TunnelConfig,
102        ssh_timeout_config: SshTimeoutConfig,
103        in_task: InTask,
104    ) -> Result<Self, PostgresError> {
105        let config = Self {
106            inner,
107            tunnel,
108            in_task,
109            ssh_timeout_config,
110        };
111
112        // Early validate that the configuration contains only a single TCP
113        // server.
114        config.address()?;
115
116        Ok(config)
117    }
118
119    /// Connects to the configured PostgreSQL database.
120    pub async fn connect(
121        &self,
122        task_name: &str,
123        ssh_tunnel_manager: &SshTunnelManager,
124    ) -> Result<Client, PostgresError> {
125        self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
126            .await
127    }
128
129    /// Starts a replication connection to the configured PostgreSQL database.
130    pub async fn connect_replication(
131        &self,
132        ssh_tunnel_manager: &SshTunnelManager,
133    ) -> Result<Client, PostgresError> {
134        self.connect_traced(
135            "postgres_connect_replication",
136            |config| {
137                config.replication_mode(ReplicationMode::Logical);
138            },
139            ssh_tunnel_manager,
140        )
141        .await
142    }
143
144    fn address(&self) -> Result<(&str, u16), PostgresError> {
145        match (self.inner.get_hosts(), self.inner.get_ports()) {
146            ([Host::Tcp(host)], [port]) => Ok((host, *port)),
147            _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
148        }
149    }
150
151    async fn connect_traced<F>(
152        &self,
153        task_name: &str,
154        configure: F,
155        ssh_tunnel_manager: &SshTunnelManager,
156    ) -> Result<Client, PostgresError>
157    where
158        F: FnOnce(&mut tokio_postgres::Config),
159    {
160        let (host, port) = self.address()?;
161        let address = format!(
162            "{}@{}:{}/{}",
163            self.get_user().display_or("<unknown-user>"),
164            host,
165            port,
166            self.get_dbname().display_or("<unknown-dbname>")
167        );
168        info!(%task_name, %address, "connecting");
169        match self
170            .connect_internal(task_name, configure, ssh_tunnel_manager)
171            .await
172        {
173            Ok(t) => {
174                let backend_pid = t.backend_pid();
175                info!(%task_name, %address, %backend_pid, "connected");
176                Ok(t)
177            }
178            Err(e) => {
179                warn!(%task_name, %address, "connection failed: {e:#}");
180                Err(e)
181            }
182        }
183    }
184
185    async fn connect_internal<F>(
186        &self,
187        task_name: &str,
188        configure: F,
189        ssh_tunnel_manager: &SshTunnelManager,
190    ) -> Result<Client, PostgresError>
191    where
192        F: FnOnce(&mut tokio_postgres::Config),
193    {
194        let mut postgres_config = self.inner.clone();
195        configure(&mut postgres_config);
196
197        let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
198            mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
199            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
200        })?;
201
202        match &self.tunnel {
203            TunnelConfig::Direct { resolved_ips } => {
204                if let Some(ips) = resolved_ips {
205                    let host = match postgres_config.get_hosts() {
206                        [Host::Tcp(host)] => host,
207                        _ => bail_generic!(
208                            "only TCP connections to a single PostgreSQL server are supported"
209                        ),
210                    }
211                    .to_owned();
212                    // Associate each resolved ip with the exact same, singular host, for tls
213                    // verification. We are required to do this dance because `tokio-postgres`
214                    // enforces that the number of 'host' and 'hostaddr' values must be the same.
215                    for (idx, ip) in ips.iter().enumerate() {
216                        if idx != 0 {
217                            postgres_config.host(&host);
218                        }
219                        postgres_config.hostaddr(ip.clone());
220                    }
221                };
222
223                let (client, connection) = async move { postgres_config.connect(tls).await }
224                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
225                    .await?;
226
227                let client = Client {
228                    inner: client,
229                    _connection_handle: task::spawn(|| task_name, async {
230                        if let Err(e) = connection.await {
231                            warn!("postgres direct connection failed: {e}");
232                        }
233                    })
234                    .abort_on_drop(),
235                };
236                Ok(client)
237            }
238            TunnelConfig::Ssh { config } => {
239                let (host, port) = self.address()?;
240                let tunnel = ssh_tunnel_manager
241                    .connect(
242                        config.clone(),
243                        host,
244                        port,
245                        self.ssh_timeout_config,
246                        self.in_task,
247                    )
248                    .await
249                    .map_err(PostgresError::Ssh)?;
250
251                let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
252                let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
253                    .await
254                    .map_err(PostgresError::SshIo)?;
255                // Because we are connecting to a local host/port, we don't configure any TCP
256                // keepalive settings. The connection is entirely local to the machine running the
257                // process and we trust the kernel to keep a local connection alive without keepalives.
258                //
259                // Ideally we'd be able to configure SSH to enable TCP keepalives on the other
260                // end of the tunnel, between the SSH bastion host and the PostgreSQL server,
261                // but SSH does not expose an option for this.
262                let (client, connection) =
263                    async move { postgres_config.connect_raw(tcp_stream, tls).await }
264                        .run_in_task_if(self.in_task, || "pg_connect".to_string())
265                        .await?;
266
267                let client = Client {
268                    inner: client,
269                    _connection_handle: task::spawn(|| task_name, async {
270                        let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
271                        if let Err(e) = connection.await {
272                            warn!("postgres via SSH tunnel connection failed: {e}");
273                        }
274                    })
275                    .abort_on_drop(),
276                };
277                Ok(client)
278            }
279            TunnelConfig::AwsPrivatelink { connection_id } => {
280                // This section of code is somewhat subtle. We are overriding the host
281                // for the actual TCP connection to be the PrivateLink host, but leaving the host
282                // for TLS verification as the original host. Managing the
283                // `tokio_postgres::Config` to do this is somewhat confusing, and requires we edit
284                // the singular host in place.
285
286                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
287                let privatelink_addrs = tokio::net::lookup_host((privatelink_host, 0)).await?;
288
289                // Override the actual IPs to connect to for the TCP connection, leaving the original host in-place
290                // for TLS verification
291                let host = match postgres_config.get_hosts() {
292                    [Host::Tcp(host)] => host,
293                    _ => bail_generic!(
294                        "only TCP connections to a single PostgreSQL server are supported"
295                    ),
296                }
297                .to_owned();
298                // Associate each resolved ip with the exact same, singular host, for tls
299                // verification. We are required to do this dance because `tokio-postgres`
300                // enforces that the number of 'host' and 'hostaddr' values must be the same.
301                for (idx, addr) in privatelink_addrs.enumerate() {
302                    if idx != 0 {
303                        postgres_config.host(&host);
304                    }
305                    postgres_config.hostaddr(addr.ip());
306                }
307
308                let (client, connection) = async move { postgres_config.connect(tls).await }
309                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
310                    .await?;
311
312                let client = Client {
313                    inner: client,
314                    _connection_handle: task::spawn(|| task_name, async {
315                        if let Err(e) = connection.await {
316                            warn!("postgres AWS link connection failed: {e}");
317                        }
318                    })
319                    .abort_on_drop(),
320                };
321                Ok(client)
322            }
323        }
324    }
325
326    pub fn get_user(&self) -> Option<&str> {
327        self.inner.get_user()
328    }
329
330    pub fn get_dbname(&self) -> Option<&str> {
331        self.inner.get_dbname()
332    }
333}