Skip to main content

mz_ssh_util/
tunnel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::fmt;
12use std::fs::{self, File};
13use std::io::Write;
14use std::net::{Ipv4Addr, SocketAddr};
15use std::os::unix::fs::PermissionsExt;
16use std::sync::atomic::{AtomicU16, Ordering};
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19
20use anyhow::bail;
21use itertools::Itertools;
22use mz_ore::error::ErrorExt;
23use mz_ore::task::{self, AbortOnDropHandle};
24use openssh::{ForwardType, Session};
25use rand::rngs::StdRng;
26use rand::{Rng, SeedableRng};
27use serde::{Deserialize, Serialize};
28use tokio::time;
29use tracing::{info, warn};
30
31use crate::keys::SshKeyPair;
32
33// TODO(benesch): allow configuring the following connection parameters via
34// server configuration parameters.
35
36pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
37pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
38
39/// TCP idle timeouts of 30s are common in the wild. An idle timeout of 10s
40/// is comfortably beneath that threshold without being overly chatty.
41pub const DEFAULT_KEEPALIVES_IDLE: Duration = Duration::from_secs(10);
42
43/// Configuration of Ssh session and tunnel timeouts.
44#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
45pub struct SshTimeoutConfig {
46    /// How often to check whether the SSH session is still alive.
47    pub check_interval: Duration,
48    /// The timeout to use when establishing the connection to the SSH server.
49    pub connect_timeout: Duration,
50    /// The idle time after which the SSH control leader process should send a
51    /// keepalive packet to the SSH server to determine whether the server is
52    /// still alive.
53    pub keepalives_idle: Duration,
54}
55
56impl Default for SshTimeoutConfig {
57    fn default() -> SshTimeoutConfig {
58        SshTimeoutConfig {
59            check_interval: DEFAULT_CHECK_INTERVAL,
60            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
61            keepalives_idle: DEFAULT_KEEPALIVES_IDLE,
62        }
63    }
64}
65
66/// Specifies an SSH tunnel.
67#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
68pub struct SshTunnelConfig {
69    /// The hostname/IP of the SSH bastion server.
70    /// If multiple hosts are specified, they are tried in order.
71    pub host: BTreeSet<String>,
72    /// The port to connect to.
73    pub port: u16,
74    /// The name of the user to connect as.
75    pub user: String,
76    /// The SSH key pair to authenticate with.
77    pub key_pair: SshKeyPair,
78}
79
80impl fmt::Debug for SshTunnelConfig {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("Tunnel")
83            .field("host", &self.host)
84            .field("port", &self.port)
85            .field("user", &self.user)
86            // Omit keys from debug output.
87            .finish()
88    }
89}
90
91impl fmt::Display for SshTunnelConfig {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(
94            f,
95            "{}@{}:{}",
96            self.user,
97            self.host.iter().join(","),
98            self.port
99        )
100    }
101}
102
103/// The status of a running SSH tunnel.
104#[derive(Clone, Debug)]
105pub enum SshTunnelStatus {
106    /// The SSH tunnel is healthy.
107    Running,
108    /// The SSH tunnel is broken, with the given error message.
109    Errored(String),
110}
111
112impl SshTunnelConfig {
113    /// Establishes a connection to the specified host and port via the
114    /// configured SSH tunnel.
115    ///
116    /// Returns a handle to the SSH tunnel. The SSH tunnel is automatically shut
117    /// down when the handle is dropped.
118    pub async fn connect(
119        &self,
120        remote_host: &str,
121        remote_port: u16,
122        timeout_config: SshTimeoutConfig,
123    ) -> Result<SshTunnelHandle, anyhow::Error> {
124        let tunnel_id = format!("{}:{} via {}", remote_host, remote_port, self);
125
126        // N.B.
127        //
128        // We could probably move this into the look and use the above channel to report this
129        // initial connection error, but this is simpler and easier to read!
130        info!(%tunnel_id, "connecting to ssh tunnel");
131        let mut session = match connect(self, timeout_config).await {
132            Ok(s) => s,
133            Err(e) => {
134                warn!(%tunnel_id, "failed to connect to ssh tunnel: {}", e.display_with_causes());
135                return Err(e);
136            }
137        };
138        let local_port = match port_forward(&session, remote_host, remote_port).await {
139            Ok(local_port) => local_port,
140            Err(e) => {
141                warn!(%tunnel_id, "failed to forward port through ssh tunnel: {}", e.display_with_causes());
142                return Err(e);
143            }
144        };
145        info!(%tunnel_id, %local_port, "connected to ssh tunnel");
146        let local_port = Arc::new(AtomicU16::new(local_port));
147        let status = Arc::new(Mutex::new(SshTunnelStatus::Running));
148
149        let join_handle = task::spawn(|| format!("ssh_session_{remote_host}:{remote_port}"), {
150            let config = self.clone();
151            let remote_host = remote_host.to_string();
152            let local_port = Arc::clone(&local_port);
153            let status = Arc::clone(&status);
154            async move {
155                scopeguard::defer! {
156                    info!(%tunnel_id, "terminating ssh tunnel");
157                }
158                let mut interval = time::interval(timeout_config.check_interval);
159                // Just in case checking takes a long time.
160                interval.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
161                // The first tick happens immediately.
162                interval.tick().await;
163                loop {
164                    interval.tick().await;
165                    if let Err(e) = session.check().await {
166                        warn!(%tunnel_id, "ssh tunnel unhealthy: {}", e.display_with_causes());
167                        let s = match connect(&config, timeout_config).await {
168                            Ok(s) => s,
169                            Err(e) => {
170                                warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
171                                *status.lock().expect("poisoned") =
172                                    SshTunnelStatus::Errored(e.to_string_with_causes());
173                                continue;
174                            }
175                        };
176                        let lp = match port_forward(&s, &remote_host, remote_port).await {
177                            Ok(lp) => lp,
178                            Err(e) => {
179                                warn!(%tunnel_id, "reconnection to ssh tunnel failed: {}", e.display_with_causes());
180                                *status.lock().expect("poisoned") =
181                                    SshTunnelStatus::Errored(e.to_string_with_causes());
182                                continue;
183                            }
184                        };
185                        session = s;
186                        local_port.store(lp, Ordering::SeqCst);
187                        *status.lock().expect("poisoned") = SshTunnelStatus::Running;
188                    }
189                }
190            }
191        });
192
193        Ok(SshTunnelHandle {
194            local_port,
195            status,
196            _join_handle: join_handle.abort_on_drop(),
197        })
198    }
199
200    /// Validates the SSH configuration by establishing a connection to the intermediate SSH
201    /// bastion host. It does not set up a port forwarding tunnel.
202    pub async fn validate(&self, timeout_config: SshTimeoutConfig) -> Result<(), anyhow::Error> {
203        connect(self, timeout_config).await?;
204        Ok(())
205    }
206}
207
208/// A handle to a running SSH tunnel.
209#[derive(Debug)]
210pub struct SshTunnelHandle {
211    local_port: Arc<AtomicU16>,
212    status: Arc<Mutex<SshTunnelStatus>>,
213    _join_handle: AbortOnDropHandle<()>,
214}
215
216impl SshTunnelHandle {
217    /// Returns the local address at which the SSH tunnel is listening.
218    pub fn local_addr(&self) -> SocketAddr {
219        let port = self.local_port.load(Ordering::SeqCst);
220        // Force use of IPv4 loopback. Do not use the hostname `localhost`, as
221        // that can resolve to IPv6, and the SSH tunnel is only listening for
222        // IPv4 connections.
223        SocketAddr::from((Ipv4Addr::LOCALHOST, port))
224    }
225
226    /// Returns the current status of the SSH tunnel.
227    ///
228    /// Note this status may be stale, as the health of the underlying SSH
229    /// tunnel is only checked periodically.
230    pub fn check_status(&self) -> SshTunnelStatus {
231        self.status.lock().expect("poisoned").clone()
232    }
233}
234
235/// Returns true if FIPS mode is enabled via the MZ_FIPS environment variable.
236fn fips_mode_enabled() -> bool {
237    std::env::var("MZ_FIPS").map_or(false, |v| v == "1" || v == "true")
238}
239
240/// Writes a temporary SSH config file that restricts algorithms to FIPS 140-3
241/// approved choices only. Returns the path to the config file.
242fn write_fips_ssh_config(dir: &std::path::Path) -> Result<std::path::PathBuf, anyhow::Error> {
243    let config_path = dir.join("ssh_config");
244    let config_contents = "\
245# FIPS 140-3 compliant SSH configuration.
246# Only NIST-approved algorithms are permitted.
247Ciphers aes256-gcm@openssh.com,aes128-gcm@openssh.com,aes256-ctr,aes192-ctr,aes128-ctr
248KexAlgorithms ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512,diffie-hellman-group14-sha256
249MACs hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha2-256,hmac-sha2-512
250HostKeyAlgorithms ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-256,rsa-sha2-512
251PubkeyAcceptedAlgorithms ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-256,rsa-sha2-512,ssh-ed25519
252";
253    fs::write(&config_path, config_contents)?;
254    Ok(config_path)
255}
256
257async fn connect(
258    config: &SshTunnelConfig,
259    timeout_config: SshTimeoutConfig,
260) -> Result<Session, anyhow::Error> {
261    let tempdir = tempfile::Builder::new()
262        .prefix("ssh-tunnel-key")
263        .tempdir()?;
264    let path = tempdir.path().join("key");
265    let mut tempfile = File::create(&path)?;
266    // Grant read and write permissions on the file.
267    tempfile.set_permissions(std::fs::Permissions::from_mode(0o600))?;
268    tempfile.write_all(config.key_pair.ssh_private_key().as_bytes())?;
269    // Remove write permissions as soon as the key is written.
270    // Mostly helpful to ensure the file is not accidentally overwritten.
271    tempfile.set_permissions(std::fs::Permissions::from_mode(0o400))?;
272
273    // In FIPS mode, write a restrictive SSH config that only allows
274    // NIST-approved algorithms.
275    let fips_config_path = if fips_mode_enabled() {
276        Some(write_fips_ssh_config(tempdir.path())?)
277    } else {
278        None
279    };
280
281    // Try connecting to each host in turn.
282    let mut connect_err = None;
283    for host in &config.host {
284        // Bastion hosts (and therefore keys) tend to change, so we don't want
285        // to lock ourselves into trusting only the first we see. In any case,
286        // recording a known host would only last as long as the life of a
287        // storage pod, so it doesn't offer any protection.
288        let mut builder = openssh::SessionBuilder::default();
289        builder
290            .known_hosts_check(openssh::KnownHosts::Accept)
291            .user_known_hosts_file("/dev/null")
292            .user(config.user.clone())
293            .port(config.port)
294            .keyfile(&path)
295            .server_alive_interval(timeout_config.keepalives_idle)
296            .connect_timeout(timeout_config.connect_timeout);
297
298        if let Some(ref fips_config) = fips_config_path {
299            builder.config_file(fips_config);
300        }
301
302        match builder.connect_mux(host.clone()).await {
303            Ok(session) => {
304                // Delete the private key for safety: since `ssh` still has an open
305                // handle to it, it still has access to the key.
306                drop(tempfile);
307                fs::remove_file(&path)?;
308                drop(tempdir);
309
310                // Ensure session is healthy.
311                session.check().await?;
312
313                return Ok(session);
314            }
315            Err(err) => {
316                connect_err = Some(err);
317            }
318        }
319    }
320    Err(connect_err
321        .map(Into::into)
322        .unwrap_or_else(|| anyhow::anyhow!("no hosts to connect to")))
323}
324
325async fn port_forward(session: &Session, host: &str, port: u16) -> Result<u16, anyhow::Error> {
326    // Loop trying to find an open port.
327    for _ in 0..50 {
328        // Choose a dynamic port according to RFC 6335.
329        let mut rng = StdRng::from_os_rng();
330        let local_port: u16 = rng.random_range(49152..65535);
331
332        // Force use of IPv4 loopback. Do not use the hostname `localhost`,
333        // as that can resolve to IPv6, and the SSH tunnel is only listening
334        // for IPv4 connections.
335        let local = openssh::Socket::from((Ipv4Addr::LOCALHOST, local_port));
336        let remote = openssh::Socket::new(host, port);
337
338        match session
339            .request_port_forward(ForwardType::Local, local, remote)
340            .await
341        {
342            Ok(_) => return Ok(local_port),
343            Err(err) => match err {
344                openssh::Error::SshMux(openssh_mux_client::Error::RequestFailure(e))
345                    if &*e == "Port forwarding failed" =>
346                {
347                    info!("port {local_port} already in use; testing another port");
348                }
349                _ => {
350                    warn!("ssh connection failed: {}", err.display_with_causes());
351                    bail!("failed to open SSH tunnel: {}", err.display_with_causes())
352                }
353            },
354        };
355    }
356    // If we failed to find an open port after 50 attempts,
357    // something is seriously wrong.
358    bail!("failed to find an open port for SSH tunnel")
359}