tonic/server/
grpc.rs

1use crate::codec::compression::{
2    CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3};
4use crate::codec::EncodeBody;
5use crate::metadata::GRPC_CONTENT_TYPE;
6use crate::{
7    body::BoxBody,
8    codec::{Codec, Streaming},
9    server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
10    Request, Status,
11};
12use http_body::Body;
13use std::{fmt, pin::pin};
14use tokio_stream::{Stream, StreamExt};
15
16macro_rules! t {
17    ($result:expr) => {
18        match $result {
19            Ok(value) => value,
20            Err(status) => return status.into_http(),
21        }
22    };
23}
24
25/// A gRPC Server handler.
26///
27/// This will wrap some inner [`Codec`] and provide utilities to handle
28/// inbound unary, client side streaming, server side streaming, and
29/// bi-directional streaming.
30///
31/// Each request handler method accepts some service that implements the
32/// corresponding service trait and a http request that contains some body that
33/// implements some [`Body`].
34pub struct Grpc<T> {
35    codec: T,
36    /// Which compression encodings does the server accept for requests?
37    accept_compression_encodings: EnabledCompressionEncodings,
38    /// Which compression encodings might the server use for responses.
39    send_compression_encodings: EnabledCompressionEncodings,
40    /// Limits the maximum size of a decoded message.
41    max_decoding_message_size: Option<usize>,
42    /// Limits the maximum size of an encoded message.
43    max_encoding_message_size: Option<usize>,
44}
45
46impl<T> Grpc<T>
47where
48    T: Codec,
49{
50    /// Creates a new gRPC server with the provided [`Codec`].
51    pub fn new(codec: T) -> Self {
52        Self {
53            codec,
54            accept_compression_encodings: EnabledCompressionEncodings::default(),
55            send_compression_encodings: EnabledCompressionEncodings::default(),
56            max_decoding_message_size: None,
57            max_encoding_message_size: None,
58        }
59    }
60
61    /// Enable accepting compressed requests.
62    ///
63    /// If a request with an unsupported encoding is received the server will respond with
64    /// [`Code::UnUnimplemented`](crate::Code).
65    ///
66    /// # Example
67    ///
68    /// The most common way of using this is through a server generated by tonic-build:
69    ///
70    /// ```rust
71    /// # enum CompressionEncoding { Gzip }
72    /// # struct Svc;
73    /// # struct ExampleServer<T>(T);
74    /// # impl<T> ExampleServer<T> {
75    /// #     fn new(svc: T) -> Self { Self(svc) }
76    /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
77    /// # }
78    /// # #[tonic::async_trait]
79    /// # trait Example {}
80    ///
81    /// #[tonic::async_trait]
82    /// impl Example for Svc {
83    ///     // ...
84    /// }
85    ///
86    /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip);
87    /// ```
88    pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
89        self.accept_compression_encodings.enable(encoding);
90        self
91    }
92
93    /// Enable sending compressed responses.
94    ///
95    /// Requires the client to also support receiving compressed responses.
96    ///
97    /// # Example
98    ///
99    /// The most common way of using this is through a server generated by tonic-build:
100    ///
101    /// ```rust
102    /// # enum CompressionEncoding { Gzip }
103    /// # struct Svc;
104    /// # struct ExampleServer<T>(T);
105    /// # impl<T> ExampleServer<T> {
106    /// #     fn new(svc: T) -> Self { Self(svc) }
107    /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
108    /// # }
109    /// # #[tonic::async_trait]
110    /// # trait Example {}
111    ///
112    /// #[tonic::async_trait]
113    /// impl Example for Svc {
114    ///     // ...
115    /// }
116    ///
117    /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip);
118    /// ```
119    pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
120        self.send_compression_encodings.enable(encoding);
121        self
122    }
123
124    /// Limits the maximum size of a decoded message.
125    ///
126    /// # Example
127    ///
128    /// The most common way of using this is through a server generated by tonic-build:
129    ///
130    /// ```rust
131    /// # struct Svc;
132    /// # struct ExampleServer<T>(T);
133    /// # impl<T> ExampleServer<T> {
134    /// #     fn new(svc: T) -> Self { Self(svc) }
135    /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
136    /// # }
137    /// # #[tonic::async_trait]
138    /// # trait Example {}
139    ///
140    /// #[tonic::async_trait]
141    /// impl Example for Svc {
142    ///     // ...
143    /// }
144    ///
145    /// // Set the limit to 2MB, Defaults to 4MB.
146    /// let limit = 2 * 1024 * 1024;
147    /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit);
148    /// ```
149    pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
150        self.max_decoding_message_size = Some(limit);
151        self
152    }
153
154    /// Limits the maximum size of a encoded message.
155    ///
156    /// # Example
157    ///
158    /// The most common way of using this is through a server generated by tonic-build:
159    ///
160    /// ```rust
161    /// # struct Svc;
162    /// # struct ExampleServer<T>(T);
163    /// # impl<T> ExampleServer<T> {
164    /// #     fn new(svc: T) -> Self { Self(svc) }
165    /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
166    /// # }
167    /// # #[tonic::async_trait]
168    /// # trait Example {}
169    ///
170    /// #[tonic::async_trait]
171    /// impl Example for Svc {
172    ///     // ...
173    /// }
174    ///
175    /// // Set the limit to 2MB, Defaults to 4MB.
176    /// let limit = 2 * 1024 * 1024;
177    /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit);
178    /// ```
179    pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
180        self.max_encoding_message_size = Some(limit);
181        self
182    }
183
184    #[doc(hidden)]
185    pub fn apply_compression_config(
186        self,
187        accept_encodings: EnabledCompressionEncodings,
188        send_encodings: EnabledCompressionEncodings,
189    ) -> Self {
190        let mut this = self;
191
192        for &encoding in CompressionEncoding::ENCODINGS {
193            if accept_encodings.is_enabled(encoding) {
194                this = this.accept_compressed(encoding);
195            }
196            if send_encodings.is_enabled(encoding) {
197                this = this.send_compressed(encoding);
198            }
199        }
200
201        this
202    }
203
204    #[doc(hidden)]
205    pub fn apply_max_message_size_config(
206        self,
207        max_decoding_message_size: Option<usize>,
208        max_encoding_message_size: Option<usize>,
209    ) -> Self {
210        let mut this = self;
211
212        if let Some(limit) = max_decoding_message_size {
213            this = this.max_decoding_message_size(limit);
214        }
215        if let Some(limit) = max_encoding_message_size {
216            this = this.max_encoding_message_size(limit);
217        }
218
219        this
220    }
221
222    /// Handle a single unary gRPC request.
223    pub async fn unary<S, B>(
224        &mut self,
225        mut service: S,
226        req: http::Request<B>,
227    ) -> http::Response<BoxBody>
228    where
229        S: UnaryService<T::Decode, Response = T::Encode>,
230        B: Body + Send + 'static,
231        B::Error: Into<crate::Error> + Send,
232    {
233        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
234            req.headers(),
235            self.send_compression_encodings,
236        );
237
238        let request = match self.map_request_unary(req).await {
239            Ok(r) => r,
240            Err(status) => {
241                return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
242                    Err(status),
243                    accept_encoding,
244                    SingleMessageCompressionOverride::default(),
245                    self.max_encoding_message_size,
246                );
247            }
248        };
249
250        let response = service
251            .call(request)
252            .await
253            .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
254
255        let compression_override = compression_override_from_response(&response);
256
257        self.map_response(
258            response,
259            accept_encoding,
260            compression_override,
261            self.max_encoding_message_size,
262        )
263    }
264
265    /// Handle a server side streaming request.
266    pub async fn server_streaming<S, B>(
267        &mut self,
268        mut service: S,
269        req: http::Request<B>,
270    ) -> http::Response<BoxBody>
271    where
272        S: ServerStreamingService<T::Decode, Response = T::Encode>,
273        S::ResponseStream: Send + 'static,
274        B: Body + Send + 'static,
275        B::Error: Into<crate::Error> + Send,
276    {
277        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
278            req.headers(),
279            self.send_compression_encodings,
280        );
281
282        let request = match self.map_request_unary(req).await {
283            Ok(r) => r,
284            Err(status) => {
285                return self.map_response::<S::ResponseStream>(
286                    Err(status),
287                    accept_encoding,
288                    SingleMessageCompressionOverride::default(),
289                    self.max_encoding_message_size,
290                );
291            }
292        };
293
294        let response = service.call(request).await;
295
296        self.map_response(
297            response,
298            accept_encoding,
299            // disabling compression of individual stream items must be done on
300            // the items themselves
301            SingleMessageCompressionOverride::default(),
302            self.max_encoding_message_size,
303        )
304    }
305
306    /// Handle a client side streaming gRPC request.
307    pub async fn client_streaming<S, B>(
308        &mut self,
309        mut service: S,
310        req: http::Request<B>,
311    ) -> http::Response<BoxBody>
312    where
313        S: ClientStreamingService<T::Decode, Response = T::Encode>,
314        B: Body + Send + 'static,
315        B::Error: Into<crate::Error> + Send + 'static,
316    {
317        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
318            req.headers(),
319            self.send_compression_encodings,
320        );
321
322        let request = t!(self.map_request_streaming(req));
323
324        let response = service
325            .call(request)
326            .await
327            .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
328
329        let compression_override = compression_override_from_response(&response);
330
331        self.map_response(
332            response,
333            accept_encoding,
334            compression_override,
335            self.max_encoding_message_size,
336        )
337    }
338
339    /// Handle a bi-directional streaming gRPC request.
340    pub async fn streaming<S, B>(
341        &mut self,
342        mut service: S,
343        req: http::Request<B>,
344    ) -> http::Response<BoxBody>
345    where
346        S: StreamingService<T::Decode, Response = T::Encode> + Send,
347        S::ResponseStream: Send + 'static,
348        B: Body + Send + 'static,
349        B::Error: Into<crate::Error> + Send,
350    {
351        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
352            req.headers(),
353            self.send_compression_encodings,
354        );
355
356        let request = t!(self.map_request_streaming(req));
357
358        let response = service.call(request).await;
359
360        self.map_response(
361            response,
362            accept_encoding,
363            SingleMessageCompressionOverride::default(),
364            self.max_encoding_message_size,
365        )
366    }
367
368    async fn map_request_unary<B>(
369        &mut self,
370        request: http::Request<B>,
371    ) -> Result<Request<T::Decode>, Status>
372    where
373        B: Body + Send + 'static,
374        B::Error: Into<crate::Error> + Send,
375    {
376        let request_compression_encoding = self.request_encoding_if_supported(&request)?;
377
378        let (parts, body) = request.into_parts();
379
380        let mut stream = pin!(Streaming::new_request(
381            self.codec.decoder(),
382            body,
383            request_compression_encoding,
384            self.max_decoding_message_size,
385        ));
386
387        let message = stream
388            .try_next()
389            .await?
390            .ok_or_else(|| Status::internal("Missing request message."))?;
391
392        let mut req = Request::from_http_parts(parts, message);
393
394        if let Some(trailers) = stream.trailers().await? {
395            req.metadata_mut().merge(trailers);
396        }
397
398        Ok(req)
399    }
400
401    fn map_request_streaming<B>(
402        &mut self,
403        request: http::Request<B>,
404    ) -> Result<Request<Streaming<T::Decode>>, Status>
405    where
406        B: Body + Send + 'static,
407        B::Error: Into<crate::Error> + Send,
408    {
409        let encoding = self.request_encoding_if_supported(&request)?;
410
411        let request = request.map(|body| {
412            Streaming::new_request(
413                self.codec.decoder(),
414                body,
415                encoding,
416                self.max_decoding_message_size,
417            )
418        });
419
420        Ok(Request::from_http(request))
421    }
422
423    fn map_response<B>(
424        &mut self,
425        response: Result<crate::Response<B>, Status>,
426        accept_encoding: Option<CompressionEncoding>,
427        compression_override: SingleMessageCompressionOverride,
428        max_message_size: Option<usize>,
429    ) -> http::Response<BoxBody>
430    where
431        B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
432    {
433        let response = t!(response);
434
435        let (mut parts, body) = response.into_http().into_parts();
436
437        // Set the content type
438        parts
439            .headers
440            .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
441
442        #[cfg(any(feature = "gzip", feature = "zstd"))]
443        if let Some(encoding) = accept_encoding {
444            // Set the content encoding
445            parts.headers.insert(
446                crate::codec::compression::ENCODING_HEADER,
447                encoding.into_header_value(),
448            );
449        }
450
451        let body = EncodeBody::new_server(
452            self.codec.encoder(),
453            body,
454            accept_encoding,
455            compression_override,
456            max_message_size,
457        );
458
459        http::Response::from_parts(parts, BoxBody::new(body))
460    }
461
462    fn request_encoding_if_supported<B>(
463        &self,
464        request: &http::Request<B>,
465    ) -> Result<Option<CompressionEncoding>, Status> {
466        CompressionEncoding::from_encoding_header(
467            request.headers(),
468            self.accept_compression_encodings,
469        )
470    }
471}
472
473impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
474    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475        let mut f = f.debug_struct("Grpc");
476
477        f.field("codec", &self.codec);
478
479        f.field(
480            "accept_compression_encodings",
481            &self.accept_compression_encodings,
482        );
483
484        f.field(
485            "send_compression_encodings",
486            &self.send_compression_encodings,
487        );
488
489        f.finish()
490    }
491}
492
493fn compression_override_from_response<B, E>(
494    res: &Result<crate::Response<B>, E>,
495) -> SingleMessageCompressionOverride {
496    res.as_ref()
497        .ok()
498        .and_then(|response| {
499            response
500                .extensions()
501                .get::<SingleMessageCompressionOverride>()
502                .copied()
503        })
504        .unwrap_or_default()
505}