openssh/
builder.rs

1use super::{Error, Session};
2
3use std::borrow::Cow;
4use std::ffi::OsString;
5use std::iter::IntoIterator;
6use std::ops::Deref;
7use std::path::{Path, PathBuf};
8use std::process::Stdio;
9use std::str;
10use std::{fs, io};
11
12use once_cell::sync::OnceCell;
13use tempfile::{Builder, TempDir};
14use tokio::process;
15
16#[cfg(not(windows))]
17fn state_dir() -> Option<PathBuf> {
18    fn get_absolute_path(path: OsString) -> Option<PathBuf> {
19        let path = PathBuf::from(path);
20        path.is_absolute().then_some(path)
21    }
22
23    #[allow(deprecated)]
24    if let Some(xdg) = std::env::var_os("XDG_STATE_HOME") {
25        get_absolute_path(xdg)
26    } else if let Some(home) = std::env::home_dir() {
27        Some(get_absolute_path(home.into())?.join(".local/state"))
28    } else {
29        None
30    }
31}
32
33#[cfg(windows)]
34fn state_dir() -> Option<PathBuf> {
35    None
36}
37
38/// The returned `&'static Path` can be coreced to any lifetime.
39fn get_default_control_dir<'a>() -> Result<&'a Path, Error> {
40    static DEFAULT_CONTROL_DIR: OnceCell<Option<Box<Path>>> = OnceCell::new();
41
42    DEFAULT_CONTROL_DIR
43        .get_or_try_init(|| {
44            if let Some(state_dir) = state_dir() {
45                fs::create_dir_all(&state_dir).map_err(Error::Connect)?;
46
47                Ok(Some(state_dir.into_boxed_path()))
48            } else {
49                Ok(None)
50            }
51        })
52        .map(|default_control_dir| {
53            default_control_dir
54                .as_deref()
55                .unwrap_or_else(|| Path::new("./"))
56        })
57}
58
59fn clean_history_control_dir(socketdir: &Path, prefix: &str) -> io::Result<()> {
60    // Read the entries in the parent directory
61    fs::read_dir(socketdir)?
62        // Filter out and keep only the valid entries
63        .filter_map(Result::ok)
64        // Filter the entries to only include files that start with prefix
65        .filter(|entry| {
66            if let Ok(file_type) = entry.file_type() {
67                file_type.is_dir() && entry.file_name().to_string_lossy().starts_with(prefix)
68            } else {
69                false
70            }
71        })
72        // For each matching entry, remove the directory
73        .for_each(|entry| {
74            let _ = fs::remove_dir_all(entry.path());
75        });
76    Ok(())
77}
78
79/// Build a [`Session`] with options.
80#[derive(Debug, Clone)]
81pub struct SessionBuilder {
82    user: Option<String>,
83    port: Option<String>,
84    keyfile: Option<PathBuf>,
85    connect_timeout: Option<String>,
86    server_alive_interval: Option<u64>,
87    known_hosts_check: KnownHosts,
88    control_dir: Option<PathBuf>,
89    control_persist: ControlPersist,
90    clean_history_control_dir: bool,
91    config_file: Option<PathBuf>,
92    compression: Option<bool>,
93    jump_hosts: Vec<Box<str>>,
94    user_known_hosts_file: Option<Box<Path>>,
95    ssh_auth_sock: Option<Box<Path>>,
96}
97
98impl Default for SessionBuilder {
99    fn default() -> Self {
100        Self {
101            user: None,
102            port: None,
103            keyfile: None,
104            connect_timeout: None,
105            server_alive_interval: None,
106            known_hosts_check: KnownHosts::Add,
107            control_dir: None,
108            control_persist: ControlPersist::Forever,
109            clean_history_control_dir: false,
110            config_file: None,
111            compression: None,
112            jump_hosts: Vec::new(),
113            user_known_hosts_file: None,
114            ssh_auth_sock: None,
115        }
116    }
117}
118
119impl SessionBuilder {
120    /// Return the user set in builder.
121    pub fn get_user(&self) -> Option<&str> {
122        self.user.as_deref()
123    }
124
125    /// Return the port set in builder.
126    pub fn get_port(&self) -> Option<&str> {
127        self.port.as_deref()
128    }
129
130    /// Set the ssh user (`ssh -l`).
131    ///
132    /// Defaults to `None`.
133    pub fn user(&mut self, user: String) -> &mut Self {
134        self.user = Some(user);
135        self
136    }
137
138    /// Set the port to connect on (`ssh -p`).
139    ///
140    /// Defaults to `None`.
141    pub fn port(&mut self, port: u16) -> &mut Self {
142        self.port = Some(format!("{}", port));
143        self
144    }
145
146    /// Set the keyfile to use (`ssh -i`).
147    ///
148    /// Defaults to `None`.
149    pub fn keyfile(&mut self, p: impl AsRef<Path>) -> &mut Self {
150        self.keyfile = Some(p.as_ref().to_path_buf());
151        self
152    }
153
154    /// See [`KnownHosts`].
155    ///
156    /// Default `KnownHosts::Add`.
157    pub fn known_hosts_check(&mut self, k: KnownHosts) -> &mut Self {
158        self.known_hosts_check = k;
159        self
160    }
161
162    /// Set the connection timeout (`ssh -o ConnectTimeout`).
163    ///
164    /// This value is specified in seconds. Any sub-second duration remainder will be ignored.
165    /// Defaults to `None`.
166    pub fn connect_timeout(&mut self, d: std::time::Duration) -> &mut Self {
167        self.connect_timeout = Some(d.as_secs().to_string());
168        self
169    }
170
171    /// Set the timeout interval after which if no data has been received from the server, ssh
172    /// will request a response from the server (`ssh -o ServerAliveInterval`).
173    ///
174    /// This value is specified in seconds. Any sub-second duration remainder will be ignored.
175    /// Defaults to `None`.
176    pub fn server_alive_interval(&mut self, d: std::time::Duration) -> &mut Self {
177        self.server_alive_interval = Some(d.as_secs());
178        self
179    }
180
181    /// Set the directory in which the temporary directory containing the control socket will
182    /// be created.
183    ///
184    /// If not set, openssh will try to use `$XDG_STATE_HOME`, `$HOME/.local/state` on unix, and fallback to
185    /// `./` (the current directory) if it failed.
186    ///
187    #[cfg(not(windows))]
188    #[cfg_attr(docsrs, doc(cfg(not(windows))))]
189    pub fn control_directory(&mut self, p: impl AsRef<Path>) -> &mut Self {
190        self.control_dir = Some(p.as_ref().to_path_buf());
191        self
192    }
193
194    /// Clean up the temporary directories with the `.ssh-connection` prefix
195    /// in directory specified by [`SessionBuilder::control_directory`], created by
196    /// previous `openssh::Session` that is not cleaned up for some reasons
197    /// (e.g. process getting killed, abort on panic, etc)
198    ///
199    /// Use this with caution, do not enable this if you don't understand
200    /// what it does,
201    #[cfg(not(windows))]
202    #[cfg_attr(docsrs, doc(cfg(not(windows))))]
203    pub fn clean_history_control_directory(&mut self, clean: bool) -> &mut Self {
204        self.clean_history_control_dir = clean;
205        self
206    }
207
208    /// Set the ControlPersist option to configure how long the controlling
209    /// ssh session should stay alive.
210    ///
211    /// Defaults to `ControlPersist::Forever`.
212    ///
213    pub fn control_persist(&mut self, value: ControlPersist) -> &mut Self {
214        self.control_persist = value;
215        self
216    }
217
218    /// Set an alternative per-user configuration file.
219    ///
220    /// By default, ssh uses `~/.ssh/config`. This is equivalent to `ssh -F <p>`.
221    ///
222    /// Defaults to `None`.
223    pub fn config_file(&mut self, p: impl AsRef<Path>) -> &mut Self {
224        self.config_file = Some(p.as_ref().to_path_buf());
225        self
226    }
227
228    /// Enable or disable compression (including stdin, stdout, stderr, data
229    /// for forwarded TCP and unix-domain connections, sftp and scp
230    /// connections).
231    ///
232    /// Note that the ssh server can forcibly disable the compression.
233    ///
234    /// By default, ssh uses configure value set in `~/.ssh/config`.
235    ///
236    /// If `~/.ssh/config` does not enable compression, then it is disabled
237    /// by default.
238    pub fn compression(&mut self, compression: bool) -> &mut Self {
239        self.compression = Some(compression);
240        self
241    }
242
243    /// Specify one or multiple jump hosts.
244    ///
245    /// Connect to the target host by first making a ssh connection to the
246    /// jump host described by destination and then establishing a TCP
247    /// forwarding to the ultimate destination from there.
248    ///
249    /// Multiple jump hops may be specified.
250    /// This is a shortcut to specify a ProxyJump configuration directive.
251    ///
252    /// Note that configuration directives specified by [`SessionBuilder`]
253    /// do not apply to the jump hosts.
254    ///
255    /// Use ~/.ssh/config to specify configuration for jump hosts.
256    pub fn jump_hosts<T: AsRef<str>>(&mut self, hosts: impl IntoIterator<Item = T>) -> &mut Self {
257        self.jump_hosts = hosts
258            .into_iter()
259            .map(|s| s.as_ref().to_string().into_boxed_str())
260            .collect();
261        self
262    }
263
264    /// Specify the path to the `known_hosts` file.
265    ///
266    /// The path provided may use tilde notation (`~`) to refer to the user's
267    /// home directory.
268    ///
269    /// The default is `~/.ssh/known_hosts` and `~/.ssh/known_hosts2`.
270    pub fn user_known_hosts_file(&mut self, user_known_hosts_file: impl AsRef<Path>) -> &mut Self {
271        self.user_known_hosts_file =
272            Some(user_known_hosts_file.as_ref().to_owned().into_boxed_path());
273        self
274    }
275
276    /// Specify the path to the ssh-agent.
277    ///
278    /// The path provided may use tilde notation (`~`) to refer to the user's
279    /// home directory.
280    ///
281    /// The default is `None`.
282    pub fn ssh_auth_sock(&mut self, ssh_auth_sock: impl AsRef<Path>) -> &mut Self {
283        self.ssh_auth_sock = Some(ssh_auth_sock.as_ref().to_owned().into_boxed_path());
284        self
285    }
286
287    /// Connect to the host at the given `host` over SSH using process impl, which will
288    /// spawn a new ssh process for each `Child` created.
289    ///
290    /// The format of `destination` is the same as the `destination` argument to `ssh`. It may be
291    /// specified as either `[user@]hostname` or a URI of the form `ssh://[user@]hostname[:port]`.
292    /// A username or port that is specified in the connection string overrides the one set in the
293    /// builder (but does not change the builder).
294    ///
295    /// If connecting requires interactive authentication based on `STDIN` (such as reading a
296    /// password), the connection will fail. Consider setting up keypair-based authentication
297    /// instead.
298    #[cfg(feature = "process-mux")]
299    #[cfg_attr(docsrs, doc(cfg(feature = "process-mux")))]
300    pub async fn connect<S: AsRef<str>>(&self, destination: S) -> Result<Session, Error> {
301        self.connect_impl(destination.as_ref(), Session::new_process_mux)
302            .await
303    }
304
305    /// Connect to the host at the given `host` over SSH using native mux, which will
306    /// create a new local socket connection for each `Child` created.
307    ///
308    /// See the crate-level documentation for more details on the difference between native and process-based mux.
309    ///
310    /// The format of `destination` is the same as the `destination` argument to `ssh`. It may be
311    /// specified as either `[user@]hostname` or a URI of the form `ssh://[user@]hostname[:port]`.
312    /// A username or port that is specified in the connection string overrides the one set in the
313    /// builder (but does not change the builder).
314    ///
315    /// If connecting requires interactive authentication based on `STDIN` (such as reading a
316    /// password), the connection will fail. Consider setting up keypair-based authentication
317    /// instead.
318    #[cfg(feature = "native-mux")]
319    #[cfg_attr(docsrs, doc(cfg(feature = "native-mux")))]
320    pub async fn connect_mux<S: AsRef<str>>(&self, destination: S) -> Result<Session, Error> {
321        self.connect_impl(destination.as_ref(), Session::new_native_mux)
322            .await
323    }
324
325    async fn connect_impl(
326        &self,
327        destination: &str,
328        f: fn(TempDir) -> Session,
329    ) -> Result<Session, Error> {
330        let (builder, destination) = self.resolve(destination);
331        let tempdir = builder.launch_master(destination).await?;
332        Ok(f(tempdir))
333    }
334
335    /// [`SessionBuilder`] support for `destination` parsing.
336    /// The format of `destination` is the same as the `destination` argument to `ssh`.
337    ///
338    /// # Examples
339    ///
340    /// ```rust
341    /// use openssh::SessionBuilder;
342    /// let b = SessionBuilder::default();
343    /// let (b, d) = b.resolve("ssh://test-user@127.0.0.1:2222");
344    /// assert_eq!(b.get_port().as_deref(), Some("2222"));
345    /// assert_eq!(b.get_user().as_deref(), Some("test-user"));
346    /// assert_eq!(d, "127.0.0.1");
347    /// ```
348    pub fn resolve<'a, 'b>(&'a self, mut destination: &'b str) -> (Cow<'a, Self>, &'b str) {
349        // the "new" ssh://user@host:port form is not supported by all versions of ssh,
350        // so we always translate it into the option form.
351        let mut user = None;
352        let mut port = None;
353        if destination.starts_with("ssh://") {
354            destination = &destination[6..];
355            if let Some(at) = destination.rfind('@') {
356                // specified a username -- extract it:
357                user = Some(&destination[..at]);
358                destination = &destination[(at + 1)..];
359            }
360            if let Some(colon) = destination.rfind(':') {
361                let p = &destination[(colon + 1)..];
362                if let Ok(p) = p.parse() {
363                    // user specified a port -- extract it:
364                    port = Some(p);
365                    destination = &destination[..colon];
366                }
367            }
368        }
369
370        if user.is_none() && port.is_none() {
371            return (Cow::Borrowed(self), destination);
372        }
373
374        let mut with_overrides = self.clone();
375        if let Some(user) = user {
376            with_overrides.user(user.to_owned());
377        }
378
379        if let Some(port) = port {
380            with_overrides.port(port);
381        }
382
383        (Cow::Owned(with_overrides), destination)
384    }
385
386    /// Create ssh master session and return [`TempDir`] which
387    /// contains the ssh control socket.
388    pub async fn launch_master(&self, destination: &str) -> Result<TempDir, Error> {
389        let socketdir = if let Some(socketdir) = self.control_dir.as_ref() {
390            socketdir
391        } else {
392            get_default_control_dir()?
393        };
394
395        let prefix = ".ssh-connection";
396
397        if self.clean_history_control_dir {
398            let _ = clean_history_control_dir(socketdir, prefix);
399        }
400
401        let dir = Builder::new()
402            .prefix(prefix)
403            .tempdir_in(socketdir)
404            .map_err(Error::Master)?;
405
406        let log = dir.path().join("log");
407
408        let mut init = process::Command::new("ssh");
409
410        init.stdin(Stdio::null())
411            .stdout(Stdio::null())
412            .stderr(Stdio::null())
413            .arg("-E")
414            .arg(&log)
415            .arg("-S")
416            .arg(dir.path().join("master"))
417            .arg("-M")
418            .arg("-f")
419            .arg("-N")
420            .arg("-o")
421            .arg(self.control_persist.as_option().deref())
422            .arg("-o")
423            .arg("BatchMode=yes")
424            .arg("-o")
425            .arg(self.known_hosts_check.as_option());
426
427        if let Some(ref timeout) = self.connect_timeout {
428            init.arg("-o").arg(format!("ConnectTimeout={}", timeout));
429        }
430
431        if let Some(ref interval) = self.server_alive_interval {
432            init.arg("-o")
433                .arg(format!("ServerAliveInterval={}", interval));
434        }
435
436        if let Some(ref port) = self.port {
437            init.arg("-p").arg(port);
438        }
439
440        if let Some(ref user) = self.user {
441            init.arg("-l").arg(user);
442        }
443
444        if let Some(ref k) = self.keyfile {
445            // if the user gives a keyfile, _only_ use that keyfile
446            init.arg("-o").arg("IdentitiesOnly=yes");
447            init.arg("-i").arg(k);
448        }
449
450        if let Some(ref config_file) = self.config_file {
451            init.arg("-F").arg(config_file);
452        }
453
454        if let Some(compression) = self.compression {
455            let arg = if compression { "yes" } else { "no" };
456
457            init.arg("-o").arg(format!("Compression={}", arg));
458        }
459
460        if let Some(ssh_auth_sock) = self.ssh_auth_sock.as_deref() {
461            init.env("SSH_AUTH_SOCK", ssh_auth_sock);
462        }
463
464        let mut it = self.jump_hosts.iter();
465
466        if let Some(jump_host) = it.next() {
467            let s = jump_host.to_string();
468
469            let dest = it.fold(s, |mut s, jump_host| {
470                s.push(',');
471                s.push_str(jump_host);
472                s
473            });
474
475            init.arg("-J").arg(&dest);
476        }
477
478        if let Some(user_known_hosts_file) = &self.user_known_hosts_file {
479            let mut option: OsString = "UserKnownHostsFile=".into();
480            option.push(&**user_known_hosts_file);
481            init.arg("-o").arg(option);
482        }
483
484        init.arg(destination);
485
486        // we spawn and immediately wait, because the process is supposed to fork.
487        let status = init.status().await.map_err(Error::Connect)?;
488
489        if !status.success() {
490            let output = fs::read_to_string(log).map_err(Error::Connect)?;
491
492            Err(Error::interpret_ssh_error(&output))
493        } else {
494            Ok(dir)
495        }
496    }
497}
498
499/// Specifies how long the controlling ssh process should stay alive.
500#[derive(Clone, Debug, Default)]
501#[non_exhaustive]
502pub enum ControlPersist {
503    /// Will stay alive indefinitely.
504    #[default]
505    Forever,
506    /// Will be closed after the initial connection is closed
507    ClosedAfterInitialConnection,
508    /// If the ssh control server has been idle for specified duration
509    /// (in seconds), it will exit.
510    IdleFor(std::num::NonZeroUsize),
511}
512
513impl ControlPersist {
514    fn as_option(&self) -> Cow<'_, str> {
515        match self {
516            ControlPersist::Forever => Cow::Borrowed("ControlPersist=yes"),
517            ControlPersist::ClosedAfterInitialConnection => Cow::Borrowed("ControlPersist=no"),
518            ControlPersist::IdleFor(d) => Cow::Owned(format!("ControlPersist={}s", d.get())),
519        }
520    }
521}
522
523/// Specifies how the host's key fingerprint should be handled.
524#[derive(Debug, Clone)]
525pub enum KnownHosts {
526    /// The host's fingerprint must match what is in the known hosts file.
527    ///
528    /// If the host is not in the known hosts file, the connection is rejected.
529    ///
530    /// This corresponds to `ssh -o StrictHostKeyChecking=yes`.
531    Strict,
532    /// Strict, but if the host is not already in the known hosts file, it will be added.
533    ///
534    /// This corresponds to `ssh -o StrictHostKeyChecking=accept-new`.
535    Add,
536    /// Accept whatever key the server provides and add it to the known hosts file.
537    ///
538    /// This corresponds to `ssh -o StrictHostKeyChecking=no`.
539    Accept,
540}
541
542impl KnownHosts {
543    fn as_option(&self) -> &'static str {
544        match *self {
545            KnownHosts::Strict => "StrictHostKeyChecking=yes",
546            KnownHosts::Add => "StrictHostKeyChecking=accept-new",
547            KnownHosts::Accept => "StrictHostKeyChecking=no",
548        }
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::SessionBuilder;
555
556    #[test]
557    fn resolve() {
558        let b = SessionBuilder::default();
559        let (b, d) = b.resolve("ssh://test-user@127.0.0.1:2222");
560        assert_eq!(b.port.as_deref(), Some("2222"));
561        assert_eq!(b.user.as_deref(), Some("test-user"));
562        assert_eq!(d, "127.0.0.1");
563
564        let b = SessionBuilder::default();
565        let (b, d) = b.resolve("ssh://test-user@opensshtest:2222");
566        assert_eq!(b.port.as_deref(), Some("2222"));
567        assert_eq!(b.user.as_deref(), Some("test-user"));
568        assert_eq!(d, "opensshtest");
569
570        let b = SessionBuilder::default();
571        let (b, d) = b.resolve("ssh://opensshtest:2222");
572        assert_eq!(b.port.as_deref(), Some("2222"));
573        assert_eq!(b.user.as_deref(), None);
574        assert_eq!(d, "opensshtest");
575
576        let b = SessionBuilder::default();
577        let (b, d) = b.resolve("ssh://test-user@opensshtest");
578        assert_eq!(b.port.as_deref(), None);
579        assert_eq!(b.user.as_deref(), Some("test-user"));
580        assert_eq!(d, "opensshtest");
581
582        let b = SessionBuilder::default();
583        let (b, d) = b.resolve("ssh://opensshtest");
584        assert_eq!(b.port.as_deref(), None);
585        assert_eq!(b.user.as_deref(), None);
586        assert_eq!(d, "opensshtest");
587
588        let b = SessionBuilder::default();
589        let (b, d) = b.resolve("opensshtest");
590        assert_eq!(b.port.as_deref(), None);
591        assert_eq!(b.user.as_deref(), None);
592        assert_eq!(d, "opensshtest");
593    }
594}