1use std::{
47 io::{self, BufRead, Read, Write},
48 mem::MaybeUninit,
49};
50
51#[cfg(any(unix, target_os = "wasi"))]
52use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
53
54#[cfg(feature = "tokio")]
55use std::{
56 pin::Pin,
57 task::{Context, Poll},
58};
59
60#[cfg(feature = "tokio")]
61use pin_project_lite::pin_project;
62
63#[cfg(feature = "tokio")]
64use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
65
66use crate::{Error, ParseConfig, ProxyHeader};
67
68#[cfg(all(feature = "tokio", not(doc)))]
69pin_project! {
70 #[derive(Debug)]
71 pub struct ProxiedStream<IO> {
72 #[pin]
73 io: IO,
74 remaining: Vec<u8>,
75 header: ProxyHeader<'static>,
76 }
77}
78
79#[cfg(any(doc, not(feature = "tokio")))]
83#[derive(Debug)]
84pub struct ProxiedStream<IO> {
85 io: IO,
86 remaining: Vec<u8>,
87 header: ProxyHeader<'static>,
88}
89
90impl<IO> ProxiedStream<IO> {
91 pub fn unproxied(io: IO) -> Self {
96 Self {
97 io,
98 remaining: vec![],
99 header: Default::default(),
100 }
101 }
102
103 pub fn proxy_header(&self) -> &ProxyHeader {
105 &self.header
106 }
107
108 pub fn get_ref(&self) -> &IO {
110 &self.io
111 }
112
113 pub fn get_mut(&mut self) -> &mut IO {
115 &mut self.io
116 }
117
118 #[cfg(feature = "tokio")]
120 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut IO> {
121 self.project().io
122 }
123
124 pub fn into_inner(self) -> IO {
126 self.io
127 }
128}
129
130#[cfg(feature = "tokio")]
131#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
132impl<IO> ProxiedStream<IO>
133where
134 IO: AsyncRead + Unpin,
135{
136 pub async fn create_from_tokio(mut io: IO, config: ParseConfig) -> io::Result<Self> {
147 use tokio::io::AsyncReadExt;
148
149 let mut bytes = Vec::with_capacity(256);
155
156 loop {
157 let bytes_read = io.read_buf(&mut bytes).await?;
158 if bytes_read == 0 {
159 return Err(io::Error::new(
160 io::ErrorKind::UnexpectedEof,
161 "end of stream",
162 ));
163 }
164
165 match ProxyHeader::parse(&bytes, config) {
166 Ok((ret, consumed)) => {
167 let ret = ret.into_owned();
168 bytes.drain(..consumed);
169
170 return Ok(Self {
171 io,
172 remaining: bytes,
173 header: ret,
174 });
175 }
176 Err(Error::BufferTooShort) => continue,
177 Err(_) => {
178 return Err(io::Error::new(
179 io::ErrorKind::InvalidData,
180 "invalid proxy header",
181 ))
182 }
183 }
184 }
185 }
186}
187
188impl<IO> ProxiedStream<IO>
189where
190 IO: Read,
191{
192 pub fn create_from_std(mut io: IO, config: ParseConfig) -> io::Result<Self> {
196 let mut bytes = Vec::with_capacity(256);
197
198 loop {
199 if bytes.capacity() == bytes.len() {
200 bytes.reserve(32);
201 }
202
203 let buf = bytes.spare_capacity_mut();
207 buf.fill(MaybeUninit::new(0));
208
209 let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
211
212 let bytes_read = io.read(buf)?;
213 if bytes_read == 0 {
214 return Err(io::Error::new(
215 io::ErrorKind::UnexpectedEof,
216 "end of stream",
217 ));
218 }
219
220 unsafe {
223 assert!(bytes_read <= buf.len());
224 bytes.set_len(bytes.len() + bytes_read);
225 }
226
227 match ProxyHeader::parse(&bytes, config) {
228 Ok((ret, consumed)) => {
229 let ret = ret.into_owned();
230 bytes.drain(..consumed);
231
232 return Ok(Self {
233 io,
234 remaining: bytes,
235 header: ret,
236 });
237 }
238 Err(Error::BufferTooShort) => continue,
239 Err(_) => {
240 return Err(io::Error::new(
241 io::ErrorKind::InvalidData,
242 "invalid proxy header",
243 ))
244 }
245 }
246 }
247 }
248}
249
250impl<IO> Read for ProxiedStream<IO>
251where
252 IO: Read,
253{
254 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
255 if !self.remaining.is_empty() {
256 let len = std::cmp::min(self.remaining.len(), buf.len());
257
258 buf[..len].copy_from_slice(&self.remaining[..len]);
259 self.remaining.drain(..len);
260
261 return Ok(len);
262 }
263
264 self.io.read(buf)
265 }
266}
267
268impl<IO> BufRead for ProxiedStream<IO>
269where
270 IO: BufRead,
271{
272 fn fill_buf(&mut self) -> io::Result<&[u8]> {
273 if !self.remaining.is_empty() {
274 return Ok(&self.remaining);
275 }
276 self.io.fill_buf()
277 }
278
279 fn consume(&mut self, mut amt: usize) {
280 if !self.remaining.is_empty() {
281 let len = std::cmp::min(self.remaining.len(), amt);
282 self.remaining.drain(..len);
283 amt -= len;
284 }
285 self.io.consume(amt);
286 }
287}
288
289impl<IO> Write for ProxiedStream<IO>
290where
291 IO: Write,
292{
293 #[inline]
294 fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
295 self.io.write_vectored(bufs)
296 }
297
298 #[inline]
299 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
300 self.io.write_all(buf)
301 }
302
303 #[inline]
304 fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> {
305 self.io.write_fmt(fmt)
306 }
307
308 #[inline]
309 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
310 self.io.write(buf)
311 }
312
313 #[inline]
314 fn flush(&mut self) -> io::Result<()> {
315 self.io.flush()
316 }
317}
318
319#[cfg(feature = "tokio")]
320#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
321impl<IO> AsyncBufRead for ProxiedStream<IO>
322where
323 IO: AsyncBufRead,
324{
325 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
326 let me = self.project();
327
328 if !me.remaining.is_empty() {
329 return Poll::Ready(Ok(&me.remaining[..]));
330 }
331
332 me.io.poll_fill_buf(cx)
333 }
334
335 fn consume(self: Pin<&mut Self>, amt: usize) {
336 let me = self.project();
337
338 if !me.remaining.is_empty() {
339 let len = std::cmp::min(me.remaining.len(), amt);
340 me.remaining.drain(..len);
341 }
342
343 me.io.consume(amt);
344 }
345}
346
347#[cfg(feature = "tokio")]
348#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
349impl<IO> AsyncRead for ProxiedStream<IO>
350where
351 IO: AsyncRead,
352{
353 fn poll_read(
354 self: Pin<&mut Self>,
355 cx: &mut Context<'_>,
356 buf: &mut ReadBuf<'_>,
357 ) -> Poll<io::Result<()>> {
358 let me = self.project();
359
360 if !me.remaining.is_empty() {
361 let len = std::cmp::min(me.remaining.len(), buf.remaining());
362
363 buf.put_slice(&me.remaining[..len]);
364 me.remaining.drain(..len);
365
366 return Poll::Ready(Ok(()));
367 }
368
369 me.io.poll_read(cx, buf)
370 }
371}
372
373#[cfg(feature = "tokio")]
374#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
375impl<IO> AsyncWrite for ProxiedStream<IO>
376where
377 IO: AsyncWrite,
378{
379 #[inline]
380 fn poll_write(
381 self: Pin<&mut Self>,
382 cx: &mut Context<'_>,
383 buf: &[u8],
384 ) -> Poll<io::Result<usize>> {
385 self.project().io.poll_write(cx, buf)
386 }
387
388 #[inline]
389 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
390 self.project().io.poll_flush(cx)
391 }
392
393 #[inline]
394 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395 self.project().io.poll_shutdown(cx)
396 }
397
398 #[inline]
399 fn poll_write_vectored(
400 self: Pin<&mut Self>,
401 cx: &mut Context<'_>,
402 bufs: &[io::IoSlice<'_>],
403 ) -> Poll<Result<usize, io::Error>> {
404 self.project().io.poll_write_vectored(cx, bufs)
405 }
406
407 #[inline]
408 fn is_write_vectored(&self) -> bool {
409 self.io.is_write_vectored()
410 }
411}
412
413#[cfg(any(unix, target_os = "wasi"))]
414#[cfg_attr(docsrs, doc(cfg(any(unix, target_os = "wasi"))))]
415impl<IO> AsRawFd for ProxiedStream<IO>
416where
417 IO: AsRawFd,
418{
419 fn as_raw_fd(&self) -> RawFd {
420 self.io.as_raw_fd()
421 }
422}
423
424#[cfg(any(unix, target_os = "wasi"))]
425#[cfg_attr(docsrs, doc(cfg(any(unix, target_os = "wasi"))))]
426impl<IO> AsFd for ProxiedStream<IO>
427where
428 IO: AsFd,
429{
430 fn as_fd(&self) -> BorrowedFd<'_> {
431 self.io.as_fd()
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 use crate::{Protocol, ProxiedAddress, ProxyHeader};
440 use std::{
441 io::Cursor,
442 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
443 };
444
445 #[test]
446 fn test_sync() {
447 let mut buf = [0; 1024];
448
449 let header = ProxyHeader::with_address(ProxiedAddress {
450 protocol: Protocol::Stream,
451 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
452 destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
453 });
454
455 let written_len = header.encode_to_slice_v2(&mut buf).unwrap();
456 buf[written_len..].fill(255);
457
458 let mut stream = Cursor::new(&buf);
459
460 let mut proxied = ProxiedStream::create_from_std(&mut stream, Default::default()).unwrap();
461 assert_eq!(proxied.proxy_header(), &header);
462
463 let mut buf = Vec::new();
464 proxied.read_to_end(&mut buf).unwrap();
465
466 assert_eq!(buf.len(), 1024 - written_len);
467 assert!(buf.into_iter().all(|b| b == 255));
468 }
469
470 #[cfg(feature = "tokio")]
471 #[tokio::test]
472 async fn test_tokio() {
473 use tokio::io::AsyncReadExt;
474
475 let mut buf = [0; 1024];
476
477 let header = ProxyHeader::with_address(ProxiedAddress {
478 protocol: Protocol::Stream,
479 source: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234)),
480 destination: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 5678)),
481 });
482
483 let written_len = header.encode_to_slice_v2(&mut buf).unwrap();
484 buf[written_len..].fill(255);
485
486 let mut stream = Cursor::new(&buf);
487
488 let mut proxied = ProxiedStream::create_from_tokio(&mut stream, Default::default())
489 .await
490 .unwrap();
491 assert_eq!(proxied.proxy_header(), &header);
492
493 let mut buf = Vec::new();
494 AsyncReadExt::read_to_end(&mut proxied, &mut buf)
495 .await
496 .unwrap();
497
498 assert_eq!(buf.len(), 1024 - written_len);
499 assert!(buf.into_iter().all(|b| b == 255));
500 }
501}