1use super::fuse::Fuse;
2use super::Encoder;
3use bytes::{Buf, BytesMut};
4use futures_sink::Sink;
5use futures_util::io::{AsyncRead, AsyncWrite};
6use futures_util::ready;
7use pin_project_lite::pin_project;
8use std::io::{Error, ErrorKind};
9use std::marker::Unpin;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14pin_project! {
15 #[derive(Debug)]
35 pub struct FramedWrite<T, E> {
36 #[pin]
37 inner: FramedWrite2<Fuse<T, E>>,
38 }
39}
40
41impl<T, E> FramedWrite<T, E>
42where
43 T: AsyncWrite,
44 E: Encoder,
45{
46 pub fn new(inner: T, encoder: E) -> Self {
48 Self {
49 inner: framed_write_2(Fuse::new(inner, encoder), None),
50 }
51 }
52
53 pub fn from_parts(
57 FramedWriteParts {
58 io,
59 encoder,
60 buffer,
61 ..
62 }: FramedWriteParts<T, E>,
63 ) -> Self {
64 Self {
65 inner: framed_write_2(Fuse::new(io, encoder), Some(buffer)),
66 }
67 }
68
69 pub fn send_high_water_mark(&self) -> usize {
81 self.inner.high_water_mark
82 }
83
84 pub fn set_send_high_water_mark(&mut self, hwm: usize) {
100 self.inner.high_water_mark = hwm;
101 }
102
103 pub fn into_parts(self) -> FramedWriteParts<T, E> {
108 let (fuse, buffer) = self.inner.into_parts();
109 FramedWriteParts {
110 io: fuse.t,
111 encoder: fuse.u,
112 buffer,
113 _priv: (),
114 }
115 }
116
117 pub fn into_inner(self) -> T {
123 self.into_parts().io
124 }
125
126 pub fn encoder(&self) -> &E {
131 &self.inner.u
132 }
133
134 pub fn encoder_mut(&mut self) -> &mut E {
139 &mut self.inner.u
140 }
141}
142
143impl<T, E> Sink<E::Item> for FramedWrite<T, E>
144where
145 T: AsyncWrite + Unpin,
146 E: Encoder,
147{
148 type Error = E::Error;
149
150 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
151 self.project().inner.poll_ready(cx)
152 }
153 fn start_send(self: Pin<&mut Self>, item: E::Item) -> Result<(), Self::Error> {
154 self.project().inner.start_send(item)
155 }
156 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
157 self.project().inner.poll_flush(cx)
158 }
159 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
160 self.project().inner.poll_close(cx)
161 }
162}
163
164impl<T, E> Deref for FramedWrite<T, E> {
165 type Target = T;
166
167 fn deref(&self) -> &T {
168 &self.inner
169 }
170}
171
172impl<T, E> DerefMut for FramedWrite<T, E> {
173 fn deref_mut(&mut self) -> &mut T {
174 &mut self.inner
175 }
176}
177
178pin_project! {
179 #[derive(Debug)]
180 pub struct FramedWrite2<T> {
181 #[pin]
182 pub inner: T,
183 pub high_water_mark: usize,
184 buffer: BytesMut,
185 }
186}
187
188impl<T> Deref for FramedWrite2<T> {
189 type Target = T;
190
191 fn deref(&self) -> &T {
192 &self.inner
193 }
194}
195
196impl<T> DerefMut for FramedWrite2<T> {
197 fn deref_mut(&mut self) -> &mut T {
198 &mut self.inner
199 }
200}
201
202const DEFAULT_SEND_HIGH_WATER_MARK: usize = 131072;
205
206pub fn framed_write_2<T>(inner: T, buffer: Option<BytesMut>) -> FramedWrite2<T> {
207 FramedWrite2 {
208 inner,
209 high_water_mark: DEFAULT_SEND_HIGH_WATER_MARK,
210 buffer: buffer.unwrap_or_else(|| BytesMut::with_capacity(1028 * 8)),
211 }
212}
213
214impl<T: AsyncRead + Unpin> AsyncRead for FramedWrite2<T> {
215 fn poll_read(
216 self: Pin<&mut Self>,
217 cx: &mut Context<'_>,
218 buf: &mut [u8],
219 ) -> Poll<Result<usize, Error>> {
220 self.project().inner.poll_read(cx, buf)
221 }
222}
223
224impl<T> Sink<T::Item> for FramedWrite2<T>
225where
226 T: AsyncWrite + Encoder + Unpin,
227{
228 type Error = T::Error;
229
230 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
231 let this = &mut *self;
232 while this.buffer.len() >= this.high_water_mark {
233 let num_write = ready!(Pin::new(&mut this.inner).poll_write(cx, &this.buffer))?;
234
235 if num_write == 0 {
236 return Poll::Ready(Err(err_eof().into()));
237 }
238
239 this.buffer.advance(num_write);
240 }
241
242 Poll::Ready(Ok(()))
243 }
244 fn start_send(mut self: Pin<&mut Self>, item: T::Item) -> Result<(), Self::Error> {
245 let this = &mut *self;
246 this.inner.encode(item, &mut this.buffer)
247 }
248 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
249 let mut this = self.project();
250
251 while !this.buffer.is_empty() {
252 let num_write = ready!(Pin::new(&mut this.inner).poll_write(cx, &this.buffer))?;
253
254 if num_write == 0 {
255 return Poll::Ready(Err(err_eof().into()));
256 }
257
258 this.buffer.advance(num_write);
259 }
260
261 this.inner.poll_flush(cx).map_err(Into::into)
262 }
263 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
264 ready!(self.as_mut().poll_flush(cx))?;
265 self.project().inner.poll_close(cx).map_err(Into::into)
266 }
267}
268
269impl<T> FramedWrite2<T> {
270 pub fn into_parts(self) -> (T, BytesMut) {
271 (self.inner, self.buffer)
272 }
273}
274
275fn err_eof() -> Error {
276 Error::new(ErrorKind::UnexpectedEof, "End of file")
277}
278
279pub struct FramedWriteParts<T, E> {
281 pub io: T,
283 pub encoder: E,
285 pub buffer: BytesMut,
287 _priv: (),
289}
290
291impl<T, E> FramedWriteParts<T, E> {
292 pub fn map_encoder<G, F>(self, f: F) -> FramedWriteParts<T, G>
294 where
295 G: Encoder,
296 F: FnOnce(E) -> G,
297 {
298 FramedWriteParts {
299 io: self.io,
300 encoder: f(self.encoder),
301 buffer: self.buffer,
302 _priv: (),
303 }
304 }
305}