1use crate::codec::compression::{
2 CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3};
4use crate::codec::EncodeBody;
5use crate::metadata::GRPC_CONTENT_TYPE;
6use crate::{
7 body::BoxBody,
8 codec::{Codec, Streaming},
9 server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
10 Request, Status,
11};
12use http_body::Body;
13use std::{fmt, pin::pin};
14use tokio_stream::{Stream, StreamExt};
15
16macro_rules! t {
17 ($result:expr) => {
18 match $result {
19 Ok(value) => value,
20 Err(status) => return status.into_http(),
21 }
22 };
23}
24
25pub struct Grpc<T> {
35 codec: T,
36 accept_compression_encodings: EnabledCompressionEncodings,
38 send_compression_encodings: EnabledCompressionEncodings,
40 max_decoding_message_size: Option<usize>,
42 max_encoding_message_size: Option<usize>,
44}
45
46impl<T> Grpc<T>
47where
48 T: Codec,
49{
50 pub fn new(codec: T) -> Self {
52 Self {
53 codec,
54 accept_compression_encodings: EnabledCompressionEncodings::default(),
55 send_compression_encodings: EnabledCompressionEncodings::default(),
56 max_decoding_message_size: None,
57 max_encoding_message_size: None,
58 }
59 }
60
61 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
89 self.accept_compression_encodings.enable(encoding);
90 self
91 }
92
93 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
120 self.send_compression_encodings.enable(encoding);
121 self
122 }
123
124 pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
150 self.max_decoding_message_size = Some(limit);
151 self
152 }
153
154 pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
180 self.max_encoding_message_size = Some(limit);
181 self
182 }
183
184 #[doc(hidden)]
185 pub fn apply_compression_config(
186 self,
187 accept_encodings: EnabledCompressionEncodings,
188 send_encodings: EnabledCompressionEncodings,
189 ) -> Self {
190 let mut this = self;
191
192 for &encoding in CompressionEncoding::ENCODINGS {
193 if accept_encodings.is_enabled(encoding) {
194 this = this.accept_compressed(encoding);
195 }
196 if send_encodings.is_enabled(encoding) {
197 this = this.send_compressed(encoding);
198 }
199 }
200
201 this
202 }
203
204 #[doc(hidden)]
205 pub fn apply_max_message_size_config(
206 self,
207 max_decoding_message_size: Option<usize>,
208 max_encoding_message_size: Option<usize>,
209 ) -> Self {
210 let mut this = self;
211
212 if let Some(limit) = max_decoding_message_size {
213 this = this.max_decoding_message_size(limit);
214 }
215 if let Some(limit) = max_encoding_message_size {
216 this = this.max_encoding_message_size(limit);
217 }
218
219 this
220 }
221
222 pub async fn unary<S, B>(
224 &mut self,
225 mut service: S,
226 req: http::Request<B>,
227 ) -> http::Response<BoxBody>
228 where
229 S: UnaryService<T::Decode, Response = T::Encode>,
230 B: Body + Send + 'static,
231 B::Error: Into<crate::Error> + Send,
232 {
233 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
234 req.headers(),
235 self.send_compression_encodings,
236 );
237
238 let request = match self.map_request_unary(req).await {
239 Ok(r) => r,
240 Err(status) => {
241 return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
242 Err(status),
243 accept_encoding,
244 SingleMessageCompressionOverride::default(),
245 self.max_encoding_message_size,
246 );
247 }
248 };
249
250 let response = service
251 .call(request)
252 .await
253 .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
254
255 let compression_override = compression_override_from_response(&response);
256
257 self.map_response(
258 response,
259 accept_encoding,
260 compression_override,
261 self.max_encoding_message_size,
262 )
263 }
264
265 pub async fn server_streaming<S, B>(
267 &mut self,
268 mut service: S,
269 req: http::Request<B>,
270 ) -> http::Response<BoxBody>
271 where
272 S: ServerStreamingService<T::Decode, Response = T::Encode>,
273 S::ResponseStream: Send + 'static,
274 B: Body + Send + 'static,
275 B::Error: Into<crate::Error> + Send,
276 {
277 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
278 req.headers(),
279 self.send_compression_encodings,
280 );
281
282 let request = match self.map_request_unary(req).await {
283 Ok(r) => r,
284 Err(status) => {
285 return self.map_response::<S::ResponseStream>(
286 Err(status),
287 accept_encoding,
288 SingleMessageCompressionOverride::default(),
289 self.max_encoding_message_size,
290 );
291 }
292 };
293
294 let response = service.call(request).await;
295
296 self.map_response(
297 response,
298 accept_encoding,
299 SingleMessageCompressionOverride::default(),
302 self.max_encoding_message_size,
303 )
304 }
305
306 pub async fn client_streaming<S, B>(
308 &mut self,
309 mut service: S,
310 req: http::Request<B>,
311 ) -> http::Response<BoxBody>
312 where
313 S: ClientStreamingService<T::Decode, Response = T::Encode>,
314 B: Body + Send + 'static,
315 B::Error: Into<crate::Error> + Send + 'static,
316 {
317 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
318 req.headers(),
319 self.send_compression_encodings,
320 );
321
322 let request = t!(self.map_request_streaming(req));
323
324 let response = service
325 .call(request)
326 .await
327 .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
328
329 let compression_override = compression_override_from_response(&response);
330
331 self.map_response(
332 response,
333 accept_encoding,
334 compression_override,
335 self.max_encoding_message_size,
336 )
337 }
338
339 pub async fn streaming<S, B>(
341 &mut self,
342 mut service: S,
343 req: http::Request<B>,
344 ) -> http::Response<BoxBody>
345 where
346 S: StreamingService<T::Decode, Response = T::Encode> + Send,
347 S::ResponseStream: Send + 'static,
348 B: Body + Send + 'static,
349 B::Error: Into<crate::Error> + Send,
350 {
351 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
352 req.headers(),
353 self.send_compression_encodings,
354 );
355
356 let request = t!(self.map_request_streaming(req));
357
358 let response = service.call(request).await;
359
360 self.map_response(
361 response,
362 accept_encoding,
363 SingleMessageCompressionOverride::default(),
364 self.max_encoding_message_size,
365 )
366 }
367
368 async fn map_request_unary<B>(
369 &mut self,
370 request: http::Request<B>,
371 ) -> Result<Request<T::Decode>, Status>
372 where
373 B: Body + Send + 'static,
374 B::Error: Into<crate::Error> + Send,
375 {
376 let request_compression_encoding = self.request_encoding_if_supported(&request)?;
377
378 let (parts, body) = request.into_parts();
379
380 let mut stream = pin!(Streaming::new_request(
381 self.codec.decoder(),
382 body,
383 request_compression_encoding,
384 self.max_decoding_message_size,
385 ));
386
387 let message = stream
388 .try_next()
389 .await?
390 .ok_or_else(|| Status::internal("Missing request message."))?;
391
392 let mut req = Request::from_http_parts(parts, message);
393
394 if let Some(trailers) = stream.trailers().await? {
395 req.metadata_mut().merge(trailers);
396 }
397
398 Ok(req)
399 }
400
401 fn map_request_streaming<B>(
402 &mut self,
403 request: http::Request<B>,
404 ) -> Result<Request<Streaming<T::Decode>>, Status>
405 where
406 B: Body + Send + 'static,
407 B::Error: Into<crate::Error> + Send,
408 {
409 let encoding = self.request_encoding_if_supported(&request)?;
410
411 let request = request.map(|body| {
412 Streaming::new_request(
413 self.codec.decoder(),
414 body,
415 encoding,
416 self.max_decoding_message_size,
417 )
418 });
419
420 Ok(Request::from_http(request))
421 }
422
423 fn map_response<B>(
424 &mut self,
425 response: Result<crate::Response<B>, Status>,
426 accept_encoding: Option<CompressionEncoding>,
427 compression_override: SingleMessageCompressionOverride,
428 max_message_size: Option<usize>,
429 ) -> http::Response<BoxBody>
430 where
431 B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
432 {
433 let response = t!(response);
434
435 let (mut parts, body) = response.into_http().into_parts();
436
437 parts
439 .headers
440 .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
441
442 #[cfg(any(feature = "gzip", feature = "zstd"))]
443 if let Some(encoding) = accept_encoding {
444 parts.headers.insert(
446 crate::codec::compression::ENCODING_HEADER,
447 encoding.into_header_value(),
448 );
449 }
450
451 let body = EncodeBody::new_server(
452 self.codec.encoder(),
453 body,
454 accept_encoding,
455 compression_override,
456 max_message_size,
457 );
458
459 http::Response::from_parts(parts, BoxBody::new(body))
460 }
461
462 fn request_encoding_if_supported<B>(
463 &self,
464 request: &http::Request<B>,
465 ) -> Result<Option<CompressionEncoding>, Status> {
466 CompressionEncoding::from_encoding_header(
467 request.headers(),
468 self.accept_compression_encodings,
469 )
470 }
471}
472
473impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
474 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475 let mut f = f.debug_struct("Grpc");
476
477 f.field("codec", &self.codec);
478
479 f.field(
480 "accept_compression_encodings",
481 &self.accept_compression_encodings,
482 );
483
484 f.field(
485 "send_compression_encodings",
486 &self.send_compression_encodings,
487 );
488
489 f.finish()
490 }
491}
492
493fn compression_override_from_response<B, E>(
494 res: &Result<crate::Response<B>, E>,
495) -> SingleMessageCompressionOverride {
496 res.as_ref()
497 .ok()
498 .and_then(|response| {
499 response
500 .extensions()
501 .get::<SingleMessageCompressionOverride>()
502 .copied()
503 })
504 .unwrap_or_default()
505}