openssh/
stdio.rs

1use super::Error;
2
3#[cfg(feature = "native-mux")]
4use super::native_mux_impl;
5
6use std::fs::File;
7use std::io;
8use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
9use std::pin::Pin;
10use std::process;
11use std::task::{Context, Poll};
12use tokio::{
13    io::{AsyncRead, AsyncWrite, ReadBuf},
14    net::unix::pipe::{Receiver as PipeReader, Sender as PipeWriter},
15};
16
17#[derive(Debug)]
18pub(crate) enum StdioImpl {
19    /// Read/Write to /dev/null
20    Null,
21    /// Read/Write to a newly created pipe
22    Pipe,
23    /// Read/Write to custom fd
24    Fd(OwnedFd),
25    /// Inherit stdin/stdout/stderr
26    Inherit,
27}
28
29/// Describes what to do with a standard I/O stream for a remote child process
30/// when passed to the stdin, stdout, and stderr methods of Command.
31#[derive(Debug)]
32pub struct Stdio(pub(crate) StdioImpl);
33impl Stdio {
34    /// A new pipe should be arranged to connect the parent and remote child processes.
35    pub const fn piped() -> Self {
36        Self(StdioImpl::Pipe)
37    }
38
39    /// This stream will be ignored.
40    /// This is the equivalent of attaching the stream to /dev/null.
41    pub const fn null() -> Self {
42        Self(StdioImpl::Null)
43    }
44
45    /// The child inherits from the corresponding parent descriptor.
46    ///
47    /// NOTE that the stdio fd must be in blocking mode, otherwise
48    /// ssh might not flush all output since it considers
49    /// (`EAGAIN`/`EWOULDBLOCK`) as an error
50    pub const fn inherit() -> Self {
51        Self(StdioImpl::Inherit)
52    }
53
54    /// `Stdio::from_raw_fd_owned` takes ownership of the fd passed in
55    /// and closes the fd on drop.
56    ///
57    /// NOTE that the fd will be put into blocking mode, then it will be
58    /// closed when `Stdio` is dropped.
59    ///
60    /// # Safety
61    ///
62    /// * `fd` - must be a valid fd and must give its ownership to `Stdio`.
63    pub unsafe fn from_raw_fd_owned(fd: RawFd) -> Self {
64        Self(StdioImpl::Fd(OwnedFd::from_raw_fd(fd)))
65    }
66}
67
68impl From<Stdio> for process::Stdio {
69    fn from(stdio: Stdio) -> Self {
70        match stdio.0 {
71            StdioImpl::Null => process::Stdio::null(),
72            StdioImpl::Pipe => process::Stdio::piped(),
73            StdioImpl::Inherit => process::Stdio::inherit(),
74            StdioImpl::Fd(fd) => process::Stdio::from(fd),
75        }
76    }
77}
78
79impl From<OwnedFd> for Stdio {
80    fn from(fd: OwnedFd) -> Self {
81        Self(StdioImpl::Fd(fd))
82    }
83}
84
85macro_rules! impl_from_for_stdio {
86    ($type:ty) => {
87        impl From<$type> for Stdio {
88            fn from(arg: $type) -> Self {
89                Self(StdioImpl::Fd(arg.into()))
90            }
91        }
92    };
93}
94
95macro_rules! impl_try_from_for_stdio {
96    ($type:ty) => {
97        impl TryFrom<$type> for Stdio {
98            type Error = Error;
99            fn try_from(arg: $type) -> Result<Self, Self::Error> {
100                Ok(Self(StdioImpl::Fd(
101                    arg.into_owned_fd().map_err(Error::ChildIo)?,
102                )))
103            }
104        }
105    };
106}
107
108impl_from_for_stdio!(process::ChildStdin);
109impl_from_for_stdio!(process::ChildStdout);
110impl_from_for_stdio!(process::ChildStderr);
111
112impl_try_from_for_stdio!(ChildStdin);
113impl_try_from_for_stdio!(ChildStdout);
114impl_try_from_for_stdio!(ChildStderr);
115
116impl_from_for_stdio!(File);
117
118macro_rules! impl_try_from_tokio_process_child_for_stdio {
119    ($type:ident) => {
120        impl TryFrom<tokio::process::$type> for Stdio {
121            type Error = Error;
122
123            fn try_from(arg: tokio::process::$type) -> Result<Self, Self::Error> {
124                arg.into_owned_fd().map_err(Error::ChildIo).map(Into::into)
125            }
126        }
127    };
128}
129
130impl_try_from_tokio_process_child_for_stdio!(ChildStdin);
131impl_try_from_tokio_process_child_for_stdio!(ChildStdout);
132impl_try_from_tokio_process_child_for_stdio!(ChildStderr);
133
134/// Input for the remote child.
135#[derive(Debug)]
136pub struct ChildStdin(PipeWriter);
137
138/// Stdout for the remote child.
139#[derive(Debug)]
140pub struct ChildStdout(PipeReader);
141
142/// Stderr for the remote child.
143#[derive(Debug)]
144pub struct ChildStderr(PipeReader);
145
146pub(crate) trait TryFromChildIo<T>: Sized {
147    type Error;
148
149    fn try_from(arg: T) -> Result<Self, Self::Error>;
150}
151
152macro_rules! impl_from_impl_child_io {
153    (process, $type:ident, $inner:ty) => {
154        impl TryFromChildIo<tokio::process::$type> for $type {
155            type Error = Error;
156
157            fn try_from(arg: tokio::process::$type) -> Result<Self, Self::Error> {
158                let fd = arg.into_owned_fd().map_err(Error::ChildIo)?;
159
160                <$inner>::from_owned_fd(fd)
161                    .map(Self)
162                    .map_err(Error::ChildIo)
163            }
164        }
165    };
166
167    (native_mux, $type:ident) => {
168        #[cfg(feature = "native-mux")]
169        impl TryFromChildIo<native_mux_impl::$type> for $type {
170            type Error = Error;
171
172            fn try_from(arg: native_mux_impl::$type) -> Result<Self, Self::Error> {
173                Ok(Self(arg))
174            }
175        }
176    };
177}
178
179impl_from_impl_child_io!(process, ChildStdin, PipeWriter);
180impl_from_impl_child_io!(process, ChildStdout, PipeReader);
181impl_from_impl_child_io!(process, ChildStderr, PipeReader);
182
183impl_from_impl_child_io!(native_mux, ChildStdin);
184impl_from_impl_child_io!(native_mux, ChildStdout);
185impl_from_impl_child_io!(native_mux, ChildStderr);
186
187macro_rules! impl_child_stdio {
188    (AsRawFd, $type:ty) => {
189        impl AsRawFd for $type {
190            fn as_raw_fd(&self) -> RawFd {
191                self.0.as_raw_fd()
192            }
193        }
194    };
195
196    (AsFd, $type:ty) => {
197        impl AsFd for $type {
198            fn as_fd(&self) -> BorrowedFd<'_> {
199                self.0.as_fd()
200            }
201        }
202    };
203
204    (into_owned_fd, $type:ty) => {
205        impl $type {
206            /// Convert into an owned fd, it'd be deregisted from tokio and in blocking mode.
207            pub fn into_owned_fd(self) -> io::Result<OwnedFd> {
208                self.0.into_blocking_fd()
209            }
210        }
211    };
212
213    (AsyncRead, $type:ty) => {
214        impl_child_stdio!(AsRawFd, $type);
215        impl_child_stdio!(AsFd, $type);
216        impl_child_stdio!(into_owned_fd, $type);
217
218        impl AsyncRead for $type {
219            fn poll_read(
220                mut self: Pin<&mut Self>,
221                cx: &mut Context<'_>,
222                buf: &mut ReadBuf<'_>,
223            ) -> Poll<io::Result<()>> {
224                Pin::new(&mut self.0).poll_read(cx, buf)
225            }
226        }
227    };
228
229    (AsyncWrite, $type: ty) => {
230        impl_child_stdio!(AsRawFd, $type);
231        impl_child_stdio!(AsFd, $type);
232        impl_child_stdio!(into_owned_fd, $type);
233
234        impl AsyncWrite for $type {
235            fn poll_write(
236                mut self: Pin<&mut Self>,
237                cx: &mut Context<'_>,
238                buf: &[u8],
239            ) -> Poll<io::Result<usize>> {
240                Pin::new(&mut self.0).poll_write(cx, buf)
241            }
242
243            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
244                Pin::new(&mut self.0).poll_flush(cx)
245            }
246
247            fn poll_shutdown(
248                mut self: Pin<&mut Self>,
249                cx: &mut Context<'_>,
250            ) -> Poll<io::Result<()>> {
251                Pin::new(&mut self.0).poll_shutdown(cx)
252            }
253
254            fn poll_write_vectored(
255                mut self: Pin<&mut Self>,
256                cx: &mut Context<'_>,
257                bufs: &[io::IoSlice<'_>],
258            ) -> Poll<io::Result<usize>> {
259                Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
260            }
261
262            fn is_write_vectored(&self) -> bool {
263                self.0.is_write_vectored()
264            }
265        }
266    };
267}
268
269impl_child_stdio!(AsyncWrite, ChildStdin);
270impl_child_stdio!(AsyncRead, ChildStdout);
271impl_child_stdio!(AsyncRead, ChildStderr);