1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
// Copyright Materialize, Inc. and contributors. All rights reserved.
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.
//! SSH tunnel management.
// NOTE(benesch): The synchronization in this module is tricky because SSH
// tunnels 1) require an async `connect` method that can return errors and 2)
// once connected, launch a long-running background task whose handle must be
// managed. The manager would be far simpler if `connect` was neither async nor
// fallible and instead synchronously returned a handle to the background task.
// That would require a different means of asynchronously reporting SSH tunnel
// errors, though, and that's a large project. A worthwhile project, though: at
// present SSH tunnel errors that occur after the initial connection are
// reported only to the logs, and not to users.
use std::collections::{btree_map, BTreeMap};
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use mz_ore::future::{InTask, OreFutureExt};
use scopeguard::ScopeGuard;
use tokio::sync::watch;
use tracing::{error, info};
use crate::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelHandle, SshTunnelStatus};
/// Thread-safe manager of SSH tunnel connections.
#[derive(Debug, Clone, Default)]
pub struct SshTunnelManager {
tunnels: Arc<Mutex<BTreeMap<SshTunnelKey, SshTunnelState>>>,
impl SshTunnelManager {
/// Establishes an SSH tunnel for the given remote host and port using the
/// provided `tunnel` configuration.
/// If there is an existing SSH tunnel, a handle to that tunnel is returned,
/// rather than establishing a new tunnel.
/// The manager guarantees that there will never be more than one in flight
/// connection attempt for the same tunnel, even when this method is called
/// concurrently from multiple threads.
pub async fn connect(
config: SshTunnelConfig,
remote_host: &str,
remote_port: u16,
// This could be held behind a lock and updated within the global `SshTunnelManager`, but
// requiring all configuration at connection time is more consistent with how other
// connections work within the workspace.
timeout_config: SshTimeoutConfig,
// Whether or not to connect to ssh from a Tokio task (to ensure futures are
// polled promptly).
in_task: InTask,
) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
// An SSH tunnel connection is uniquely identified by the SSH tunnel
// configuration and the remote address.
let key = SshTunnelKey {
config: config.clone(),
remote_host: remote_host.to_string(),
loop {
// NOTE: this code is structured awkwardly to convince rustc that
// the lock is not held across an await point. rustc's analysis
// does not take into account explicit `drop` calls, so we have to
// structure such that the lock guard goes out of scope.
// See:
enum Action {
let action = match self
.expect("lock poisoned")
btree_map::Entry::Occupied(mut occupancy) => match occupancy.get_mut() {
// There is an existing tunnel.
SshTunnelState::Connected(handle) => Action::Return(ManagedSshTunnelHandle {
handle: Arc::clone(handle),
manager: self.clone(),
key: key.clone(),
// There is an existing connection attempt.
SshTunnelState::Connecting(rx) => Action::AwaitConnection(rx.clone()),
btree_map::Entry::Vacant(vacancy) => {
// There is no existing tunnel or connection attempt. Record
// that we're starting one.
let (tx, rx) = watch::channel(());
match action {
Action::Return(handle) => {
if let SshTunnelStatus::Errored(e) = handle.check_status() {
"not using existing ssh tunnel \
({}:{} via {}) because it's broken: {e}",
remote_host, remote_port, config
// This is bit unfortunate, as this method returns an
// `anyhow::Error`, but the SSH status needs to share a
// cloneable `String`. So we just package up the
// pre-`.to_string_with_causes()` error that is at the
// bottom of the stack. In the future we can probably
// make ALL SSH errors structured to avoid this.
return Err(anyhow::anyhow!(e));
"reusing existing ssh tunnel ({}:{} via {})",
remote_host, remote_port, config
return Ok(handle);
Action::AwaitConnection(mut rx) => {
// Wait for the connection attempt to finish. The next turn
// of the loop will determine whether the connection attempt
// succeeded or failed and proceed accordingly.
let _ = rx.changed().await;
Action::StartConnection(_tx) => {
// IMPORTANT: clear the `Connecting` state on scope exit.
// This is *required* for cancel safety. If the future is
// dropped at the following await point, we need to record
// that we are no longer attemping the connection.
let guard = scopeguard::guard((), |()| {
let mut tunnels = self.tunnels.lock().expect("lock poisoned");
// Try to connect.
"initiating new ssh tunnel ({}:{} via {})",
remote_host, remote_port, config
let config = config.clone();
let remote_host = remote_host.to_string();
let handle = async move {
.connect(&remote_host, remote_port, timeout_config)
.run_in_task_if(in_task, || "ssh_connect".to_string())
// Successful connection, so defuse the scope guard.
let _ = ScopeGuard::into_inner(guard);
// Record the tunnel handle for future threads.
let handle = Arc::new(handle);
let mut tunnels = self.tunnels.lock().expect("lock poisoned");
tunnels.insert(key.clone(), SshTunnelState::Connected(Arc::clone(&handle)));
// Return a handle to the tunnel.
return Ok(ManagedSshTunnelHandle {
manager: self.clone(),
key: key.clone(),
/// Identifies a connection to a remote host via an SSH tunnel.
/// There are a couple of edge cases where this key format may result
/// in extra connections being created:
/// 1. If a host resolves to a different number of ips on different workers
/// 2. Different workers connect to different upstream resolved ips if they
/// appear connectable at different times.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
struct SshTunnelKey {
config: SshTunnelConfig,
remote_host: String,
remote_port: u16,
/// The state of an SSH tunnel connection.
/// There is an additional state not represented by this enum, which is the
/// absence of an entry in the map entirely, indicating there is neither an
/// existing tunnel nor an existing connection attempt.
enum SshTunnelState {
/// An existing thread is connecting to the tunnel.
/// The managing thread will resolve the enclosed future when the connection
/// attempt is complete. Only the thread that entered the `Connecting` state
/// is allowed to move out of this state.
/// An existing thread has successfully established the tunnel.
/// Only the last `ManagedSshTunnelHandle` is allowed to move out of this
/// state.
/// A clonable handle to an SSH tunnel managed by an [`SshTunnelManager`].
/// The tunnel will be automatically closed when all handles are dropped.
#[derive(Debug, Clone)]
pub struct ManagedSshTunnelHandle {
handle: Arc<SshTunnelHandle>,
manager: SshTunnelManager,
key: SshTunnelKey,
impl Deref for ManagedSshTunnelHandle {
type Target = SshTunnelHandle;
fn deref(&self) -> &SshTunnelHandle {
impl Drop for ManagedSshTunnelHandle {
fn drop(&mut self) {
let mut tunnels = self.manager.tunnels.lock().expect("lock poisoned");
// If there are only two strong references, the manager holds one and we
// hold the other, so this is the last handle.
// IMPORTANT: We must be holding the lock when we perform this check, to
// ensure no other threads can acquire a new handle via the manager.
if Arc::strong_count(&self.handle) == 2 {