1use std::collections::BTreeSet;
11use std::fmt;
12use std::fs::{self, File};
13use std::io::Write;
14use std::net::{Ipv4Addr, SocketAddr};
15use std::os::unix::fs::PermissionsExt;
16use std::sync::atomic::{AtomicU16, Ordering};
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19
20use anyhow::bail;
21use itertools::Itertools;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{self, AbortOnDropHandle};
24use openssh::{ForwardType, Session};
25use rand::rngs::StdRng;
26use rand::{Rng, SeedableRng};
27use serde::{Deserialize, Serialize};
28use tokio::time;
29use tracing::{info, warn};
30
31use crate::keys::SshKeyPair;
32
33pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
37pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
38
39pub const DEFAULT_KEEPALIVES_IDLE: Duration = Duration::from_secs(10);
42
43#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
45pub struct SshTimeoutConfig {
46 pub check_interval: Duration,
48 pub connect_timeout: Duration,
50 pub keepalives_idle: Duration,
54}
55
56impl Default for SshTimeoutConfig {
57 fn default() -> SshTimeoutConfig {
58 SshTimeoutConfig {
59 check_interval: DEFAULT_CHECK_INTERVAL,
60 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
61 keepalives_idle: DEFAULT_KEEPALIVES_IDLE,
62 }
63 }
64}
65
66#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
68pub struct SshTunnelConfig {
69 pub host: BTreeSet<String>,
72 pub port: u16,
74 pub user: String,
76 pub key_pair: SshKeyPair,
78}
79
80impl fmt::Debug for SshTunnelConfig {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.debug_struct("Tunnel")
83 .field("host", &self.host)
84 .field("port", &self.port)
85 .field("user", &self.user)
86 .finish()
88 }
89}
90
91impl fmt::Display for SshTunnelConfig {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 write!(
94 f,
95 "{}@{}:{}",
96 self.user,
97 self.host.iter().join(","),
98 self.port
99 )
100 }
101}
102
103#[derive(Clone, Debug)]
105pub enum SshTunnelStatus {
106 Running,
108 Errored(String),
110}
111
112impl SshTunnelConfig {
113 pub async fn connect(
119 &self,
120 remote_host: &str,
121 remote_port: u16,
122 timeout_config: SshTimeoutConfig,
123 ) -> Result<SshTunnelHandle, anyhow::Error> {
124 let tunnel_id = format!("{}:{} via {}", remote_host, remote_port, self);
125
126 info!(%tunnel_id, "connecting to ssh tunnel");
131 let mut session = match connect(self, timeout_config).await {
132 Ok(s) => s,
133 Err(e) => {
134 warn!(%tunnel_id, "failed to connect to ssh tunnel: {}", e.display_with_causes());
135 return Err(e);
136 }
137 };
138 let local_port = match port_forward(&session, remote_host, remote_port).await {
139 Ok(local_port) => local_port,
140 Err(e) => {
141 warn!(%tunnel_id, "failed to forward port through ssh tunnel: {}", e.display_with_causes());
142 return Err(e);
143 }
144 };
145 info!(%tunnel_id, %local_port, "connected to ssh tunnel");
146 let local_port = Arc::new(AtomicU16::new(local_port));
147 let status = Arc::new(Mutex::new(SshTunnelStatus::Running));
148
149 let join_handle = task::spawn(|| format!("ssh_session_{remote_host}:{remote_port}"), {
150 let config = self.clone();
151 let remote_host = remote_host.to_string();
152 let local_port = Arc::clone(&local_port);
153 let status = Arc::clone(&status);
154 async move {
155 scopeguard::defer! {
156 info!(%tunnel_id, "terminating ssh tunnel");
157 }
158 let mut interval = time::interval(timeout_config.check_interval);
159 interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
161 interval.tick().await;
163 loop {
164 interval.tick().await;
165 if let Err(e) = session.check().await {
166 warn!(%tunnel_id, "ssh tunnel unhealthy: {}", e.display_with_causes());
167 let s = match connect(&config, timeout_config).await {
168 Ok(s) => s,
169 Err(e) => {
170 warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
171 *status.lock().expect("poisoned") =
172 SshTunnelStatus::Errored(e.to_string_with_causes());
173 continue;
174 }
175 };
176 let lp = match port_forward(&s, &remote_host, remote_port).await {
177 Ok(lp) => lp,
178 Err(e) => {
179 warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
180 *status.lock().expect("poisoned") =
181 SshTunnelStatus::Errored(e.to_string_with_causes());
182 continue;
183 }
184 };
185 session = s;
186 local_port.store(lp, Ordering::SeqCst);
187 *status.lock().expect("poisoned") = SshTunnelStatus::Running;
188 }
189 }
190 }
191 });
192
193 Ok(SshTunnelHandle {
194 local_port,
195 status,
196 _join_handle: join_handle.abort_on_drop(),
197 })
198 }
199
200 pub async fn validate(&self, timeout_config: SshTimeoutConfig) -> Result<(), anyhow::Error> {
203 connect(self, timeout_config).await?;
204 Ok(())
205 }
206}
207
208#[derive(Debug)]
210pub struct SshTunnelHandle {
211 local_port: Arc<AtomicU16>,
212 status: Arc<Mutex<SshTunnelStatus>>,
213 _join_handle: AbortOnDropHandle<()>,
214}
215
216impl SshTunnelHandle {
217 pub fn local_addr(&self) -> SocketAddr {
219 let port = self.local_port.load(Ordering::SeqCst);
220 SocketAddr::from((Ipv4Addr::LOCALHOST, port))
224 }
225
226 pub fn check_status(&self) -> SshTunnelStatus {
231 self.status.lock().expect("poisoned").clone()
232 }
233}
234
235async fn connect(
236 config: &SshTunnelConfig,
237 timeout_config: SshTimeoutConfig,
238) -> Result<Session, anyhow::Error> {
239 let tempdir = tempfile::Builder::new()
240 .prefix("ssh-tunnel-key")
241 .tempdir()?;
242 let path = tempdir.path().join("key");
243 let mut tempfile = File::create(&path)?;
244 tempfile.set_permissions(std::fs::Permissions::from_mode(0o600))?;
246 tempfile.write_all(config.key_pair.ssh_private_key().as_bytes())?;
247 tempfile.set_permissions(std::fs::Permissions::from_mode(0o400))?;
250
251 let mut connect_err = None;
253 for host in &config.host {
254 match openssh::SessionBuilder::default()
259 .known_hosts_check(openssh::KnownHosts::Accept)
260 .user_known_hosts_file("/dev/null")
261 .user(config.user.clone())
262 .port(config.port)
263 .keyfile(&path)
264 .server_alive_interval(timeout_config.keepalives_idle)
265 .connect_timeout(timeout_config.connect_timeout)
266 .connect_mux(host.clone())
267 .await
268 {
269 Ok(session) => {
270 drop(tempfile);
273 fs::remove_file(&path)?;
274 drop(tempdir);
275
276 session.check().await?;
278
279 return Ok(session);
280 }
281 Err(err) => {
282 connect_err = Some(err);
283 }
284 }
285 }
286 Err(connect_err
287 .map(Into::into)
288 .unwrap_or_else(|| anyhow::anyhow!("no hosts to connect to")))
289}
290
291async fn port_forward(session: &Session, host: &str, port: u16) -> Result<u16, anyhow::Error> {
292 for _ in 0..50 {
294 let mut rng = StdRng::from_entropy();
296 let local_port: u16 = rng.gen_range(49152..65535);
297
298 let local = openssh::Socket::from((Ipv4Addr::LOCALHOST, local_port));
302 let remote = openssh::Socket::new(host, port);
303
304 match session
305 .request_port_forward(ForwardType::Local, local, remote)
306 .await
307 {
308 Ok(_) => return Ok(local_port),
309 Err(err) => match err {
310 openssh::Error::SshMux(openssh_mux_client::Error::RequestFailure(e))
311 if &*e == "Port forwarding failed" =>
312 {
313 info!("port {local_port} already in use; testing another port");
314 }
315 _ => {
316 warn!("ssh connection failed: {}", err.display_with_causes());
317 bail!("failed to open SSH tunnel: {}", err.display_with_causes())
318 }
319 },
320 };
321 }
322 bail!("failed to find an open port for SSH tunnel")
325}