mz_mysql_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use aws_types::SdkConfig;
11use mysql_async::{Conn, Opts, OptsBuilder};
12use std::collections::BTreeSet;
13use std::net::IpAddr;
14use std::ops::{Deref, DerefMut};
15use std::time::Duration;
16
17use mz_ore::future::{InTask, TimeoutError};
18use mz_ore::option::OptionExt;
19use mz_ore::task::{JoinHandleExt, spawn};
20use mz_repr::CatalogItemId;
21use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
22use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
23use serde::{Deserialize, Serialize};
24use tracing::{error, info, warn};
25
26use crate::MySqlError;
27use crate::aws_rds::rds_auth_token;
28
29/// Configures an optional tunnel for use when connecting to a MySQL
30/// database.
31#[derive(Debug, PartialEq, Clone)]
32pub enum TunnelConfig {
33    /// Establish a direct TCP connection to the database host.
34    /// If `resolved_ips` is not None, the provided IPs will be used
35    /// rather than resolving the hostname.
36    Direct {
37        resolved_ips: Option<BTreeSet<IpAddr>>,
38    },
39    /// Establish a TCP connection to the database via an SSH tunnel.
40    /// This means first establishing an SSH connection to a bastion host,
41    /// and then opening a separate connection from that host to the database.
42    /// This is commonly referred by vendors as a "direct SSH tunnel", in
43    /// opposition to "reverse SSH tunnel", which is currently unsupported.
44    Ssh { config: SshTunnelConfig },
45    /// Establish a TCP connection to the database via an AWS PrivateLink
46    /// service.
47    AwsPrivatelink {
48        /// The ID of the AWS PrivateLink service.
49        connection_id: CatalogItemId,
50    },
51}
52
53pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60);
54pub const DEFAULT_SNAPSHOT_MAX_EXECUTION_TIME: Duration = Duration::ZERO;
55pub const DEFAULT_SNAPSHOT_LOCK_WAIT_TIMEOUT: Duration = Duration::from_secs(3600);
56pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
57
58#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
59pub struct TimeoutConfig {
60    // Snapshot-related configs
61    pub snapshot_max_execution_time: Option<Duration>,
62    pub snapshot_lock_wait_timeout: Option<Duration>,
63
64    // Socket-related configs
65    pub tcp_keepalive: Option<Duration>,
66
67    // Connection timeout.  This timeout covers creating an authenticated connection
68    // (e.g. includes network connection, TLS handshake, authentication, etc.).
69    // If the connection has not been established in that time, it is considered an error.
70    pub connect_timeout: Option<Duration>,
71    // There are other timeout options on `mysql_async::OptsBuilder`
72    // (e.g. `conn_ttl` and `wait_timeout`) that could be exposed
73    // but they only apply to connection pools, which we are not currently using.
74}
75
76impl Default for TimeoutConfig {
77    fn default() -> Self {
78        Self {
79            snapshot_max_execution_time: Some(DEFAULT_SNAPSHOT_MAX_EXECUTION_TIME),
80            snapshot_lock_wait_timeout: Some(DEFAULT_SNAPSHOT_LOCK_WAIT_TIMEOUT),
81            tcp_keepalive: Some(DEFAULT_TCP_KEEPALIVE),
82            connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
83        }
84    }
85}
86
87impl TimeoutConfig {
88    pub fn build(
89        snapshot_max_execution_time: Duration,
90        snapshot_lock_wait_timeout: Duration,
91        tcp_keepalive: Duration,
92        connect_timeout: Duration,
93    ) -> Self {
94        // Verify values are within valid ranges
95        // Note we error log but do not fail as this is called in a non-fallible
96        // LD-sync in the adapter.
97
98        // https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_lock_wait_timeout
99        let snapshot_lock_wait_timeout = if snapshot_lock_wait_timeout.as_secs() > 31536000 {
100            error!(
101                "snapshot_lock_wait_timeout is too large: {}. Maximum is 31536000.",
102                snapshot_lock_wait_timeout.as_secs()
103            );
104            Some(DEFAULT_SNAPSHOT_LOCK_WAIT_TIMEOUT)
105        } else {
106            Some(snapshot_lock_wait_timeout)
107        };
108
109        // https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_max_execution_time
110        let snapshot_max_execution_time = if snapshot_max_execution_time.as_millis() > 4294967295 {
111            error!(
112                "snapshot_max_execution_time is too large: {}. Maximum is 4294967295.",
113                snapshot_max_execution_time.as_secs()
114            );
115            Some(DEFAULT_SNAPSHOT_MAX_EXECUTION_TIME)
116        } else {
117            Some(snapshot_max_execution_time)
118        };
119
120        let tcp_keepalive = match u32::try_from(tcp_keepalive.as_millis()) {
121            Err(_) => {
122                error!(
123                    "tcp_keepalive is too large: {}. Maximum is {}.",
124                    tcp_keepalive.as_millis(),
125                    u32::MAX,
126                );
127                Some(DEFAULT_TCP_KEEPALIVE)
128            }
129            Ok(_) => Some(tcp_keepalive),
130        };
131
132        let connect_timeout = match u32::try_from(connect_timeout.as_millis()) {
133            Err(_) => {
134                error!(
135                    "connect_timeout is too large: {}. Maximum is {}.",
136                    connect_timeout.as_millis(),
137                    u32::MAX,
138                );
139                Some(DEFAULT_CONNECT_TIMEOUT)
140            }
141            Ok(_) => Some(connect_timeout),
142        };
143
144        Self {
145            snapshot_max_execution_time,
146            snapshot_lock_wait_timeout,
147            tcp_keepalive,
148            connect_timeout,
149        }
150    }
151
152    /// Apply relevant timeout configurations to a `mysql_async::OptsBuilder`.
153    pub fn apply_to_opts(&self, mut opts_builder: OptsBuilder) -> Result<OptsBuilder, MySqlError> {
154        if let Some(tcp_keepalive) = self.tcp_keepalive {
155            opts_builder = opts_builder.tcp_keepalive(Some(
156                u32::try_from(tcp_keepalive.as_millis()).map_err(|e| {
157                    MySqlError::InvalidClientConfig(format!(
158                        "invalid tcp_keepalive duration: {}",
159                        e
160                    ))
161                })?,
162            ));
163        }
164        Ok(opts_builder)
165    }
166}
167
168/// A MySQL connection with an optional SSH tunnel handle.
169///
170/// This wrapper is intended to be used in place of `mysql_async::Conn` to
171/// keep the SSH tunnel alive for the lifecycle of the connection by holding
172/// a reference to the tunnel handle.
173#[derive(Debug)]
174pub struct MySqlConn {
175    conn: Conn,
176    _ssh_tunnel_handle: Option<ManagedSshTunnelHandle>,
177}
178
179impl Deref for MySqlConn {
180    type Target = Conn;
181
182    fn deref(&self) -> &Self::Target {
183        &self.conn
184    }
185}
186
187impl DerefMut for MySqlConn {
188    fn deref_mut(&mut self) -> &mut Self::Target {
189        &mut self.conn
190    }
191}
192
193impl MySqlConn {
194    pub async fn disconnect(mut self) -> Result<(), MySqlError> {
195        self.conn.disconnect().await?;
196        self._ssh_tunnel_handle.take();
197        Ok(())
198    }
199
200    pub fn take(self) -> (Conn, Option<ManagedSshTunnelHandle>) {
201        (self.conn, self._ssh_tunnel_handle)
202    }
203}
204
205/// Configuration for MySQL connections.
206///
207/// This wraps [`mysql_async::Opts`] to allow the configuration of a
208/// tunnel via a [`TunnelConfig`].
209#[derive(Clone, Debug)]
210pub struct Config {
211    inner: Opts,
212    tunnel: TunnelConfig,
213    // Whether to poll I/O for this connection in a tokio task
214    // TODO(roshan): Make this apply to queries on the returned connection, not just the initial
215    // connection.
216    in_task: InTask,
217    ssh_timeout_config: SshTimeoutConfig,
218    mysql_timeout_config: TimeoutConfig,
219    aws_config: Option<SdkConfig>,
220}
221
222impl Config {
223    pub fn new(
224        builder: OptsBuilder,
225        tunnel: TunnelConfig,
226        ssh_timeout_config: SshTimeoutConfig,
227        in_task: InTask,
228        mysql_timeout_config: TimeoutConfig,
229        aws_config: Option<SdkConfig>,
230    ) -> Result<Self, MySqlError> {
231        let opts = mysql_timeout_config.apply_to_opts(builder)?;
232        Ok(Self {
233            inner: opts.into(),
234            tunnel,
235            in_task,
236            ssh_timeout_config,
237            mysql_timeout_config,
238            aws_config,
239        })
240    }
241
242    pub async fn connect(
243        &self,
244        task_name: &str,
245        ssh_tunnel_manager: &SshTunnelManager,
246    ) -> Result<MySqlConn, MySqlError> {
247        let address = format!(
248            "mysql:://{}@{}:{}/{}",
249            self.inner.user().display_or("<unknown-user>"),
250            self.inner.ip_or_hostname(),
251            self.inner.tcp_port(),
252            self.inner.db_name().display_or("<unknown-dbname>"),
253        );
254        info!(%task_name, %address, "connecting");
255        match self.connect_internal(ssh_tunnel_manager).await {
256            Ok(t) => {
257                info!(%task_name, %address, "connected");
258                Ok(t)
259            }
260            Err(e) => {
261                warn!(%task_name, %address, "connection failed: {e:#}");
262                Err(e)
263            }
264        }
265    }
266
267    fn address(&self) -> (&str, u16) {
268        (self.inner.ip_or_hostname(), self.inner.tcp_port())
269    }
270
271    async fn connect_internal(
272        &self,
273        ssh_tunnel_manager: &SshTunnelManager,
274    ) -> Result<MySqlConn, MySqlError> {
275        let mut opts_builder = OptsBuilder::from_opts(self.inner.clone());
276
277        if let Some(aws_config) = &self.aws_config {
278            let (host, port) = self.address();
279            let username = self.inner.user().expect("MySQL: username required");
280
281            let token = rds_auth_token(host, port, username, aws_config).await?;
282            // Cleartext plugin must be enabled for IAM authentication, for security,
283            // the network traffic is SSL/TLS encrypted.  The cleartext plugin is built
284            // into the MySQL client library.
285            opts_builder = opts_builder
286                .pass(Some(token.to_string()))
287                .enable_cleartext_plugin(true);
288        }
289
290        match &self.tunnel {
291            TunnelConfig::Direct { resolved_ips } => {
292                opts_builder = opts_builder.resolved_ips(
293                    resolved_ips
294                        .clone()
295                        .map(|ips| ips.into_iter().collect::<Vec<_>>()),
296                );
297
298                Ok(MySqlConn {
299                    conn: self.connect_with_timeout(opts_builder).await?,
300                    _ssh_tunnel_handle: None,
301                })
302            }
303            TunnelConfig::Ssh { config } => {
304                let (host, port) = self.address();
305                let tunnel = ssh_tunnel_manager
306                    .connect(
307                        config.clone(),
308                        host,
309                        port,
310                        self.ssh_timeout_config,
311                        self.in_task,
312                    )
313                    .await
314                    .map_err(MySqlError::Ssh)?;
315
316                let tunnel_addr = tunnel.local_addr();
317                // Override the connection host and port for the actual TCP connection to point to
318                // the local tunnel instead.
319                opts_builder = opts_builder
320                    .ip_or_hostname(tunnel_addr.ip().to_string())
321                    .tcp_port(tunnel_addr.port());
322
323                if let Some(ssl_opts) = self.inner.ssl_opts() {
324                    if !ssl_opts.skip_domain_validation() {
325                        // If the TLS configuration will validate the hostname, we need to set
326                        // the TLS hostname back to the actual upstream host and not the hostname
327                        // of the local SSH tunnel
328                        opts_builder = opts_builder.ssl_opts(Some(
329                            ssl_opts.clone().with_danger_tls_hostname_override(Some(
330                                self.inner.ip_or_hostname().to_string(),
331                            )),
332                        ));
333                    }
334                }
335
336                Ok(MySqlConn {
337                    conn: self.connect_with_timeout(opts_builder).await?,
338                    _ssh_tunnel_handle: Some(tunnel),
339                })
340            }
341            TunnelConfig::AwsPrivatelink { connection_id } => {
342                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
343
344                // Override the connection host for the actual TCP connection to point to
345                // the privatelink hostname instead.
346                let mut opts_builder = opts_builder.ip_or_hostname(privatelink_host);
347
348                if let Some(ssl_opts) = self.inner.ssl_opts() {
349                    if !ssl_opts.skip_domain_validation() {
350                        // If the TLS configuration will validate the hostname, we need to set
351                        // the TLS hostname back to the actual upstream host and not the
352                        // privatelink hostname.
353                        opts_builder = opts_builder.ssl_opts(Some(
354                            ssl_opts.clone().with_danger_tls_hostname_override(Some(
355                                self.inner.ip_or_hostname().to_string(),
356                            )),
357                        ));
358                    }
359                }
360
361                Ok(MySqlConn {
362                    conn: self.connect_with_timeout(opts_builder).await?,
363                    _ssh_tunnel_handle: None,
364                })
365            }
366        }
367    }
368
369    async fn connect_with_timeout(
370        &self,
371        opts_builder: OptsBuilder,
372    ) -> Result<mysql_async::Conn, MySqlError> {
373        let connection_future = if let InTask::Yes = self.in_task {
374            spawn(|| "mysql_connect".to_string(), Conn::new(opts_builder))
375                .abort_on_drop()
376                .wait_and_assert_finished()
377        } else {
378            Conn::new(opts_builder)
379        };
380
381        if let Some(connect_timeout) = self.mysql_timeout_config.connect_timeout {
382            mz_ore::future::timeout(connect_timeout, connection_future)
383                .await
384                .map_err(|err| match err {
385                    // match instead of impl From<> for MySqlError so we can capture the timeout value
386                    TimeoutError::DeadlineElapsed => MySqlError::ConnectionTimeout(connect_timeout),
387                    TimeoutError::Inner(e) => MySqlError::from(e),
388                })
389        } else {
390            connection_future.await.map_err(MySqlError::from)
391        }
392    }
393}