1use std::collections::BTreeSet;
11use std::fmt;
12use std::fs::{self, File};
13use std::io::Write;
14use std::net::{Ipv4Addr, SocketAddr};
15use std::os::unix::fs::PermissionsExt;
16use std::sync::atomic::{AtomicU16, Ordering};
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19
20use anyhow::bail;
21use itertools::Itertools;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{self, AbortOnDropHandle};
24use openssh::{ForwardType, Session};
25use rand::rngs::StdRng;
26use rand::{Rng, SeedableRng};
27use serde::{Deserialize, Serialize};
28use tokio::time;
29use tracing::{info, warn};
30
31use crate::keys::SshKeyPair;
32
33pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
37pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
38
39pub const DEFAULT_KEEPALIVES_IDLE: Duration = Duration::from_secs(10);
42
43#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
45pub struct SshTimeoutConfig {
46 pub check_interval: Duration,
48 pub connect_timeout: Duration,
50 pub keepalives_idle: Duration,
54}
55
56impl Default for SshTimeoutConfig {
57 fn default() -> SshTimeoutConfig {
58 SshTimeoutConfig {
59 check_interval: DEFAULT_CHECK_INTERVAL,
60 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
61 keepalives_idle: DEFAULT_KEEPALIVES_IDLE,
62 }
63 }
64}
65
66#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
68pub struct SshTunnelConfig {
69 pub host: BTreeSet<String>,
72 pub port: u16,
74 pub user: String,
76 pub key_pair: SshKeyPair,
78}
79
80impl fmt::Debug for SshTunnelConfig {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.debug_struct("Tunnel")
83 .field("host", &self.host)
84 .field("port", &self.port)
85 .field("user", &self.user)
86 .finish()
88 }
89}
90
91impl fmt::Display for SshTunnelConfig {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 write!(
94 f,
95 "{}@{}:{}",
96 self.user,
97 self.host.iter().join(","),
98 self.port
99 )
100 }
101}
102
103#[derive(Clone, Debug)]
105pub enum SshTunnelStatus {
106 Running,
108 Errored(String),
110}
111
112impl SshTunnelConfig {
113 pub async fn connect(
119 &self,
120 remote_host: &str,
121 remote_port: u16,
122 timeout_config: SshTimeoutConfig,
123 ) -> Result<SshTunnelHandle, anyhow::Error> {
124 let tunnel_id = format!("{}:{} via {}", remote_host, remote_port, self);
125
126 info!(%tunnel_id, "connecting to ssh tunnel");
131 let mut session = match connect(self, timeout_config).await {
132 Ok(s) => s,
133 Err(e) => {
134 warn!(%tunnel_id, "failed to connect to ssh tunnel: {}", e.display_with_causes());
135 return Err(e);
136 }
137 };
138 let local_port = match port_forward(&session, remote_host, remote_port).await {
139 Ok(local_port) => local_port,
140 Err(e) => {
141 warn!(%tunnel_id, "failed to forward port through ssh tunnel: {}", e.display_with_causes());
142 return Err(e);
143 }
144 };
145 info!(%tunnel_id, %local_port, "connected to ssh tunnel");
146 let local_port = Arc::new(AtomicU16::new(local_port));
147 let status = Arc::new(Mutex::new(SshTunnelStatus::Running));
148
149 let join_handle = task::spawn(|| format!("ssh_session_{remote_host}:{remote_port}"), {
150 let config = self.clone();
151 let remote_host = remote_host.to_string();
152 let local_port = Arc::clone(&local_port);
153 let status = Arc::clone(&status);
154 async move {
155 scopeguard::defer! {
156 info!(%tunnel_id, "terminating ssh tunnel");
157 }
158 let mut interval = time::interval(timeout_config.check_interval);
159 interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
161 interval.tick().await;
163 loop {
164 interval.tick().await;
165 if let Err(e) = session.check().await {
166 warn!(%tunnel_id, "ssh tunnel unhealthy: {}", e.display_with_causes());
167 let s = match connect(&config, timeout_config).await {
168 Ok(s) => s,
169 Err(e) => {
170 warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
171 *status.lock().expect("poisoned") =
172 SshTunnelStatus::Errored(e.to_string_with_causes());
173 continue;
174 }
175 };
176 let lp = match port_forward(&s, &remote_host, remote_port).await {
177 Ok(lp) => lp,
178 Err(e) => {
179 warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
180 *status.lock().expect("poisoned") =
181 SshTunnelStatus::Errored(e.to_string_with_causes());
182 continue;
183 }
184 };
185 session = s;
186 local_port.store(lp, Ordering::SeqCst);
187 *status.lock().expect("poisoned") = SshTunnelStatus::Running;
188 }
189 }
190 }
191 });
192
193 Ok(SshTunnelHandle {
194 local_port,
195 status,
196 _join_handle: join_handle.abort_on_drop(),
197 })
198 }
199
200 pub async fn validate(&self, timeout_config: SshTimeoutConfig) -> Result<(), anyhow::Error> {
203 connect(self, timeout_config).await?;
204 Ok(())
205 }
206}
207
208#[derive(Debug)]
210pub struct SshTunnelHandle {
211 local_port: Arc<AtomicU16>,
212 status: Arc<Mutex<SshTunnelStatus>>,
213 _join_handle: AbortOnDropHandle<()>,
214}
215
216impl SshTunnelHandle {
217 pub fn local_addr(&self) -> SocketAddr {
219 let port = self.local_port.load(Ordering::SeqCst);
220 SocketAddr::from((Ipv4Addr::LOCALHOST, port))
224 }
225
226 pub fn check_status(&self) -> SshTunnelStatus {
231 self.status.lock().expect("poisoned").clone()
232 }
233}
234
235fn fips_mode_enabled() -> bool {
237 std::env::var("MZ_FIPS").map_or(false, |v| v == "1" || v == "true")
238}
239
240fn write_fips_ssh_config(dir: &std::path::Path) -> Result<std::path::PathBuf, anyhow::Error> {
243 let config_path = dir.join("ssh_config");
244 let config_contents = "\
245# FIPS 140-3 compliant SSH configuration.
246# Only NIST-approved algorithms are permitted.
247Ciphers aes256-gcm@openssh.com,aes128-gcm@openssh.com,aes256-ctr,aes192-ctr,aes128-ctr
248KexAlgorithms ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512,diffie-hellman-group14-sha256
249MACs hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha2-256,hmac-sha2-512
250HostKeyAlgorithms ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-256,rsa-sha2-512
251PubkeyAcceptedAlgorithms ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-256,rsa-sha2-512,ssh-ed25519
252";
253 fs::write(&config_path, config_contents)?;
254 Ok(config_path)
255}
256
257async fn connect(
258 config: &SshTunnelConfig,
259 timeout_config: SshTimeoutConfig,
260) -> Result<Session, anyhow::Error> {
261 let tempdir = tempfile::Builder::new()
262 .prefix("ssh-tunnel-key")
263 .tempdir()?;
264 let path = tempdir.path().join("key");
265 let mut tempfile = File::create(&path)?;
266 tempfile.set_permissions(std::fs::Permissions::from_mode(0o600))?;
268 tempfile.write_all(config.key_pair.ssh_private_key().as_bytes())?;
269 tempfile.set_permissions(std::fs::Permissions::from_mode(0o400))?;
272
273 let fips_config_path = if fips_mode_enabled() {
276 Some(write_fips_ssh_config(tempdir.path())?)
277 } else {
278 None
279 };
280
281 let mut connect_err = None;
283 for host in &config.host {
284 let mut builder = openssh::SessionBuilder::default();
289 builder
290 .known_hosts_check(openssh::KnownHosts::Accept)
291 .user_known_hosts_file("/dev/null")
292 .user(config.user.clone())
293 .port(config.port)
294 .keyfile(&path)
295 .server_alive_interval(timeout_config.keepalives_idle)
296 .connect_timeout(timeout_config.connect_timeout);
297
298 if let Some(ref fips_config) = fips_config_path {
299 builder.config_file(fips_config);
300 }
301
302 match builder.connect_mux(host.clone()).await {
303 Ok(session) => {
304 drop(tempfile);
307 fs::remove_file(&path)?;
308 drop(tempdir);
309
310 session.check().await?;
312
313 return Ok(session);
314 }
315 Err(err) => {
316 connect_err = Some(err);
317 }
318 }
319 }
320 Err(connect_err
321 .map(Into::into)
322 .unwrap_or_else(|| anyhow::anyhow!("no hosts to connect to")))
323}
324
325async fn port_forward(session: &Session, host: &str, port: u16) -> Result<u16, anyhow::Error> {
326 for _ in 0..50 {
328 let mut rng = StdRng::from_os_rng();
330 let local_port: u16 = rng.random_range(49152..65535);
331
332 let local = openssh::Socket::from((Ipv4Addr::LOCALHOST, local_port));
336 let remote = openssh::Socket::new(host, port);
337
338 match session
339 .request_port_forward(ForwardType::Local, local, remote)
340 .await
341 {
342 Ok(_) => return Ok(local_port),
343 Err(err) => match err {
344 openssh::Error::SshMux(openssh_mux_client::Error::RequestFailure(e))
345 if &*e == "Port forwarding failed" =>
346 {
347 info!("port {local_port} already in use; testing another port");
348 }
349 _ => {
350 warn!("ssh connection failed: {}", err.display_with_causes());
351 bail!("failed to open SSH tunnel: {}", err.display_with_causes())
352 }
353 },
354 };
355 }
356 bail!("failed to find an open port for SSH tunnel")
359}