1use std::borrow::Cow;
2use std::io::Write;
3use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
4
5use crate::util::AddressFamily;
6use crate::ParseConfig;
7use crate::{
8 Error::{self, *},
9 Protocol, ProxiedAddress, ProxyHeader,
10};
11
12const GREETING: &[u8] = b"\r\n\r\n\x00\r\nQUIT\n";
13const AF_UNIX_ADDRS_LEN: usize = 216;
14
15fn parse_addrs<T: AddressFamily>(
16 buf: &[u8],
17 pos: &mut usize,
18 rest: &mut usize,
19 protocol: Protocol,
20) -> Result<ProxiedAddress, Error> {
21 if buf.len() < *pos + T::BYTES * 2 + 4 {
22 return Err(BufferTooShort);
23 }
24 if *rest < T::BYTES * 2 + 4 {
25 return Err(Invalid);
26 }
27
28 let ret = ProxiedAddress {
29 protocol,
30 source: SocketAddr::new(
31 T::from_slice(&buf[*pos..*pos + T::BYTES]).to_ip_addr(),
32 u16::from_be_bytes([buf[*pos + T::BYTES * 2], buf[*pos + T::BYTES * 2 + 1]]),
33 ),
34 destination: SocketAddr::new(
35 T::from_slice(&buf[*pos + T::BYTES..*pos + T::BYTES * 2]).to_ip_addr(),
36 u16::from_be_bytes([buf[*pos + T::BYTES * 2 + 2], buf[*pos + T::BYTES * 2 + 3]]),
37 ),
38 };
39
40 *rest -= T::BYTES * 2 + 4;
41 *pos += T::BYTES * 2 + 4;
42
43 Ok(ret)
44}
45
46pub fn decode(buf: &[u8], config: ParseConfig) -> Result<(ProxyHeader, usize), Error> {
50 let mut pos = 0;
51
52 if buf.len() < 4 + GREETING.len() {
53 return Err(BufferTooShort);
54 }
55 if !buf.starts_with(GREETING) {
56 return Err(Invalid);
57 }
58 pos += GREETING.len();
59
60 let is_local = match buf[pos] {
61 0x20 => true,
62 0x21 => false,
63 _ => return Err(Invalid),
64 };
65 let protocol = buf[pos + 1];
66 let mut rest = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]) as usize;
67 pos += 4;
68
69 if buf.len() < pos + rest {
70 return Err(BufferTooShort);
71 }
72
73 use Protocol::{Datagram, Stream};
74 let addr_info = match protocol {
75 0x00 => None,
76 0x11 => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos, &mut rest, Stream)?),
77 0x12 => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos, &mut rest, Datagram)?),
78 0x21 => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos, &mut rest, Stream)?),
79 0x22 => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos, &mut rest, Datagram)?),
80 0x31 | 0x32 => {
81 if rest < AF_UNIX_ADDRS_LEN {
85 return Err(Invalid);
86 }
87 rest -= AF_UNIX_ADDRS_LEN;
88 pos += AF_UNIX_ADDRS_LEN;
89
90 None
91 }
92 _ => return Err(Invalid),
93 };
94
95 let tlv_data = if config.include_tlvs {
96 Cow::Borrowed(&buf[pos..pos + rest])
97 } else {
98 Default::default()
99 };
100
101 pos += rest;
102
103 let header = if is_local {
104 ProxyHeader(None, tlv_data)
105 } else {
106 ProxyHeader(addr_info, tlv_data)
107 };
108
109 Ok((header, pos))
110}
111
112pub fn encode<W: Write>(header: &ProxyHeader, buf: &mut W) -> Result<(), Error> {
113 buf.write_all(GREETING).map_err(|_| BufferTooShort)?;
114
115 match &header.0 {
116 Some(ProxiedAddress {
117 protocol,
118 source: SocketAddr::V4(src),
119 destination: SocketAddr::V4(dest),
120 }) => {
121 buf.write_all(b"\x21").map_err(|_| BufferTooShort)?;
122 match protocol {
123 Protocol::Stream => buf.write_all(b"\x11").map_err(|_| BufferTooShort)?,
124 Protocol::Datagram => buf.write_all(b"\x12").map_err(|_| BufferTooShort)?,
125 }
126
127 let len: u16 = (4 + 4 + 2 + 2 + header.1.len())
128 .try_into()
129 .map_err(|_| HeaderTooBig)?;
130 buf.write_all(&len.to_be_bytes())
131 .map_err(|_| BufferTooShort)?;
132
133 buf.write_all(&src.ip().octets())
134 .map_err(|_| BufferTooShort)?;
135 buf.write_all(&dest.ip().octets())
136 .map_err(|_| BufferTooShort)?;
137 buf.write_all(&src.port().to_be_bytes())
138 .map_err(|_| BufferTooShort)?;
139 buf.write_all(&dest.port().to_be_bytes())
140 .map_err(|_| BufferTooShort)?;
141 }
142 Some(ProxiedAddress {
143 protocol,
144 source: SocketAddr::V6(src),
145 destination: SocketAddr::V6(dest),
146 }) => {
147 buf.write_all(b"\x21").map_err(|_| BufferTooShort)?;
148 match protocol {
149 Protocol::Stream => buf.write_all(b"\x21").map_err(|_| BufferTooShort)?,
150 Protocol::Datagram => buf.write_all(b"\x22").map_err(|_| BufferTooShort)?,
151 }
152
153 let len: u16 = (16 + 16 + 2 + 2 + header.1.len())
154 .try_into()
155 .map_err(|_| HeaderTooBig)?;
156 buf.write_all(&len.to_be_bytes())
157 .map_err(|_| BufferTooShort)?;
158
159 buf.write_all(&src.ip().octets())
160 .map_err(|_| BufferTooShort)?;
161 buf.write_all(&dest.ip().octets())
162 .map_err(|_| BufferTooShort)?;
163 buf.write_all(&src.port().to_be_bytes())
164 .map_err(|_| BufferTooShort)?;
165 buf.write_all(&dest.port().to_be_bytes())
166 .map_err(|_| BufferTooShort)?;
167 }
168 None => {
169 buf.write_all(b"\x20\x00").map_err(|_| BufferTooShort)?;
170
171 let len: u16 = header.1.len().try_into().map_err(|_| HeaderTooBig)?;
172 buf.write_all(&len.to_be_bytes())
173 .map_err(|_| BufferTooShort)?;
174 }
175 _ => return Err(AddressFamilyMismatch),
176 }
177
178 buf.write_all(&header.1).map_err(|_| BufferTooShort)?;
179 Ok(())
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use std::net::{SocketAddrV4, SocketAddrV6};
186
187 #[test]
188 fn test_encode_local() {
189 let mut buf = [0u8; 1024];
190 let header = ProxyHeader::with_local();
191
192 let len = header.encode_to_slice_v2(&mut buf).unwrap();
193 assert_eq!(&buf[..len], b"\r\n\r\n\x00\r\nQUIT\n\x20\x00\x00\x00");
194
195 let decoded = decode(&buf, ParseConfig::default()).unwrap();
196 assert_eq!(decoded.0, header);
197 assert_eq!(decoded.1, len);
198 }
199
200 #[test]
201 fn test_encode_ipv4() {
202 let mut buf = [0u8; 102400];
203 let header = ProxyHeader::with_address(ProxiedAddress {
204 protocol: Protocol::Stream,
205 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
206 destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 5678)),
207 });
208
209 let len = header.encode_to_slice_v2(&mut buf).unwrap();
210 assert_eq!(
211 &buf[..len],
212 b"\r\n\r\n\x00\r\nQUIT\n!\x11\x00\x0c\x7f\x00\x00\x01\x7f\x00\x00\x01\x04\xd2\x16."
213 );
214
215 let decoded = decode(&buf, ParseConfig::default()).unwrap();
216 assert_eq!(decoded.0, header);
217 assert_eq!(decoded.1, len);
218 }
219
220 #[test]
221 fn test_encode_ipv6() {
222 let mut buf = [0u8; 102400];
223 let header = ProxyHeader::with_address(ProxiedAddress {
224 protocol: Protocol::Datagram,
225 source: SocketAddr::V6(SocketAddrV6::new(
226 Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
227 1234,
228 0,
229 0,
230 )),
231 destination: SocketAddr::V6(SocketAddrV6::new(
232 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
233 5678,
234 0,
235 0,
236 )),
237 });
238
239 let len = header.encode_to_slice_v2(&mut buf).unwrap();
240 assert_eq!(
241 &buf[..len],
242 &[
243 13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 34, 0, 36, 32, 1, 13, 184, 0, 0,
244 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4,
245 210, 22, 46
246 ]
247 );
248
249 let decoded = decode(&buf, ParseConfig::default()).unwrap();
250 assert_eq!(decoded.0, header);
251 assert_eq!(decoded.1, len);
252 }
253
254 #[test]
255 fn test_tlvs() {
256 let mut buf = [0u8; 102400];
257 let mut header = ProxyHeader::with_local();
258 header.append_tlv(crate::Tlv::UniqueId(b"unique"[..].into()));
259 header.append_tlv(crate::Tlv::Crc32c(1234));
260
261 let len = header.encode_to_slice_v2(&mut buf).unwrap();
262
263 let decoded = decode(
264 &buf,
265 ParseConfig {
266 include_tlvs: true,
267 ..Default::default()
268 },
269 )
270 .unwrap();
271
272 assert_eq!(decoded.0, header);
273 assert_eq!(decoded.1, len);
274
275 assert_eq!(decoded.0.unique_id(), Some(&b"unique"[..]));
276 assert_eq!(decoded.0.crc32c(), Some(1234));
277 }
278
279 #[test]
280 fn test_family_mismatch() {
281 let mut buf = [0u8; 1024];
282 let header = ProxyHeader::with_address(ProxiedAddress {
283 protocol: Protocol::Stream,
284 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
285 destination: SocketAddr::V6(SocketAddrV6::new(
286 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
287 5678,
288 0,
289 0,
290 )),
291 });
292
293 assert_eq!(
294 header.encode_to_slice_v2(&mut buf),
295 Err(AddressFamilyMismatch)
296 );
297 }
298}