1use super::{Error, Session};
2
3use std::borrow::Cow;
4use std::ffi::OsString;
5use std::iter::IntoIterator;
6use std::ops::Deref;
7use std::path::{Path, PathBuf};
8use std::process::Stdio;
9use std::str;
10use std::{fs, io};
11
12use once_cell::sync::OnceCell;
13use tempfile::{Builder, TempDir};
14use tokio::process;
15
16#[cfg(not(windows))]
17fn state_dir() -> Option<PathBuf> {
18 fn get_absolute_path(path: OsString) -> Option<PathBuf> {
19 let path = PathBuf::from(path);
20 path.is_absolute().then_some(path)
21 }
22
23 #[allow(deprecated)]
24 if let Some(xdg) = std::env::var_os("XDG_STATE_HOME") {
25 get_absolute_path(xdg)
26 } else if let Some(home) = std::env::home_dir() {
27 Some(get_absolute_path(home.into())?.join(".local/state"))
28 } else {
29 None
30 }
31}
32
33#[cfg(windows)]
34fn state_dir() -> Option<PathBuf> {
35 None
36}
37
38fn get_default_control_dir<'a>() -> Result<&'a Path, Error> {
40 static DEFAULT_CONTROL_DIR: OnceCell<Option<Box<Path>>> = OnceCell::new();
41
42 DEFAULT_CONTROL_DIR
43 .get_or_try_init(|| {
44 if let Some(state_dir) = state_dir() {
45 fs::create_dir_all(&state_dir).map_err(Error::Connect)?;
46
47 Ok(Some(state_dir.into_boxed_path()))
48 } else {
49 Ok(None)
50 }
51 })
52 .map(|default_control_dir| {
53 default_control_dir
54 .as_deref()
55 .unwrap_or_else(|| Path::new("./"))
56 })
57}
58
59fn clean_history_control_dir(socketdir: &Path, prefix: &str) -> io::Result<()> {
60 fs::read_dir(socketdir)?
62 .filter_map(Result::ok)
64 .filter(|entry| {
66 if let Ok(file_type) = entry.file_type() {
67 file_type.is_dir() && entry.file_name().to_string_lossy().starts_with(prefix)
68 } else {
69 false
70 }
71 })
72 .for_each(|entry| {
74 let _ = fs::remove_dir_all(entry.path());
75 });
76 Ok(())
77}
78
79#[derive(Debug, Clone)]
81pub struct SessionBuilder {
82 user: Option<String>,
83 port: Option<String>,
84 keyfile: Option<PathBuf>,
85 connect_timeout: Option<String>,
86 server_alive_interval: Option<u64>,
87 known_hosts_check: KnownHosts,
88 control_dir: Option<PathBuf>,
89 control_persist: ControlPersist,
90 clean_history_control_dir: bool,
91 config_file: Option<PathBuf>,
92 compression: Option<bool>,
93 jump_hosts: Vec<Box<str>>,
94 user_known_hosts_file: Option<Box<Path>>,
95 ssh_auth_sock: Option<Box<Path>>,
96}
97
98impl Default for SessionBuilder {
99 fn default() -> Self {
100 Self {
101 user: None,
102 port: None,
103 keyfile: None,
104 connect_timeout: None,
105 server_alive_interval: None,
106 known_hosts_check: KnownHosts::Add,
107 control_dir: None,
108 control_persist: ControlPersist::Forever,
109 clean_history_control_dir: false,
110 config_file: None,
111 compression: None,
112 jump_hosts: Vec::new(),
113 user_known_hosts_file: None,
114 ssh_auth_sock: None,
115 }
116 }
117}
118
119impl SessionBuilder {
120 pub fn get_user(&self) -> Option<&str> {
122 self.user.as_deref()
123 }
124
125 pub fn get_port(&self) -> Option<&str> {
127 self.port.as_deref()
128 }
129
130 pub fn user(&mut self, user: String) -> &mut Self {
134 self.user = Some(user);
135 self
136 }
137
138 pub fn port(&mut self, port: u16) -> &mut Self {
142 self.port = Some(format!("{}", port));
143 self
144 }
145
146 pub fn keyfile(&mut self, p: impl AsRef<Path>) -> &mut Self {
150 self.keyfile = Some(p.as_ref().to_path_buf());
151 self
152 }
153
154 pub fn known_hosts_check(&mut self, k: KnownHosts) -> &mut Self {
158 self.known_hosts_check = k;
159 self
160 }
161
162 pub fn connect_timeout(&mut self, d: std::time::Duration) -> &mut Self {
167 self.connect_timeout = Some(d.as_secs().to_string());
168 self
169 }
170
171 pub fn server_alive_interval(&mut self, d: std::time::Duration) -> &mut Self {
177 self.server_alive_interval = Some(d.as_secs());
178 self
179 }
180
181 #[cfg(not(windows))]
188 #[cfg_attr(docsrs, doc(cfg(not(windows))))]
189 pub fn control_directory(&mut self, p: impl AsRef<Path>) -> &mut Self {
190 self.control_dir = Some(p.as_ref().to_path_buf());
191 self
192 }
193
194 #[cfg(not(windows))]
202 #[cfg_attr(docsrs, doc(cfg(not(windows))))]
203 pub fn clean_history_control_directory(&mut self, clean: bool) -> &mut Self {
204 self.clean_history_control_dir = clean;
205 self
206 }
207
208 pub fn control_persist(&mut self, value: ControlPersist) -> &mut Self {
214 self.control_persist = value;
215 self
216 }
217
218 pub fn config_file(&mut self, p: impl AsRef<Path>) -> &mut Self {
224 self.config_file = Some(p.as_ref().to_path_buf());
225 self
226 }
227
228 pub fn compression(&mut self, compression: bool) -> &mut Self {
239 self.compression = Some(compression);
240 self
241 }
242
243 pub fn jump_hosts<T: AsRef<str>>(&mut self, hosts: impl IntoIterator<Item = T>) -> &mut Self {
257 self.jump_hosts = hosts
258 .into_iter()
259 .map(|s| s.as_ref().to_string().into_boxed_str())
260 .collect();
261 self
262 }
263
264 pub fn user_known_hosts_file(&mut self, user_known_hosts_file: impl AsRef<Path>) -> &mut Self {
271 self.user_known_hosts_file =
272 Some(user_known_hosts_file.as_ref().to_owned().into_boxed_path());
273 self
274 }
275
276 pub fn ssh_auth_sock(&mut self, ssh_auth_sock: impl AsRef<Path>) -> &mut Self {
283 self.ssh_auth_sock = Some(ssh_auth_sock.as_ref().to_owned().into_boxed_path());
284 self
285 }
286
287 #[cfg(feature = "process-mux")]
299 #[cfg_attr(docsrs, doc(cfg(feature = "process-mux")))]
300 pub async fn connect<S: AsRef<str>>(&self, destination: S) -> Result<Session, Error> {
301 self.connect_impl(destination.as_ref(), Session::new_process_mux)
302 .await
303 }
304
305 #[cfg(feature = "native-mux")]
319 #[cfg_attr(docsrs, doc(cfg(feature = "native-mux")))]
320 pub async fn connect_mux<S: AsRef<str>>(&self, destination: S) -> Result<Session, Error> {
321 self.connect_impl(destination.as_ref(), Session::new_native_mux)
322 .await
323 }
324
325 async fn connect_impl(
326 &self,
327 destination: &str,
328 f: fn(TempDir) -> Session,
329 ) -> Result<Session, Error> {
330 let (builder, destination) = self.resolve(destination);
331 let tempdir = builder.launch_master(destination).await?;
332 Ok(f(tempdir))
333 }
334
335 pub fn resolve<'a, 'b>(&'a self, mut destination: &'b str) -> (Cow<'a, Self>, &'b str) {
349 let mut user = None;
352 let mut port = None;
353 if destination.starts_with("ssh://") {
354 destination = &destination[6..];
355 if let Some(at) = destination.rfind('@') {
356 user = Some(&destination[..at]);
358 destination = &destination[(at + 1)..];
359 }
360 if let Some(colon) = destination.rfind(':') {
361 let p = &destination[(colon + 1)..];
362 if let Ok(p) = p.parse() {
363 port = Some(p);
365 destination = &destination[..colon];
366 }
367 }
368 }
369
370 if user.is_none() && port.is_none() {
371 return (Cow::Borrowed(self), destination);
372 }
373
374 let mut with_overrides = self.clone();
375 if let Some(user) = user {
376 with_overrides.user(user.to_owned());
377 }
378
379 if let Some(port) = port {
380 with_overrides.port(port);
381 }
382
383 (Cow::Owned(with_overrides), destination)
384 }
385
386 pub async fn launch_master(&self, destination: &str) -> Result<TempDir, Error> {
389 let socketdir = if let Some(socketdir) = self.control_dir.as_ref() {
390 socketdir
391 } else {
392 get_default_control_dir()?
393 };
394
395 let prefix = ".ssh-connection";
396
397 if self.clean_history_control_dir {
398 let _ = clean_history_control_dir(socketdir, prefix);
399 }
400
401 let dir = Builder::new()
402 .prefix(prefix)
403 .tempdir_in(socketdir)
404 .map_err(Error::Master)?;
405
406 let log = dir.path().join("log");
407
408 let mut init = process::Command::new("ssh");
409
410 init.stdin(Stdio::null())
411 .stdout(Stdio::null())
412 .stderr(Stdio::null())
413 .arg("-E")
414 .arg(&log)
415 .arg("-S")
416 .arg(dir.path().join("master"))
417 .arg("-M")
418 .arg("-f")
419 .arg("-N")
420 .arg("-o")
421 .arg(self.control_persist.as_option().deref())
422 .arg("-o")
423 .arg("BatchMode=yes")
424 .arg("-o")
425 .arg(self.known_hosts_check.as_option());
426
427 if let Some(ref timeout) = self.connect_timeout {
428 init.arg("-o").arg(format!("ConnectTimeout={}", timeout));
429 }
430
431 if let Some(ref interval) = self.server_alive_interval {
432 init.arg("-o")
433 .arg(format!("ServerAliveInterval={}", interval));
434 }
435
436 if let Some(ref port) = self.port {
437 init.arg("-p").arg(port);
438 }
439
440 if let Some(ref user) = self.user {
441 init.arg("-l").arg(user);
442 }
443
444 if let Some(ref k) = self.keyfile {
445 init.arg("-o").arg("IdentitiesOnly=yes");
447 init.arg("-i").arg(k);
448 }
449
450 if let Some(ref config_file) = self.config_file {
451 init.arg("-F").arg(config_file);
452 }
453
454 if let Some(compression) = self.compression {
455 let arg = if compression { "yes" } else { "no" };
456
457 init.arg("-o").arg(format!("Compression={}", arg));
458 }
459
460 if let Some(ssh_auth_sock) = self.ssh_auth_sock.as_deref() {
461 init.env("SSH_AUTH_SOCK", ssh_auth_sock);
462 }
463
464 let mut it = self.jump_hosts.iter();
465
466 if let Some(jump_host) = it.next() {
467 let s = jump_host.to_string();
468
469 let dest = it.fold(s, |mut s, jump_host| {
470 s.push(',');
471 s.push_str(jump_host);
472 s
473 });
474
475 init.arg("-J").arg(&dest);
476 }
477
478 if let Some(user_known_hosts_file) = &self.user_known_hosts_file {
479 let mut option: OsString = "UserKnownHostsFile=".into();
480 option.push(&**user_known_hosts_file);
481 init.arg("-o").arg(option);
482 }
483
484 init.arg(destination);
485
486 let status = init.status().await.map_err(Error::Connect)?;
488
489 if !status.success() {
490 let output = fs::read_to_string(log).map_err(Error::Connect)?;
491
492 Err(Error::interpret_ssh_error(&output))
493 } else {
494 Ok(dir)
495 }
496 }
497}
498
499#[derive(Clone, Debug, Default)]
501#[non_exhaustive]
502pub enum ControlPersist {
503 #[default]
505 Forever,
506 ClosedAfterInitialConnection,
508 IdleFor(std::num::NonZeroUsize),
511}
512
513impl ControlPersist {
514 fn as_option(&self) -> Cow<'_, str> {
515 match self {
516 ControlPersist::Forever => Cow::Borrowed("ControlPersist=yes"),
517 ControlPersist::ClosedAfterInitialConnection => Cow::Borrowed("ControlPersist=no"),
518 ControlPersist::IdleFor(d) => Cow::Owned(format!("ControlPersist={}s", d.get())),
519 }
520 }
521}
522
523#[derive(Debug, Clone)]
525pub enum KnownHosts {
526 Strict,
532 Add,
536 Accept,
540}
541
542impl KnownHosts {
543 fn as_option(&self) -> &'static str {
544 match *self {
545 KnownHosts::Strict => "StrictHostKeyChecking=yes",
546 KnownHosts::Add => "StrictHostKeyChecking=accept-new",
547 KnownHosts::Accept => "StrictHostKeyChecking=no",
548 }
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::SessionBuilder;
555
556 #[test]
557 fn resolve() {
558 let b = SessionBuilder::default();
559 let (b, d) = b.resolve("ssh://test-user@127.0.0.1:2222");
560 assert_eq!(b.port.as_deref(), Some("2222"));
561 assert_eq!(b.user.as_deref(), Some("test-user"));
562 assert_eq!(d, "127.0.0.1");
563
564 let b = SessionBuilder::default();
565 let (b, d) = b.resolve("ssh://test-user@opensshtest:2222");
566 assert_eq!(b.port.as_deref(), Some("2222"));
567 assert_eq!(b.user.as_deref(), Some("test-user"));
568 assert_eq!(d, "opensshtest");
569
570 let b = SessionBuilder::default();
571 let (b, d) = b.resolve("ssh://opensshtest:2222");
572 assert_eq!(b.port.as_deref(), Some("2222"));
573 assert_eq!(b.user.as_deref(), None);
574 assert_eq!(d, "opensshtest");
575
576 let b = SessionBuilder::default();
577 let (b, d) = b.resolve("ssh://test-user@opensshtest");
578 assert_eq!(b.port.as_deref(), None);
579 assert_eq!(b.user.as_deref(), Some("test-user"));
580 assert_eq!(d, "opensshtest");
581
582 let b = SessionBuilder::default();
583 let (b, d) = b.resolve("ssh://opensshtest");
584 assert_eq!(b.port.as_deref(), None);
585 assert_eq!(b.user.as_deref(), None);
586 assert_eq!(d, "opensshtest");
587
588 let b = SessionBuilder::default();
589 let (b, d) = b.resolve("opensshtest");
590 assert_eq!(b.port.as_deref(), None);
591 assert_eq!(b.user.as_deref(), None);
592 assert_eq!(d, "opensshtest");
593 }
594}