1use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder};
2use crate::codec::EncodeBuf;
3use crate::Status;
4use prost::Message;
5use std::marker::PhantomData;
6
7#[derive(Debug, Clone)]
9pub struct ProstCodec<T, U> {
10 _pd: PhantomData<(T, U)>,
11}
12
13impl<T, U> ProstCodec<T, U> {
14 pub fn new() -> Self {
17 Self { _pd: PhantomData }
18 }
19}
20
21impl<T, U> Default for ProstCodec<T, U> {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl<T, U> ProstCodec<T, U>
28where
29 T: Message + Send + 'static,
30 U: Message + Default + Send + 'static,
31{
32 pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
35 ProstEncoder {
36 _pd: PhantomData,
37 buffer_settings,
38 }
39 }
40
41 pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
44 ProstDecoder {
45 _pd: PhantomData,
46 buffer_settings,
47 }
48 }
49}
50
51impl<T, U> Codec for ProstCodec<T, U>
52where
53 T: Message + Send + 'static,
54 U: Message + Default + Send + 'static,
55{
56 type Encode = T;
57 type Decode = U;
58
59 type Encoder = ProstEncoder<T>;
60 type Decoder = ProstDecoder<U>;
61
62 fn encoder(&mut self) -> Self::Encoder {
63 ProstEncoder {
64 _pd: PhantomData,
65 buffer_settings: BufferSettings::default(),
66 }
67 }
68
69 fn decoder(&mut self) -> Self::Decoder {
70 ProstDecoder {
71 _pd: PhantomData,
72 buffer_settings: BufferSettings::default(),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Default)]
79pub struct ProstEncoder<T> {
80 _pd: PhantomData<T>,
81 buffer_settings: BufferSettings,
82}
83
84impl<T> ProstEncoder<T> {
85 pub fn new(buffer_settings: BufferSettings) -> Self {
87 Self {
88 _pd: PhantomData,
89 buffer_settings,
90 }
91 }
92}
93
94impl<T: Message> Encoder for ProstEncoder<T> {
95 type Item = T;
96 type Error = Status;
97
98 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
99 item.encode(buf)
100 .expect("Message only errors if not enough space");
101
102 Ok(())
103 }
104
105 fn buffer_settings(&self) -> BufferSettings {
106 self.buffer_settings
107 }
108}
109
110#[derive(Debug, Clone, Default)]
112pub struct ProstDecoder<U> {
113 _pd: PhantomData<U>,
114 buffer_settings: BufferSettings,
115}
116
117impl<U> ProstDecoder<U> {
118 pub fn new(buffer_settings: BufferSettings) -> Self {
120 Self {
121 _pd: PhantomData,
122 buffer_settings,
123 }
124 }
125}
126
127impl<U: Message + Default> Decoder for ProstDecoder<U> {
128 type Item = U;
129 type Error = Status;
130
131 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
132 let item = Message::decode(buf)
133 .map(Option::Some)
134 .map_err(from_decode_error)?;
135
136 Ok(item)
137 }
138
139 fn buffer_settings(&self) -> BufferSettings {
140 self.buffer_settings
141 }
142}
143
144fn from_decode_error(error: prost::DecodeError) -> crate::Status {
145 Status::internal(error.to_string())
148}
149
150#[cfg(test)]
151mod tests {
152 use crate::codec::compression::SingleMessageCompressionOverride;
153 use crate::codec::{
154 DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
155 };
156 use crate::Status;
157 use bytes::{Buf, BufMut, BytesMut};
158 use http_body::Body;
159 use http_body_util::BodyExt as _;
160 use std::pin::pin;
161
162 const LEN: usize = 10000;
163 const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
165
166 #[tokio::test]
167 async fn decode() {
168 let decoder = MockDecoder::default();
169
170 let msg = vec![0u8; LEN];
171
172 let mut buf = BytesMut::new();
173
174 buf.reserve(msg.len() + HEADER_SIZE);
175 buf.put_u8(0);
176 buf.put_u32(msg.len() as u32);
177
178 buf.put(&msg[..]);
179
180 let body = body::MockBody::new(&buf[..], 10005, 0);
181
182 let mut stream = Streaming::new_request(decoder, body, None, None);
183
184 let mut i = 0usize;
185 while let Some(output_msg) = stream.message().await.unwrap() {
186 assert_eq!(output_msg.len(), msg.len());
187 i += 1;
188 }
189 assert_eq!(i, 1);
190 }
191
192 #[tokio::test]
193 async fn decode_max_message_size_exceeded() {
194 let decoder = MockDecoder::default();
195
196 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
197
198 let mut buf = BytesMut::new();
199
200 buf.reserve(msg.len() + HEADER_SIZE);
201 buf.put_u8(0);
202 buf.put_u32(msg.len() as u32);
203
204 buf.put(&msg[..]);
205
206 let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
207
208 let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
209
210 let actual = stream.message().await.unwrap_err();
211
212 let expected = Status::out_of_range(format!(
213 "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
214 msg.len(),
215 MAX_MESSAGE_SIZE
216 ));
217
218 assert_eq!(actual.code(), expected.code());
219 assert_eq!(actual.message(), expected.message());
220 }
221
222 #[tokio::test]
223 async fn encode() {
224 let encoder = MockEncoder::default();
225
226 let msg = Vec::from(&[0u8; 1024][..]);
227
228 let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
229 let source = tokio_stream::iter(messages);
230
231 let mut body = pin!(EncodeBody::new_server(
232 encoder,
233 source,
234 None,
235 SingleMessageCompressionOverride::default(),
236 None,
237 ));
238
239 while let Some(r) = body.frame().await {
240 r.unwrap();
241 }
242 }
243
244 #[tokio::test]
245 async fn encode_max_message_size_exceeded() {
246 let encoder = MockEncoder::default();
247
248 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
249
250 let messages = std::iter::once(Ok::<_, Status>(msg));
251 let source = tokio_stream::iter(messages);
252
253 let mut body = pin!(EncodeBody::new_server(
254 encoder,
255 source,
256 None,
257 SingleMessageCompressionOverride::default(),
258 Some(MAX_MESSAGE_SIZE),
259 ));
260
261 let frame = body
262 .frame()
263 .await
264 .expect("at least one frame")
265 .expect("no error polling frame");
266 assert_eq!(
267 frame
268 .into_trailers()
269 .expect("got trailers")
270 .get(Status::GRPC_STATUS)
271 .expect("grpc-status header"),
272 "11"
273 );
274 assert!(body.is_end_stream());
275 }
276
277 #[cfg(not(target_family = "windows"))]
279 #[tokio::test]
280 async fn encode_too_big() {
281 use crate::codec::EncodeBody;
282
283 let encoder = MockEncoder::default();
284
285 let msg = vec![0u8; u32::MAX as usize + 1];
286
287 let messages = std::iter::once(Ok::<_, Status>(msg));
288 let source = tokio_stream::iter(messages);
289
290 let mut body = pin!(EncodeBody::new_server(
291 encoder,
292 source,
293 None,
294 SingleMessageCompressionOverride::default(),
295 Some(usize::MAX),
296 ));
297
298 let frame = body
299 .frame()
300 .await
301 .expect("at least one frame")
302 .expect("no error polling frame");
303 assert_eq!(
304 frame
305 .into_trailers()
306 .expect("got trailers")
307 .get(Status::GRPC_STATUS)
308 .expect("grpc-status header"),
309 "8"
310 );
311 assert!(body.is_end_stream());
312 }
313
314 #[derive(Debug, Clone, Default)]
315 struct MockEncoder {}
316
317 impl Encoder for MockEncoder {
318 type Item = Vec<u8>;
319 type Error = Status;
320
321 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
322 buf.put(&item[..]);
323 Ok(())
324 }
325
326 fn buffer_settings(&self) -> crate::codec::BufferSettings {
327 Default::default()
328 }
329 }
330
331 #[derive(Debug, Clone, Default)]
332 struct MockDecoder {}
333
334 impl Decoder for MockDecoder {
335 type Item = Vec<u8>;
336 type Error = Status;
337
338 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
339 let out = Vec::from(buf.chunk());
340 buf.advance(LEN);
341 Ok(Some(out))
342 }
343
344 fn buffer_settings(&self) -> crate::codec::BufferSettings {
345 Default::default()
346 }
347 }
348
349 mod body {
350 use crate::Status;
351 use bytes::Bytes;
352 use http_body::{Body, Frame};
353 use std::{
354 pin::Pin,
355 task::{Context, Poll},
356 };
357
358 #[derive(Debug)]
359 pub(super) struct MockBody {
360 data: Bytes,
361
362 partial_len: usize,
364
365 count: usize,
367 }
368
369 impl MockBody {
370 pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
371 MockBody {
372 data: Bytes::copy_from_slice(b),
373 partial_len,
374 count,
375 }
376 }
377 }
378
379 impl Body for MockBody {
380 type Data = Bytes;
381 type Error = Status;
382
383 fn poll_frame(
384 mut self: Pin<&mut Self>,
385 cx: &mut Context<'_>,
386 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
387 let should_send = self.count % 2 == 0;
389 let data_len = self.data.len();
390 let partial_len = self.partial_len;
391 let count = self.count;
392 if data_len > 0 {
393 let result = if should_send {
394 let response =
395 self.data
396 .split_to(if count == 0 { partial_len } else { data_len });
397 Poll::Ready(Some(Ok(Frame::data(response))))
398 } else {
399 cx.waker().wake_by_ref();
400 Poll::Pending
401 };
402 self.count += 1;
404 result
405 } else {
406 Poll::Ready(None)
407 }
408 }
409 }
410 }
411}