tonic/
request.rs

1use crate::metadata::{MetadataMap, MetadataValue};
2#[cfg(feature = "server")]
3use crate::transport::server::TcpConnectInfo;
4#[cfg(all(feature = "server", feature = "tls"))]
5use crate::transport::server::TlsConnectInfo;
6use http::Extensions;
7#[cfg(feature = "server")]
8use std::net::SocketAddr;
9#[cfg(all(feature = "server", feature = "tls"))]
10use std::sync::Arc;
11use std::time::Duration;
12#[cfg(all(feature = "server", feature = "tls"))]
13use tokio_rustls::rustls::pki_types::CertificateDer;
14use tokio_stream::Stream;
15
16/// A gRPC request and metadata from an RPC call.
17#[derive(Debug)]
18pub struct Request<T> {
19    metadata: MetadataMap,
20    message: T,
21    extensions: Extensions,
22}
23
24/// Trait implemented by RPC request types.
25///
26/// Types implementing this trait can be used as arguments to client RPC
27/// methods without explicitly wrapping them into `tonic::Request`s. The purpose
28/// is to make client calls slightly more convenient to write.
29///
30/// Tonic's code generation and blanket implementations handle this for you,
31/// so it is not necessary to implement this trait directly.
32///
33/// # Example
34///
35/// Given the following gRPC method definition:
36/// ```proto
37/// rpc GetFeature(Point) returns (Feature) {}
38/// ```
39///
40/// we can call `get_feature` in two equivalent ways:
41/// ```rust
42/// # pub struct Point {}
43/// # pub struct Client {}
44/// # impl Client {
45/// #   fn get_feature(&self, r: impl tonic::IntoRequest<Point>) {}
46/// # }
47/// # let client = Client {};
48/// use tonic::Request;
49///
50/// client.get_feature(Point {});
51/// client.get_feature(Request::new(Point {}));
52/// ```
53pub trait IntoRequest<T>: sealed::Sealed {
54    /// Wrap the input message `T` in a `tonic::Request`
55    fn into_request(self) -> Request<T>;
56}
57
58/// Trait implemented by RPC streaming request types.
59///
60/// Types implementing this trait can be used as arguments to client streaming
61/// RPC methods without explicitly wrapping them into `tonic::Request`s. The
62/// purpose is to make client calls slightly more convenient to write.
63///
64/// Tonic's code generation and blanket implementations handle this for you,
65/// so it is not necessary to implement this trait directly.
66///
67/// # Example
68///
69/// Given the following gRPC service method definition:
70/// ```proto
71/// rpc RecordRoute(stream Point) returns (RouteSummary) {}
72/// ```
73/// we can call `record_route` in two equivalent ways:
74///
75/// ```rust
76/// # #[derive(Clone)]
77/// # pub struct Point {};
78/// # pub struct Client {};
79/// # impl Client {
80/// #   fn record_route(&self, r: impl tonic::IntoStreamingRequest<Message = Point>) {}
81/// # }
82/// # let client = Client {};
83/// use tonic::Request;
84///
85/// let messages = vec![Point {}, Point {}];
86///
87/// client.record_route(Request::new(tokio_stream::iter(messages.clone())));
88/// client.record_route(tokio_stream::iter(messages));
89/// ```
90pub trait IntoStreamingRequest: sealed::Sealed {
91    /// The RPC request stream type
92    type Stream: Stream<Item = Self::Message> + Send + 'static;
93
94    /// The RPC request type
95    type Message;
96
97    /// Wrap the stream of messages in a `tonic::Request`
98    fn into_streaming_request(self) -> Request<Self::Stream>;
99}
100
101impl<T> Request<T> {
102    /// Create a new gRPC request.
103    ///
104    /// ```rust
105    /// # use tonic::Request;
106    /// # pub struct HelloRequest {
107    /// #   pub name: String,
108    /// # }
109    /// Request::new(HelloRequest {
110    ///    name: "Bob".into(),
111    /// });
112    /// ```
113    pub fn new(message: T) -> Self {
114        Request {
115            metadata: MetadataMap::new(),
116            message,
117            extensions: Extensions::new(),
118        }
119    }
120
121    /// Get a reference to the message
122    pub fn get_ref(&self) -> &T {
123        &self.message
124    }
125
126    /// Get a mutable reference to the message
127    pub fn get_mut(&mut self) -> &mut T {
128        &mut self.message
129    }
130
131    /// Get a reference to the custom request metadata.
132    pub fn metadata(&self) -> &MetadataMap {
133        &self.metadata
134    }
135
136    /// Get a mutable reference to the request metadata.
137    pub fn metadata_mut(&mut self) -> &mut MetadataMap {
138        &mut self.metadata
139    }
140
141    /// Consumes `self`, returning the message
142    pub fn into_inner(self) -> T {
143        self.message
144    }
145
146    /// Consumes `self` returning the parts of the request.
147    pub fn into_parts(self) -> (MetadataMap, Extensions, T) {
148        (self.metadata, self.extensions, self.message)
149    }
150
151    /// Create a new gRPC request from metadata, extensions and message.
152    pub fn from_parts(metadata: MetadataMap, extensions: Extensions, message: T) -> Self {
153        Self {
154            metadata,
155            extensions,
156            message,
157        }
158    }
159
160    pub(crate) fn from_http_parts(parts: http::request::Parts, message: T) -> Self {
161        Request {
162            metadata: MetadataMap::from_headers(parts.headers),
163            message,
164            extensions: parts.extensions,
165        }
166    }
167
168    /// Convert an HTTP request to a gRPC request
169    pub fn from_http(http: http::Request<T>) -> Self {
170        let (parts, message) = http.into_parts();
171        Request::from_http_parts(parts, message)
172    }
173
174    pub(crate) fn into_http(
175        self,
176        uri: http::Uri,
177        method: http::Method,
178        version: http::Version,
179        sanitize_headers: SanitizeHeaders,
180    ) -> http::Request<T> {
181        let mut request = http::Request::new(self.message);
182
183        *request.version_mut() = version;
184        *request.method_mut() = method;
185        *request.uri_mut() = uri;
186        *request.headers_mut() = match sanitize_headers {
187            SanitizeHeaders::Yes => self.metadata.into_sanitized_headers(),
188            SanitizeHeaders::No => self.metadata.into_headers(),
189        };
190        *request.extensions_mut() = self.extensions;
191
192        request
193    }
194
195    #[doc(hidden)]
196    pub fn map<F, U>(self, f: F) -> Request<U>
197    where
198        F: FnOnce(T) -> U,
199    {
200        let message = f(self.message);
201
202        Request {
203            metadata: self.metadata,
204            message,
205            extensions: self.extensions,
206        }
207    }
208
209    /// Get the local address of this connection.
210    ///
211    /// This will return `None` if the `IO` type used
212    /// does not implement `Connected` or when using a unix domain socket.
213    /// This currently only works on the server side.
214    #[cfg(feature = "server")]
215    pub fn local_addr(&self) -> Option<SocketAddr> {
216        let addr = self
217            .extensions()
218            .get::<TcpConnectInfo>()
219            .and_then(|i| i.local_addr());
220
221        #[cfg(feature = "tls")]
222        let addr = addr.or_else(|| {
223            self.extensions()
224                .get::<TlsConnectInfo<TcpConnectInfo>>()
225                .and_then(|i| i.get_ref().local_addr())
226        });
227
228        addr
229    }
230
231    /// Get the remote address of this connection.
232    ///
233    /// This will return `None` if the `IO` type used
234    /// does not implement `Connected` or when using a unix domain socket.
235    /// This currently only works on the server side.
236    #[cfg(feature = "server")]
237    pub fn remote_addr(&self) -> Option<SocketAddr> {
238        let addr = self
239            .extensions()
240            .get::<TcpConnectInfo>()
241            .and_then(|i| i.remote_addr());
242
243        #[cfg(feature = "tls")]
244        let addr = addr.or_else(|| {
245            self.extensions()
246                .get::<TlsConnectInfo<TcpConnectInfo>>()
247                .and_then(|i| i.get_ref().remote_addr())
248        });
249
250        addr
251    }
252
253    /// Get the peer certificates of the connected client.
254    ///
255    /// This is used to fetch the certificates from the TLS session
256    /// and is mostly used for mTLS. This currently only returns
257    /// `Some` on the server side of the `transport` server with
258    /// TLS enabled connections.
259    #[cfg(all(feature = "server", feature = "tls"))]
260    pub fn peer_certs(&self) -> Option<Arc<Vec<CertificateDer<'static>>>> {
261        self.extensions()
262            .get::<TlsConnectInfo<TcpConnectInfo>>()
263            .and_then(|i| i.peer_certs())
264    }
265
266    /// Set the max duration the request is allowed to take.
267    ///
268    /// Requires the server to support the `grpc-timeout` metadata, which Tonic does.
269    ///
270    /// The duration will be formatted according to [the spec] and use the most precise unit
271    /// possible.
272    ///
273    /// Example:
274    ///
275    /// ```rust
276    /// use std::time::Duration;
277    /// use tonic::Request;
278    ///
279    /// let mut request = Request::new(());
280    ///
281    /// request.set_timeout(Duration::from_secs(30));
282    ///
283    /// let value = request.metadata().get("grpc-timeout").unwrap();
284    ///
285    /// assert_eq!(
286    ///     value,
287    ///     // equivalent to 30 seconds
288    ///     "30000000u"
289    /// );
290    /// ```
291    ///
292    /// [the spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
293    pub fn set_timeout(&mut self, deadline: Duration) {
294        let value: MetadataValue<_> = duration_to_grpc_timeout(deadline).parse().unwrap();
295        self.metadata_mut()
296            .insert(crate::metadata::GRPC_TIMEOUT_HEADER, value);
297    }
298
299    /// Returns a reference to the associated extensions.
300    pub fn extensions(&self) -> &Extensions {
301        &self.extensions
302    }
303
304    /// Returns a mutable reference to the associated extensions.
305    ///
306    /// # Example
307    ///
308    /// Extensions can be set in interceptors:
309    ///
310    /// ```no_run
311    /// use tonic::{Request, service::interceptor};
312    ///
313    /// #[derive(Clone)] // Extensions must be Clone
314    /// struct MyExtension {
315    ///     some_piece_of_data: String,
316    /// }
317    ///
318    /// interceptor(|mut request: Request<()>| {
319    ///     request.extensions_mut().insert(MyExtension {
320    ///         some_piece_of_data: "foo".to_string(),
321    ///     });
322    ///
323    ///     Ok(request)
324    /// });
325    /// ```
326    ///
327    /// And picked up by RPCs:
328    ///
329    /// ```no_run
330    /// use tonic::{async_trait, Status, Request, Response};
331    /// #
332    /// # struct Output {}
333    /// # struct Input;
334    /// # struct MyService;
335    /// # struct MyExtension;
336    /// # #[async_trait]
337    /// # trait TestService {
338    /// #     async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status>;
339    /// # }
340    ///
341    /// #[async_trait]
342    /// impl TestService for MyService {
343    ///     async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
344    ///         let value: &MyExtension = req.extensions().get::<MyExtension>().unwrap();
345    ///
346    ///         Ok(Response::new(Output {}))
347    ///     }
348    /// }
349    /// ```
350    pub fn extensions_mut(&mut self) -> &mut Extensions {
351        &mut self.extensions
352    }
353}
354
355impl<T> IntoRequest<T> for T {
356    fn into_request(self) -> Request<Self> {
357        Request::new(self)
358    }
359}
360
361impl<T> IntoRequest<T> for Request<T> {
362    fn into_request(self) -> Request<T> {
363        self
364    }
365}
366
367impl<T> IntoStreamingRequest for T
368where
369    T: Stream + Send + 'static,
370{
371    type Stream = T;
372    type Message = T::Item;
373
374    fn into_streaming_request(self) -> Request<Self> {
375        Request::new(self)
376    }
377}
378
379impl<T> IntoStreamingRequest for Request<T>
380where
381    T: Stream + Send + 'static,
382{
383    type Stream = T;
384    type Message = T::Item;
385
386    fn into_streaming_request(self) -> Self {
387        self
388    }
389}
390
391impl<T> sealed::Sealed for T {}
392
393mod sealed {
394    pub trait Sealed {}
395}
396
397fn duration_to_grpc_timeout(duration: Duration) -> String {
398    fn try_format<T: Into<u128>>(
399        duration: Duration,
400        unit: char,
401        convert: impl FnOnce(Duration) -> T,
402    ) -> Option<String> {
403        // The gRPC spec specifies that the timeout most be at most 8 digits. So this is the largest a
404        // value can be before we need to use a bigger unit.
405        let max_size: u128 = 99_999_999; // exactly 8 digits
406
407        let value = convert(duration).into();
408        if value > max_size {
409            None
410        } else {
411            Some(format!("{}{}", value, unit))
412        }
413    }
414
415    // pick the most precise unit that is less than or equal to 8 digits as per the gRPC spec
416    try_format(duration, 'n', |d| d.as_nanos())
417        .or_else(|| try_format(duration, 'u', |d| d.as_micros()))
418        .or_else(|| try_format(duration, 'm', |d| d.as_millis()))
419        .or_else(|| try_format(duration, 'S', |d| d.as_secs()))
420        .or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60))
421        .or_else(|| {
422            try_format(duration, 'H', |d| {
423                let minutes = d.as_secs() / 60;
424                minutes / 60
425            })
426        })
427        // duration has to be more than 11_415 years for this to happen
428        .expect("duration is unrealistically large")
429}
430
431/// When converting a `tonic::Request` into a `http::Request` should reserved
432/// headers be removed?
433pub(crate) enum SanitizeHeaders {
434    Yes,
435    No,
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::metadata::{MetadataKey, MetadataValue};
442
443    use http::Uri;
444
445    #[test]
446    fn reserved_headers_are_excluded() {
447        let mut r = Request::new(1);
448
449        for header in &MetadataMap::GRPC_RESERVED_HEADERS {
450            r.metadata_mut().insert(
451                MetadataKey::unchecked_from_header_name(header.clone()),
452                MetadataValue::from_static("invalid"),
453            );
454        }
455
456        let http_request = r.into_http(
457            Uri::default(),
458            http::Method::POST,
459            http::Version::HTTP_2,
460            SanitizeHeaders::Yes,
461        );
462        assert!(http_request.headers().is_empty());
463    }
464
465    #[test]
466    fn duration_to_grpc_timeout_less_than_second() {
467        let timeout = Duration::from_millis(500);
468        let value = duration_to_grpc_timeout(timeout);
469        assert_eq!(value, format!("{}u", timeout.as_micros()));
470    }
471
472    #[test]
473    fn duration_to_grpc_timeout_more_than_second() {
474        let timeout = Duration::from_secs(30);
475        let value = duration_to_grpc_timeout(timeout);
476        assert_eq!(value, format!("{}u", timeout.as_micros()));
477    }
478
479    #[test]
480    fn duration_to_grpc_timeout_a_very_long_time() {
481        let one_hour = Duration::from_secs(60 * 60);
482        let value = duration_to_grpc_timeout(one_hour);
483        assert_eq!(value, format!("{}m", one_hour.as_millis()));
484    }
485}