azure_core/
bytes_stream.rs

1use crate::SeekableStream;
2use bytes::Bytes;
3use futures::io::AsyncRead;
4use futures::stream::Stream;
5use std::pin::Pin;
6use std::task::Poll;
7
8/// Convenience struct that maps a `bytes::Bytes` buffer into a stream.
9///
10/// This struct implements both `Stream` and `SeekableStream` for an
11/// immutable bytes buffer. It's cheap to clone but remember to `reset`
12/// the stream position if you clone it.
13#[derive(Debug, Clone)]
14pub struct BytesStream {
15    bytes: Bytes,
16    bytes_read: usize,
17}
18
19impl BytesStream {
20    pub fn new(bytes: impl Into<Bytes>) -> Self {
21        Self {
22            bytes: bytes.into(),
23            bytes_read: 0,
24        }
25    }
26
27    /// Creates a stream that resolves immediately with no data.
28    pub fn new_empty() -> Self {
29        Self::new(Bytes::new())
30    }
31}
32
33impl From<Bytes> for BytesStream {
34    fn from(bytes: Bytes) -> Self {
35        Self::new(bytes)
36    }
37}
38
39impl Stream for BytesStream {
40    type Item = crate::Result<Bytes>;
41
42    fn poll_next(
43        self: Pin<&mut Self>,
44        _cx: &mut std::task::Context<'_>,
45    ) -> Poll<Option<Self::Item>> {
46        let self_mut = self.get_mut();
47
48        // we return all the available bytes in one call.
49        if self_mut.bytes_read < self_mut.bytes.len() {
50            let bytes_read = self_mut.bytes_read;
51            self_mut.bytes_read = self_mut.bytes.len();
52            Poll::Ready(Some(Ok(self_mut.bytes.slice(bytes_read..))))
53        } else {
54            Poll::Ready(None)
55        }
56    }
57}
58
59#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
60#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
61impl SeekableStream for BytesStream {
62    async fn reset(&mut self) -> crate::Result<()> {
63        self.bytes_read = 0;
64        Ok(())
65    }
66
67    fn len(&self) -> usize {
68        self.bytes.len()
69    }
70}
71
72impl AsyncRead for BytesStream {
73    fn poll_read(
74        self: Pin<&mut Self>,
75        _cx: &mut std::task::Context<'_>,
76        buf: &mut [u8],
77    ) -> Poll<std::io::Result<usize>> {
78        let self_mut = self.get_mut();
79
80        if self_mut.bytes_read < self_mut.bytes.len() {
81            let bytes_read = self_mut.bytes_read;
82            let remaining_bytes = self_mut.bytes.len() - bytes_read;
83
84            let bytes_to_copy = std::cmp::min(remaining_bytes, buf.len());
85            let bytes_to_read_end = self_mut.bytes_read + bytes_to_copy;
86
87            for (buf_byte, bytes_byte) in buf
88                .iter_mut()
89                .zip(self_mut.bytes.slice(self_mut.bytes_read..bytes_to_read_end))
90            {
91                *buf_byte = bytes_byte;
92            }
93
94            self_mut.bytes_read += bytes_to_copy;
95
96            Poll::Ready(Ok(bytes_to_copy))
97        } else {
98            Poll::Ready(Ok(0))
99        }
100    }
101}
102
103// Unit tests
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use futures::io::AsyncReadExt;
108    use futures::stream::StreamExt;
109
110    // Test BytesStream Stream
111    #[test]
112    fn test_bytes_stream() {
113        let bytes = Bytes::from("hello world");
114        let mut stream = BytesStream::new(bytes.clone());
115
116        let mut buf = Vec::new();
117        let mut bytes_read = 0;
118        while let Some(Ok(bytes)) = futures::executor::block_on(stream.next()) {
119            buf.extend_from_slice(&bytes);
120            bytes_read += bytes.len();
121        }
122
123        assert_eq!(bytes_read, bytes.len());
124        assert_eq!(buf, bytes);
125    }
126
127    // Test BytesStream AsyncRead, all bytes at once
128    #[test]
129    fn test_async_read_all_bytes_at_once() {
130        let bytes = Bytes::from("hello world");
131        let mut stream = BytesStream::new(bytes.clone());
132
133        let mut buf = [0; 11];
134        let bytes_read = futures::executor::block_on(stream.read(&mut buf)).unwrap();
135        assert_eq!(bytes_read, 11);
136        assert_eq!(&buf[..], &bytes);
137    }
138
139    // Test BytesStream AsyncRead, one byte at a time
140    #[test]
141    fn test_async_read_one_byte_at_a_time() {
142        let bytes = Bytes::from("hello world");
143        let mut stream = BytesStream::new(bytes.clone());
144
145        for i in 0..bytes.len() {
146            let mut buf = [0; 1];
147            let bytes_read = futures::executor::block_on(stream.read(&mut buf)).unwrap();
148            assert_eq!(bytes_read, 1);
149            assert_eq!(buf[0], bytes[i]);
150        }
151    }
152}