1#![cfg_attr(sendfd_docs, feature(doc_cfg))]
2
3extern crate libc;
4#[cfg(feature = "tokio")]
5extern crate tokio;
6
7use std::os::unix::io::{AsRawFd, RawFd};
8use std::os::unix::net;
9use std::{alloc, io, mem, ptr};
10#[cfg(feature = "tokio")]
11use tokio::io::Interest;
12
13pub mod changelog;
14
15pub trait SendWithFd {
17 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize>;
19}
20
21pub trait RecvWithFd {
23 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)>;
27}
28
29unsafe fn ptr_offset_from(this: *const u8, origin: *const u8) -> isize {
31 isize::wrapping_sub(this as _, origin as _)
32}
33
34unsafe fn construct_msghdr_for(
45 iov: &mut libc::iovec,
46 fd_count: usize,
47) -> (libc::msghdr, alloc::Layout, usize) {
48 let fd_len = mem::size_of::<RawFd>() * fd_count;
49 let cmsg_buffer_len = libc::CMSG_SPACE(fd_len as u32) as usize;
50 let layout = alloc::Layout::from_size_align(cmsg_buffer_len, mem::align_of::<libc::cmsghdr>());
51 let (cmsg_buffer, cmsg_layout) = if let Ok(layout) = layout {
52 const NULL_MUT_U8: *mut u8 = ptr::null_mut();
53 match alloc::alloc(layout) {
54 NULL_MUT_U8 => alloc::handle_alloc_error(layout),
55 x => (x as *mut _, layout),
56 }
57 } else {
58 alloc::handle_alloc_error(alloc::Layout::from_size_align_unchecked(
62 cmsg_buffer_len,
63 mem::align_of::<libc::cmsghdr>(),
64 ))
65 };
66
67 let mut msghdr = mem::zeroed::<libc::msghdr>();
68 msghdr.msg_name = ptr::null_mut();
69 msghdr.msg_namelen = 0;
70 msghdr.msg_iov = iov as *mut _;
71 msghdr.msg_iovlen = 1;
72 msghdr.msg_control = cmsg_buffer;
73 msghdr.msg_controllen = cmsg_buffer_len as _;
74
75 (msghdr, cmsg_layout, fd_len)
76}
77
78fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
81 unsafe {
82 let mut iov = libc::iovec {
83 iov_base: bs.as_ptr() as *const _ as *mut _,
86 iov_len: bs.len(),
87 };
88 let (mut msghdr, cmsg_layout, fd_len) = construct_msghdr_for(&mut iov, fds.len());
89 let cmsg_buffer = msghdr.msg_control;
90
91 let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
93 let mut cmsghdr = mem::zeroed::<libc::cmsghdr>();
94 cmsghdr.cmsg_level = libc::SOL_SOCKET;
95 cmsghdr.cmsg_type = libc::SCM_RIGHTS;
96 cmsghdr.cmsg_len = libc::CMSG_LEN(fd_len as u32) as _;
97
98 ptr::write(cmsg_header, cmsghdr);
99
100 let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd;
101 for (i, fd) in fds.iter().enumerate() {
102 ptr::write_unaligned(cmsg_data.add(i), *fd);
103 }
104 let count = libc::sendmsg(socket, &msghdr as *const _, 0);
105 if count < 0 {
106 let error = io::Error::last_os_error();
107 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
108 Err(error)
109 } else {
110 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
111 Ok(count as usize)
112 }
113 }
114}
115
116fn recv_with_fd(socket: RawFd, bs: &mut [u8], mut fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
119 unsafe {
120 let mut iov = libc::iovec {
121 iov_base: bs.as_mut_ptr() as *mut _,
122 iov_len: bs.len(),
123 };
124 let (mut msghdr, cmsg_layout, _) = construct_msghdr_for(&mut iov, fds.len());
125 let cmsg_buffer = msghdr.msg_control;
126 let count = libc::recvmsg(socket, &mut msghdr as *mut _, 0);
127 if count < 0 {
128 let error = io::Error::last_os_error();
129 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
130 return Err(error);
131 }
132
133 let mut descriptor_count = 0;
136 let mut cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
137 while !cmsg_header.is_null() {
138 if (*cmsg_header).cmsg_level == libc::SOL_SOCKET
139 && (*cmsg_header).cmsg_type == libc::SCM_RIGHTS
140 {
141 let data_ptr = libc::CMSG_DATA(cmsg_header);
142 let data_offset = ptr_offset_from(data_ptr, cmsg_header as *const _);
143 debug_assert!(data_offset >= 0);
144 let data_byte_count = (*cmsg_header).cmsg_len as usize - data_offset as usize;
145 debug_assert!((*cmsg_header).cmsg_len as isize > data_offset);
146 debug_assert!(data_byte_count % mem::size_of::<RawFd>() == 0);
147 let rawfd_count = (data_byte_count / mem::size_of::<RawFd>()) as isize;
148 let fd_ptr = data_ptr as *const RawFd;
149 for i in 0..rawfd_count {
150 if let Some((dst, rest)) = { fds }.split_first_mut() {
151 *dst = ptr::read_unaligned(fd_ptr.offset(i));
152 descriptor_count += 1;
153 fds = rest;
154 } else {
155 unreachable!();
165 }
166 }
167 }
168 cmsg_header = libc::CMSG_NXTHDR(&mut msghdr as *mut _, cmsg_header);
169 }
170
171 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
172 Ok((count as usize, descriptor_count))
173 }
174}
175
176impl SendWithFd for net::UnixStream {
177 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
182 send_with_fd(self.as_raw_fd(), bytes, fds)
183 }
184}
185
186#[cfg(feature = "tokio")]
187#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
188impl SendWithFd for tokio::net::UnixStream {
189 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
194 self.try_io(Interest::WRITABLE, || send_with_fd(self.as_raw_fd(), bytes, fds))
195 }
196}
197
198#[cfg(feature = "tokio")]
199#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
200impl SendWithFd for tokio::net::unix::WriteHalf<'_> {
201 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
206 let unix_stream: &tokio::net::UnixStream = self.as_ref();
207 unix_stream.send_with_fd(bytes, fds)
208 }
209}
210
211impl SendWithFd for net::UnixDatagram {
212 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
218 send_with_fd(self.as_raw_fd(), bytes, fds)
219 }
220}
221
222#[cfg(feature = "tokio")]
223#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
224impl SendWithFd for tokio::net::UnixDatagram {
225 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
231 self.try_io(Interest::WRITABLE, || send_with_fd(self.as_raw_fd(), bytes, fds))
232 }
233}
234
235impl RecvWithFd for net::UnixStream {
236 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
242 recv_with_fd(self.as_raw_fd(), bytes, fds)
243 }
244}
245
246#[cfg(feature = "tokio")]
247#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
248impl RecvWithFd for tokio::net::UnixStream {
249 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
255 self.try_io(Interest::READABLE, || recv_with_fd(self.as_raw_fd(), bytes, fds))
256 }
257}
258
259#[cfg(feature = "tokio")]
260#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
261impl RecvWithFd for tokio::net::unix::ReadHalf<'_> {
262 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
268 let unix_stream: &tokio::net::UnixStream = self.as_ref();
269 unix_stream.recv_with_fd(bytes, fds)
270 }
271}
272
273impl RecvWithFd for net::UnixDatagram {
274 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
285 recv_with_fd(self.as_raw_fd(), bytes, fds)
286 }
287}
288
289#[cfg(feature = "tokio")]
290#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
291impl RecvWithFd for tokio::net::UnixDatagram {
292 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
303 self.try_io(Interest::READABLE, || recv_with_fd(self.as_raw_fd(), bytes, fds))
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::{RecvWithFd, SendWithFd};
310 use std::os::unix::io::{AsRawFd, FromRawFd};
311 use std::os::unix::net;
312
313 #[test]
314 fn stream_works() {
315 let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
316 let sent_bytes = b"hello world!";
317 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
318 assert_eq!(
319 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
320 .expect("send should be successful"),
321 sent_bytes.len()
322 );
323 let mut recv_bytes = [0; 128];
324 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
325 assert_eq!(
326 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
327 .expect("recv should be successful"),
328 (sent_bytes.len(), sent_fds.len())
329 );
330 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
331 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
332 let expected_value = Some(std::time::Duration::from_secs(42));
335 unsafe {
336 let s = net::UnixStream::from_raw_fd(sent);
337 s.set_read_timeout(expected_value)
338 .expect("set read timeout");
339 std::mem::forget(s);
340 assert_eq!(
341 net::UnixStream::from_raw_fd(recvd)
342 .read_timeout()
343 .expect("get read timeout"),
344 expected_value
345 );
346 }
347 }
348 }
349
350 #[test]
351 fn datagram_works() {
352 let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
353 let sent_bytes = b"hello world!";
354 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
355 assert_eq!(
356 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
357 .expect("send should be successful"),
358 sent_bytes.len()
359 );
360 let mut recv_bytes = [0; 128];
361 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
362 assert_eq!(
363 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
364 .expect("recv should be successful"),
365 (sent_bytes.len(), sent_fds.len())
366 );
367 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
368 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
369 let expected_value = Some(std::time::Duration::from_secs(42));
372 unsafe {
373 let s = net::UnixDatagram::from_raw_fd(sent);
374 s.set_read_timeout(expected_value)
375 .expect("set read timeout");
376 std::mem::forget(s);
377 assert_eq!(
378 net::UnixDatagram::from_raw_fd(recvd)
379 .read_timeout()
380 .expect("get read timeout"),
381 expected_value
382 );
383 }
384 }
385 }
386
387 #[test]
388 fn datagram_works_across_processes() {
389 let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
390 let sent_bytes = b"hello world!";
391 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
392
393 unsafe {
394 match libc::fork() {
395 -1 => panic!("fork failed!"),
396 0 => {
397 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
400 .expect("send should be successful");
401 ::std::process::exit(0);
402 }
403 _ => {
404 }
406 }
407 let mut recv_bytes = [0; 128];
408 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
409 assert_eq!(
410 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
411 .expect("recv should be successful"),
412 (sent_bytes.len(), sent_fds.len())
413 );
414 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
415 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
416 let expected_value = Some(std::time::Duration::from_secs(42));
419 let s = net::UnixDatagram::from_raw_fd(sent);
420 s.set_read_timeout(expected_value)
421 .expect("set read timeout");
422 std::mem::forget(s);
423 assert_eq!(
424 net::UnixDatagram::from_raw_fd(recvd)
425 .read_timeout()
426 .expect("get read timeout"),
427 expected_value
428 );
429 }
430 }
431 }
432
433 #[test]
434 fn sending_junk_fails() {
435 let (l, _) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
436 let sent_bytes = b"hello world!";
437 if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[i32::max_value()][..]) {
438 panic!("expected an error when sending a junk file descriptor");
439 }
440 if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[0xffi32][..]) {
441 panic!("expected an error when sending a junk file descriptor");
442 }
443 }
444}