proxy_header/
io.rs

1//! IO wrapper for proxied streams.
2//!
3//! PROXY protocol header is variable length so it is not possible to read a fixed number of bytes
4//! directly from the stream and reading it byte-by-byte can be inefficient. [`ProxiedStream`] reads
5//! enough bytes to parse the header and retains any extra bytes that may have been read.
6//!
7//! If the underlying stream is already buffered (i.e. [`std::io::BufRead`] or equivalent), it is
8//! probably a better idea to just decode the header directly instead of using [`ProxiedStream`].
9//!
10//! The wrapper is usable both with standard ([`std::io::Read`]) and Tokio streams ([`tokio::io::AsyncRead`]).
11//!
12//! ## Example (Tokio)
13//!
14//! ```no_run
15//! # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//! use tokio::io::{AsyncReadExt, AsyncWriteExt};
17//! use tokio::net::TcpListener;
18//! use proxy_header::io::ProxiedStream;
19//!
20//! let listener = TcpListener::bind("[::]:1234").await?;
21//!
22//! loop {
23//!     let (mut socket, _) = listener.accept().await?;
24//!     tokio::spawn(async move {
25//!         // Read the proxy header first
26//!         let mut socket = ProxiedStream::create_from_tokio(socket, Default::default())
27//!             .await
28//!             .expect("failed to create proxied stream");
29//!
30//!         // We can now inspect the address
31//!         println!("proxy header: {:?}", socket.proxy_header());
32//!
33//!         /// Then process the protocol
34//!         let mut buf = vec![0; 1024];
35//!         loop {
36//!             let n = socket.read(&mut buf).await.unwrap();
37//!             if n == 0 {
38//!                 return;
39//!             }
40//!             socket.write_all(&buf[0..n]).await.unwrap();
41//!         }
42//!     });
43//! }
44//! # }
45//! ```
46use std::{
47    io::{self, BufRead, Read, Write},
48    mem::MaybeUninit,
49};
50
51#[cfg(any(unix, target_os = "wasi"))]
52use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
53
54#[cfg(feature = "tokio")]
55use std::{
56    pin::Pin,
57    task::{Context, Poll},
58};
59
60#[cfg(feature = "tokio")]
61use pin_project_lite::pin_project;
62
63#[cfg(feature = "tokio")]
64use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
65
66use crate::{Error, ParseConfig, ProxyHeader};
67
68#[cfg(all(feature = "tokio", not(doc)))]
69pin_project! {
70    #[derive(Debug)]
71    pub struct ProxiedStream<IO> {
72        #[pin]
73        io: IO,
74        remaining: Vec<u8>,
75        header: ProxyHeader<'static>,
76    }
77}
78
79/// Wrapper around a stream that starts with a proxy header.
80///
81/// See [module level documentation](`crate::io`)
82#[cfg(any(doc, not(feature = "tokio")))]
83#[derive(Debug)]
84pub struct ProxiedStream<IO> {
85    io: IO,
86    remaining: Vec<u8>,
87    header: ProxyHeader<'static>,
88}
89
90impl<IO> ProxiedStream<IO> {
91    /// Create a new proxied stream from an stream that does not have a proxy header.
92    ///
93    /// This is useful if you want to use the same stream type for proxied and unproxied
94    /// connections.
95    pub fn unproxied(io: IO) -> Self {
96        Self {
97            io,
98            remaining: vec![],
99            header: Default::default(),
100        }
101    }
102
103    /// Get the proxy header.
104    pub fn proxy_header(&self) -> &ProxyHeader {
105        &self.header
106    }
107
108    /// Gets a reference to the underlying stream.
109    pub fn get_ref(&self) -> &IO {
110        &self.io
111    }
112
113    /// Gets a mutable reference to the underlying stream.
114    pub fn get_mut(&mut self) -> &mut IO {
115        &mut self.io
116    }
117
118    /// Gets a pinned mutable reference to the underlying stream.
119    #[cfg(feature = "tokio")]
120    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut IO> {
121        self.project().io
122    }
123
124    /// Consumes this wrapper, returning the underlying stream.
125    pub fn into_inner(self) -> IO {
126        self.io
127    }
128}
129
130#[cfg(feature = "tokio")]
131#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
132impl<IO> ProxiedStream<IO>
133where
134    IO: AsyncRead + Unpin,
135{
136    /// Reads the proxy header from an [`tokio::io::AsyncRead`] stream and returns a new [`ProxiedStream`].
137    ///
138    /// This method will read from the stream until a proxy header is found, or the
139    /// stream is closed. If the stream is closed before a proxy header is found,
140    /// this method will return an [`io::Error`] with [`io::ErrorKind::UnexpectedEof`].
141    ///
142    /// If the stream contains invalid data, this method will return an [`io::Error`]
143    /// with [`io::ErrorKind::InvalidData`]. In case of an error, the stream is dropped,
144    /// and any remaining bytes are discarded (which usually means the connection
145    /// is closed).
146    pub async fn create_from_tokio(mut io: IO, config: ParseConfig) -> io::Result<Self> {
147        use tokio::io::AsyncReadExt;
148
149        // 256 bytes should be enough for the longest realistic header with
150        // all extensions. If not, we'll just reallocate. theoretical maximum
151        // is 12 + 4 + 65535 = 65551 bytes, though that would be very silly.
152        //
153        // Maybe we should just error out if we get more than 512 bytes?
154        let mut bytes = Vec::with_capacity(256);
155
156        loop {
157            let bytes_read = io.read_buf(&mut bytes).await?;
158            if bytes_read == 0 {
159                return Err(io::Error::new(
160                    io::ErrorKind::UnexpectedEof,
161                    "end of stream",
162                ));
163            }
164
165            match ProxyHeader::parse(&bytes, config) {
166                Ok((ret, consumed)) => {
167                    let ret = ret.into_owned();
168                    bytes.drain(..consumed);
169
170                    return Ok(Self {
171                        io,
172                        remaining: bytes,
173                        header: ret,
174                    });
175                }
176                Err(Error::BufferTooShort) => continue,
177                Err(_) => {
178                    return Err(io::Error::new(
179                        io::ErrorKind::InvalidData,
180                        "invalid proxy header",
181                    ))
182                }
183            }
184        }
185    }
186}
187
188impl<IO> ProxiedStream<IO>
189where
190    IO: Read,
191{
192    /// Reads the proxy header from a [`Read`] stream and returns a new `ProxiedStream`.
193    ///
194    /// Other than the fact that this method is synchronous, it is identical to [`create_from_tokio`](Self::create_from_tokio).
195    pub fn create_from_std(mut io: IO, config: ParseConfig) -> io::Result<Self> {
196        let mut bytes = Vec::with_capacity(256);
197
198        loop {
199            if bytes.capacity() == bytes.len() {
200                bytes.reserve(32);
201            }
202
203            // TODO: Get rid of this once read-buf is stabilized
204            // (https://github.com/rust-lang/rust/issues/78485)
205
206            let buf = bytes.spare_capacity_mut();
207            buf.fill(MaybeUninit::new(0));
208
209            // SAFETY: We just initialized the whole spare capacity
210            let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
211
212            let bytes_read = io.read(buf)?;
213            if bytes_read == 0 {
214                return Err(io::Error::new(
215                    io::ErrorKind::UnexpectedEof,
216                    "end of stream",
217                ));
218            }
219
220            // SAFETY: The bytes are initialized even if the reader lies about how many
221            // bytes were read.
222            unsafe {
223                assert!(bytes_read <= buf.len());
224                bytes.set_len(bytes.len() + bytes_read);
225            }
226
227            match ProxyHeader::parse(&bytes, config) {
228                Ok((ret, consumed)) => {
229                    let ret = ret.into_owned();
230                    bytes.drain(..consumed);
231
232                    return Ok(Self {
233                        io,
234                        remaining: bytes,
235                        header: ret,
236                    });
237                }
238                Err(Error::BufferTooShort) => continue,
239                Err(_) => {
240                    return Err(io::Error::new(
241                        io::ErrorKind::InvalidData,
242                        "invalid proxy header",
243                    ))
244                }
245            }
246        }
247    }
248}
249
250impl<IO> Read for ProxiedStream<IO>
251where
252    IO: Read,
253{
254    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
255        if !self.remaining.is_empty() {
256            let len = std::cmp::min(self.remaining.len(), buf.len());
257
258            buf[..len].copy_from_slice(&self.remaining[..len]);
259            self.remaining.drain(..len);
260
261            return Ok(len);
262        }
263
264        self.io.read(buf)
265    }
266}
267
268impl<IO> BufRead for ProxiedStream<IO>
269where
270    IO: BufRead,
271{
272    fn fill_buf(&mut self) -> io::Result<&[u8]> {
273        if !self.remaining.is_empty() {
274            return Ok(&self.remaining);
275        }
276        self.io.fill_buf()
277    }
278
279    fn consume(&mut self, mut amt: usize) {
280        if !self.remaining.is_empty() {
281            let len = std::cmp::min(self.remaining.len(), amt);
282            self.remaining.drain(..len);
283            amt -= len;
284        }
285        self.io.consume(amt);
286    }
287}
288
289impl<IO> Write for ProxiedStream<IO>
290where
291    IO: Write,
292{
293    #[inline]
294    fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
295        self.io.write_vectored(bufs)
296    }
297
298    #[inline]
299    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
300        self.io.write_all(buf)
301    }
302
303    #[inline]
304    fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> {
305        self.io.write_fmt(fmt)
306    }
307
308    #[inline]
309    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
310        self.io.write(buf)
311    }
312
313    #[inline]
314    fn flush(&mut self) -> io::Result<()> {
315        self.io.flush()
316    }
317}
318
319#[cfg(feature = "tokio")]
320#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
321impl<IO> AsyncBufRead for ProxiedStream<IO>
322where
323    IO: AsyncBufRead,
324{
325    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
326        let me = self.project();
327
328        if !me.remaining.is_empty() {
329            return Poll::Ready(Ok(&me.remaining[..]));
330        }
331
332        me.io.poll_fill_buf(cx)
333    }
334
335    fn consume(self: Pin<&mut Self>, amt: usize) {
336        let me = self.project();
337
338        if !me.remaining.is_empty() {
339            let len = std::cmp::min(me.remaining.len(), amt);
340            me.remaining.drain(..len);
341        }
342
343        me.io.consume(amt);
344    }
345}
346
347#[cfg(feature = "tokio")]
348#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
349impl<IO> AsyncRead for ProxiedStream<IO>
350where
351    IO: AsyncRead,
352{
353    fn poll_read(
354        self: Pin<&mut Self>,
355        cx: &mut Context<'_>,
356        buf: &mut ReadBuf<'_>,
357    ) -> Poll<io::Result<()>> {
358        let me = self.project();
359
360        if !me.remaining.is_empty() {
361            let len = std::cmp::min(me.remaining.len(), buf.remaining());
362
363            buf.put_slice(&me.remaining[..len]);
364            me.remaining.drain(..len);
365
366            return Poll::Ready(Ok(()));
367        }
368
369        me.io.poll_read(cx, buf)
370    }
371}
372
373#[cfg(feature = "tokio")]
374#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
375impl<IO> AsyncWrite for ProxiedStream<IO>
376where
377    IO: AsyncWrite,
378{
379    #[inline]
380    fn poll_write(
381        self: Pin<&mut Self>,
382        cx: &mut Context<'_>,
383        buf: &[u8],
384    ) -> Poll<io::Result<usize>> {
385        self.project().io.poll_write(cx, buf)
386    }
387
388    #[inline]
389    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
390        self.project().io.poll_flush(cx)
391    }
392
393    #[inline]
394    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395        self.project().io.poll_shutdown(cx)
396    }
397
398    #[inline]
399    fn poll_write_vectored(
400        self: Pin<&mut Self>,
401        cx: &mut Context<'_>,
402        bufs: &[io::IoSlice<'_>],
403    ) -> Poll<Result<usize, io::Error>> {
404        self.project().io.poll_write_vectored(cx, bufs)
405    }
406
407    #[inline]
408    fn is_write_vectored(&self) -> bool {
409        self.io.is_write_vectored()
410    }
411}
412
413#[cfg(any(unix, target_os = "wasi"))]
414#[cfg_attr(docsrs, doc(cfg(any(unix, target_os = "wasi"))))]
415impl<IO> AsRawFd for ProxiedStream<IO>
416where
417    IO: AsRawFd,
418{
419    fn as_raw_fd(&self) -> RawFd {
420        self.io.as_raw_fd()
421    }
422}
423
424#[cfg(any(unix, target_os = "wasi"))]
425#[cfg_attr(docsrs, doc(cfg(any(unix, target_os = "wasi"))))]
426impl<IO> AsFd for ProxiedStream<IO>
427where
428    IO: AsFd,
429{
430    fn as_fd(&self) -> BorrowedFd<'_> {
431        self.io.as_fd()
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    use crate::{Protocol, ProxiedAddress, ProxyHeader};
440    use std::{
441        io::Cursor,
442        net::{Ipv4Addr, SocketAddr, SocketAddrV4},
443    };
444
445    #[test]
446    fn test_sync() {
447        let mut buf = [0; 1024];
448
449        let header = ProxyHeader::with_address(ProxiedAddress {
450            protocol: Protocol::Stream,
451            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
452            destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
453        });
454
455        let written_len = header.encode_to_slice_v2(&mut buf).unwrap();
456        buf[written_len..].fill(255);
457
458        let mut stream = Cursor::new(&buf);
459
460        let mut proxied = ProxiedStream::create_from_std(&mut stream, Default::default()).unwrap();
461        assert_eq!(proxied.proxy_header(), &header);
462
463        let mut buf = Vec::new();
464        proxied.read_to_end(&mut buf).unwrap();
465
466        assert_eq!(buf.len(), 1024 - written_len);
467        assert!(buf.into_iter().all(|b| b == 255));
468    }
469
470    #[cfg(feature = "tokio")]
471    #[tokio::test]
472    async fn test_tokio() {
473        use tokio::io::AsyncReadExt;
474
475        let mut buf = [0; 1024];
476
477        let header = ProxyHeader::with_address(ProxiedAddress {
478            protocol: Protocol::Stream,
479            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
480            destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
481        });
482
483        let written_len = header.encode_to_slice_v2(&mut buf).unwrap();
484        buf[written_len..].fill(255);
485
486        let mut stream = Cursor::new(&buf);
487
488        let mut proxied = ProxiedStream::create_from_tokio(&mut stream, Default::default())
489            .await
490            .unwrap();
491        assert_eq!(proxied.proxy_header(), &header);
492
493        let mut buf = Vec::new();
494        AsyncReadExt::read_to_end(&mut proxied, &mut buf)
495            .await
496            .unwrap();
497
498        assert_eq!(buf.len(), 1024 - written_len);
499        assert!(buf.into_iter().all(|b| b == 255));
500    }
501}