tower/make/make_service/
shared.rs

1use std::convert::Infallible;
2use std::task::{Context, Poll};
3use tower_service::Service;
4
5/// A [`MakeService`] that produces services by cloning an inner service.
6///
7/// [`MakeService`]: super::MakeService
8///
9/// # Example
10///
11/// ```
12/// # use std::task::{Context, Poll};
13/// # use std::pin::Pin;
14/// # use std::convert::Infallible;
15/// use tower::make::{MakeService, Shared};
16/// use tower::buffer::Buffer;
17/// use tower::Service;
18/// use futures::future::{Ready, ready};
19///
20/// // An example connection type
21/// struct Connection {}
22///
23/// // An example request type
24/// struct Request {}
25///
26/// // An example response type
27/// struct Response {}
28///
29/// // Some service that doesn't implement `Clone`
30/// struct MyService;
31///
32/// impl Service<Request> for MyService {
33///     type Response = Response;
34///     type Error = Infallible;
35///     type Future = Ready<Result<Response, Infallible>>;
36///
37///     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38///         Poll::Ready(Ok(()))
39///     }
40///
41///     fn call(&mut self, req: Request) -> Self::Future {
42///         ready(Ok(Response {}))
43///     }
44/// }
45///
46/// // Example function that runs a service by accepting new connections and using
47/// // `Make` to create new services that might be bound to the connection.
48/// //
49/// // This is similar to what you might find in hyper.
50/// async fn serve_make_service<Make>(make: Make)
51/// where
52///     Make: MakeService<Connection, Request>
53/// {
54///     // ...
55/// }
56///
57/// # async {
58/// // Our service
59/// let svc = MyService;
60///
61/// // Make it `Clone` by putting a channel in front
62/// let buffered = Buffer::new(svc, 1024);
63///
64/// // Convert it into a `MakeService`
65/// let make = Shared::new(buffered);
66///
67/// // Run the service and just ignore the `Connection`s as `MyService` doesn't need them
68/// serve_make_service(make).await;
69/// # };
70/// ```
71#[derive(Debug, Clone, Copy)]
72pub struct Shared<S> {
73    service: S,
74}
75
76impl<S> Shared<S> {
77    /// Create a new [`Shared`] from a service.
78    pub const fn new(service: S) -> Self {
79        Self { service }
80    }
81}
82
83impl<S, T> Service<T> for Shared<S>
84where
85    S: Clone,
86{
87    type Response = S;
88    type Error = Infallible;
89    type Future = SharedFuture<S>;
90
91    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        Poll::Ready(Ok(()))
93    }
94
95    fn call(&mut self, _target: T) -> Self::Future {
96        SharedFuture::new(futures_util::future::ready(Ok(self.service.clone())))
97    }
98}
99
100opaque_future! {
101    /// Response future from [`Shared`] services.
102    pub type SharedFuture<S> = futures_util::future::Ready<Result<S, Infallible>>;
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::make::MakeService;
109    use crate::service_fn;
110    use futures::future::poll_fn;
111
112    async fn echo<R>(req: R) -> Result<R, Infallible> {
113        Ok(req)
114    }
115
116    #[tokio::test]
117    async fn as_make_service() {
118        let mut shared = Shared::new(service_fn(echo::<&'static str>));
119
120        poll_fn(|cx| MakeService::<(), _>::poll_ready(&mut shared, cx))
121            .await
122            .unwrap();
123        let mut svc = shared.make_service(()).await.unwrap();
124
125        poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
126        let res = svc.call("foo").await.unwrap();
127
128        assert_eq!(res, "foo");
129    }
130
131    #[tokio::test]
132    async fn as_make_service_into_service() {
133        let shared = Shared::new(service_fn(echo::<&'static str>));
134        let mut shared = MakeService::<(), _>::into_service(shared);
135
136        poll_fn(|cx| Service::<()>::poll_ready(&mut shared, cx))
137            .await
138            .unwrap();
139        let mut svc = shared.call(()).await.unwrap();
140
141        poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
142        let res = svc.call("foo").await.unwrap();
143
144        assert_eq!(res, "foo");
145    }
146}