proxy_header/
lib.rs

1//! PROXY protocol decoder and encoder
2//!
3//! This crate provides a decoder and encoder for the
4//! [PROXY protocol](https://www.haproxy.org/download/2.8/doc/proxy-protocol.txt),
5//! which is used to preserve original client connection information when proxying TCP
6//! connections for protocols that do not support this higher up in the stack.
7//!
8//! The PROXY protocol is supported by many load balancers and proxies, including HAProxy,
9//! Amazon ELB, Amazon ALB, and others.
10//!
11//! This crate implements the entire specification, except parsing the `AF_UNIX` address
12//! type (the header is validated / parsed, but the address is not decoded or exposed in
13//! the API).
14//!
15//! # Usage
16//!
17//! ## Decoding
18//!
19//! To decode a PROXY protocol header from an existing buffer, use [`ProxyHeader::parse`]:
20//! ```
21//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! use proxy_header::{ProxyHeader, ParseConfig};
23//!
24//! let buf = b"PROXY TCP6 2001:db8:1::1 2001:db8:2::1 52953 25\r\nHELO example.com\r\n";
25//!
26//! let (header, len) = ProxyHeader::parse(buf, ParseConfig::default())?;
27//! match header.proxied_address() {
28//!    Some(addr) => {
29//!       println!("Proxied connection from {} to {}", addr.source, addr.destination);
30//!    }
31//!    None => {
32//!       println!("Local connection (e.g. healthcheck)");
33//!   }
34//! }
35//!
36//! println!("Client sent: {:?}", &buf[len..]);
37//! # Ok(())
38//! # }
39//! ```
40//!
41//! In addition to the address information, the PROXY protocol version 2 header can contain
42//! additional information in the form of TLV (type-length-value) fields. These can be accessed
43//! through the [`ProxyHeader::tlvs`] iterator or through convenience accessors such as [`ProxyHeader::authority`].
44//!
45//! See [`Tlv`] for more information on the different types of TLV fields.
46//!
47//! ```
48//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
49//! # use proxy_header::{ProxyHeader, ParseConfig};
50//! # let buf = b"PROXY TCP4 10.0.0.1 10.0.0.2 52953 25\r\nHELO example.com\r\n";
51//! # let (header, _) = ProxyHeader::parse(buf, ParseConfig::default()).unwrap();
52//! use proxy_header::Tlv;
53//!
54//! for tlv in header.tlvs() {
55//!     match tlv? {  // TLV can be malformed
56//!         Tlv::UniqueId(v) => {
57//!             println!("Unique connection ID: {:?}", v);
58//!         }
59//!         Tlv::Authority(v) => {
60//!             println!("Authority string (SNI): {:?}", v);
61//!         }
62//!         _ => {}
63//!     }
64//! }
65//! # Ok(())
66//! # }
67//! ```
68//!
69//! See also [`io`] module for a stream wrapper that can automatically parse PROXY protocol.
70//!
71//! ## Encoding
72//!
73//! To encode a PROXY protocol header, use [`ProxyHeader::encode_v1`] for version 1 headers and
74//! [`ProxyHeader::encode_v2`] for version 2 headers.
75//!
76//! ```
77//! use proxy_header::{ProxyHeader, ProxiedAddress, Protocol};
78//!
79//! let addrs = ProxiedAddress::stream(
80//!    "[2001:db8::1:1]:51234".parse().unwrap(),
81//!    "[2001:db8::2:1]:443".parse().unwrap()
82//! );
83//! let header = ProxyHeader::with_address(addrs);
84//!
85//! let mut buf = [0u8; 1024];
86//! let len = header.encode_to_slice_v2(&mut buf).unwrap();
87//! ```
88#![cfg_attr(docsrs, feature(doc_cfg))]
89#![cfg_attr(docsrs, allow(unused_attributes))]
90
91mod util;
92mod v1;
93mod v2;
94
95pub mod io;
96
97use crate::util::{tlv, tlv_borrowed};
98use std::borrow::Cow;
99use std::fmt;
100use std::net::SocketAddr;
101
102/// Protocol type
103#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
104pub enum Protocol {
105    /// Stream protocol (TCP)
106    Stream,
107    /// Datagram protocol (UDP)
108    Datagram,
109}
110
111/// Address information from a PROXY protocol header
112#[derive(Debug, PartialEq, Eq, Clone, Hash)]
113pub struct ProxiedAddress {
114    /// Protocol type (TCP or UDP)
115    pub protocol: Protocol,
116    /// Source address (this is the address of the actual client)
117    pub source: SocketAddr,
118    /// Destination address (this is the address of the proxy)
119    pub destination: SocketAddr,
120}
121
122impl ProxiedAddress {
123    pub fn stream(source: SocketAddr, destination: SocketAddr) -> Self {
124        Self {
125            protocol: Protocol::Stream,
126            source,
127            destination,
128        }
129    }
130
131    pub fn datagram(source: SocketAddr, destination: SocketAddr) -> Self {
132        Self {
133            protocol: Protocol::Datagram,
134            source,
135            destination,
136        }
137    }
138}
139
140/// Iterator over PROXY protocol TLV (type-length-value) fields
141pub struct Tlvs<'a> {
142    buf: &'a [u8],
143}
144
145impl<'a> Iterator for Tlvs<'a> {
146    type Item = Result<Tlv<'a>, Error>;
147
148    fn next(&mut self) -> Option<Self::Item> {
149        if self.buf.is_empty() {
150            return None;
151        }
152
153        let kind = self.buf[0];
154        match self
155            .buf
156            .get(1..3)
157            .map(|s| u16::from_be_bytes(s.try_into().unwrap()) as usize)
158        {
159            Some(u) if u + 3 <= self.buf.len() => {
160                let (ret, new) = self.buf.split_at(3 + u);
161                self.buf = new;
162
163                Some(Tlv::decode(kind, &ret[3..]))
164            }
165            _ => {
166                // Malformed TLV, we cannot continue
167                self.buf = &[];
168                Some(Err(Error::Invalid))
169            }
170        }
171    }
172}
173
174/// SSL information from a PROXY protocol header
175#[derive(PartialEq, Eq, Clone)]
176pub struct SslInfo<'a>(u8, u32, Cow<'a, [u8]>);
177
178impl<'a> SslInfo<'a> {
179    /// Create a new SSL information struct
180    pub fn new(
181        client_ssl: bool,
182        client_cert_conn: bool,
183        client_cert_sess: bool,
184        verify: u32,
185    ) -> Self {
186        Self(
187            (client_ssl as u8) | (client_cert_conn as u8) << 1 | (client_cert_sess as u8) << 2,
188            verify,
189            Default::default(),
190        )
191    }
192
193    /// Client connected over SSL/TLS
194    ///
195    /// The PP2_CLIENT_SSL flag indicates that the client connected over SSL/TLS. When
196    /// this field is present, the US-ASCII string representation of the TLS version is
197    /// appended at the end of the field in the TLV format using the type
198    /// PP2_SUBTYPE_SSL_VERSION.
199    pub fn client_ssl(&self) -> bool {
200        self.0 & 0x01 != 0
201    }
202
203    /// Client certificate presented in the connection
204    ///
205    /// PP2_CLIENT_CERT_CONN indicates that the client provided a certificate over the
206    /// current connection.
207    pub fn client_cert_conn(&self) -> bool {
208        self.0 & 0x02 != 0
209    }
210
211    /// Client certificate presented in the session
212    ///
213    /// PP2_CLIENT_CERT_SESS indicates that the client provided a
214    /// certificate at least once over the TLS session this connection belongs to.
215    pub fn client_cert_sess(&self) -> bool {
216        self.0 & 0x04 != 0
217    }
218
219    /// Whether the certificate was verified
220    ///
221    /// The verify field will be zero if the client presented a certificate
222    /// and it was successfully verified, and non-zero otherwise.
223    pub fn verify(&self) -> u32 {
224        self.1
225    }
226
227    /// Iterator over all TLV (type-length-value) fields
228    pub fn tlvs(&self) -> Tlvs<'_> {
229        Tlvs { buf: &self.2 }
230    }
231
232    // Convenience accessors for common TLVs
233
234    /// SSL version
235    ///
236    /// See [`Tlv::SslVersion`] for more information.
237    pub fn version(&self) -> Option<&str> {
238        tlv_borrowed!(self, SslVersion)
239    }
240
241    /// SSL CN
242    ///
243    /// See [`Tlv::SslCn`] for more information.
244    pub fn cn(&self) -> Option<&str> {
245        tlv_borrowed!(self, SslCn)
246    }
247
248    /// SSL cipher
249    ///
250    /// See [`Tlv::SslCipher`] for more information.
251    pub fn cipher(&self) -> Option<&str> {
252        tlv_borrowed!(self, SslCipher)
253    }
254
255    /// SSL signature algorithm
256    ///
257    /// See [`Tlv::SslSigAlg`] for more information.
258    pub fn sig_alg(&self) -> Option<&str> {
259        tlv_borrowed!(self, SslSigAlg)
260    }
261
262    /// SSL key algorithm
263    ///
264    /// See [`Tlv::SslKeyAlg`] for more information.
265    pub fn key_alg(&self) -> Option<&str> {
266        tlv_borrowed!(self, SslKeyAlg)
267    }
268
269    /// Returns an owned version of this struct
270    pub fn into_owned(self) -> SslInfo<'static> {
271        SslInfo(self.0, self.1, Cow::Owned(self.2.into_owned()))
272    }
273
274    /// Appends an additional sub-TLV field
275    ///
276    /// See [`ProxyHeader::append_tlv`] for more information.
277    pub fn append_tlv(&mut self, tlv: Tlv<'_>) {
278        tlv.encode(self.2.to_mut());
279    }
280}
281
282impl fmt::Debug for SslInfo<'_> {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        f.debug_struct("Ssl")
285            .field("verify", &self.verify())
286            .field("client_ssl", &self.client_ssl())
287            .field("client_cert_conn", &self.client_cert_conn())
288            .field("client_cert_sess", &self.client_cert_sess())
289            .field("fields", &self.tlvs().collect::<Vec<_>>())
290            .finish()
291    }
292}
293
294/// Typed TLV (type-length-value) field
295///
296/// Represents the currently known types of TLV fields from the PROXY protocol specification.
297/// Non-recognized TLV fields are represented as [`Tlv::Custom`].
298#[non_exhaustive]
299#[derive(Debug, PartialEq, Eq, Clone)]
300pub enum Tlv<'a> {
301    /// Application-Layer Protocol Negotiation (ALPN). It is a byte sequence defining
302    /// the upper layer protocol in use over the connection. The most common use case
303    /// will be to pass the exact copy of the ALPN extension of the Transport Layer
304    /// Security (TLS) protocol as defined by RFC7301.
305    Alpn(Cow<'a, [u8]>),
306
307    /// Contains the host name value passed by the client, as an UTF8-encoded string.
308    /// In case of TLS being used on the client connection, this is the exact copy of
309    /// the "server_name" extension as defined by RFC3546, section 3.1, often
310    /// referred to as "SNI". There are probably other situations where an authority
311    /// can be mentioned on a connection without TLS being involved at all.
312    Authority(Cow<'a, str>),
313
314    /// The value of the type PP2_TYPE_CRC32C is a 32-bit number storing the CRC32c
315    /// checksum of the PROXY protocol header.
316    ///
317    /// When the checksum is supported by the sender after constructing the header
318    /// the sender MUST:
319    ///
320    /// - initialize the checksum field to '0's.
321    ///
322    /// - calculate the CRC32c checksum of the PROXY header as described in RFC4960,
323    /// Appendix B.
324    ///
325    /// - put the resultant value into the checksum field, and leave the rest of
326    /// the bits unchanged.
327    ///
328    /// If the checksum is provided as part of the PROXY header and the checksum
329    /// functionality is supported by the receiver, the receiver MUST:
330    ///
331    /// - store the received CRC32c checksum value aside.
332    ///
333    /// - replace the 32 bits of the checksum field in the received PROXY header with
334    /// all '0's and calculate a CRC32c checksum value of the whole PROXY header.
335    ///
336    /// - verify that the calculated CRC32c checksum is the same as the received
337    /// CRC32c checksum. If it is not, the receiver MUST treat the TCP connection
338    /// providing the header as invalid.
339    ///
340    /// The default procedure for handling an invalid TCP connection is to abort it.
341    Crc32c(u32),
342
343    /// The TLV of this type should be ignored when parsed. The value is zero or more
344    /// bytes. Can be used for data padding or alignment. Note that it can be used
345    /// to align only by 3 or more bytes because a TLV can not be smaller than that.
346    Noop(usize),
347
348    /// The value of the type PP2_TYPE_UNIQUE_ID is an opaque byte sequence of up to
349    /// 128 bytes generated by the upstream proxy that uniquely identifies the
350    /// connection.
351    ///
352    /// The unique ID can be used to easily correlate connections across multiple
353    /// layers of proxies, without needing to look up IP addresses and port numbers.
354    UniqueId(Cow<'a, [u8]>),
355
356    /// SSL (TLS) information
357    ///
358    /// See [`SslInfo`] for more information.
359    Ssl(SslInfo<'a>),
360
361    /// The type PP2_TYPE_NETNS defines the value as the US-ASCII string representation
362    /// of the namespace's name.
363    Netns(Cow<'a, str>),
364
365    // These can only appear as a sub-TLV of SslInfo
366    /// SSL/TLS version
367    SslVersion(Cow<'a, str>),
368
369    /// In all cases, the string representation (in UTF8) of the Common Name field
370    /// (OID: 2.5.4.3) of the client certificate's Distinguished Name, is appended
371    /// using the TLV format and the type PP2_SUBTYPE_SSL_CN. E.g. "example.com".
372    SslCn(Cow<'a, str>),
373
374    /// The second level TLV PP2_SUBTYPE_SSL_CIPHER provides the US-ASCII string name
375    /// of the used cipher, for example "ECDHE-RSA-AES128-GCM-SHA256".
376    SslCipher(Cow<'a, str>),
377
378    /// The second level TLV PP2_SUBTYPE_SSL_SIG_ALG provides the US-ASCII string name
379    /// of the algorithm used to sign the certificate presented by the frontend when
380    /// the incoming connection was made over an SSL/TLS transport layer, for example
381    /// "SHA256".
382    SslSigAlg(Cow<'a, str>),
383
384    /// The second level TLV PP2_SUBTYPE_SSL_KEY_ALG provides the US-ASCII string name
385    /// of the algorithm used to generate the key of the certificate presented by the
386    /// frontend when the incoming connection was made over an SSL/TLS transport layer,
387    /// for example "RSA2048".
388    SslKeyAlg(Cow<'a, str>),
389
390    /// Unrecognized or custom TLV field
391    Custom(u8, Cow<'a, [u8]>),
392}
393
394impl<'a> Tlv<'a> {
395    /// Decode a TLV field from the given buffer
396    ///
397    /// Returns an error if the field is malformed.
398    pub fn decode(kind: u8, data: &'a [u8]) -> Result<Tlv<'a>, Error> {
399        use std::str::from_utf8;
400        use Tlv::*;
401
402        match kind {
403            0x01 => Ok(Alpn(data.into())),
404            0x02 => Ok(Authority(
405                from_utf8(data).map_err(|_| Error::Invalid)?.into(),
406            )),
407            0x03 => Ok(Crc32c(u32::from_be_bytes(
408                data.try_into().map_err(|_| Error::Invalid)?,
409            ))),
410            0x04 => Ok(Noop(data.len())),
411            0x05 => Ok(UniqueId(data.into())),
412            0x20 => Ok(Ssl(SslInfo(
413                *data.first().ok_or(Error::Invalid)?,
414                u32::from_be_bytes(
415                    data.get(1..5)
416                        .ok_or(Error::Invalid)?
417                        .try_into()
418                        .map_err(|_| Error::Invalid)?,
419                ),
420                data.get(5..).ok_or(Error::Invalid)?.into(),
421            ))),
422            0x21 => Ok(SslVersion(
423                from_utf8(data).map_err(|_| Error::Invalid)?.into(),
424            )),
425            0x22 => Ok(SslCn(from_utf8(data).map_err(|_| Error::Invalid)?.into())),
426            0x23 => Ok(SslCipher(
427                from_utf8(data).map_err(|_| Error::Invalid)?.into(),
428            )),
429            0x24 => Ok(SslSigAlg(
430                from_utf8(data).map_err(|_| Error::Invalid)?.into(),
431            )),
432            0x25 => Ok(SslKeyAlg(
433                from_utf8(data).map_err(|_| Error::Invalid)?.into(),
434            )),
435            0x30 => Ok(Netns(from_utf8(data).map_err(|_| Error::Invalid)?.into())),
436            a => Ok(Custom(a, data.into())),
437        }
438    }
439
440    /// Returns the raw kind of this TLV field
441    pub fn kind(&self) -> u8 {
442        match self {
443            Tlv::Alpn(_) => 0x01,
444            Tlv::Authority(_) => 0x02,
445            Tlv::Crc32c(_) => 0x03,
446            Tlv::Noop(_) => 0x04,
447            Tlv::UniqueId(_) => 0x05,
448            Tlv::Ssl(_) => 0x20,
449            Tlv::Netns(_) => 0x30,
450            Tlv::SslVersion(_) => 0x21,
451            Tlv::SslCn(_) => 0x22,
452            Tlv::SslCipher(_) => 0x23,
453            Tlv::SslSigAlg(_) => 0x24,
454            Tlv::SslKeyAlg(_) => 0x25,
455            Tlv::Custom(a, _) => *a,
456        }
457    }
458
459    /// Encode this TLV field into the given buffer
460    ///
461    /// # Panics
462    /// Panics if the field is too long for its length to fit in a [`u16`].
463    pub fn encode(&self, buf: &mut Vec<u8>) {
464        let initial = buf.len();
465
466        buf.extend_from_slice(&[self.kind(), 0, 0]);
467        match self {
468            Tlv::Alpn(v) => buf.extend_from_slice(v),
469            Tlv::Authority(v) => buf.extend_from_slice(v.as_bytes()),
470            Tlv::Crc32c(v) => buf.extend_from_slice(&v.to_be_bytes()),
471            Tlv::Noop(len) => {
472                buf.resize(buf.len() + len, 0);
473            }
474            Tlv::UniqueId(v) => buf.extend_from_slice(v),
475            Tlv::Ssl(v) => {
476                buf.push(v.0);
477                buf.extend_from_slice(&v.1.to_be_bytes());
478                buf.extend_from_slice(&v.2);
479            }
480            Tlv::Netns(v) => buf.extend_from_slice(v.as_bytes()),
481            Tlv::SslVersion(v) => buf.extend_from_slice(v.as_bytes()),
482            Tlv::SslCn(v) => buf.extend_from_slice(v.as_bytes()),
483            Tlv::SslCipher(v) => buf.extend_from_slice(v.as_bytes()),
484            Tlv::SslSigAlg(v) => buf.extend_from_slice(v.as_bytes()),
485            Tlv::SslKeyAlg(v) => buf.extend_from_slice(v.as_bytes()),
486            Tlv::Custom(_, v) => buf.extend_from_slice(v),
487        }
488
489        let len = buf.len() - initial - 3;
490        if len > u16::MAX as usize {
491            panic!("TLV field too long");
492        }
493
494        buf[initial + 1] = ((len >> 8) & 0xff) as u8;
495        buf[initial + 2] = (len & 0xff) as u8;
496    }
497
498    /// Returns an owned version of this struct
499    pub fn into_owned(self) -> Tlv<'static> {
500        match self {
501            Tlv::Alpn(v) => Tlv::Alpn(Cow::Owned(v.into_owned())),
502            Tlv::Authority(v) => Tlv::Authority(Cow::Owned(v.into_owned())),
503            Tlv::Crc32c(v) => Tlv::Crc32c(v),
504            Tlv::Noop(v) => Tlv::Noop(v),
505            Tlv::UniqueId(v) => Tlv::UniqueId(Cow::Owned(v.into_owned())),
506            Tlv::Ssl(v) => Tlv::Ssl(v.into_owned()),
507            Tlv::Netns(v) => Tlv::Netns(Cow::Owned(v.into_owned())),
508            Tlv::SslVersion(v) => Tlv::SslVersion(Cow::Owned(v.into_owned())),
509            Tlv::SslCn(v) => Tlv::SslCn(Cow::Owned(v.into_owned())),
510            Tlv::SslCipher(v) => Tlv::SslCipher(Cow::Owned(v.into_owned())),
511            Tlv::SslSigAlg(v) => Tlv::SslSigAlg(Cow::Owned(v.into_owned())),
512            Tlv::SslKeyAlg(v) => Tlv::SslKeyAlg(Cow::Owned(v.into_owned())),
513            Tlv::Custom(a, v) => Tlv::Custom(a, Cow::Owned(v.into_owned())),
514        }
515    }
516}
517
518/// Configuration for parsing PROXY protocol headers
519#[derive(Debug, Copy, Clone)]
520pub struct ParseConfig {
521    /// Whether to include TLV (type-length-value) fields in the parsed header
522    ///
523    /// Even though the TLV section is parsed lazily when accessed, this can save
524    /// an allocation.
525    pub include_tlvs: bool,
526
527    /// Whether to allow V1 headers
528    pub allow_v1: bool,
529
530    /// Whether to allow V2 headers
531    pub allow_v2: bool,
532}
533
534impl Default for ParseConfig {
535    fn default() -> Self {
536        Self {
537            include_tlvs: true,
538            allow_v1: true,
539            allow_v2: true,
540        }
541    }
542}
543
544/// A PROXY protocol header
545#[derive(Default, PartialEq, Eq, Clone)]
546pub struct ProxyHeader<'a>(Option<ProxiedAddress>, Cow<'a, [u8]>);
547
548impl<'a> ProxyHeader<'a> {
549    /// Create a new PROXY protocol header (local mode)
550    pub fn with_local() -> Self {
551        Default::default()
552    }
553
554    /// Create a new PROXY protocol header (proxied mode)
555    pub fn with_address(addr: ProxiedAddress) -> Self {
556        Self(Some(addr), Cow::Owned(Vec::new()))
557    }
558
559    /// Create a new PROXY protocol header with the given TLV fields
560    ///
561    /// ```
562    /// use proxy_header::{ProxyHeader, ProxiedAddress, Tlv, Protocol, SslInfo};
563    ///
564    /// let addrs = ProxiedAddress::stream(
565    ///     "[2001:db8::1:1]:51234".parse().unwrap(),
566    ///     "[2001:db8::2:1]:443".parse().unwrap()
567    /// );
568    /// let header = ProxyHeader::with_tlvs(
569    ///    Some(addrs), [
570    ///         Tlv::Authority("example.com".into()),
571    ///         Tlv::Ssl(SslInfo::new(true, false, false, 0)),
572    ///      ]
573    /// );
574    ///
575    /// println!("{:?}", header);
576    /// ```
577    pub fn with_tlvs<'b>(
578        addr: Option<ProxiedAddress>,
579        tlvs: impl IntoIterator<Item = Tlv<'b>>,
580    ) -> Self {
581        let mut buf = Vec::with_capacity(64);
582        for tlv in tlvs {
583            tlv.encode(&mut buf);
584        }
585
586        Self(addr, Cow::Owned(buf))
587    }
588
589    /// Attempt to parse a PROXY protocol header from the given buffer
590    ///
591    /// Returns the parsed header and the number of bytes consumed from the buffer. If the header
592    /// is incomplete, returns [`Error::BufferTooShort`] so more data can be read from the socket.
593    ///
594    /// If the header is malformed or unsupported, returns [`Error::Invalid`].
595    ///
596    /// This function will borrow the buffer for the lifetime of the returned header. If
597    /// you need to keep the header around for longer than the buffer, use [`ProxyHeader::into_owned`].
598    pub fn parse(buf: &'a [u8], config: ParseConfig) -> Result<(Self, usize), Error> {
599        match buf.first() {
600            Some(b'P') if config.allow_v1 => v1::decode(buf),
601            Some(b'\r') if config.allow_v2 => v2::decode(buf, config),
602            None => Err(Error::BufferTooShort),
603            _ => Err(Error::Invalid),
604        }
605    }
606
607    /// Proxied address information
608    ///
609    /// If `None`, this indicates so-called "local" mode, where the connection is not proxied.
610    /// This is usually the case when the connection is initiated by the proxy itself, e.g. for
611    /// health checks.
612    pub fn proxied_address(&self) -> Option<&ProxiedAddress> {
613        self.0.as_ref()
614    }
615
616    /// Iterator that yields all extension TLV (type-length-value) fields present in the header
617    ///
618    /// See [`Tlv`] for more information on the different types of TLV fields.
619    pub fn tlvs(&self) -> Tlvs<'_> {
620        Tlvs { buf: &self.1 }
621    }
622
623    // Convenience accessors for common fields
624
625    /// Raw ALPN extension data
626    ///
627    /// See [`Tlv::Alpn`] for more information.
628    pub fn alpn(&self) -> Option<&[u8]> {
629        tlv_borrowed!(self, Alpn)
630    }
631
632    /// Authority - typically the hostname of the client (SNI)
633    ///
634    /// See [`Tlv::Authority`] for more information.
635    pub fn authority(&self) -> Option<&str> {
636        tlv_borrowed!(self, Authority)
637    }
638
639    /// CRC32c checksum of the address information
640    ///
641    /// See [`Tlv::Crc32c`] for more information.
642    pub fn crc32c(&self) -> Option<u32> {
643        tlv!(self, Crc32c)
644    }
645
646    /// Unique ID of the connection
647    ///
648    /// See [`Tlv::UniqueId`] for more information.
649    pub fn unique_id(&self) -> Option<&[u8]> {
650        tlv_borrowed!(self, UniqueId)
651    }
652
653    /// SSL information
654    ///
655    /// See [`Tlv::Ssl`] for more information.
656    pub fn ssl(&self) -> Option<SslInfo<'_>> {
657        tlv!(self, Ssl)
658    }
659
660    /// Network namespace
661    ///
662    /// See [`Tlv::Netns`] for more information.
663    pub fn netns(&self) -> Option<&str> {
664        tlv_borrowed!(self, Netns)
665    }
666
667    /// Returns an owned version of this struct
668    pub fn into_owned(self) -> ProxyHeader<'static> {
669        ProxyHeader(self.0, Cow::Owned(self.1.into_owned()))
670    }
671
672    /// Appends an additional TLV field
673    pub fn append_tlv(&mut self, tlv: Tlv<'_>) {
674        tlv.encode(self.1.to_mut());
675    }
676
677    /// Encode this PROXY protocol header into a [`Vec`] in version 1 format.
678    ///
679    /// Returns [`Error::V1UnsupportedTlv`] if the header contains any TLV fields and
680    /// [`Error::V1UnsupportedProtocol`] if the header contains a non-TCP protocol, as
681    /// version 1 PROXY protocol does not support either of these.
682    pub fn encode_v1(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
683        v1::encode(self, buf)
684    }
685
686    /// Encode this PROXY protocol header into a [`Vec`] in version 2 format.
687    pub fn encode_v2(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
688        v2::encode(self, buf)
689    }
690
691    /// Encode this PROXY protocol header into an existing buffer in version 1 format.
692    ///
693    /// If the buffer is too small to contain the entire header, returns [`Error::BufferTooShort`].
694    ///
695    /// See [`ProxyHeader::encode_v1`] for more information.
696    pub fn encode_to_slice_v1(&self, buf: &mut [u8]) -> Result<usize, Error> {
697        let mut cursor = std::io::Cursor::new(buf);
698        v1::encode(self, &mut cursor)?;
699
700        Ok(cursor.position() as usize)
701    }
702
703    /// Encode this PROXY protocol header into an existing buffer in version 2 format.
704    ///
705    /// If the buffer is too small to contain the entire header, returns [`Error::BufferTooShort`].
706    ///
707    /// See [`ProxyHeader::encode_v2`] for more information.
708    pub fn encode_to_slice_v2(&self, buf: &mut [u8]) -> Result<usize, Error> {
709        let mut cursor = std::io::Cursor::new(buf);
710        v2::encode(self, &mut cursor)?;
711
712        Ok(cursor.position() as usize)
713    }
714}
715
716impl fmt::Debug for ProxyHeader<'_> {
717    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
718        f.debug_struct("ProxyHeader")
719            .field("address_info", &self.proxied_address())
720            .field("fields", &self.tlvs().collect::<Vec<_>>())
721            .finish()
722    }
723}
724
725#[derive(Debug, PartialEq, Eq)]
726pub enum Error {
727    /// The buffer is too short to contain a complete PROXY protocol header
728    BufferTooShort,
729    /// The PROXY protocol header is malformed
730    Invalid,
731    /// The source and destination address families do not match
732    AddressFamilyMismatch,
733    /// The total size of the PROXY protocol header would exceed the maximum allowed size
734    HeaderTooBig,
735    /// The PROXY protocol header contains a TLV field, which is not supported in version 1
736    V1UnsupportedTlv,
737    /// The PROXY protocol header contains a non-TCP protocol, which is not supported in version 1
738    V1UnsupportedProtocol,
739}
740
741impl fmt::Display for Error {
742    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
743        use Error::*;
744        match self {
745            BufferTooShort => write!(f, "buffer too short"),
746            Invalid => write!(f, "invalid PROXY header"),
747            AddressFamilyMismatch => {
748                write!(f, "source and destination address families do not match")
749            }
750            HeaderTooBig => write!(f, "PROXY header too big"),
751            V1UnsupportedTlv => write!(f, "TLV fields are not supported in v1 header"),
752            V1UnsupportedProtocol => {
753                write!(f, "protocols other than TCP are not supported in v1 header")
754            }
755        }
756    }
757}
758
759impl std::error::Error for Error {}
760
761#[cfg(test)]
762mod tests {
763    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
764
765    use super::*;
766
767    const V1_UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
768
769    const V1_TCPV4: &[u8] = b"PROXY TCP4 127.0.0.1 192.168.0.1 12345 443\r\n";
770    const V1_TCPV6: &[u8] = b"PROXY TCP6 2001:db8::1 ::1 12345 443\r\n";
771
772    const V2_LOCAL: &[u8] =
773        b"\r\n\r\n\0\r\nQUIT\n \0\0\x0f\x03\0\x04\x88\x9d\xa1\xdf \0\x05\0\0\0\0\0";
774
775    const V2_TCPV4: &[u8] = &[
776        13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 12, 127, 0, 0, 1, 192, 168, 0, 1,
777        48, 57, 1, 187,
778    ];
779    const V2_TCPV6: &[u8] = &[
780        13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 33, 0, 36, 32, 1, 13, 184, 0, 0, 0, 0,
781        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 48, 57, 1, 187,
782    ];
783    const V2_TCPV4_TLV: &[u8] = &[
784        13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 104, 127, 0, 0, 1, 192, 168, 0,
785        1, 48, 57, 1, 187, 3, 0, 4, 211, 153, 216, 216, 5, 0, 4, 49, 50, 51, 52, 32, 0, 75, 7, 0,
786        0, 0, 0, 33, 0, 7, 84, 76, 83, 118, 49, 46, 51, 34, 0, 9, 108, 111, 99, 97, 108, 104, 111,
787        115, 116, 37, 0, 7, 82, 83, 65, 52, 48, 57, 54, 36, 0, 10, 82, 83, 65, 45, 83, 72, 65, 50,
788        53, 54, 35, 0, 22, 84, 76, 83, 95, 65, 69, 83, 95, 50, 53, 54, 95, 71, 67, 77, 95, 83, 72,
789        65, 51, 56, 52,
790    ];
791
792    #[test]
793    fn test_parse_proxy_header_too_short() {
794        for case in [
795            V1_TCPV4,
796            V1_TCPV6,
797            V1_UNKNOWN,
798            V2_TCPV4,
799            V2_TCPV6,
800            V2_TCPV4_TLV,
801            V2_LOCAL,
802        ]
803        .iter()
804        {
805            for i in 0..case.len() {
806                assert!(matches!(
807                    ProxyHeader::parse(&case[..i], Default::default()),
808                    Err(Error::BufferTooShort)
809                ));
810            }
811
812            assert!(matches!(
813                ProxyHeader::parse(case, Default::default()),
814                Ok(_)
815            ));
816        }
817    }
818
819    #[test]
820    fn test_parse_proxy_header_v1_unterminated() {
821        let line = b"PROXY TCP4 THISISSTORYALLABOUTHOWMYLIFEGOTFLIPPEDTURNEDUPSIDEDOWNANDIDLIKETOTAKEAMINUTEJUSTSITRIGHTTHEREANDILLTELLYOUHOWIGOTTHEPRINCEOFAIR\r\n";
822        assert!(matches!(
823            ProxyHeader::parse(line, Default::default()),
824            Err(Error::Invalid)
825        ));
826    }
827
828    #[test]
829    fn test_parse_proxy_header_v1() {
830        let (res, consumed) = ProxyHeader::parse(V1_TCPV4, Default::default()).unwrap();
831        assert_eq!(consumed, V1_TCPV4.len());
832        assert_eq!(
833            res.0,
834            Some(ProxiedAddress {
835                protocol: Protocol::Stream,
836                source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
837                destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
838            })
839        );
840        assert_eq!(res.1, vec![0; 0]);
841
842        let (res, consumed) = ProxyHeader::parse(V1_TCPV6, Default::default()).unwrap();
843
844        assert_eq!(consumed, V1_TCPV6.len());
845        assert_eq!(
846            res.0,
847            Some(ProxiedAddress {
848                protocol: Protocol::Stream,
849                source: SocketAddr::new(
850                    IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
851                    12345
852                ),
853                destination: SocketAddr::new(
854                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
855                    443
856                ),
857            })
858        );
859        assert_eq!(res.1, vec![0; 0]);
860    }
861
862    #[test]
863    fn test_parse_proxy_header_v2() {
864        let (res, consumed) = ProxyHeader::parse(V2_LOCAL, Default::default()).unwrap();
865        assert_eq!(consumed, V2_LOCAL.len());
866        assert_eq!(res.0, None);
867
868        let (res, consumed) = ProxyHeader::parse(V2_TCPV4, Default::default()).unwrap();
869        assert_eq!(consumed, V2_TCPV4.len());
870        assert_eq!(
871            res.0,
872            Some(ProxiedAddress {
873                protocol: Protocol::Stream,
874                source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
875                destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
876            })
877        );
878
879        let (res, consumed) = ProxyHeader::parse(V2_TCPV6, Default::default()).unwrap();
880        assert_eq!(consumed, V2_TCPV6.len());
881        assert_eq!(
882            res.0,
883            Some(ProxiedAddress {
884                protocol: Protocol::Stream,
885                source: SocketAddr::new(
886                    IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
887                    12345
888                ),
889                destination: SocketAddr::new(
890                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
891                    443
892                ),
893            })
894        );
895    }
896
897    #[test]
898    fn test_parse_proxy_header_with_tlvs() {
899        let (res, _) = ProxyHeader::parse(
900            V2_TCPV4_TLV,
901            ParseConfig {
902                include_tlvs: true,
903                ..Default::default()
904            },
905        )
906        .unwrap();
907
908        use Tlv::*;
909
910        let mut fields = res.tlvs();
911
912        assert_eq!(fields.next(), Some(Ok(Crc32c(0xd399d8d8))));
913        assert_eq!(fields.next(), Some(Ok(UniqueId(b"1234"[..].into()))));
914
915        let ssl = fields.next().unwrap().unwrap();
916        let ssl = match ssl {
917            Tlv::Ssl(ssl) => ssl,
918            _ => panic!("expected SSL TLV"),
919        };
920
921        assert!(ssl.verify() == 0);
922        assert!(ssl.client_ssl());
923        assert!(ssl.client_cert_conn());
924        assert!(ssl.client_cert_sess());
925
926        let mut f = ssl.tlvs();
927
928        assert_eq!(f.next(), Some(Ok(SslVersion("TLSv1.3".into()))));
929        assert_eq!(f.next(), Some(Ok(SslCn("localhost".into()))));
930        assert_eq!(f.next(), Some(Ok(SslKeyAlg("RSA4096".into()))));
931        assert_eq!(f.next(), Some(Ok(SslSigAlg("RSA-SHA256".into()))));
932        assert_eq!(
933            f.next(),
934            Some(Ok(SslCipher("TLS_AES_256_GCM_SHA384".into())))
935        );
936        assert!(f.next().is_none());
937
938        assert!(fields.next().is_none());
939    }
940}