1use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
2use crate::codec::EncodeBody;
3use crate::metadata::GRPC_CONTENT_TYPE;
4use crate::{
5 body::BoxBody,
6 client::GrpcService,
7 codec::{Codec, Decoder, Streaming},
8 request::SanitizeHeaders,
9 Code, Request, Response, Status,
10};
11use http::{
12 header::{HeaderValue, CONTENT_TYPE, TE},
13 uri::{PathAndQuery, Uri},
14};
15use http_body::Body;
16use std::{fmt, future, pin::pin};
17use tokio_stream::{Stream, StreamExt};
18
19pub struct Grpc<T> {
33 inner: T,
34 config: GrpcConfig,
35}
36
37struct GrpcConfig {
38 origin: Uri,
39 accept_compression_encodings: EnabledCompressionEncodings,
41 send_compression_encodings: Option<CompressionEncoding>,
43 max_decoding_message_size: Option<usize>,
45 max_encoding_message_size: Option<usize>,
47}
48
49impl<T> Grpc<T> {
50 pub fn new(inner: T) -> Self {
52 Self::with_origin(inner, Uri::default())
53 }
54
55 pub fn with_origin(inner: T, origin: Uri) -> Self {
60 Self {
61 inner,
62 config: GrpcConfig {
63 origin,
64 send_compression_encodings: None,
65 accept_compression_encodings: EnabledCompressionEncodings::default(),
66 max_decoding_message_size: None,
67 max_encoding_message_size: None,
68 },
69 }
70 }
71
72 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
99 self.config.send_compression_encodings = Some(encoding);
100 self
101 }
102
103 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
130 self.config.accept_compression_encodings.enable(encoding);
131 self
132 }
133
134 pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
160 self.config.max_decoding_message_size = Some(limit);
161 self
162 }
163
164 pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
190 self.config.max_encoding_message_size = Some(limit);
191 self
192 }
193
194 pub async fn ready(&mut self) -> Result<(), T::Error>
200 where
201 T: GrpcService<BoxBody>,
202 {
203 future::poll_fn(|cx| self.inner.poll_ready(cx)).await
204 }
205
206 pub async fn unary<M1, M2, C>(
208 &mut self,
209 request: Request<M1>,
210 path: PathAndQuery,
211 codec: C,
212 ) -> Result<Response<M2>, Status>
213 where
214 T: GrpcService<BoxBody>,
215 T::ResponseBody: Body + Send + 'static,
216 <T::ResponseBody as Body>::Error: Into<crate::Error>,
217 C: Codec<Encode = M1, Decode = M2>,
218 M1: Send + Sync + 'static,
219 M2: Send + Sync + 'static,
220 {
221 let request = request.map(|m| tokio_stream::once(m));
222 self.client_streaming(request, path, codec).await
223 }
224
225 pub async fn client_streaming<S, M1, M2, C>(
227 &mut self,
228 request: Request<S>,
229 path: PathAndQuery,
230 codec: C,
231 ) -> Result<Response<M2>, Status>
232 where
233 T: GrpcService<BoxBody>,
234 T::ResponseBody: Body + Send + 'static,
235 <T::ResponseBody as Body>::Error: Into<crate::Error>,
236 S: Stream<Item = M1> + Send + 'static,
237 C: Codec<Encode = M1, Decode = M2>,
238 M1: Send + Sync + 'static,
239 M2: Send + Sync + 'static,
240 {
241 let (mut parts, body, extensions) =
242 self.streaming(request, path, codec).await?.into_parts();
243
244 let mut body = pin!(body);
245
246 let message = body
247 .try_next()
248 .await
249 .map_err(|mut status| {
250 status.metadata_mut().merge(parts.clone());
251 status
252 })?
253 .ok_or_else(|| Status::internal("Missing response message."))?;
254
255 if let Some(trailers) = body.trailers().await? {
256 parts.merge(trailers);
257 }
258
259 Ok(Response::from_parts(parts, message, extensions))
260 }
261
262 pub async fn server_streaming<M1, M2, C>(
264 &mut self,
265 request: Request<M1>,
266 path: PathAndQuery,
267 codec: C,
268 ) -> Result<Response<Streaming<M2>>, Status>
269 where
270 T: GrpcService<BoxBody>,
271 T::ResponseBody: Body + Send + 'static,
272 <T::ResponseBody as Body>::Error: Into<crate::Error>,
273 C: Codec<Encode = M1, Decode = M2>,
274 M1: Send + Sync + 'static,
275 M2: Send + Sync + 'static,
276 {
277 let request = request.map(|m| tokio_stream::once(m));
278 self.streaming(request, path, codec).await
279 }
280
281 pub async fn streaming<S, M1, M2, C>(
283 &mut self,
284 request: Request<S>,
285 path: PathAndQuery,
286 mut codec: C,
287 ) -> Result<Response<Streaming<M2>>, Status>
288 where
289 T: GrpcService<BoxBody>,
290 T::ResponseBody: Body + Send + 'static,
291 <T::ResponseBody as Body>::Error: Into<crate::Error>,
292 S: Stream<Item = M1> + Send + 'static,
293 C: Codec<Encode = M1, Decode = M2>,
294 M1: Send + Sync + 'static,
295 M2: Send + Sync + 'static,
296 {
297 let request = request
298 .map(|s| {
299 EncodeBody::new_client(
300 codec.encoder(),
301 s.map(Ok),
302 self.config.send_compression_encodings,
303 self.config.max_encoding_message_size,
304 )
305 })
306 .map(BoxBody::new);
307
308 let request = self.config.prepare_request(request, path);
309
310 let response = self
311 .inner
312 .call(request)
313 .await
314 .map_err(Status::from_error_generic)?;
315
316 let decoder = codec.decoder();
317
318 self.create_response(decoder, response)
319 }
320
321 fn create_response<M2>(
324 &self,
325 decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static,
326 response: http::Response<T::ResponseBody>,
327 ) -> Result<Response<Streaming<M2>>, Status>
328 where
329 T: GrpcService<BoxBody>,
330 T::ResponseBody: Body + Send + 'static,
331 <T::ResponseBody as Body>::Error: Into<crate::Error>,
332 {
333 let encoding = CompressionEncoding::from_encoding_header(
334 response.headers(),
335 self.config.accept_compression_encodings,
336 )?;
337
338 let status_code = response.status();
339 let trailers_only_status = Status::from_header_map(response.headers());
340
341 let expect_additional_trailers = if let Some(status) = trailers_only_status {
344 if status.code() != Code::Ok {
345 return Err(status);
346 }
347
348 false
349 } else {
350 true
351 };
352
353 let response = response.map(|body| {
354 if expect_additional_trailers {
355 Streaming::new_response(
356 decoder,
357 body,
358 status_code,
359 encoding,
360 self.config.max_decoding_message_size,
361 )
362 } else {
363 Streaming::new_empty(decoder, body)
364 }
365 });
366
367 Ok(Response::from_http(response))
368 }
369}
370
371impl GrpcConfig {
372 fn prepare_request(
373 &self,
374 request: Request<BoxBody>,
375 path: PathAndQuery,
376 ) -> http::Request<BoxBody> {
377 let mut parts = self.origin.clone().into_parts();
378
379 match &parts.path_and_query {
380 Some(pnq) if pnq != "/" => {
381 parts.path_and_query = Some(
382 format!("{}{}", pnq.path(), path)
383 .parse()
384 .expect("must form valid path_and_query"),
385 )
386 }
387 _ => {
388 parts.path_and_query = Some(path);
389 }
390 }
391
392 let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
393
394 let mut request = request.into_http(
395 uri,
396 http::Method::POST,
397 http::Version::HTTP_2,
398 SanitizeHeaders::Yes,
399 );
400
401 request
403 .headers_mut()
404 .insert(TE, HeaderValue::from_static("trailers"));
405
406 request
408 .headers_mut()
409 .insert(CONTENT_TYPE, GRPC_CONTENT_TYPE);
410
411 #[cfg(any(feature = "gzip", feature = "zstd"))]
412 if let Some(encoding) = self.send_compression_encodings {
413 request.headers_mut().insert(
414 crate::codec::compression::ENCODING_HEADER,
415 encoding.into_header_value(),
416 );
417 }
418
419 if let Some(header_value) = self
420 .accept_compression_encodings
421 .into_accept_encoding_header_value()
422 {
423 request.headers_mut().insert(
424 crate::codec::compression::ACCEPT_ENCODING_HEADER,
425 header_value,
426 );
427 }
428
429 request
430 }
431}
432
433impl<T: Clone> Clone for Grpc<T> {
434 fn clone(&self) -> Self {
435 Self {
436 inner: self.inner.clone(),
437 config: GrpcConfig {
438 origin: self.config.origin.clone(),
439 send_compression_encodings: self.config.send_compression_encodings,
440 accept_compression_encodings: self.config.accept_compression_encodings,
441 max_encoding_message_size: self.config.max_encoding_message_size,
442 max_decoding_message_size: self.config.max_decoding_message_size,
443 },
444 }
445 }
446}
447
448impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
449 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450 let mut f = f.debug_struct("Grpc");
451
452 f.field("inner", &self.inner);
453
454 f.field("origin", &self.config.origin);
455
456 f.field(
457 "compression_encoding",
458 &self.config.send_compression_encodings,
459 );
460
461 f.field(
462 "accept_compression_encodings",
463 &self.config.accept_compression_encodings,
464 );
465
466 f.field(
467 "max_decoding_message_size",
468 &self.config.max_decoding_message_size,
469 );
470
471 f.field(
472 "max_encoding_message_size",
473 &self.config.max_encoding_message_size,
474 );
475
476 f.finish()
477 }
478}