proxy_header/
v1.rs

1use std::io::Write;
2use std::net::SocketAddr;
3use std::str::from_utf8;
4use std::{
5    net::{Ipv4Addr, Ipv6Addr},
6    str::FromStr,
7};
8
9use crate::util::{read_until, AddressFamily};
10use crate::{
11    Error::{self, *},
12    Protocol, ProxiedAddress, ProxyHeader,
13};
14
15const MAX_LENGTH: usize = 107;
16const GREETING: &[u8] = b"PROXY";
17
18fn parse_addr<T: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<T, Error> {
19    let Some(address) = read_until(&buf[*pos..], b' ') else {
20        return Err(BufferTooShort);
21    };
22
23    let addr = from_utf8(address)
24        .map_err(|_| Invalid)
25        .and_then(|s| T::from_str(s).map_err(|_| Invalid))?;
26    *pos += address.len() + 1;
27
28    Ok(addr)
29}
30
31fn parse_port(buf: &[u8], pos: &mut usize, terminator: u8) -> Result<u16, Error> {
32    let Some(port) = read_until(&buf[*pos..], terminator) else {
33        return Err(BufferTooShort);
34    };
35
36    let p = from_utf8(port)
37        .map_err(|_| Invalid)
38        .and_then(|s| u16::from_str(s).map_err(|_| Invalid))?;
39    *pos += port.len() + 1;
40
41    Ok(p)
42}
43
44fn parse_addrs<T: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<ProxiedAddress, Error> {
45    let src_addr: T = parse_addr(buf, pos)?;
46    let dst_addr: T = parse_addr(buf, pos)?;
47    let src_port = parse_port(buf, pos, b' ')?;
48    let dst_port = parse_port(buf, pos, b'\r')?;
49
50    Ok(ProxiedAddress {
51        protocol: Protocol::Stream, // v1 header only supports TCP
52        source: SocketAddr::new(src_addr.to_ip_addr(), src_port),
53        destination: SocketAddr::new(dst_addr.to_ip_addr(), dst_port),
54    })
55}
56
57fn decode_inner(buf: &[u8]) -> Result<(ProxyHeader, usize), Error> {
58    let mut pos = 0;
59
60    if buf.len() < b"PROXY UNKNOWN\r\n".len() {
61        // All other valid PROXY headers are longer than this.
62        return Err(BufferTooShort);
63    }
64    if !buf.starts_with(GREETING) {
65        return Err(Invalid);
66    }
67    pos += GREETING.len() + 1;
68
69    let addrs = if buf[pos..].starts_with(b"UNKNOWN") {
70        let Some(rest) = read_until(&buf[pos..], b'\r') else {
71            return Err(BufferTooShort);
72        };
73        pos += rest.len() + 1;
74
75        None
76    } else {
77        let proto = &buf[pos..pos + 5];
78        pos += 5;
79
80        match proto {
81            b"TCP4 " => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos)?),
82            b"TCP6 " => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos)?),
83            _ => return Err(Invalid),
84        }
85    };
86
87    match buf.get(pos) {
88        Some(b'\n') => pos += 1,
89        None => return Err(BufferTooShort),
90        _ => return Err(Invalid),
91    }
92
93    Ok((ProxyHeader(addrs, Default::default()), pos))
94}
95
96/// Decode a version 1 PROXY header from a buffer.
97///
98/// Returns the decoded header and the number of bytes consumed from the buffer.
99pub fn decode(buf: &[u8]) -> Result<(ProxyHeader, usize), Error> {
100    // Guard against a malicious client sending a very long header, since it is a
101    // delimited protocol.
102
103    match decode_inner(buf) {
104        Err(Error::BufferTooShort) if buf.len() >= MAX_LENGTH => Err(Error::Invalid),
105        other => other,
106    }
107}
108
109pub fn encode<W: Write>(header: &ProxyHeader, writer: &mut W) -> Result<(), Error> {
110    if !header.1.is_empty() {
111        return Err(V1UnsupportedTlv);
112    }
113    writer.write_all(GREETING).map_err(|_| BufferTooShort)?;
114    writer.write_all(b" ").map_err(|_| BufferTooShort)?;
115
116    match header.0 {
117        Some(ProxiedAddress {
118            protocol: Protocol::Stream,
119            source,
120            destination,
121        }) => match (source, destination) {
122            (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
123                write!(
124                    writer,
125                    "TCP4 {} {} {} {}\r\n",
126                    src.ip(),
127                    dst.ip(),
128                    src.port(),
129                    dst.port()
130                )
131                .map_err(|_| BufferTooShort)?;
132            }
133            (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
134                write!(
135                    writer,
136                    "TCP6 {} {} {} {}\r\n",
137                    src.ip(),
138                    dst.ip(),
139                    src.port(),
140                    dst.port()
141                )
142                .map_err(|_| BufferTooShort)?;
143            }
144            _ => return Err(AddressFamilyMismatch),
145        },
146        None => {
147            writer
148                .write_all(b"UNKNOWN\r\n")
149                .map_err(|_| BufferTooShort)?;
150        }
151        _ => return Err(V1UnsupportedProtocol),
152    }
153
154    Ok(())
155}
156
157#[cfg(test)]
158mod tests {
159    use std::net::{SocketAddrV4, SocketAddrV6};
160
161    use super::*;
162
163    #[test]
164    fn test_encode_local() {
165        let mut buf = [0u8; 1024];
166        let header = ProxyHeader::with_local();
167
168        let len = header.encode_to_slice_v1(&mut buf).unwrap();
169        assert_eq!(&buf[..len], b"PROXY UNKNOWN\r\n");
170
171        let decoded = decode(&buf).unwrap();
172        assert_eq!(decoded.0, header);
173        assert_eq!(decoded.1, len);
174    }
175
176    #[test]
177    fn test_encode_ipv4() {
178        let mut buf = [0u8; 1024];
179        let header = ProxyHeader::with_address(ProxiedAddress {
180            protocol: Protocol::Stream,
181            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
182            destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
183        });
184
185        let len = header.encode_to_slice_v1(&mut buf).unwrap();
186        assert_eq!(&buf[..len], b"PROXY TCP4 127.0.0.1 8.8.4.4 1234 5678\r\n");
187
188        let decoded = decode(&buf).unwrap();
189        assert_eq!(decoded.0, header);
190        assert_eq!(decoded.1, len);
191    }
192
193    #[test]
194    fn test_encode_ipv6() {
195        let mut buf = [0u8; 1024];
196        let header = ProxyHeader::with_address(ProxiedAddress {
197            protocol: Protocol::Stream,
198            source: SocketAddr::V6(SocketAddrV6::new(
199                Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
200                1234,
201                0,
202                0,
203            )),
204            destination: SocketAddr::V6(SocketAddrV6::new(
205                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
206                5678,
207                0,
208                0,
209            )),
210        });
211
212        let len = header.encode_to_slice_v1(&mut buf).unwrap();
213        assert_eq!(&buf[..len], b"PROXY TCP6 2001:db8::1 ::1 1234 5678\r\n");
214
215        let decoded = decode(&buf).unwrap();
216        assert_eq!(decoded.0, header);
217        assert_eq!(decoded.1, len);
218    }
219
220    #[test]
221    fn test_tlvs() {
222        let mut buf = [0u8; 1024];
223        let mut header = ProxyHeader::with_local();
224        header.append_tlv(crate::Tlv::Noop(10));
225
226        assert_eq!(header.encode_to_slice_v1(&mut buf), Err(V1UnsupportedTlv));
227    }
228
229    #[test]
230    fn test_family_mismatch() {
231        let mut buf = [0u8; 1024];
232        let header = ProxyHeader::with_address(ProxiedAddress {
233            protocol: Protocol::Stream,
234            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
235            destination: SocketAddr::V6(SocketAddrV6::new(
236                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
237                5678,
238                0,
239                0,
240            )),
241        });
242
243        assert_eq!(
244            header.encode_to_slice_v1(&mut buf),
245            Err(AddressFamilyMismatch)
246        );
247    }
248
249    #[test]
250    fn test_buffer_too_short() {
251        let mut buf = [0u8; 1024];
252        let header = ProxyHeader::with_address(ProxiedAddress {
253            protocol: Protocol::Stream,
254            source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
255            destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
256        });
257
258        assert_eq!(
259            header.encode_to_slice_v1(&mut buf[0..10]),
260            Err(BufferTooShort)
261        );
262    }
263}