mz_ssh_util/tunnel_manager.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! SSH tunnel management.
11
12// NOTE(benesch): The synchronization in this module is tricky because SSH
13// tunnels 1) require an async `connect` method that can return errors and 2)
14// once connected, launch a long-running background task whose handle must be
15// managed. The manager would be far simpler if `connect` was neither async nor
16// fallible and instead synchronously returned a handle to the background task.
17// That would require a different means of asynchronously reporting SSH tunnel
18// errors, though, and that's a large project. A worthwhile project, though: at
19// present SSH tunnel errors that occur after the initial connection are
20// reported only to the logs, and not to users.
21
22use std::collections::{BTreeMap, btree_map};
23use std::ops::Deref;
24use std::sync::{Arc, Mutex};
25
26use mz_ore::future::{InTask, OreFutureExt};
27use scopeguard::ScopeGuard;
28use tokio::sync::watch;
29use tracing::{error, info};
30
31use crate::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelHandle, SshTunnelStatus};
32
33/// Thread-safe manager of SSH tunnel connections.
34#[derive(Debug, Clone, Default)]
35pub struct SshTunnelManager {
36 tunnels: Arc<Mutex<BTreeMap<SshTunnelKey, SshTunnelState>>>,
37}
38
39impl SshTunnelManager {
40 /// Establishes an SSH tunnel for the given remote host and port using the
41 /// provided `tunnel` configuration.
42 ///
43 /// If there is an existing SSH tunnel, a handle to that tunnel is returned,
44 /// rather than establishing a new tunnel.
45 ///
46 /// The manager guarantees that there will never be more than one in flight
47 /// connection attempt for the same tunnel, even when this method is called
48 /// concurrently from multiple threads.
49 pub async fn connect(
50 &self,
51 config: SshTunnelConfig,
52 remote_host: &str,
53 remote_port: u16,
54 // This could be held behind a lock and updated within the global `SshTunnelManager`, but
55 // requiring all configuration at connection time is more consistent with how other
56 // connections work within the workspace.
57 timeout_config: SshTimeoutConfig,
58 // Whether or not to connect to ssh from a Tokio task (to ensure futures are
59 // polled promptly).
60 in_task: InTask,
61 ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
62 // An SSH tunnel connection is uniquely identified by the SSH tunnel
63 // configuration and the remote address.
64 let key = SshTunnelKey {
65 config: config.clone(),
66 remote_host: remote_host.to_string(),
67 remote_port,
68 };
69
70 loop {
71 // NOTE: this code is structured awkwardly to convince rustc that
72 // the lock is not held across an await point. rustc's analysis
73 // does not take into account explicit `drop` calls, so we have to
74 // structure such that the lock guard goes out of scope.
75 // See: https://github.com/rust-lang/rust/issues/69663
76 enum Action {
77 Return(ManagedSshTunnelHandle),
78 AwaitConnection(watch::Receiver<()>),
79 StartConnection(watch::Sender<()>),
80 }
81
82 let action = match self
83 .tunnels
84 .lock()
85 .expect("lock poisoned")
86 .entry(key.clone())
87 {
88 btree_map::Entry::Occupied(mut occupancy) => match occupancy.get_mut() {
89 // There is an existing tunnel.
90 SshTunnelState::Connected(handle) => Action::Return(ManagedSshTunnelHandle {
91 handle: Arc::clone(handle),
92 manager: self.clone(),
93 key: key.clone(),
94 }),
95 // There is an existing connection attempt.
96 SshTunnelState::Connecting(rx) => Action::AwaitConnection(rx.clone()),
97 },
98 btree_map::Entry::Vacant(vacancy) => {
99 // There is no existing tunnel or connection attempt. Record
100 // that we're starting one.
101 let (tx, rx) = watch::channel(());
102 vacancy.insert(SshTunnelState::Connecting(rx));
103 Action::StartConnection(tx)
104 }
105 };
106
107 match action {
108 Action::Return(handle) => {
109 if let SshTunnelStatus::Errored(e) = handle.check_status() {
110 error!(
111 "not using existing ssh tunnel \
112 ({}:{} via {}) because it's broken: {e}",
113 remote_host, remote_port, config
114 );
115
116 // This is bit unfortunate, as this method returns an
117 // `anyhow::Error`, but the SSH status needs to share a
118 // cloneable `String`. So we just package up the
119 // pre-`.to_string_with_causes()` error that is at the
120 // bottom of the stack. In the future we can probably
121 // make ALL SSH errors structured to avoid this.
122 return Err(anyhow::anyhow!(e));
123 }
124
125 info!(
126 "reusing existing ssh tunnel ({}:{} via {})",
127 remote_host, remote_port, config
128 );
129 return Ok(handle);
130 }
131 Action::AwaitConnection(mut rx) => {
132 // Wait for the connection attempt to finish. The next turn
133 // of the loop will determine whether the connection attempt
134 // succeeded or failed and proceed accordingly.
135 let _ = rx.changed().await;
136 }
137 Action::StartConnection(_tx) => {
138 // IMPORTANT: clear the `Connecting` state on scope exit.
139 // This is *required* for cancel safety. If the future is
140 // dropped at the following await point, we need to record
141 // that we are no longer attemping the connection.
142 let guard = scopeguard::guard((), |()| {
143 let mut tunnels = self.tunnels.lock().expect("lock poisoned");
144 tunnels.remove(&key);
145 });
146
147 // Try to connect.
148 info!(
149 "initiating new ssh tunnel ({}:{} via {})",
150 remote_host, remote_port, config
151 );
152
153 let config = config.clone();
154 let remote_host = remote_host.to_string();
155 let handle = async move {
156 config
157 .connect(&remote_host, remote_port, timeout_config)
158 .await
159 }
160 .run_in_task_if(in_task, || "ssh_connect".to_string())
161 .await?;
162
163 // Successful connection, so defuse the scope guard.
164 let _ = ScopeGuard::into_inner(guard);
165
166 // Record the tunnel handle for future threads.
167 let handle = Arc::new(handle);
168 let mut tunnels = self.tunnels.lock().expect("lock poisoned");
169 tunnels.insert(key.clone(), SshTunnelState::Connected(Arc::clone(&handle)));
170
171 // Return a handle to the tunnel.
172 return Ok(ManagedSshTunnelHandle {
173 handle,
174 manager: self.clone(),
175 key: key.clone(),
176 });
177 }
178 }
179 }
180 }
181}
182
183/// Identifies a connection to a remote host via an SSH tunnel.
184/// There are a couple of edge cases where this key format may result
185/// in extra connections being created:
186/// 1. If a host resolves to a different number of ips on different workers
187/// 2. Different workers connect to different upstream resolved ips if they
188/// appear connectable at different times.
189#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
190struct SshTunnelKey {
191 config: SshTunnelConfig,
192 remote_host: String,
193 remote_port: u16,
194}
195
196/// The state of an SSH tunnel connection.
197///
198/// There is an additional state not represented by this enum, which is the
199/// absence of an entry in the map entirely, indicating there is neither an
200/// existing tunnel nor an existing connection attempt.
201#[derive(Debug)]
202enum SshTunnelState {
203 /// An existing thread is connecting to the tunnel.
204 ///
205 /// The managing thread will resolve the enclosed future when the connection
206 /// attempt is complete. Only the thread that entered the `Connecting` state
207 /// is allowed to move out of this state.
208 Connecting(watch::Receiver<()>),
209 /// An existing thread has successfully established the tunnel.
210 ///
211 /// Only the last `ManagedSshTunnelHandle` is allowed to move out of this
212 /// state.
213 Connected(Arc<SshTunnelHandle>),
214}
215
216/// A clonable handle to an SSH tunnel managed by an [`SshTunnelManager`].
217///
218/// The tunnel will be automatically closed when all handles are dropped.
219#[derive(Debug, Clone)]
220pub struct ManagedSshTunnelHandle {
221 handle: Arc<SshTunnelHandle>,
222 manager: SshTunnelManager,
223 key: SshTunnelKey,
224}
225
226impl Deref for ManagedSshTunnelHandle {
227 type Target = SshTunnelHandle;
228
229 fn deref(&self) -> &SshTunnelHandle {
230 &self.handle
231 }
232}
233
234impl Drop for ManagedSshTunnelHandle {
235 fn drop(&mut self) {
236 let mut tunnels = self.manager.tunnels.lock().expect("lock poisoned");
237 // If there are only two strong references, the manager holds one and we
238 // hold the other, so this is the last handle.
239 //
240 // IMPORTANT: We must be holding the lock when we perform this check, to
241 // ensure no other threads can acquire a new handle via the manager.
242 if Arc::strong_count(&self.handle) == 2 {
243 tunnels.remove(&self.key);
244 }
245 }
246}