mz_ssh_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::fmt;
12use std::fs::{self, File};
13use std::io::Write;
14use std::net::{Ipv4Addr, SocketAddr};
15use std::os::unix::fs::PermissionsExt;
16use std::sync::atomic::{AtomicU16, Ordering};
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19
20use anyhow::bail;
21use itertools::Itertools;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{self, AbortOnDropHandle};
24use openssh::{ForwardType, Session};
25use rand::rngs::StdRng;
26use rand::{Rng, SeedableRng};
27use serde::{Deserialize, Serialize};
28use tokio::time;
29use tracing::{info, warn};
30
31use crate::keys::SshKeyPair;
32
33// TODO(benesch): allow configuring the following connection parameters via
34// server configuration parameters.
35
36pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
37pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
38
39/// TCP idle timeouts of 30s are common in the wild. An idle timeout of 10s
40/// is comfortably beneath that threshold without being overly chatty.
41pub const DEFAULT_KEEPALIVES_IDLE: Duration = Duration::from_secs(10);
42
43/// Configuration of Ssh session and tunnel timeouts.
44#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
45pub struct SshTimeoutConfig {
46    /// How often to check whether the SSH session is still alive.
47    pub check_interval: Duration,
48    /// The timeout to use when establishing the connection to the SSH server.
49    pub connect_timeout: Duration,
50    /// The idle time after which the SSH control leader process should send a
51    /// keepalive packet to the SSH server to determine whether the server is
52    /// still alive.
53    pub keepalives_idle: Duration,
54}
55
56impl Default for SshTimeoutConfig {
57    fn default() -> SshTimeoutConfig {
58        SshTimeoutConfig {
59            check_interval: DEFAULT_CHECK_INTERVAL,
60            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
61            keepalives_idle: DEFAULT_KEEPALIVES_IDLE,
62        }
63    }
64}
65
66/// Specifies an SSH tunnel.
67#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
68pub struct SshTunnelConfig {
69    /// The hostname/IP of the SSH bastion server.
70    /// If multiple hosts are specified, they are tried in order.
71    pub host: BTreeSet<String>,
72    /// The port to connect to.
73    pub port: u16,
74    /// The name of the user to connect as.
75    pub user: String,
76    /// The SSH key pair to authenticate with.
77    pub key_pair: SshKeyPair,
78}
79
80impl fmt::Debug for SshTunnelConfig {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("Tunnel")
83            .field("host", &self.host)
84            .field("port", &self.port)
85            .field("user", &self.user)
86            // Omit keys from debug output.
87            .finish()
88    }
89}
90
91impl fmt::Display for SshTunnelConfig {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(
94            f,
95            "{}@{}:{}",
96            self.user,
97            self.host.iter().join(","),
98            self.port
99        )
100    }
101}
102
103/// The status of a running SSH tunnel.
104#[derive(Clone, Debug)]
105pub enum SshTunnelStatus {
106    /// The SSH tunnel is healthy.
107    Running,
108    /// The SSH tunnel is broken, with the given error message.
109    Errored(String),
110}
111
112impl SshTunnelConfig {
113    /// Establishes a connection to the specified host and port via the
114    /// configured SSH tunnel.
115    ///
116    /// Returns a handle to the SSH tunnel. The SSH tunnel is automatically shut
117    /// down when the handle is dropped.
118    pub async fn connect(
119        &self,
120        remote_host: &str,
121        remote_port: u16,
122        timeout_config: SshTimeoutConfig,
123    ) -> Result<SshTunnelHandle, anyhow::Error> {
124        let tunnel_id = format!("{}:{} via {}", remote_host, remote_port, self);
125
126        // N.B.
127        //
128        // We could probably move this into the look and use the above channel to report this
129        // initial connection error, but this is simpler and easier to read!
130        info!(%tunnel_id, "connecting to ssh tunnel");
131        let mut session = match connect(self, timeout_config).await {
132            Ok(s) => s,
133            Err(e) => {
134                warn!(%tunnel_id, "failed to connect to ssh tunnel: {}", e.display_with_causes());
135                return Err(e);
136            }
137        };
138        let local_port = match port_forward(&session, remote_host, remote_port).await {
139            Ok(local_port) => local_port,
140            Err(e) => {
141                warn!(%tunnel_id, "failed to forward port through ssh tunnel: {}", e.display_with_causes());
142                return Err(e);
143            }
144        };
145        info!(%tunnel_id, %local_port, "connected to ssh tunnel");
146        let local_port = Arc::new(AtomicU16::new(local_port));
147        let status = Arc::new(Mutex::new(SshTunnelStatus::Running));
148
149        let join_handle = task::spawn(|| format!("ssh_session_{remote_host}:{remote_port}"), {
150            let config = self.clone();
151            let remote_host = remote_host.to_string();
152            let local_port = Arc::clone(&local_port);
153            let status = Arc::clone(&status);
154            async move {
155                scopeguard::defer! {
156                    info!(%tunnel_id, "terminating ssh tunnel");
157                }
158                let mut interval = time::interval(timeout_config.check_interval);
159                // Just in case checking takes a long time.
160                interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
161                // The first tick happens immediately.
162                interval.tick().await;
163                loop {
164                    interval.tick().await;
165                    if let Err(e) = session.check().await {
166                        warn!(%tunnel_id, "ssh tunnel unhealthy: {}", e.display_with_causes());
167                        let s = match connect(&config, timeout_config).await {
168                            Ok(s) => s,
169                            Err(e) => {
170                                warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
171                                *status.lock().expect("poisoned") =
172                                    SshTunnelStatus::Errored(e.to_string_with_causes());
173                                continue;
174                            }
175                        };
176                        let lp = match port_forward(&s, &remote_host, remote_port).await {
177                            Ok(lp) => lp,
178                            Err(e) => {
179                                warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
180                                *status.lock().expect("poisoned") =
181                                    SshTunnelStatus::Errored(e.to_string_with_causes());
182                                continue;
183                            }
184                        };
185                        session = s;
186                        local_port.store(lp, Ordering::SeqCst);
187                        *status.lock().expect("poisoned") = SshTunnelStatus::Running;
188                    }
189                }
190            }
191        });
192
193        Ok(SshTunnelHandle {
194            local_port,
195            status,
196            _join_handle: join_handle.abort_on_drop(),
197        })
198    }
199
200    /// Validates the SSH configuration by establishing a connection to the intermediate SSH
201    /// bastion host. It does not set up a port forwarding tunnel.
202    pub async fn validate(&self, timeout_config: SshTimeoutConfig) -> Result<(), anyhow::Error> {
203        connect(self, timeout_config).await?;
204        Ok(())
205    }
206}
207
208/// A handle to a running SSH tunnel.
209#[derive(Debug)]
210pub struct SshTunnelHandle {
211    local_port: Arc<AtomicU16>,
212    status: Arc<Mutex<SshTunnelStatus>>,
213    _join_handle: AbortOnDropHandle<()>,
214}
215
216impl SshTunnelHandle {
217    /// Returns the local address at which the SSH tunnel is listening.
218    pub fn local_addr(&self) -> SocketAddr {
219        let port = self.local_port.load(Ordering::SeqCst);
220        // Force use of IPv4 loopback. Do not use the hostname `localhost`, as
221        // that can resolve to IPv6, and the SSH tunnel is only listening for
222        // IPv4 connections.
223        SocketAddr::from((Ipv4Addr::LOCALHOST, port))
224    }
225
226    /// Returns the current status of the SSH tunnel.
227    ///
228    /// Note this status may be stale, as the health of the underlying SSH
229    /// tunnel is only checked periodically.
230    pub fn check_status(&self) -> SshTunnelStatus {
231        self.status.lock().expect("poisoned").clone()
232    }
233}
234
235async fn connect(
236    config: &SshTunnelConfig,
237    timeout_config: SshTimeoutConfig,
238) -> Result<Session, anyhow::Error> {
239    let tempdir = tempfile::Builder::new()
240        .prefix("ssh-tunnel-key")
241        .tempdir()?;
242    let path = tempdir.path().join("key");
243    let mut tempfile = File::create(&path)?;
244    // Grant read and write permissions on the file.
245    tempfile.set_permissions(std::fs::Permissions::from_mode(0o600))?;
246    tempfile.write_all(config.key_pair.ssh_private_key().as_bytes())?;
247    // Remove write permissions as soon as the key is written.
248    // Mostly helpful to ensure the file is not accidentally overwritten.
249    tempfile.set_permissions(std::fs::Permissions::from_mode(0o400))?;
250
251    // Try connecting to each host in turn.
252    let mut connect_err = None;
253    for host in &config.host {
254        // Bastion hosts (and therefore keys) tend to change, so we don't want
255        // to lock ourselves into trusting only the first we see. In any case,
256        // recording a known host would only last as long as the life of a
257        // storage pod, so it doesn't offer any protection.
258        match openssh::SessionBuilder::default()
259            .known_hosts_check(openssh::KnownHosts::Accept)
260            .user_known_hosts_file("/dev/null")
261            .user(config.user.clone())
262            .port(config.port)
263            .keyfile(&path)
264            .server_alive_interval(timeout_config.keepalives_idle)
265            .connect_timeout(timeout_config.connect_timeout)
266            .connect_mux(host.clone())
267            .await
268        {
269            Ok(session) => {
270                // Delete the private key for safety: since `ssh` still has an open
271                // handle to it, it still has access to the key.
272                drop(tempfile);
273                fs::remove_file(&path)?;
274                drop(tempdir);
275
276                // Ensure session is healthy.
277                session.check().await?;
278
279                return Ok(session);
280            }
281            Err(err) => {
282                connect_err = Some(err);
283            }
284        }
285    }
286    Err(connect_err
287        .map(Into::into)
288        .unwrap_or_else(|| anyhow::anyhow!("no hosts to connect to")))
289}
290
291async fn port_forward(session: &Session, host: &str, port: u16) -> Result<u16, anyhow::Error> {
292    // Loop trying to find an open port.
293    for _ in 0..50 {
294        // Choose a dynamic port according to RFC 6335.
295        let mut rng = StdRng::from_entropy();
296        let local_port: u16 = rng.gen_range(49152..65535);
297
298        // Force use of IPv4 loopback. Do not use the hostname `localhost`,
299        // as that can resolve to IPv6, and the SSH tunnel is only listening
300        // for IPv4 connections.
301        let local = openssh::Socket::from((Ipv4Addr::LOCALHOST, local_port));
302        let remote = openssh::Socket::new(host, port);
303
304        match session
305            .request_port_forward(ForwardType::Local, local, remote)
306            .await
307        {
308            Ok(_) => return Ok(local_port),
309            Err(err) => match err {
310                openssh::Error::SshMux(openssh_mux_client::Error::RequestFailure(e))
311                    if &*e == "Port forwarding failed" =>
312                {
313                    info!("port {local_port} already in use; testing another port");
314                }
315                _ => {
316                    warn!("ssh connection failed: {}", err.display_with_causes());
317                    bail!("failed to open SSH tunnel: {}", err.display_with_causes())
318                }
319            },
320        };
321    }
322    // If we failed to find an open port after 50 attempts,
323    // something is seriously wrong.
324    bail!("failed to find an open port for SSH tunnel")
325}