mz_postgres_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::net::IpAddr;
12use std::ops::{Deref, DerefMut};
13use std::time::Duration;
14
15use mz_ore::future::{InTask, OreFutureExt};
16use mz_ore::netio::DUMMY_DNS_PORT;
17use mz_ore::option::OptionExt;
18use mz_ore::task;
19use mz_repr::CatalogItemId;
20use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
21use mz_ssh_util::tunnel_manager::SshTunnelManager;
22use tokio::io::{AsyncRead, AsyncWrite};
23use tokio::net::TcpStream as TokioTcpStream;
24use tokio_postgres::config::{Host, ReplicationMode};
25use tokio_postgres::tls::MakeTlsConnect;
26use tracing::{info, warn};
27
28use crate::PostgresError;
29
30macro_rules! bail_generic {
31    ($fmt:expr, $($arg:tt)*) => {
32        return Err(PostgresError::Generic(anyhow::anyhow!($fmt, $($arg)*)))
33    };
34    ($err:expr $(,)?) => {
35        return Err(PostgresError::Generic(anyhow::anyhow!($err)))
36    };
37}
38
39/// Configures an optional tunnel for use when connecting to a PostgreSQL
40/// database.
41#[derive(Debug, PartialEq, Clone)]
42pub enum TunnelConfig {
43    /// Establish a direct TCP connection to the database host.
44    /// If `resolved_ips` is not None, the provided IPs will be used
45    /// rather than resolving the hostname.
46    Direct {
47        resolved_ips: Option<BTreeSet<IpAddr>>,
48    },
49    /// Establish a TCP connection to the database via an SSH tunnel.
50    /// This means first establishing an SSH connection to a bastion host,
51    /// and then opening a separate connection from that host to the database.
52    /// This is commonly referred by vendors as a "direct SSH tunnel", in
53    /// opposition to "reverse SSH tunnel", which is currently unsupported.
54    Ssh { config: SshTunnelConfig },
55    /// Establish a TCP connection to the database via an AWS PrivateLink
56    /// service.
57    AwsPrivatelink {
58        /// The ID of the AWS PrivateLink service.
59        connection_id: CatalogItemId,
60    },
61}
62
63pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
64
65/// A wrapper for [`tokio_postgres::Client`] that can report the server version.
66pub struct Client {
67    inner: tokio_postgres::Client,
68    server_version: Option<String>,
69}
70
71impl Client {
72    fn new<S, T>(
73        client: tokio_postgres::Client,
74        connection: &tokio_postgres::Connection<S, T>,
75    ) -> Client
76    where
77        S: AsyncRead + AsyncWrite + Unpin,
78        T: AsyncRead + AsyncWrite + Unpin,
79    {
80        let server_version = connection
81            .parameter("server_version")
82            .map(|v| v.to_string());
83        Client {
84            inner: client,
85            server_version,
86        }
87    }
88
89    /// Reports the value of the `server_version` parameter reported by the
90    /// server.
91    pub fn server_version(&self) -> Option<&str> {
92        self.server_version.as_deref()
93    }
94}
95
96impl Deref for Client {
97    type Target = tokio_postgres::Client;
98
99    fn deref(&self) -> &Self::Target {
100        &self.inner
101    }
102}
103
104impl DerefMut for Client {
105    fn deref_mut(&mut self) -> &mut Self::Target {
106        &mut self.inner
107    }
108}
109
110/// Configuration for PostgreSQL connections.
111///
112/// This wraps [`tokio_postgres::Config`] to allow the configuration of a
113/// tunnel via a [`TunnelConfig`].
114#[derive(Clone, Debug)]
115pub struct Config {
116    inner: tokio_postgres::Config,
117    tunnel: TunnelConfig,
118    in_task: InTask,
119    ssh_timeout_config: SshTimeoutConfig,
120}
121
122impl Config {
123    pub fn new(
124        inner: tokio_postgres::Config,
125        tunnel: TunnelConfig,
126        ssh_timeout_config: SshTimeoutConfig,
127        in_task: InTask,
128    ) -> Result<Self, PostgresError> {
129        let config = Self {
130            inner,
131            tunnel,
132            in_task,
133            ssh_timeout_config,
134        };
135
136        // Early validate that the configuration contains only a single TCP
137        // server.
138        config.address()?;
139
140        Ok(config)
141    }
142
143    /// Connects to the configured PostgreSQL database.
144    pub async fn connect(
145        &self,
146        task_name: &str,
147        ssh_tunnel_manager: &SshTunnelManager,
148    ) -> Result<Client, PostgresError> {
149        self.connect_traced(task_name, |_| (), ssh_tunnel_manager)
150            .await
151    }
152
153    /// Starts a replication connection to the configured PostgreSQL database.
154    pub async fn connect_replication(
155        &self,
156        ssh_tunnel_manager: &SshTunnelManager,
157    ) -> Result<Client, PostgresError> {
158        self.connect_traced(
159            "postgres_connect_replication",
160            |config| {
161                config.replication_mode(ReplicationMode::Logical);
162            },
163            ssh_tunnel_manager,
164        )
165        .await
166    }
167
168    fn address(&self) -> Result<(&str, u16), PostgresError> {
169        match (self.inner.get_hosts(), self.inner.get_ports()) {
170            ([Host::Tcp(host)], [port]) => Ok((host, *port)),
171            _ => bail_generic!("only TCP connections to a single PostgreSQL server are supported"),
172        }
173    }
174
175    async fn connect_traced<F>(
176        &self,
177        task_name: &str,
178        configure: F,
179        ssh_tunnel_manager: &SshTunnelManager,
180    ) -> Result<Client, PostgresError>
181    where
182        F: FnOnce(&mut tokio_postgres::Config),
183    {
184        let (host, port) = self.address()?;
185        let address = format!(
186            "{}@{}:{}/{}",
187            self.get_user().display_or("<unknown-user>"),
188            host,
189            port,
190            self.get_dbname().display_or("<unknown-dbname>")
191        );
192        info!(%task_name, %address, "connecting");
193        match self
194            .connect_internal(task_name, configure, ssh_tunnel_manager)
195            .await
196        {
197            Ok(t) => {
198                let backend_pid = t.backend_pid();
199                info!(%task_name, %address, %backend_pid, "connected");
200                Ok(t)
201            }
202            Err(e) => {
203                warn!(%task_name, %address, "connection failed: {e:#}");
204                Err(e)
205            }
206        }
207    }
208
209    async fn connect_internal<F>(
210        &self,
211        task_name: &str,
212        configure: F,
213        ssh_tunnel_manager: &SshTunnelManager,
214    ) -> Result<Client, PostgresError>
215    where
216        F: FnOnce(&mut tokio_postgres::Config),
217    {
218        let mut postgres_config = self.inner.clone();
219        configure(&mut postgres_config);
220
221        let mut tls = mz_tls_util::make_tls(&postgres_config).map_err(|tls_err| match tls_err {
222            mz_tls_util::TlsError::Generic(e) => PostgresError::Generic(e),
223            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::PostgresSsl(e),
224        })?;
225
226        match &self.tunnel {
227            TunnelConfig::Direct { resolved_ips } => {
228                if let Some(ips) = resolved_ips {
229                    let host = match postgres_config.get_hosts() {
230                        [Host::Tcp(host)] => host,
231                        _ => bail_generic!(
232                            "only TCP connections to a single PostgreSQL server are supported"
233                        ),
234                    }
235                    .to_owned();
236                    // Associate each resolved ip with the exact same, singular host, for tls
237                    // verification. We are required to do this dance because `tokio-postgres`
238                    // enforces that the number of 'host' and 'hostaddr' values must be the same.
239                    for (idx, ip) in ips.iter().enumerate() {
240                        if idx != 0 {
241                            postgres_config.host(&host);
242                        }
243                        postgres_config.hostaddr(ip.clone());
244                    }
245                };
246
247                let (client, connection) = async move { postgres_config.connect(tls).await }
248                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
249                    .await?;
250                let client = Client::new(client, &connection);
251                task::spawn(|| task_name, connection);
252                Ok(client)
253            }
254            TunnelConfig::Ssh { config } => {
255                let (host, port) = self.address()?;
256                let tunnel = ssh_tunnel_manager
257                    .connect(
258                        config.clone(),
259                        host,
260                        port,
261                        self.ssh_timeout_config,
262                        self.in_task,
263                    )
264                    .await
265                    .map_err(PostgresError::Ssh)?;
266
267                let tls = MakeTlsConnect::<TokioTcpStream>::make_tls_connect(&mut tls, host)?;
268                let tcp_stream = TokioTcpStream::connect(tunnel.local_addr())
269                    .await
270                    .map_err(PostgresError::SshIo)?;
271                // Because we are connecting to a local host/port, we don't configure any TCP
272                // keepalive settings. The connection is entirely local to the machine running the
273                // process and we trust the kernel to keep a local connection alive without keepalives.
274                //
275                // Ideally we'd be able to configure SSH to enable TCP keepalives on the other
276                // end of the tunnel, between the SSH bastion host and the PostgreSQL server,
277                // but SSH does not expose an option for this.
278                let (client, connection) =
279                    async move { postgres_config.connect_raw(tcp_stream, tls).await }
280                        .run_in_task_if(self.in_task, || "pg_connect".to_string())
281                        .await?;
282                let client = Client::new(client, &connection);
283                task::spawn(|| task_name, async {
284                    let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
285
286                    if let Err(e) = connection.await {
287                        warn!("postgres connection failed: {e}");
288                    }
289                });
290                Ok(client)
291            }
292            TunnelConfig::AwsPrivatelink { connection_id } => {
293                // This section of code is somewhat subtle. We are overriding the host
294                // for the actual TCP connection to be the PrivateLink host, but leaving the host
295                // for TLS verification as the original host. Managing the
296                // `tokio_postgres::Config` to do this is somewhat confusing, and requires we edit
297                // the singular host in place.
298
299                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
300                let privatelink_addrs =
301                    tokio::net::lookup_host((privatelink_host, DUMMY_DNS_PORT)).await?;
302
303                // Override the actual IPs to connect to for the TCP connection, leaving the original host in-place
304                // for TLS verification
305                let host = match postgres_config.get_hosts() {
306                    [Host::Tcp(host)] => host,
307                    _ => bail_generic!(
308                        "only TCP connections to a single PostgreSQL server are supported"
309                    ),
310                }
311                .to_owned();
312                // Associate each resolved ip with the exact same, singular host, for tls
313                // verification. We are required to do this dance because `tokio-postgres`
314                // enforces that the number of 'host' and 'hostaddr' values must be the same.
315                for (idx, addr) in privatelink_addrs.enumerate() {
316                    if idx != 0 {
317                        postgres_config.host(&host);
318                    }
319                    postgres_config.hostaddr(addr.ip());
320                }
321
322                let (client, connection) = async move { postgres_config.connect(tls).await }
323                    .run_in_task_if(self.in_task, || "pg_connect".to_string())
324                    .await?;
325                let client = Client::new(client, &connection);
326                task::spawn(|| task_name, connection);
327                Ok(client)
328            }
329        }
330    }
331
332    pub fn get_user(&self) -> Option<&str> {
333        self.inner.get_user()
334    }
335
336    pub fn get_dbname(&self) -> Option<&str> {
337        self.inner.get_dbname()
338    }
339}