proxy_header/
v2.rs

1use std::borrow::Cow;
2use std::io::Write;
3use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
4
5use crate::util::AddressFamily;
6use crate::ParseConfig;
7use crate::{
8    Error::{self, *},
9    Protocol, ProxiedAddress, ProxyHeader,
10};
11
12const GREETING: &[u8] = b"\r\n\r\n\x00\r\nQUIT\n";
13const AF_UNIX_ADDRS_LEN: usize = 216;
14
15fn parse_addrs<T: AddressFamily>(
16    buf: &[u8],
17    pos: &mut usize,
18    rest: &mut usize,
19    protocol: Protocol,
20) -> Result<ProxiedAddress, Error> {
21    if buf.len() < *pos + T::BYTES * 2 + 4 {
22        return Err(BufferTooShort);
23    }
24    if *rest < T::BYTES * 2 + 4 {
25        return Err(Invalid);
26    }
27
28    let ret = ProxiedAddress {
29        protocol,
30        source: SocketAddr::new(
31            T::from_slice(&buf[*pos..*pos + T::BYTES]).to_ip_addr(),
32            u16::from_be_bytes([buf[*pos + T::BYTES * 2], buf[*pos + T::BYTES * 2 + 1]]),
33        ),
34        destination: SocketAddr::new(
35            T::from_slice(&buf[*pos + T::BYTES..*pos + T::BYTES * 2]).to_ip_addr(),
36            u16::from_be_bytes([buf[*pos + T::BYTES * 2 + 2], buf[*pos + T::BYTES * 2 + 3]]),
37        ),
38    };
39
40    *rest -= T::BYTES * 2 + 4;
41    *pos += T::BYTES * 2 + 4;
42
43    Ok(ret)
44}
45
46/// Decode a version 2 PROXY header from a buffer.
47///
48/// Returns the decoded header and the number of bytes consumed from the buffer.
49pub fn decode(buf: &[u8], config: ParseConfig) -> Result<(ProxyHeader, usize), Error> {
50    let mut pos = 0;
51
52    if buf.len() < 4 + GREETING.len() {
53        return Err(BufferTooShort);
54    }
55    if !buf.starts_with(GREETING) {
56        return Err(Invalid);
57    }
58    pos += GREETING.len();
59
60    let is_local = match buf[pos] {
61        0x20 => true,
62        0x21 => false,
63        _ => return Err(Invalid),
64    };
65    let protocol = buf[pos + 1];
66    let mut rest = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]) as usize;
67    pos += 4;
68
69    if buf.len() < pos + rest {
70        return Err(BufferTooShort);
71    }
72
73    use Protocol::{Datagram, Stream};
74    let addr_info = match protocol {
75        0x00 => None,
76        0x11 => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos, &mut rest, Stream)?),
77        0x12 => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos, &mut rest, Datagram)?),
78        0x21 => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos, &mut rest, Stream)?),
79        0x22 => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos, &mut rest, Datagram)?),
80        0x31 | 0x32 => {
81            // AF_UNIX - we do not parse this, but don't reject it either in case
82            // someone needs the TLVs
83
84            if rest < AF_UNIX_ADDRS_LEN {
85                return Err(Invalid);
86            }
87            rest -= AF_UNIX_ADDRS_LEN;
88            pos += AF_UNIX_ADDRS_LEN;
89
90            None
91        }
92        _ => return Err(Invalid),
93    };
94
95    let tlv_data = if config.include_tlvs {
96        Cow::Borrowed(&buf[pos..pos + rest])
97    } else {
98        Default::default()
99    };
100
101    pos += rest;
102
103    let header = if is_local {
104        ProxyHeader(None, tlv_data)
105    } else {
106        ProxyHeader(addr_info, tlv_data)
107    };
108
109    Ok((header, pos))
110}
111
112pub fn encode<W: Write>(header: &ProxyHeader, buf: &mut W) -> Result<(), Error> {
113    buf.write_all(GREETING).map_err(|_| BufferTooShort)?;
114
115    match &header.0 {
116        Some(ProxiedAddress {
117            protocol,
118            source: SocketAddr::V4(src),
119            destination: SocketAddr::V4(dest),
120        }) => {
121            buf.write_all(b"\x21").map_err(|_| BufferTooShort)?;
122            match protocol {
123                Protocol::Stream => buf.write_all(b"\x11").map_err(|_| BufferTooShort)?,
124                Protocol::Datagram => buf.write_all(b"\x12").map_err(|_| BufferTooShort)?,
125            }
126
127            let len: u16 = (4 + 4 + 2 + 2 + header.1.len())
128                .try_into()
129                .map_err(|_| HeaderTooBig)?;
130            buf.write_all(&len.to_be_bytes())
131                .map_err(|_| BufferTooShort)?;
132
133            buf.write_all(&src.ip().octets())
134                .map_err(|_| BufferTooShort)?;
135            buf.write_all(&dest.ip().octets())
136                .map_err(|_| BufferTooShort)?;
137            buf.write_all(&src.port().to_be_bytes())
138                .map_err(|_| BufferTooShort)?;
139            buf.write_all(&dest.port().to_be_bytes())
140                .map_err(|_| BufferTooShort)?;
141        }
142        Some(ProxiedAddress {
143            protocol,
144            source: SocketAddr::V6(src),
145            destination: SocketAddr::V6(dest),
146        }) => {
147            buf.write_all(b"\x21").map_err(|_| BufferTooShort)?;
148            match protocol {
149                Protocol::Stream => buf.write_all(b"\x21").map_err(|_| BufferTooShort)?,
150                Protocol::Datagram => buf.write_all(b"\x22").map_err(|_| BufferTooShort)?,
151            }
152
153            let len: u16 = (16 + 16 + 2 + 2 + header.1.len())
154                .try_into()
155                .map_err(|_| HeaderTooBig)?;
156            buf.write_all(&len.to_be_bytes())
157                .map_err(|_| BufferTooShort)?;
158
159            buf.write_all(&src.ip().octets())
160                .map_err(|_| BufferTooShort)?;
161            buf.write_all(&dest.ip().octets())
162                .map_err(|_| BufferTooShort)?;
163            buf.write_all(&src.port().to_be_bytes())
164                .map_err(|_| BufferTooShort)?;
165            buf.write_all(&dest.port().to_be_bytes())
166                .map_err(|_| BufferTooShort)?;
167        }
168        None => {
169            buf.write_all(b"\x20\x00").map_err(|_| BufferTooShort)?;
170
171            let len: u16 = header.1.len().try_into().map_err(|_| HeaderTooBig)?;
172            buf.write_all(&len.to_be_bytes())
173                .map_err(|_| BufferTooShort)?;
174        }
175        _ => return Err(AddressFamilyMismatch),
176    }
177
178    buf.write_all(&header.1).map_err(|_| BufferTooShort)?;
179    Ok(())
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use std::net::{SocketAddrV4, SocketAddrV6};
186
187    #[test]
188    fn test_encode_local() {
189        let mut buf = [0u8; 1024];
190        let header = ProxyHeader::with_local();
191
192        let len = header.encode_to_slice_v2(&mut buf).unwrap();
193        assert_eq!(&buf[..len], b"\r\n\r\n\x00\r\nQUIT\n\x20\x00\x00\x00");
194
195        let decoded = decode(&buf, ParseConfig::default()).unwrap();
196        assert_eq!(decoded.0, header);
197        assert_eq!(decoded.1, len);
198    }
199
200    #[test]
201    fn test_encode_ipv4() {
202        let mut buf = [0u8; 102400];
203        let header = ProxyHeader::with_address(ProxiedAddress {
204            protocol: Protocol::Stream,
205            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
206            destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 5678)),
207        });
208
209        let len = header.encode_to_slice_v2(&mut buf).unwrap();
210        assert_eq!(
211            &buf[..len],
212            b"\r\n\r\n\x00\r\nQUIT\n!\x11\x00\x0c\x7f\x00\x00\x01\x7f\x00\x00\x01\x04\xd2\x16."
213        );
214
215        let decoded = decode(&buf, ParseConfig::default()).unwrap();
216        assert_eq!(decoded.0, header);
217        assert_eq!(decoded.1, len);
218    }
219
220    #[test]
221    fn test_encode_ipv6() {
222        let mut buf = [0u8; 102400];
223        let header = ProxyHeader::with_address(ProxiedAddress {
224            protocol: Protocol::Datagram,
225            source: SocketAddr::V6(SocketAddrV6::new(
226                Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
227                1234,
228                0,
229                0,
230            )),
231            destination: SocketAddr::V6(SocketAddrV6::new(
232                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
233                5678,
234                0,
235                0,
236            )),
237        });
238
239        let len = header.encode_to_slice_v2(&mut buf).unwrap();
240        assert_eq!(
241            &buf[..len],
242            &[
243                13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 34, 0, 36, 32, 1, 13, 184, 0, 0,
244                0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4,
245                210, 22, 46
246            ]
247        );
248
249        let decoded = decode(&buf, ParseConfig::default()).unwrap();
250        assert_eq!(decoded.0, header);
251        assert_eq!(decoded.1, len);
252    }
253
254    #[test]
255    fn test_tlvs() {
256        let mut buf = [0u8; 102400];
257        let mut header = ProxyHeader::with_local();
258        header.append_tlv(crate::Tlv::UniqueId(b"unique"[..].into()));
259        header.append_tlv(crate::Tlv::Crc32c(1234));
260
261        let len = header.encode_to_slice_v2(&mut buf).unwrap();
262
263        let decoded = decode(
264            &buf,
265            ParseConfig {
266                include_tlvs: true,
267                ..Default::default()
268            },
269        )
270        .unwrap();
271
272        assert_eq!(decoded.0, header);
273        assert_eq!(decoded.1, len);
274
275        assert_eq!(decoded.0.unique_id(), Some(&b"unique"[..]));
276        assert_eq!(decoded.0.crc32c(), Some(1234));
277    }
278
279    #[test]
280    fn test_family_mismatch() {
281        let mut buf = [0u8; 1024];
282        let header = ProxyHeader::with_address(ProxiedAddress {
283            protocol: Protocol::Stream,
284            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
285            destination: SocketAddr::V6(SocketAddrV6::new(
286                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
287                5678,
288                0,
289                0,
290            )),
291        });
292
293        assert_eq!(
294            header.encode_to_slice_v2(&mut buf),
295            Err(AddressFamilyMismatch)
296        );
297    }
298}