tonic/transport/channel/service/
user_agent.rs

1use http::{header::USER_AGENT, HeaderValue, Request};
2use std::task::{Context, Poll};
3use tower_service::Service;
4
5const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));
6
7#[derive(Debug)]
8pub(crate) struct UserAgent<T> {
9    inner: T,
10    user_agent: HeaderValue,
11}
12
13impl<T> UserAgent<T> {
14    pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self {
15        let user_agent = user_agent
16            .map(|value| {
17                let mut buf = Vec::new();
18                buf.extend(value.as_bytes());
19                buf.push(b' ');
20                buf.extend(TONIC_USER_AGENT.as_bytes());
21                HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
22            })
23            .unwrap_or_else(|| HeaderValue::from_static(TONIC_USER_AGENT));
24
25        Self { inner, user_agent }
26    }
27}
28
29impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T>
30where
31    T: Service<Request<ReqBody>>,
32{
33    type Response = T::Response;
34    type Error = T::Error;
35    type Future = T::Future;
36
37    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38        self.inner.poll_ready(cx)
39    }
40
41    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
42        req.headers_mut()
43            .insert(USER_AGENT, self.user_agent.clone());
44
45        self.inner.call(req)
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    struct Svc;
54
55    #[test]
56    fn sets_default_if_no_custom_user_agent() {
57        assert_eq!(
58            UserAgent::new(Svc, None).user_agent,
59            HeaderValue::from_static(TONIC_USER_AGENT)
60        )
61    }
62
63    #[test]
64    fn prepends_custom_user_agent_to_default() {
65        assert_eq!(
66            UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent,
67            HeaderValue::from_str(&format!("Greeter 1.1 {}", TONIC_USER_AGENT)).unwrap()
68        )
69    }
70}