1use std::io::Write;
2use std::net::SocketAddr;
3use std::str::from_utf8;
4use std::{
5 net::{Ipv4Addr, Ipv6Addr},
6 str::FromStr,
7};
8
9use crate::util::{read_until, AddressFamily};
10use crate::{
11 Error::{self, *},
12 Protocol, ProxiedAddress, ProxyHeader,
13};
14
15const MAX_LENGTH: usize = 107;
16const GREETING: &[u8] = b"PROXY";
17
18fn parse_addr<T: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<T, Error> {
19 let Some(address) = read_until(&buf[*pos..], b' ') else {
20 return Err(BufferTooShort);
21 };
22
23 let addr = from_utf8(address)
24 .map_err(|_| Invalid)
25 .and_then(|s| T::from_str(s).map_err(|_| Invalid))?;
26 *pos += address.len() + 1;
27
28 Ok(addr)
29}
30
31fn parse_port(buf: &[u8], pos: &mut usize, terminator: u8) -> Result<u16, Error> {
32 let Some(port) = read_until(&buf[*pos..], terminator) else {
33 return Err(BufferTooShort);
34 };
35
36 let p = from_utf8(port)
37 .map_err(|_| Invalid)
38 .and_then(|s| u16::from_str(s).map_err(|_| Invalid))?;
39 *pos += port.len() + 1;
40
41 Ok(p)
42}
43
44fn parse_addrs<T: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<ProxiedAddress, Error> {
45 let src_addr: T = parse_addr(buf, pos)?;
46 let dst_addr: T = parse_addr(buf, pos)?;
47 let src_port = parse_port(buf, pos, b' ')?;
48 let dst_port = parse_port(buf, pos, b'\r')?;
49
50 Ok(ProxiedAddress {
51 protocol: Protocol::Stream, source: SocketAddr::new(src_addr.to_ip_addr(), src_port),
53 destination: SocketAddr::new(dst_addr.to_ip_addr(), dst_port),
54 })
55}
56
57fn decode_inner(buf: &[u8]) -> Result<(ProxyHeader, usize), Error> {
58 let mut pos = 0;
59
60 if buf.len() < b"PROXY UNKNOWN\r\n".len() {
61 return Err(BufferTooShort);
63 }
64 if !buf.starts_with(GREETING) {
65 return Err(Invalid);
66 }
67 pos += GREETING.len() + 1;
68
69 let addrs = if buf[pos..].starts_with(b"UNKNOWN") {
70 let Some(rest) = read_until(&buf[pos..], b'\r') else {
71 return Err(BufferTooShort);
72 };
73 pos += rest.len() + 1;
74
75 None
76 } else {
77 let proto = &buf[pos..pos + 5];
78 pos += 5;
79
80 match proto {
81 b"TCP4 " => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos)?),
82 b"TCP6 " => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos)?),
83 _ => return Err(Invalid),
84 }
85 };
86
87 match buf.get(pos) {
88 Some(b'\n') => pos += 1,
89 None => return Err(BufferTooShort),
90 _ => return Err(Invalid),
91 }
92
93 Ok((ProxyHeader(addrs, Default::default()), pos))
94}
95
96pub fn decode(buf: &[u8]) -> Result<(ProxyHeader, usize), Error> {
100 match decode_inner(buf) {
104 Err(Error::BufferTooShort) if buf.len() >= MAX_LENGTH => Err(Error::Invalid),
105 other => other,
106 }
107}
108
109pub fn encode<W: Write>(header: &ProxyHeader, writer: &mut W) -> Result<(), Error> {
110 if !header.1.is_empty() {
111 return Err(V1UnsupportedTlv);
112 }
113 writer.write_all(GREETING).map_err(|_| BufferTooShort)?;
114 writer.write_all(b" ").map_err(|_| BufferTooShort)?;
115
116 match header.0 {
117 Some(ProxiedAddress {
118 protocol: Protocol::Stream,
119 source,
120 destination,
121 }) => match (source, destination) {
122 (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
123 write!(
124 writer,
125 "TCP4 {} {} {} {}\r\n",
126 src.ip(),
127 dst.ip(),
128 src.port(),
129 dst.port()
130 )
131 .map_err(|_| BufferTooShort)?;
132 }
133 (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
134 write!(
135 writer,
136 "TCP6 {} {} {} {}\r\n",
137 src.ip(),
138 dst.ip(),
139 src.port(),
140 dst.port()
141 )
142 .map_err(|_| BufferTooShort)?;
143 }
144 _ => return Err(AddressFamilyMismatch),
145 },
146 None => {
147 writer
148 .write_all(b"UNKNOWN\r\n")
149 .map_err(|_| BufferTooShort)?;
150 }
151 _ => return Err(V1UnsupportedProtocol),
152 }
153
154 Ok(())
155}
156
157#[cfg(test)]
158mod tests {
159 use std::net::{SocketAddrV4, SocketAddrV6};
160
161 use super::*;
162
163 #[test]
164 fn test_encode_local() {
165 let mut buf = [0u8; 1024];
166 let header = ProxyHeader::with_local();
167
168 let len = header.encode_to_slice_v1(&mut buf).unwrap();
169 assert_eq!(&buf[..len], b"PROXY UNKNOWN\r\n");
170
171 let decoded = decode(&buf).unwrap();
172 assert_eq!(decoded.0, header);
173 assert_eq!(decoded.1, len);
174 }
175
176 #[test]
177 fn test_encode_ipv4() {
178 let mut buf = [0u8; 1024];
179 let header = ProxyHeader::with_address(ProxiedAddress {
180 protocol: Protocol::Stream,
181 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
182 destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
183 });
184
185 let len = header.encode_to_slice_v1(&mut buf).unwrap();
186 assert_eq!(&buf[..len], b"PROXY TCP4 127.0.0.1 8.8.4.4 1234 5678\r\n");
187
188 let decoded = decode(&buf).unwrap();
189 assert_eq!(decoded.0, header);
190 assert_eq!(decoded.1, len);
191 }
192
193 #[test]
194 fn test_encode_ipv6() {
195 let mut buf = [0u8; 1024];
196 let header = ProxyHeader::with_address(ProxiedAddress {
197 protocol: Protocol::Stream,
198 source: SocketAddr::V6(SocketAddrV6::new(
199 Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
200 1234,
201 0,
202 0,
203 )),
204 destination: SocketAddr::V6(SocketAddrV6::new(
205 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
206 5678,
207 0,
208 0,
209 )),
210 });
211
212 let len = header.encode_to_slice_v1(&mut buf).unwrap();
213 assert_eq!(&buf[..len], b"PROXY TCP6 2001:db8::1 ::1 1234 5678\r\n");
214
215 let decoded = decode(&buf).unwrap();
216 assert_eq!(decoded.0, header);
217 assert_eq!(decoded.1, len);
218 }
219
220 #[test]
221 fn test_tlvs() {
222 let mut buf = [0u8; 1024];
223 let mut header = ProxyHeader::with_local();
224 header.append_tlv(crate::Tlv::Noop(10));
225
226 assert_eq!(header.encode_to_slice_v1(&mut buf), Err(V1UnsupportedTlv));
227 }
228
229 #[test]
230 fn test_family_mismatch() {
231 let mut buf = [0u8; 1024];
232 let header = ProxyHeader::with_address(ProxiedAddress {
233 protocol: Protocol::Stream,
234 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
235 destination: SocketAddr::V6(SocketAddrV6::new(
236 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
237 5678,
238 0,
239 0,
240 )),
241 });
242
243 assert_eq!(
244 header.encode_to_slice_v1(&mut buf),
245 Err(AddressFamilyMismatch)
246 );
247 }
248
249 #[test]
250 fn test_buffer_too_short() {
251 let mut buf = [0u8; 1024];
252 let header = ProxyHeader::with_address(ProxiedAddress {
253 protocol: Protocol::Stream,
254 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
255 destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
256 });
257
258 assert_eq!(
259 header.encode_to_slice_v1(&mut buf[0..10]),
260 Err(BufferTooShort)
261 );
262 }
263}