Skip to main content

mz_postgres_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::option::OptionExt;
17use mz_ore::task::{self, AbortOnDropHandle};
18use mz_repr::CatalogItemId;
19use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
20use mz_ssh_util::tunnel_manager::SshTunnelManager;
21use tokio::net::TcpStream as TokioTcpStream;
22use tokio_postgres::config::{Host, ReplicationMode};
23use tokio_postgres::tls::MakeTlsConnect;
24use tracing::{info, warn};
25
26use crate::PostgresError;
27
28macro_rules! bail_generic {
29    ($err:expr $(,)?) => {
30        return Err(PostgresError::Generic(anyhow::anyhow!($err)))
31    };
32}
33
34/// Configures an optional tunnel for use when connecting to a PostgreSQL
35/// database.
36#[derive(Debug, PartialEq, Clone)]
37pub enum TunnelConfig {
38    /// Establish a direct TCP connection to the database host.
39    /// If `resolved_ips` is not None, the provided IPs will be used
40    /// rather than resolving the hostname.
41    Direct {
42        resolved_ips: Option<BTreeSet<IpAddr>>,
43    },
44    /// Establish a TCP connection to the database via an SSH tunnel.
45    /// This means first establishing an SSH connection to a bastion host,
46    /// and then opening a separate connection from that host to the database.
47    /// This is commonly referred by vendors as a "direct SSH tunnel", in
48    /// opposition to "reverse SSH tunnel", which is currently unsupported.
49    Ssh { config: SshTunnelConfig },
50    /// Establish a TCP connection to the database via an AWS PrivateLink
51    /// service.
52    AwsPrivatelink {
53        /// The ID of the AWS PrivateLink service.
54        connection_id: CatalogItemId,
55    },
56}
57
58pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
59
60/// A wrapper for [`tokio_postgres::Client`] that can report the server version.
61pub struct Client {
62    inner: tokio_postgres::Client,
63    // Holds a handle to the task with the connection to ensure that when
64    // the client is dropped, the task can be aborted to close the connection.
65    // This is also useful for maintaining the lifetimes of dependent object (e.g. ssh tunnel).
66    _connection_handle: AbortOnDropHandle<()>,
67}
68
69impl Deref for Client {
70    type Target = tokio_postgres::Client;
71
72    fn deref(&self) -> &Self::Target {
73        &self.inner
74    }
75}
76
77impl DerefMut for Client {
78    fn deref_mut(&mut self) -> &mut Self::Target {
79        &mut self.inner
80    }
81}
82
83/// Configuration for PostgreSQL connections.
84///
85/// This wraps [`tokio_postgres::Config`] to allow the configuration of a
86/// tunnel via a [`TunnelConfig`].
87#[derive(Clone, Debug)]
88pub struct Config {
89    inner: tokio_postgres::Config,
90    tunnel: TunnelConfig,
91    in_task: InTask,
92    ssh_timeout_config: SshTimeoutConfig,
93}
94
95impl Config {
96    pub fn new(
97        inner: tokio_postgres::Config,
98        tunnel: TunnelConfig,
99        ssh_timeout_config: SshTimeoutConfig,
100        in_task: InTask,
101    ) -> Result<Self, PostgresError> {
102        let config = Self {
103            inner,
104            tunnel,
105            in_task,
106            ssh_timeout_config,
107        };
108
109        // Early validate that the configuration contains only a single TCP
110        // server.
111        config.address()?;
112
113        Ok(config)
114    }
115
116    /// Connects to the configured PostgreSQL database.
117    pub async fn connect(
118        &self,
119        task_name: &str,
120        ssh_tunnel_manager: &SshTunnelManager,
121    ) -> Result<Client, PostgresError> {
122        self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
123            .await
124    }
125
126    /// Starts a replication connection to the configured PostgreSQL database.
127    pub async fn connect_replication(
128        &self,
129        ssh_tunnel_manager: &SshTunnelManager,
130    ) -> Result<Client, PostgresError> {
131        self.connect_traced(
132            "postgres_connect_replication",
133            |config| {
134                config.replication_mode(ReplicationMode::Logical);
135            },
136            ssh_tunnel_manager,
137        )
138        .await
139    }
140
141    fn address(&self) -> Result<(&str, u16), PostgresError> {
142        match (self.inner.get_hosts(), self.inner.get_ports()) {
143            ([Host::Tcp(host)], [port]) => Ok((host, *port)),
144            _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
145        }
146    }
147
148    async fn connect_traced<F>(
149        &self,
150        task_name: &str,
151        configure: F,
152        ssh_tunnel_manager: &SshTunnelManager,
153    ) -> Result<Client, PostgresError>
154    where
155        F: FnOnce(&mut tokio_postgres::Config),
156    {
157        let (host, port) = self.address()?;
158        let address = format!(
159            "{}@{}:{}/{}",
160            self.get_user().display_or("<unknown-user>"),
161            host,
162            port,
163            self.get_dbname().display_or("<unknown-dbname>")
164        );
165        info!(%task_name, %address, "connecting");
166        match self
167            .connect_internal(task_name, configure, ssh_tunnel_manager)
168            .await
169        {
170            Ok(t) => {
171                let backend_pid = t.backend_pid();
172                info!(%task_name, %address, %backend_pid, "connected");
173                Ok(t)
174            }
175            Err(e) => {
176                warn!(%task_name, %address, "connection failed: {e:#}");
177                Err(e)
178            }
179        }
180    }
181
182    async fn connect_internal<F>(
183        &self,
184        task_name: &str,
185        configure: F,
186        ssh_tunnel_manager: &SshTunnelManager,
187    ) -> Result<Client, PostgresError>
188    where
189        F: FnOnce(&mut tokio_postgres::Config),
190    {
191        let mut postgres_config = self.inner.clone();
192        configure(&mut postgres_config);
193
194        let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
195            mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
196            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
197        })?;
198
199        match &self.tunnel {
200            TunnelConfig::Direct { resolved_ips } => {
201                if let Some(ips) = resolved_ips {
202                    let host = match postgres_config.get_hosts() {
203                        [Host::Tcp(host)] => host,
204                        _ => bail_generic!(
205                            "only TCP connections to a single PostgreSQL server are supported"
206                        ),
207                    }
208                    .to_owned();
209                    // Associate each resolved ip with the exact same, singular host, for tls
210                    // verification. We are required to do this dance because `tokio-postgres`
211                    // enforces that the number of 'host' and 'hostaddr' values must be the same.
212                    for (idx, ip) in ips.iter().enumerate() {
213                        if idx != 0 {
214                            postgres_config.host(&host);
215                        }
216                        postgres_config.hostaddr(ip.clone());
217                    }
218                };
219
220                let (client, connection) = async move { postgres_config.connect(tls).await }
221                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
222                    .await?;
223
224                let client = Client {
225                    inner: client,
226                    _connection_handle: task::spawn(|| task_name, async {
227                        if let Err(e) = connection.await {
228                            warn!("postgres direct connection failed: {e}");
229                        }
230                    })
231                    .abort_on_drop(),
232                };
233                Ok(client)
234            }
235            TunnelConfig::Ssh { config } => {
236                let (host, port) = self.address()?;
237                let tunnel = ssh_tunnel_manager
238                    .connect(
239                        config.clone(),
240                        host,
241                        port,
242                        self.ssh_timeout_config,
243                        self.in_task,
244                    )
245                    .await
246                    .map_err(PostgresError::Ssh)?;
247
248                let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
249                let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
250                    .await
251                    .map_err(PostgresError::SshIo)?;
252                // Because we are connecting to a local host/port, we don't configure any TCP
253                // keepalive settings. The connection is entirely local to the machine running the
254                // process and we trust the kernel to keep a local connection alive without keepalives.
255                //
256                // Ideally we'd be able to configure SSH to enable TCP keepalives on the other
257                // end of the tunnel, between the SSH bastion host and the PostgreSQL server,
258                // but SSH does not expose an option for this.
259                let (client, connection) =
260                    async move { postgres_config.connect_raw(tcp_stream, tls).await }
261                        .run_in_task_if(self.in_task, || "pg_connect".to_string())
262                        .await?;
263
264                let client = Client {
265                    inner: client,
266                    _connection_handle: task::spawn(|| task_name, async {
267                        let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
268                        if let Err(e) = connection.await {
269                            warn!("postgres via SSH tunnel connection failed: {e}");
270                        }
271                    })
272                    .abort_on_drop(),
273                };
274                Ok(client)
275            }
276            TunnelConfig::AwsPrivatelink { connection_id } => {
277                // This section of code is somewhat subtle. We are overriding the host
278                // for the actual TCP connection to be the PrivateLink host, but leaving the host
279                // for TLS verification as the original host. Managing the
280                // `tokio_postgres::Config` to do this is somewhat confusing, and requires we edit
281                // the singular host in place.
282
283                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
284                let privatelink_addrs = tokio::net::lookup_host((privatelink_host, 0)).await?;
285
286                // Override the actual IPs to connect to for the TCP connection, leaving the original host in-place
287                // for TLS verification
288                let host = match postgres_config.get_hosts() {
289                    [Host::Tcp(host)] => host,
290                    _ => bail_generic!(
291                        "only TCP connections to a single PostgreSQL server are supported"
292                    ),
293                }
294                .to_owned();
295                // Associate each resolved ip with the exact same, singular host, for tls
296                // verification. We are required to do this dance because `tokio-postgres`
297                // enforces that the number of 'host' and 'hostaddr' values must be the same.
298                for (idx, addr) in privatelink_addrs.enumerate() {
299                    if idx != 0 {
300                        postgres_config.host(&host);
301                    }
302                    postgres_config.hostaddr(addr.ip());
303                }
304
305                let (client, connection) = async move { postgres_config.connect(tls).await }
306                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
307                    .await?;
308
309                let client = Client {
310                    inner: client,
311                    _connection_handle: task::spawn(|| task_name, async {
312                        if let Err(e) = connection.await {
313                            warn!("postgres AWS link connection failed: {e}");
314                        }
315                    })
316                    .abort_on_drop(),
317                };
318                Ok(client)
319            }
320        }
321    }
322
323    pub fn get_user(&self) -> Option<&str> {
324        self.inner.get_user()
325    }
326
327    pub fn get_dbname(&self) -> Option<&str> {
328        self.inner.get_dbname()
329    }
330}