1use aws_types::SdkConfig;
11use mysql_async::{Conn, Opts, OptsBuilder};
12use std::collections::BTreeSet;
13use std::net::IpAddr;
14use std::ops::{Deref, DerefMut};
15use std::time::Duration;
16
17use mz_ore::future::{InTask, TimeoutError};
18use mz_ore::option::OptionExt;
19use mz_ore::task::{JoinHandleExt, spawn};
20use mz_repr::CatalogItemId;
21use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
22use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
23use serde::{Deserialize, Serialize};
24use tracing::{error, info, warn};
25
26use crate::MySqlError;
27use crate::aws_rds::rds_auth_token;
28
29#[derive(Debug, PartialEq, Clone)]
32pub enum TunnelConfig {
33 Direct {
37 resolved_ips: Option<BTreeSet<IpAddr>>,
38 },
39 Ssh { config: SshTunnelConfig },
45 AwsPrivatelink {
48 connection_id: CatalogItemId,
50 },
51}
52
53pub const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(60);
54pub const DEFAULT_SNAPSHOT_MAX_EXECUTION_TIME: Duration = Duration::ZERO;
55pub const DEFAULT_SNAPSHOT_LOCK_WAIT_TIMEOUT: Duration = Duration::from_secs(3600);
56pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
57
58#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
59pub struct TimeoutConfig {
60 pub snapshot_max_execution_time: Option<Duration>,
62 pub snapshot_lock_wait_timeout: Option<Duration>,
63
64 pub tcp_keepalive: Option<Duration>,
66
67 pub connect_timeout: Option<Duration>,
71 }
75
76impl Default for TimeoutConfig {
77 fn default() -> Self {
78 Self {
79 snapshot_max_execution_time: Some(DEFAULT_SNAPSHOT_MAX_EXECUTION_TIME),
80 snapshot_lock_wait_timeout: Some(DEFAULT_SNAPSHOT_LOCK_WAIT_TIMEOUT),
81 tcp_keepalive: Some(DEFAULT_TCP_KEEPALIVE),
82 connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
83 }
84 }
85}
86
87impl TimeoutConfig {
88 pub fn build(
89 snapshot_max_execution_time: Duration,
90 snapshot_lock_wait_timeout: Duration,
91 tcp_keepalive: Duration,
92 connect_timeout: Duration,
93 ) -> Self {
94 let snapshot_lock_wait_timeout = if snapshot_lock_wait_timeout.as_secs() > 31536000 {
100 error!(
101 "snapshot_lock_wait_timeout is too large: {}. Maximum is 31536000.",
102 snapshot_lock_wait_timeout.as_secs()
103 );
104 Some(DEFAULT_SNAPSHOT_LOCK_WAIT_TIMEOUT)
105 } else {
106 Some(snapshot_lock_wait_timeout)
107 };
108
109 let snapshot_max_execution_time = if snapshot_max_execution_time.as_millis() > 4294967295 {
111 error!(
112 "snapshot_max_execution_time is too large: {}. Maximum is 4294967295.",
113 snapshot_max_execution_time.as_secs()
114 );
115 Some(DEFAULT_SNAPSHOT_MAX_EXECUTION_TIME)
116 } else {
117 Some(snapshot_max_execution_time)
118 };
119
120 let tcp_keepalive = match u32::try_from(tcp_keepalive.as_millis()) {
121 Err(_) => {
122 error!(
123 "tcp_keepalive is too large: {}. Maximum is {}.",
124 tcp_keepalive.as_millis(),
125 u32::MAX,
126 );
127 Some(DEFAULT_TCP_KEEPALIVE)
128 }
129 Ok(_) => Some(tcp_keepalive),
130 };
131
132 let connect_timeout = match u32::try_from(connect_timeout.as_millis()) {
133 Err(_) => {
134 error!(
135 "connect_timeout is too large: {}. Maximum is {}.",
136 connect_timeout.as_millis(),
137 u32::MAX,
138 );
139 Some(DEFAULT_CONNECT_TIMEOUT)
140 }
141 Ok(_) => Some(connect_timeout),
142 };
143
144 Self {
145 snapshot_max_execution_time,
146 snapshot_lock_wait_timeout,
147 tcp_keepalive,
148 connect_timeout,
149 }
150 }
151
152 pub fn apply_to_opts(&self, mut opts_builder: OptsBuilder) -> Result<OptsBuilder, MySqlError> {
154 if let Some(tcp_keepalive) = self.tcp_keepalive {
155 opts_builder = opts_builder.tcp_keepalive(Some(
156 u32::try_from(tcp_keepalive.as_millis()).map_err(|e| {
157 MySqlError::InvalidClientConfig(format!(
158 "invalid tcp_keepalive duration: {}",
159 e
160 ))
161 })?,
162 ));
163 }
164 Ok(opts_builder)
165 }
166}
167
168#[derive(Debug)]
174pub struct MySqlConn {
175 conn: Conn,
176 _ssh_tunnel_handle: Option<ManagedSshTunnelHandle>,
177}
178
179impl Deref for MySqlConn {
180 type Target = Conn;
181
182 fn deref(&self) -> &Self::Target {
183 &self.conn
184 }
185}
186
187impl DerefMut for MySqlConn {
188 fn deref_mut(&mut self) -> &mut Self::Target {
189 &mut self.conn
190 }
191}
192
193impl MySqlConn {
194 pub async fn disconnect(mut self) -> Result<(), MySqlError> {
195 self.conn.disconnect().await?;
196 self._ssh_tunnel_handle.take();
197 Ok(())
198 }
199
200 pub fn take(self) -> (Conn, Option<ManagedSshTunnelHandle>) {
201 (self.conn, self._ssh_tunnel_handle)
202 }
203}
204
205#[derive(Clone, Debug)]
210pub struct Config {
211 inner: Opts,
212 tunnel: TunnelConfig,
213 in_task: InTask,
217 ssh_timeout_config: SshTimeoutConfig,
218 mysql_timeout_config: TimeoutConfig,
219 aws_config: Option<SdkConfig>,
220}
221
222impl Config {
223 pub fn new(
224 builder: OptsBuilder,
225 tunnel: TunnelConfig,
226 ssh_timeout_config: SshTimeoutConfig,
227 in_task: InTask,
228 mysql_timeout_config: TimeoutConfig,
229 aws_config: Option<SdkConfig>,
230 ) -> Result<Self, MySqlError> {
231 let opts = mysql_timeout_config.apply_to_opts(builder)?;
232 Ok(Self {
233 inner: opts.into(),
234 tunnel,
235 in_task,
236 ssh_timeout_config,
237 mysql_timeout_config,
238 aws_config,
239 })
240 }
241
242 pub async fn connect(
243 &self,
244 task_name: &str,
245 ssh_tunnel_manager: &SshTunnelManager,
246 ) -> Result<MySqlConn, MySqlError> {
247 let address = format!(
248 "mysql:://{}@{}:{}/{}",
249 self.inner.user().display_or("<unknown-user>"),
250 self.inner.ip_or_hostname(),
251 self.inner.tcp_port(),
252 self.inner.db_name().display_or("<unknown-dbname>"),
253 );
254 info!(%task_name, %address, "connecting");
255 match self.connect_internal(ssh_tunnel_manager).await {
256 Ok(t) => {
257 info!(%task_name, %address, "connected");
258 Ok(t)
259 }
260 Err(e) => {
261 warn!(%task_name, %address, "connection failed: {e:#}");
262 Err(e)
263 }
264 }
265 }
266
267 fn address(&self) -> (&str, u16) {
268 (self.inner.ip_or_hostname(), self.inner.tcp_port())
269 }
270
271 async fn connect_internal(
272 &self,
273 ssh_tunnel_manager: &SshTunnelManager,
274 ) -> Result<MySqlConn, MySqlError> {
275 let mut opts_builder = OptsBuilder::from_opts(self.inner.clone());
276
277 if let Some(aws_config) = &self.aws_config {
278 let (host, port) = self.address();
279 let username = self.inner.user().expect("MySQL: username required");
280
281 let token = rds_auth_token(host, port, username, aws_config).await?;
282 opts_builder = opts_builder
286 .pass(Some(token.to_string()))
287 .enable_cleartext_plugin(true);
288 }
289
290 match &self.tunnel {
291 TunnelConfig::Direct { resolved_ips } => {
292 opts_builder = opts_builder.resolved_ips(
293 resolved_ips
294 .clone()
295 .map(|ips| ips.into_iter().collect::<Vec<_>>()),
296 );
297
298 Ok(MySqlConn {
299 conn: self.connect_with_timeout(opts_builder).await?,
300 _ssh_tunnel_handle: None,
301 })
302 }
303 TunnelConfig::Ssh { config } => {
304 let (host, port) = self.address();
305 let tunnel = ssh_tunnel_manager
306 .connect(
307 config.clone(),
308 host,
309 port,
310 self.ssh_timeout_config,
311 self.in_task,
312 )
313 .await
314 .map_err(MySqlError::Ssh)?;
315
316 let tunnel_addr = tunnel.local_addr();
317 opts_builder = opts_builder
320 .ip_or_hostname(tunnel_addr.ip().to_string())
321 .tcp_port(tunnel_addr.port());
322
323 if let Some(ssl_opts) = self.inner.ssl_opts() {
324 if !ssl_opts.skip_domain_validation() {
325 opts_builder = opts_builder.ssl_opts(Some(
329 ssl_opts.clone().with_danger_tls_hostname_override(Some(
330 self.inner.ip_or_hostname().to_string(),
331 )),
332 ));
333 }
334 }
335
336 Ok(MySqlConn {
337 conn: self.connect_with_timeout(opts_builder).await?,
338 _ssh_tunnel_handle: Some(tunnel),
339 })
340 }
341 TunnelConfig::AwsPrivatelink { connection_id } => {
342 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
343
344 let mut opts_builder = opts_builder.ip_or_hostname(privatelink_host);
347
348 if let Some(ssl_opts) = self.inner.ssl_opts() {
349 if !ssl_opts.skip_domain_validation() {
350 opts_builder = opts_builder.ssl_opts(Some(
354 ssl_opts.clone().with_danger_tls_hostname_override(Some(
355 self.inner.ip_or_hostname().to_string(),
356 )),
357 ));
358 }
359 }
360
361 Ok(MySqlConn {
362 conn: self.connect_with_timeout(opts_builder).await?,
363 _ssh_tunnel_handle: None,
364 })
365 }
366 }
367 }
368
369 async fn connect_with_timeout(
370 &self,
371 opts_builder: OptsBuilder,
372 ) -> Result<mysql_async::Conn, MySqlError> {
373 let connection_future = if let InTask::Yes = self.in_task {
374 spawn(|| "mysql_connect".to_string(), Conn::new(opts_builder))
375 .abort_on_drop()
376 .wait_and_assert_finished()
377 } else {
378 Conn::new(opts_builder)
379 };
380
381 if let Some(connect_timeout) = self.mysql_timeout_config.connect_timeout {
382 mz_ore::future::timeout(connect_timeout, connection_future)
383 .await
384 .map_err(|err| match err {
385 TimeoutError::DeadlineElapsed => MySqlError::ConnectionTimeout(connect_timeout),
387 TimeoutError::Inner(e) => MySqlError::from(e),
388 })
389 } else {
390 connection_future.await.map_err(MySqlError::from)
391 }
392 }
393}