tonic/client/
grpc.rs

1use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
2use crate::codec::EncodeBody;
3use crate::metadata::GRPC_CONTENT_TYPE;
4use crate::{
5    body::BoxBody,
6    client::GrpcService,
7    codec::{Codec, Decoder, Streaming},
8    request::SanitizeHeaders,
9    Code, Request, Response, Status,
10};
11use http::{
12    header::{HeaderValue, CONTENT_TYPE, TE},
13    uri::{PathAndQuery, Uri},
14};
15use http_body::Body;
16use std::{fmt, future, pin::pin};
17use tokio_stream::{Stream, StreamExt};
18
19/// A gRPC client dispatcher.
20///
21/// This will wrap some inner [`GrpcService`] and will encode/decode
22/// messages via the provided codec.
23///
24/// Each request method takes a [`Request`], a [`PathAndQuery`], and a
25/// [`Codec`]. The request contains the message to send via the
26/// [`Codec::encoder`]. The path determines the fully qualified path
27/// that will be append to the outgoing uri. The path must follow
28/// the conventions explained in the [gRPC protocol definition] under `Path →`. An
29/// example of this path could look like `/greeter.Greeter/SayHello`.
30///
31/// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
32pub struct Grpc<T> {
33    inner: T,
34    config: GrpcConfig,
35}
36
37struct GrpcConfig {
38    origin: Uri,
39    /// Which compression encodings does the client accept?
40    accept_compression_encodings: EnabledCompressionEncodings,
41    /// The compression encoding that will be applied to requests.
42    send_compression_encodings: Option<CompressionEncoding>,
43    /// Limits the maximum size of a decoded message.
44    max_decoding_message_size: Option<usize>,
45    /// Limits the maximum size of an encoded message.
46    max_encoding_message_size: Option<usize>,
47}
48
49impl<T> Grpc<T> {
50    /// Creates a new gRPC client with the provided [`GrpcService`].
51    pub fn new(inner: T) -> Self {
52        Self::with_origin(inner, Uri::default())
53    }
54
55    /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`.
56    ///
57    /// The provided Uri will use only the scheme and authority parts as the
58    /// path_and_query portion will be set for each method.
59    pub fn with_origin(inner: T, origin: Uri) -> Self {
60        Self {
61            inner,
62            config: GrpcConfig {
63                origin,
64                send_compression_encodings: None,
65                accept_compression_encodings: EnabledCompressionEncodings::default(),
66                max_decoding_message_size: None,
67                max_encoding_message_size: None,
68            },
69        }
70    }
71
72    /// Compress requests with the provided encoding.
73    ///
74    /// Requires the server to accept the specified encoding, otherwise it might return an error.
75    ///
76    /// # Example
77    ///
78    /// The most common way of using this is through a client generated by tonic-build:
79    ///
80    /// ```rust
81    /// use tonic::transport::Channel;
82    /// # enum CompressionEncoding { Gzip }
83    /// # struct TestClient<T>(T);
84    /// # impl<T> TestClient<T> {
85    /// #     fn new(channel: T) -> Self { Self(channel) }
86    /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
87    /// # }
88    ///
89    /// # async {
90    /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
91    ///     .connect()
92    ///     .await
93    ///     .unwrap();
94    ///
95    /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip);
96    /// # };
97    /// ```
98    pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
99        self.config.send_compression_encodings = Some(encoding);
100        self
101    }
102
103    /// Enable accepting compressed responses.
104    ///
105    /// Requires the server to also support sending compressed responses.
106    ///
107    /// # Example
108    ///
109    /// The most common way of using this is through a client generated by tonic-build:
110    ///
111    /// ```rust
112    /// use tonic::transport::Channel;
113    /// # enum CompressionEncoding { Gzip }
114    /// # struct TestClient<T>(T);
115    /// # impl<T> TestClient<T> {
116    /// #     fn new(channel: T) -> Self { Self(channel) }
117    /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
118    /// # }
119    ///
120    /// # async {
121    /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
122    ///     .connect()
123    ///     .await
124    ///     .unwrap();
125    ///
126    /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip);
127    /// # };
128    /// ```
129    pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
130        self.config.accept_compression_encodings.enable(encoding);
131        self
132    }
133
134    /// Limits the maximum size of a decoded message.
135    ///
136    /// # Example
137    ///
138    /// The most common way of using this is through a client generated by tonic-build:
139    ///
140    /// ```rust
141    /// use tonic::transport::Channel;
142    /// # struct TestClient<T>(T);
143    /// # impl<T> TestClient<T> {
144    /// #     fn new(channel: T) -> Self { Self(channel) }
145    /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
146    /// # }
147    ///
148    /// # async {
149    /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
150    ///     .connect()
151    ///     .await
152    ///     .unwrap();
153    ///
154    /// // Set the limit to 2MB, Defaults to 4MB.
155    /// let limit = 2 * 1024 * 1024;
156    /// let client = TestClient::new(channel).max_decoding_message_size(limit);
157    /// # };
158    /// ```
159    pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
160        self.config.max_decoding_message_size = Some(limit);
161        self
162    }
163
164    /// Limits the maximum size of an encoded message.
165    ///
166    /// # Example
167    ///
168    /// The most common way of using this is through a client generated by tonic-build:
169    ///
170    /// ```rust
171    /// use tonic::transport::Channel;
172    /// # struct TestClient<T>(T);
173    /// # impl<T> TestClient<T> {
174    /// #     fn new(channel: T) -> Self { Self(channel) }
175    /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
176    /// # }
177    ///
178    /// # async {
179    /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
180    ///     .connect()
181    ///     .await
182    ///     .unwrap();
183    ///
184    /// // Set the limit to 2MB, Defaults to 4MB.
185    /// let limit = 2 * 1024 * 1024;
186    /// let client = TestClient::new(channel).max_encoding_message_size(limit);
187    /// # };
188    /// ```
189    pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
190        self.config.max_encoding_message_size = Some(limit);
191        self
192    }
193
194    /// Check if the inner [`GrpcService`] is able to accept a  new request.
195    ///
196    /// This will call [`GrpcService::poll_ready`] until it returns ready or
197    /// an error. If this returns ready the inner [`GrpcService`] is ready to
198    /// accept one more request.
199    pub async fn ready(&mut self) -> Result<(), T::Error>
200    where
201        T: GrpcService<BoxBody>,
202    {
203        future::poll_fn(|cx| self.inner.poll_ready(cx)).await
204    }
205
206    /// Send a single unary gRPC request.
207    pub async fn unary<M1, M2, C>(
208        &mut self,
209        request: Request<M1>,
210        path: PathAndQuery,
211        codec: C,
212    ) -> Result<Response<M2>, Status>
213    where
214        T: GrpcService<BoxBody>,
215        T::ResponseBody: Body + Send + 'static,
216        <T::ResponseBody as Body>::Error: Into<crate::Error>,
217        C: Codec<Encode = M1, Decode = M2>,
218        M1: Send + Sync + 'static,
219        M2: Send + Sync + 'static,
220    {
221        let request = request.map(|m| tokio_stream::once(m));
222        self.client_streaming(request, path, codec).await
223    }
224
225    /// Send a client side streaming gRPC request.
226    pub async fn client_streaming<S, M1, M2, C>(
227        &mut self,
228        request: Request<S>,
229        path: PathAndQuery,
230        codec: C,
231    ) -> Result<Response<M2>, Status>
232    where
233        T: GrpcService<BoxBody>,
234        T::ResponseBody: Body + Send + 'static,
235        <T::ResponseBody as Body>::Error: Into<crate::Error>,
236        S: Stream<Item = M1> + Send + 'static,
237        C: Codec<Encode = M1, Decode = M2>,
238        M1: Send + Sync + 'static,
239        M2: Send + Sync + 'static,
240    {
241        let (mut parts, body, extensions) =
242            self.streaming(request, path, codec).await?.into_parts();
243
244        let mut body = pin!(body);
245
246        let message = body
247            .try_next()
248            .await
249            .map_err(|mut status| {
250                status.metadata_mut().merge(parts.clone());
251                status
252            })?
253            .ok_or_else(|| Status::internal("Missing response message."))?;
254
255        if let Some(trailers) = body.trailers().await? {
256            parts.merge(trailers);
257        }
258
259        Ok(Response::from_parts(parts, message, extensions))
260    }
261
262    /// Send a server side streaming gRPC request.
263    pub async fn server_streaming<M1, M2, C>(
264        &mut self,
265        request: Request<M1>,
266        path: PathAndQuery,
267        codec: C,
268    ) -> Result<Response<Streaming<M2>>, Status>
269    where
270        T: GrpcService<BoxBody>,
271        T::ResponseBody: Body + Send + 'static,
272        <T::ResponseBody as Body>::Error: Into<crate::Error>,
273        C: Codec<Encode = M1, Decode = M2>,
274        M1: Send + Sync + 'static,
275        M2: Send + Sync + 'static,
276    {
277        let request = request.map(|m| tokio_stream::once(m));
278        self.streaming(request, path, codec).await
279    }
280
281    /// Send a bi-directional streaming gRPC request.
282    pub async fn streaming<S, M1, M2, C>(
283        &mut self,
284        request: Request<S>,
285        path: PathAndQuery,
286        mut codec: C,
287    ) -> Result<Response<Streaming<M2>>, Status>
288    where
289        T: GrpcService<BoxBody>,
290        T::ResponseBody: Body + Send + 'static,
291        <T::ResponseBody as Body>::Error: Into<crate::Error>,
292        S: Stream<Item = M1> + Send + 'static,
293        C: Codec<Encode = M1, Decode = M2>,
294        M1: Send + Sync + 'static,
295        M2: Send + Sync + 'static,
296    {
297        let request = request
298            .map(|s| {
299                EncodeBody::new_client(
300                    codec.encoder(),
301                    s.map(Ok),
302                    self.config.send_compression_encodings,
303                    self.config.max_encoding_message_size,
304                )
305            })
306            .map(BoxBody::new);
307
308        let request = self.config.prepare_request(request, path);
309
310        let response = self
311            .inner
312            .call(request)
313            .await
314            .map_err(Status::from_error_generic)?;
315
316        let decoder = codec.decoder();
317
318        self.create_response(decoder, response)
319    }
320
321    // Keeping this code in a separate function from Self::streaming lets functions that return the
322    // same output share the generated binary code
323    fn create_response<M2>(
324        &self,
325        decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static,
326        response: http::Response<T::ResponseBody>,
327    ) -> Result<Response<Streaming<M2>>, Status>
328    where
329        T: GrpcService<BoxBody>,
330        T::ResponseBody: Body + Send + 'static,
331        <T::ResponseBody as Body>::Error: Into<crate::Error>,
332    {
333        let encoding = CompressionEncoding::from_encoding_header(
334            response.headers(),
335            self.config.accept_compression_encodings,
336        )?;
337
338        let status_code = response.status();
339        let trailers_only_status = Status::from_header_map(response.headers());
340
341        // We do not need to check for trailers if the `grpc-status` header is present
342        // with a valid code.
343        let expect_additional_trailers = if let Some(status) = trailers_only_status {
344            if status.code() != Code::Ok {
345                return Err(status);
346            }
347
348            false
349        } else {
350            true
351        };
352
353        let response = response.map(|body| {
354            if expect_additional_trailers {
355                Streaming::new_response(
356                    decoder,
357                    body,
358                    status_code,
359                    encoding,
360                    self.config.max_decoding_message_size,
361                )
362            } else {
363                Streaming::new_empty(decoder, body)
364            }
365        });
366
367        Ok(Response::from_http(response))
368    }
369}
370
371impl GrpcConfig {
372    fn prepare_request(
373        &self,
374        request: Request<BoxBody>,
375        path: PathAndQuery,
376    ) -> http::Request<BoxBody> {
377        let mut parts = self.origin.clone().into_parts();
378
379        match &parts.path_and_query {
380            Some(pnq) if pnq != "/" => {
381                parts.path_and_query = Some(
382                    format!("{}{}", pnq.path(), path)
383                        .parse()
384                        .expect("must form valid path_and_query"),
385                )
386            }
387            _ => {
388                parts.path_and_query = Some(path);
389            }
390        }
391
392        let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
393
394        let mut request = request.into_http(
395            uri,
396            http::Method::POST,
397            http::Version::HTTP_2,
398            SanitizeHeaders::Yes,
399        );
400
401        // Add the gRPC related HTTP headers
402        request
403            .headers_mut()
404            .insert(TE, HeaderValue::from_static("trailers"));
405
406        // Set the content type
407        request
408            .headers_mut()
409            .insert(CONTENT_TYPE, GRPC_CONTENT_TYPE);
410
411        #[cfg(any(feature = "gzip", feature = "zstd"))]
412        if let Some(encoding) = self.send_compression_encodings {
413            request.headers_mut().insert(
414                crate::codec::compression::ENCODING_HEADER,
415                encoding.into_header_value(),
416            );
417        }
418
419        if let Some(header_value) = self
420            .accept_compression_encodings
421            .into_accept_encoding_header_value()
422        {
423            request.headers_mut().insert(
424                crate::codec::compression::ACCEPT_ENCODING_HEADER,
425                header_value,
426            );
427        }
428
429        request
430    }
431}
432
433impl<T: Clone> Clone for Grpc<T> {
434    fn clone(&self) -> Self {
435        Self {
436            inner: self.inner.clone(),
437            config: GrpcConfig {
438                origin: self.config.origin.clone(),
439                send_compression_encodings: self.config.send_compression_encodings,
440                accept_compression_encodings: self.config.accept_compression_encodings,
441                max_encoding_message_size: self.config.max_encoding_message_size,
442                max_decoding_message_size: self.config.max_decoding_message_size,
443            },
444        }
445    }
446}
447
448impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
449    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450        let mut f = f.debug_struct("Grpc");
451
452        f.field("inner", &self.inner);
453
454        f.field("origin", &self.config.origin);
455
456        f.field(
457            "compression_encoding",
458            &self.config.send_compression_encodings,
459        );
460
461        f.field(
462            "accept_compression_encodings",
463            &self.config.accept_compression_encodings,
464        );
465
466        f.field(
467            "max_decoding_message_size",
468            &self.config.max_decoding_message_size,
469        );
470
471        f.field(
472            "max_encoding_message_size",
473            &self.config.max_encoding_message_size,
474        );
475
476        f.finish()
477    }
478}