tokio_io_utility/
async_write_utility.rs

1use super::IoSliceExt;
2
3use std::io::{self, IoSlice, Result};
4use tokio::io::{AsyncWrite, AsyncWriteExt};
5
6/// Return true if the `bufs` contains at least one byte.
7pub async fn write_vectored_all<Writer: AsyncWrite + Unpin + ?Sized>(
8    writer: &mut Writer,
9    mut bufs: &mut [IoSlice<'_>],
10) -> Result<()> {
11    if bufs.is_empty() {
12        return Ok(());
13    }
14
15    while bufs[0].is_empty() {
16        bufs = &mut bufs[1..];
17
18        if bufs.is_empty() {
19            return Ok(());
20        }
21    }
22
23    // Loop Invariant:
24    //  - bufs must not be empty;
25    //  - bufs contain at least one byte.
26    loop {
27        // bytes must be greater than 0 since bufs contain
28        // at least one byte.
29        let mut bytes = writer.write_vectored(bufs).await?;
30
31        if bytes == 0 {
32            return Err(io::ErrorKind::WriteZero.into());
33        }
34
35        // This loop would also skip all `IoSlice` that is empty
36        // until the first non-empty `IoSlice` is met.
37        while bufs[0].len() <= bytes {
38            bytes -= bufs[0].len();
39            bufs = &mut bufs[1..];
40
41            if bufs.is_empty() {
42                debug_assert_eq!(bytes, 0);
43                return Ok(());
44            }
45        }
46
47        bufs[0] = IoSlice::new(&bufs[0].into_inner()[bytes..]);
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use super::write_vectored_all;
54
55    use std::io::IoSlice;
56    use std::slice::from_raw_parts;
57    use tokio::io::AsyncReadExt;
58
59    fn as_ioslice<T>(v: &[T]) -> IoSlice<'_> {
60        IoSlice::new(unsafe {
61            from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::<T>())
62        })
63    }
64
65    #[test]
66    fn test() {
67        tokio::runtime::Builder::new_current_thread()
68            .enable_all()
69            .build()
70            .unwrap()
71            .block_on(async {
72                let (mut r, mut w) = tokio_pipe::pipe().unwrap();
73
74                let w_task = tokio::spawn(async move {
75                    let buffer0: Vec<u32> = (0..1024).collect();
76                    let buffer1: Vec<u32> = (1024..2048).collect();
77
78                    write_vectored_all(&mut w, &mut [as_ioslice(&buffer0), as_ioslice(&buffer1)])
79                        .await
80                        .unwrap();
81
82                    write_vectored_all(&mut w, &mut [as_ioslice(&buffer0), as_ioslice(&buffer1)])
83                        .await
84                        .unwrap();
85                });
86
87                let r_task = tokio::spawn(async move {
88                    for _ in 0..2 {
89                        let mut n = 0u32;
90                        let mut buf = [0; 4 * 128];
91                        while n < 2048 {
92                            r.read_exact(&mut buf).await.unwrap();
93                            for x in buf.chunks(4) {
94                                assert_eq!(x, n.to_ne_bytes());
95                                n += 1;
96                            }
97                        }
98                    }
99                });
100                r_task.await.unwrap();
101                w_task.await.unwrap();
102            });
103    }
104}