1use std::collections::BTreeSet;
11use std::fmt;
12use std::fs::{self, File};
13use std::io::Write;
14use std::net::{Ipv4Addr, SocketAddr};
15use std::os::unix::fs::PermissionsExt;
16use std::sync::atomic::{AtomicU16, Ordering};
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19
20use anyhow::bail;
21use itertools::Itertools;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{self, AbortOnDropHandle};
24use openssh::{ForwardType, Session};
25use rand::rngs::StdRng;
26use rand::{Rng, SeedableRng};
27use serde::{Deserialize, Serialize};
28use tokio::time;
29use tracing::{info, warn};
30
31use crate::keys::SshKeyPair;
32
33pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
37pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
38
39pub const DEFAULT_KEEPALIVES_IDLE: Duration = Duration::from_secs(10);
42
43#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
45pub struct SshTimeoutConfig {
46    pub check_interval: Duration,
48    pub connect_timeout: Duration,
50    pub keepalives_idle: Duration,
54}
55
56impl Default for SshTimeoutConfig {
57    fn default() -> SshTimeoutConfig {
58        SshTimeoutConfig {
59            check_interval: DEFAULT_CHECK_INTERVAL,
60            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
61            keepalives_idle: DEFAULT_KEEPALIVES_IDLE,
62        }
63    }
64}
65
66#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
68pub struct SshTunnelConfig {
69    pub host: BTreeSet<String>,
72    pub port: u16,
74    pub user: String,
76    pub key_pair: SshKeyPair,
78}
79
80impl fmt::Debug for SshTunnelConfig {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("Tunnel")
83            .field("host", &self.host)
84            .field("port", &self.port)
85            .field("user", &self.user)
86            .finish()
88    }
89}
90
91impl fmt::Display for SshTunnelConfig {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(
94            f,
95            "{}@{}:{}",
96            self.user,
97            self.host.iter().join(","),
98            self.port
99        )
100    }
101}
102
103#[derive(Clone, Debug)]
105pub enum SshTunnelStatus {
106    Running,
108    Errored(String),
110}
111
112impl SshTunnelConfig {
113    pub async fn connect(
119        &self,
120        remote_host: &str,
121        remote_port: u16,
122        timeout_config: SshTimeoutConfig,
123    ) -> Result<SshTunnelHandle, anyhow::Error> {
124        let tunnel_id = format!("{}:{} via {}", remote_host, remote_port, self);
125
126        info!(%tunnel_id, "connecting to ssh tunnel");
131        let mut session = match connect(self, timeout_config).await {
132            Ok(s) => s,
133            Err(e) => {
134                warn!(%tunnel_id, "failed to connect to ssh tunnel: {}", e.display_with_causes());
135                return Err(e);
136            }
137        };
138        let local_port = match port_forward(&session, remote_host, remote_port).await {
139            Ok(local_port) => local_port,
140            Err(e) => {
141                warn!(%tunnel_id, "failed to forward port through ssh tunnel: {}", e.display_with_causes());
142                return Err(e);
143            }
144        };
145        info!(%tunnel_id, %local_port, "connected to ssh tunnel");
146        let local_port = Arc::new(AtomicU16::new(local_port));
147        let status = Arc::new(Mutex::new(SshTunnelStatus::Running));
148
149        let join_handle = task::spawn(|| format!("ssh_session_{remote_host}:{remote_port}"), {
150            let config = self.clone();
151            let remote_host = remote_host.to_string();
152            let local_port = Arc::clone(&local_port);
153            let status = Arc::clone(&status);
154            async move {
155                scopeguard::defer! {
156                    info!(%tunnel_id, "terminating ssh tunnel");
157                }
158                let mut interval = time::interval(timeout_config.check_interval);
159                interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
161                interval.tick().await;
163                loop {
164                    interval.tick().await;
165                    if let Err(e) = session.check().await {
166                        warn!(%tunnel_id, "ssh tunnel unhealthy: {}", e.display_with_causes());
167                        let s = match connect(&config, timeout_config).await {
168                            Ok(s) => s,
169                            Err(e) => {
170                                warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
171                                *status.lock().expect("poisoned") =
172                                    SshTunnelStatus::Errored(e.to_string_with_causes());
173                                continue;
174                            }
175                        };
176                        let lp = match port_forward(&s, &remote_host, remote_port).await {
177                            Ok(lp) => lp,
178                            Err(e) => {
179                                warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
180                                *status.lock().expect("poisoned") =
181                                    SshTunnelStatus::Errored(e.to_string_with_causes());
182                                continue;
183                            }
184                        };
185                        session = s;
186                        local_port.store(lp, Ordering::SeqCst);
187                        *status.lock().expect("poisoned") = SshTunnelStatus::Running;
188                    }
189                }
190            }
191        });
192
193        Ok(SshTunnelHandle {
194            local_port,
195            status,
196            _join_handle: join_handle.abort_on_drop(),
197        })
198    }
199
200    pub async fn validate(&self, timeout_config: SshTimeoutConfig) -> Result<(), anyhow::Error> {
203        connect(self, timeout_config).await?;
204        Ok(())
205    }
206}
207
208#[derive(Debug)]
210pub struct SshTunnelHandle {
211    local_port: Arc<AtomicU16>,
212    status: Arc<Mutex<SshTunnelStatus>>,
213    _join_handle: AbortOnDropHandle<()>,
214}
215
216impl SshTunnelHandle {
217    pub fn local_addr(&self) -> SocketAddr {
219        let port = self.local_port.load(Ordering::SeqCst);
220        SocketAddr::from((Ipv4Addr::LOCALHOST, port))
224    }
225
226    pub fn check_status(&self) -> SshTunnelStatus {
231        self.status.lock().expect("poisoned").clone()
232    }
233}
234
235async fn connect(
236    config: &SshTunnelConfig,
237    timeout_config: SshTimeoutConfig,
238) -> Result<Session, anyhow::Error> {
239    let tempdir = tempfile::Builder::new()
240        .prefix("ssh-tunnel-key")
241        .tempdir()?;
242    let path = tempdir.path().join("key");
243    let mut tempfile = File::create(&path)?;
244    tempfile.set_permissions(std::fs::Permissions::from_mode(0o600))?;
246    tempfile.write_all(config.key_pair.ssh_private_key().as_bytes())?;
247    tempfile.set_permissions(std::fs::Permissions::from_mode(0o400))?;
250
251    let mut connect_err = None;
253    for host in &config.host {
254        match openssh::SessionBuilder::default()
259            .known_hosts_check(openssh::KnownHosts::Accept)
260            .user_known_hosts_file("/dev/null")
261            .user(config.user.clone())
262            .port(config.port)
263            .keyfile(&path)
264            .server_alive_interval(timeout_config.keepalives_idle)
265            .connect_timeout(timeout_config.connect_timeout)
266            .connect_mux(host.clone())
267            .await
268        {
269            Ok(session) => {
270                drop(tempfile);
273                fs::remove_file(&path)?;
274                drop(tempdir);
275
276                session.check().await?;
278
279                return Ok(session);
280            }
281            Err(err) => {
282                connect_err = Some(err);
283            }
284        }
285    }
286    Err(connect_err
287        .map(Into::into)
288        .unwrap_or_else(|| anyhow::anyhow!("no hosts to connect to")))
289}
290
291async fn port_forward(session: &Session, host: &str, port: u16) -> Result<u16, anyhow::Error> {
292    for _ in 0..50 {
294        let mut rng = StdRng::from_entropy();
296        let local_port: u16 = rng.gen_range(49152..65535);
297
298        let local = openssh::Socket::from((Ipv4Addr::LOCALHOST, local_port));
302        let remote = openssh::Socket::new(host, port);
303
304        match session
305            .request_port_forward(ForwardType::Local, local, remote)
306            .await
307        {
308            Ok(_) => return Ok(local_port),
309            Err(err) => match err {
310                openssh::Error::SshMux(openssh_mux_client::Error::RequestFailure(e))
311                    if &*e == "Port forwarding failed" =>
312                {
313                    info!("port {local_port} already in use; testing another port");
314                }
315                _ => {
316                    warn!("ssh connection failed: {}", err.display_with_causes());
317                    bail!("failed to open SSH tunnel: {}", err.display_with_causes())
318                }
319            },
320        };
321    }
322    bail!("failed to find an open port for SSH tunnel")
325}