mz_ssh_util/
tunnel_manager.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! SSH tunnel management.
11
12// NOTE(benesch): The synchronization in this module is tricky because SSH
13// tunnels 1) require an async `connect` method that can return errors and 2)
14// once connected, launch a long-running background task whose handle must be
15// managed. The manager would be far simpler if `connect` was neither async nor
16// fallible and instead synchronously returned a handle to the background task.
17// That would require a different means of asynchronously reporting SSH tunnel
18// errors, though, and that's a large project. A worthwhile project, though: at
19// present SSH tunnel errors that occur after the initial connection are
20// reported only to the logs, and not to users.
21
22use std::collections::{BTreeMap, btree_map};
23use std::ops::Deref;
24use std::sync::{Arc, Mutex};
25
26use mz_ore::future::{InTask, OreFutureExt};
27use scopeguard::ScopeGuard;
28use tokio::sync::watch;
29use tracing::{error, info};
30
31use crate::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelHandle, SshTunnelStatus};
32
33/// Thread-safe manager of SSH tunnel connections.
34#[derive(Debug, Clone, Default)]
35pub struct SshTunnelManager {
36    tunnels: Arc<Mutex<BTreeMap<SshTunnelKey, SshTunnelState>>>,
37}
38
39impl SshTunnelManager {
40    /// Establishes an SSH tunnel for the given remote host and port using the
41    /// provided `tunnel` configuration.
42    ///
43    /// If there is an existing SSH tunnel, a handle to that tunnel is returned,
44    /// rather than establishing a new tunnel.
45    ///
46    /// The manager guarantees that there will never be more than one in flight
47    /// connection attempt for the same tunnel, even when this method is called
48    /// concurrently from multiple threads.
49    pub async fn connect(
50        &self,
51        config: SshTunnelConfig,
52        remote_host: &str,
53        remote_port: u16,
54        // This could be held behind a lock and updated within the global `SshTunnelManager`, but
55        // requiring all configuration at connection time is more consistent with how other
56        // connections work within the workspace.
57        timeout_config: SshTimeoutConfig,
58        // Whether or not to connect to ssh from a Tokio task (to ensure futures are
59        // polled promptly).
60        in_task: InTask,
61    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
62        // An SSH tunnel connection is uniquely identified by the SSH tunnel
63        // configuration and the remote address.
64        let key = SshTunnelKey {
65            config: config.clone(),
66            remote_host: remote_host.to_string(),
67            remote_port,
68        };
69
70        loop {
71            // NOTE: this code is structured awkwardly to convince rustc that
72            // the lock is not held across an await point. rustc's analysis
73            // does not take into account explicit `drop` calls, so we have to
74            // structure such that the lock guard goes out of scope.
75            // See: https://github.com/rust-lang/rust/issues/69663
76            enum Action {
77                Return(ManagedSshTunnelHandle),
78                AwaitConnection(watch::Receiver<()>),
79                StartConnection(watch::Sender<()>),
80            }
81
82            let action = match self
83                .tunnels
84                .lock()
85                .expect("lock poisoned")
86                .entry(key.clone())
87            {
88                btree_map::Entry::Occupied(mut occupancy) => match occupancy.get_mut() {
89                    // There is an existing tunnel.
90                    SshTunnelState::Connected(handle) => Action::Return(ManagedSshTunnelHandle {
91                        handle: Arc::clone(handle),
92                        manager: self.clone(),
93                        key: key.clone(),
94                    }),
95                    // There is an existing connection attempt.
96                    SshTunnelState::Connecting(rx) => Action::AwaitConnection(rx.clone()),
97                },
98                btree_map::Entry::Vacant(vacancy) => {
99                    // There is no existing tunnel or connection attempt. Record
100                    // that we're starting one.
101                    let (tx, rx) = watch::channel(());
102                    vacancy.insert(SshTunnelState::Connecting(rx));
103                    Action::StartConnection(tx)
104                }
105            };
106
107            match action {
108                Action::Return(handle) => {
109                    if let SshTunnelStatus::Errored(e) = handle.check_status() {
110                        error!(
111                            "not using existing ssh tunnel \
112                            ({}:{} via {}) because it's broken: {e}",
113                            remote_host, remote_port, config
114                        );
115
116                        // This is bit unfortunate, as this method returns an
117                        // `anyhow::Error`, but the SSH status needs to share a
118                        // cloneable `String`. So we just package up the
119                        // pre-`.to_string_with_causes()` error that is at the
120                        // bottom of the stack. In the future we can probably
121                        // make ALL SSH errors structured to avoid this.
122                        return Err(anyhow::anyhow!(e));
123                    }
124
125                    info!(
126                        "reusing existing ssh tunnel ({}:{} via {})",
127                        remote_host, remote_port, config
128                    );
129                    return Ok(handle);
130                }
131                Action::AwaitConnection(mut rx) => {
132                    // Wait for the connection attempt to finish. The next turn
133                    // of the loop will determine whether the connection attempt
134                    // succeeded or failed and proceed accordingly.
135                    let _ = rx.changed().await;
136                }
137                Action::StartConnection(_tx) => {
138                    // IMPORTANT: clear the `Connecting` state on scope exit.
139                    // This is *required* for cancel safety. If the future is
140                    // dropped at the following await point, we need to record
141                    // that we are no longer attemping the connection.
142                    let guard = scopeguard::guard((), |()| {
143                        let mut tunnels = self.tunnels.lock().expect("lock poisoned");
144                        tunnels.remove(&key);
145                    });
146
147                    // Try to connect.
148                    info!(
149                        "initiating new ssh tunnel ({}:{} via {})",
150                        remote_host, remote_port, config
151                    );
152
153                    let config = config.clone();
154                    let remote_host = remote_host.to_string();
155                    let handle = async move {
156                        config
157                            .connect(&remote_host, remote_port, timeout_config)
158                            .await
159                    }
160                    .run_in_task_if(in_task, || "ssh_connect".to_string())
161                    .await?;
162
163                    // Successful connection, so defuse the scope guard.
164                    let _ = ScopeGuard::into_inner(guard);
165
166                    // Record the tunnel handle for future threads.
167                    let handle = Arc::new(handle);
168                    let mut tunnels = self.tunnels.lock().expect("lock poisoned");
169                    tunnels.insert(key.clone(), SshTunnelState::Connected(Arc::clone(&handle)));
170
171                    // Return a handle to the tunnel.
172                    return Ok(ManagedSshTunnelHandle {
173                        handle,
174                        manager: self.clone(),
175                        key: key.clone(),
176                    });
177                }
178            }
179        }
180    }
181}
182
183/// Identifies a connection to a remote host via an SSH tunnel.
184/// There are a couple of edge cases where this key format may result
185/// in extra connections being created:
186/// 1. If a host resolves to a different number of ips on different workers
187/// 2. Different workers connect to different upstream resolved ips if they
188/// appear connectable at different times.
189#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
190struct SshTunnelKey {
191    config: SshTunnelConfig,
192    remote_host: String,
193    remote_port: u16,
194}
195
196/// The state of an SSH tunnel connection.
197///
198/// There is an additional state not represented by this enum, which is the
199/// absence of an entry in the map entirely, indicating there is neither an
200/// existing tunnel nor an existing connection attempt.
201#[derive(Debug)]
202enum SshTunnelState {
203    /// An existing thread is connecting to the tunnel.
204    ///
205    /// The managing thread will resolve the enclosed future when the connection
206    /// attempt is complete. Only the thread that entered the `Connecting` state
207    /// is allowed to move out of this state.
208    Connecting(watch::Receiver<()>),
209    /// An existing thread has successfully established the tunnel.
210    ///
211    /// Only the last `ManagedSshTunnelHandle` is allowed to move out of this
212    /// state.
213    Connected(Arc<SshTunnelHandle>),
214}
215
216/// A clonable handle to an SSH tunnel managed by an [`SshTunnelManager`].
217///
218/// The tunnel will be automatically closed when all handles are dropped.
219#[derive(Debug, Clone)]
220pub struct ManagedSshTunnelHandle {
221    handle: Arc<SshTunnelHandle>,
222    manager: SshTunnelManager,
223    key: SshTunnelKey,
224}
225
226impl Deref for ManagedSshTunnelHandle {
227    type Target = SshTunnelHandle;
228
229    fn deref(&self) -> &SshTunnelHandle {
230        &self.handle
231    }
232}
233
234impl Drop for ManagedSshTunnelHandle {
235    fn drop(&mut self) {
236        let mut tunnels = self.manager.tunnels.lock().expect("lock poisoned");
237        // If there are only two strong references, the manager holds one and we
238        // hold the other, so this is the last handle.
239        //
240        // IMPORTANT: We must be holding the lock when we perform this check, to
241        // ensure no other threads can acquire a new handle via the manager.
242        if Arc::strong_count(&self.handle) == 2 {
243            tunnels.remove(&self.key);
244        }
245    }
246}