tokio_openssl/
lib.rs

1//! Async TLS streams backed by OpenSSL.
2//!
3//! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`tokio`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6#![warn(missing_docs)]
7
8use openssl::error::ErrorStack;
9use openssl::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef};
10use std::fmt;
11use std::future;
12use std::io::{self, Read, Write};
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17#[cfg(test)]
18mod test;
19
20struct StreamWrapper<S> {
21    stream: S,
22    context: usize,
23}
24
25impl<S> fmt::Debug for StreamWrapper<S>
26where
27    S: fmt::Debug,
28{
29    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
30        fmt::Debug::fmt(&self.stream, fmt)
31    }
32}
33
34impl<S> StreamWrapper<S> {
35    /// # Safety
36    ///
37    /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
38    /// wrapper must be pinned in memory.
39    unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
40        debug_assert_ne!(self.context, 0);
41        let stream = Pin::new_unchecked(&mut self.stream);
42        let context = &mut *(self.context as *mut _);
43        (stream, context)
44    }
45}
46
47impl<S> Read for StreamWrapper<S>
48where
49    S: AsyncRead,
50{
51    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
52        let (stream, cx) = unsafe { self.parts() };
53        let mut buf = ReadBuf::new(buf);
54        match stream.poll_read(cx, &mut buf)? {
55            Poll::Ready(()) => Ok(buf.filled().len()),
56            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
57        }
58    }
59}
60
61impl<S> Write for StreamWrapper<S>
62where
63    S: AsyncWrite,
64{
65    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
66        let (stream, cx) = unsafe { self.parts() };
67        match stream.poll_write(cx, buf) {
68            Poll::Ready(r) => r,
69            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
70        }
71    }
72
73    fn flush(&mut self) -> io::Result<()> {
74        let (stream, cx) = unsafe { self.parts() };
75        match stream.poll_flush(cx) {
76            Poll::Ready(r) => r,
77            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
78        }
79    }
80}
81
82fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
83    match r {
84        Ok(v) => Poll::Ready(Ok(v)),
85        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
86        Err(e) => Poll::Ready(Err(e)),
87    }
88}
89
90fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
91    match r {
92        Ok(v) => Poll::Ready(Ok(v)),
93        Err(e) => match e.code() {
94            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
95            _ => Poll::Ready(Err(e)),
96        },
97    }
98}
99
100/// An asynchronous version of [`openssl::ssl::SslStream`].
101#[derive(Debug)]
102pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
103
104impl<S> SslStream<S>
105where
106    S: AsyncRead + AsyncWrite,
107{
108    /// Like [`SslStream::new`](ssl::SslStream::new).
109    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
110        ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
111    }
112
113    /// Like [`SslStream::connect`](ssl::SslStream::connect).
114    pub fn poll_connect(
115        self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117    ) -> Poll<Result<(), ssl::Error>> {
118        self.with_context(cx, |s| cvt_ossl(s.connect()))
119    }
120
121    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
122    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
123        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
124    }
125
126    /// Like [`SslStream::accept`](ssl::SslStream::accept).
127    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
128        self.with_context(cx, |s| cvt_ossl(s.accept()))
129    }
130
131    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
132    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
133        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
134    }
135
136    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
137    pub fn poll_do_handshake(
138        self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140    ) -> Poll<Result<(), ssl::Error>> {
141        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
142    }
143
144    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
145    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
146        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
147    }
148
149    /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek).
150    pub fn poll_peek(
151        self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153        buf: &mut [u8],
154    ) -> Poll<Result<usize, ssl::Error>> {
155        self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
156    }
157
158    /// A convenience method wrapping [`poll_peek`](Self::poll_peek).
159    pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
160        future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
161    }
162
163    /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
164    #[cfg(ossl111)]
165    pub fn poll_read_early_data(
166        self: Pin<&mut Self>,
167        cx: &mut Context<'_>,
168        buf: &mut [u8],
169    ) -> Poll<Result<usize, ssl::Error>> {
170        self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
171    }
172
173    /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
174    #[cfg(ossl111)]
175    pub async fn read_early_data(
176        mut self: Pin<&mut Self>,
177        buf: &mut [u8],
178    ) -> Result<usize, ssl::Error> {
179        future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
180    }
181
182    /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
183    #[cfg(ossl111)]
184    pub fn poll_write_early_data(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &[u8],
188    ) -> Poll<Result<usize, ssl::Error>> {
189        self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
190    }
191
192    /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
193    #[cfg(ossl111)]
194    pub async fn write_early_data(
195        mut self: Pin<&mut Self>,
196        buf: &[u8],
197    ) -> Result<usize, ssl::Error> {
198        future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
199    }
200}
201
202impl<S> SslStream<S> {
203    /// Returns a shared reference to the `Ssl` object associated with this stream.
204    pub fn ssl(&self) -> &SslRef {
205        self.0.ssl()
206    }
207
208    /// Returns a shared reference to the underlying stream.
209    pub fn get_ref(&self) -> &S {
210        &self.0.get_ref().stream
211    }
212
213    /// Returns a mutable reference to the underlying stream.
214    pub fn get_mut(&mut self) -> &mut S {
215        &mut self.0.get_mut().stream
216    }
217
218    /// Returns a pinned mutable reference to the underlying stream.
219    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
220        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
221    }
222
223    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
224    where
225        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
226    {
227        let this = unsafe { self.get_unchecked_mut() };
228        this.0.get_mut().context = ctx as *mut _ as usize;
229        let r = f(&mut this.0);
230        this.0.get_mut().context = 0;
231        r
232    }
233}
234
235impl<S> AsyncRead for SslStream<S>
236where
237    S: AsyncRead + AsyncWrite,
238{
239    fn poll_read(
240        self: Pin<&mut Self>,
241        ctx: &mut Context<'_>,
242        buf: &mut ReadBuf<'_>,
243    ) -> Poll<io::Result<()>> {
244        self.with_context(ctx, |s| {
245            // SAFETY: read_uninit does not de-initialize the buffer.
246            match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
247                Poll::Ready(nread) => {
248                    // SAFETY: read_uninit guarantees that nread bytes have been initialized.
249                    unsafe { buf.assume_init(nread) };
250                    buf.advance(nread);
251                    Poll::Ready(Ok(()))
252                }
253                Poll::Pending => Poll::Pending,
254            }
255        })
256    }
257}
258
259impl<S> AsyncWrite for SslStream<S>
260where
261    S: AsyncRead + AsyncWrite,
262{
263    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
264        self.with_context(ctx, |s| cvt(s.write(buf)))
265    }
266
267    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
268        self.with_context(ctx, |s| cvt(s.flush()))
269    }
270
271    fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
272        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
273            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
274            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
275            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
276                return Poll::Pending;
277            }
278            Err(e) => {
279                return Poll::Ready(Err(e
280                    .into_io_error()
281                    .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
282            }
283        }
284
285        self.get_pin_mut().poll_shutdown(ctx)
286    }
287}