hyper_openssl/
lib.rs

1//! Hyper TLS support via OpenSSL.
2#![warn(missing_docs)]
3
4use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
5use openssl::error::ErrorStack;
6use openssl::ssl::{self, ErrorCode, Ssl, SslRef};
7use std::fmt;
8use std::future;
9use std::io::{self, Write as _};
10use std::pin::Pin;
11use std::ptr;
12use std::task::{Context, Poll};
13
14#[cfg(feature = "client-legacy")]
15pub mod client;
16#[cfg(test)]
17mod test;
18
19struct StreamWrapper<S> {
20    stream: S,
21    context: *mut Context<'static>,
22}
23
24unsafe impl<S> Sync for StreamWrapper<S> where S: Sync {}
25unsafe impl<S> Send for StreamWrapper<S> where S: Send {}
26
27impl<S> fmt::Debug for StreamWrapper<S>
28where
29    S: fmt::Debug,
30{
31    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
32        fmt::Debug::fmt(&self.stream, fmt)
33    }
34}
35
36impl<S> StreamWrapper<S> {
37    /// # Safety
38    ///
39    /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
40    /// wrapper must be pinned in memory.
41    unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
42        debug_assert_ne!(self.context, ptr::null_mut());
43        let stream = Pin::new_unchecked(&mut self.stream);
44        let context = &mut *self.context.cast();
45        (stream, context)
46    }
47}
48
49impl<S> io::Read for StreamWrapper<S>
50where
51    S: Read,
52{
53    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
54        let (stream, cx) = unsafe { self.parts() };
55        let mut buf = ReadBuf::new(buf);
56        match stream.poll_read(cx, buf.unfilled())? {
57            Poll::Ready(()) => Ok(buf.filled().len()),
58            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
59        }
60    }
61}
62
63impl<S> io::Write for StreamWrapper<S>
64where
65    S: Write,
66{
67    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
68        let (stream, cx) = unsafe { self.parts() };
69        match stream.poll_write(cx, buf) {
70            Poll::Ready(r) => r,
71            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
72        }
73    }
74
75    fn flush(&mut self) -> io::Result<()> {
76        let (stream, cx) = unsafe { self.parts() };
77        match stream.poll_flush(cx) {
78            Poll::Ready(r) => r,
79            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
80        }
81    }
82}
83
84/// A Hyper stream type using OpenSSL.
85#[derive(Debug)]
86pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
87
88impl<S> SslStream<S>
89where
90    S: Read + Write,
91{
92    /// Like [`SslStream::new`](ssl::SslStream::new).
93    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
94        ssl::SslStream::new(
95            ssl,
96            StreamWrapper {
97                stream,
98                context: ptr::null_mut(),
99            },
100        )
101        .map(SslStream)
102    }
103
104    /// Like [`SslStream::connect`](ssl::SslStream::connect).
105    pub fn poll_connect(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108    ) -> Poll<Result<(), ssl::Error>> {
109        self.with_context(cx, |s| cvt_ossl(s.connect()))
110    }
111
112    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
113    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
114        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
115    }
116
117    /// Like [`SslStream::accept`](ssl::SslStream::accept).
118    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
119        self.with_context(cx, |s| cvt_ossl(s.accept()))
120    }
121
122    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
123    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
124        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
125    }
126
127    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
128    pub fn poll_do_handshake(
129        self: Pin<&mut Self>,
130        cx: &mut Context<'_>,
131    ) -> Poll<Result<(), ssl::Error>> {
132        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
133    }
134
135    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
136    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
137        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
138    }
139}
140
141impl<S> SslStream<S> {
142    /// Returns a shared reference to the `Ssl` object associated with this stream.
143    pub fn ssl(&self) -> &SslRef {
144        self.0.ssl()
145    }
146
147    /// Returns a shared reference to the underlying stream.
148    pub fn get_ref(&self) -> &S {
149        &self.0.get_ref().stream
150    }
151
152    /// Returns a mutable reference to the underlying stream.
153    pub fn get_mut(&mut self) -> &mut S {
154        &mut self.0.get_mut().stream
155    }
156
157    /// Returns a pinned mutable reference to the underlying stream.
158    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
159        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
160    }
161
162    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
163    where
164        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
165    {
166        unsafe {
167            let this = self.get_unchecked_mut();
168            this.0.get_mut().context = (ctx as *mut Context<'_>).cast();
169            let r = f(&mut this.0);
170            this.0.get_mut().context = ptr::null_mut();
171            r
172        }
173    }
174}
175
176impl<S> Read for SslStream<S>
177where
178    S: Read + Write,
179{
180    fn poll_read(
181        self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        mut buf: ReadBufCursor<'_>,
184    ) -> Poll<io::Result<()>> {
185        // SAFETY: read_uninit doesn't de-initialize the buffer and guarantees that nread bytes have
186        // been initialized.
187        self.with_context(cx, |s| unsafe {
188            match cvt(s.read_uninit(buf.as_mut()))? {
189                Poll::Ready(nread) => {
190                    buf.advance(nread);
191                    Poll::Ready(Ok(()))
192                }
193                Poll::Pending => Poll::Pending,
194            }
195        })
196    }
197}
198
199impl<S> Write for SslStream<S>
200where
201    S: Read + Write,
202{
203    fn poll_write(
204        self: Pin<&mut Self>,
205        cx: &mut Context<'_>,
206        buf: &[u8],
207    ) -> Poll<io::Result<usize>> {
208        self.with_context(cx, |s| cvt(s.write(buf)))
209    }
210
211    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
212        self.with_context(cx, |s| cvt(s.flush()))
213    }
214
215    fn poll_shutdown(
216        self: Pin<&mut Self>,
217        cx: &mut Context<'_>,
218    ) -> Poll<Result<(), std::io::Error>> {
219        self.with_context(cx, |s| {
220            match s.shutdown() {
221                Ok(_) => {}
222                Err(e) => match e.code() {
223                    ErrorCode::ZERO_RETURN => {}
224                    ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending,
225                    _ => {
226                        return Poll::Ready(Err(e
227                            .into_io_error()
228                            .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))))
229                    }
230                },
231            }
232
233            let (stream, cx) = unsafe { s.get_mut().parts() };
234            stream.poll_shutdown(cx)
235        })
236    }
237}
238
239fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
240    match r {
241        Ok(v) => Poll::Ready(Ok(v)),
242        Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
243        Err(e) => Poll::Ready(Err(e)),
244    }
245}
246
247fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
248    match r {
249        Ok(v) => Poll::Ready(Ok(v)),
250        Err(e) => match e.code() {
251            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
252            _ => Poll::Ready(Err(e)),
253        },
254    }
255}