1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//! Middleware types returned from `ConfigExt` methods.
use tower::{filter::AsyncFilterLayer, util::Either, Layer};
pub(crate) use tower_http::auth::AddAuthorizationLayer;

mod base_uri;
mod extra_headers;

pub use base_uri::{BaseUri, BaseUriLayer};
pub use extra_headers::{ExtraHeaders, ExtraHeadersLayer};

use super::auth::RefreshableToken;
/// Layer to set up `Authorization` header depending on the config.
pub struct AuthLayer(pub(crate) Either<AddAuthorizationLayer, AsyncFilterLayer<RefreshableToken>>);

impl<S> Layer<S> for AuthLayer {
    type Service = Either<
        <AddAuthorizationLayer as Layer<S>>::Service,
        <AsyncFilterLayer<RefreshableToken> as Layer<S>>::Service,
    >;

    fn layer(&self, inner: S) -> Self::Service {
        self.0.layer(inner)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use std::{matches, pin::pin, sync::Arc};

    use chrono::{Duration, Utc};
    use http::{header::AUTHORIZATION, HeaderValue, Request, Response};
    use secrecy::SecretString;
    use tokio::sync::Mutex;
    use tokio_test::assert_ready_ok;
    use tower::filter::AsyncFilterLayer;
    use tower_test::{mock, mock::Handle};

    use crate::{
        client::{AuthError, Body},
        config::AuthInfo,
    };

    #[tokio::test(flavor = "current_thread")]
    async fn valid_token() {
        const TOKEN: &str = "test";
        let auth = test_token(TOKEN.into());
        let (mut service, handle): (_, Handle<Request<Body>, Response<Body>>) =
            mock::spawn_layer(AsyncFilterLayer::new(auth));

        let spawned = tokio::spawn(async move {
            // Receive the requests and respond
            let mut handle = pin!(handle);
            let (request, send) = handle.next_request().await.expect("service not called");
            assert_eq!(
                request.headers().get(AUTHORIZATION).unwrap(),
                HeaderValue::try_from(format!("Bearer {TOKEN}")).unwrap()
            );
            send.send_response(Response::builder().body(Body::empty()).unwrap());
        });

        assert_ready_ok!(service.poll_ready());
        service
            .call(Request::builder().uri("/").body(Body::empty()).unwrap())
            .await
            .unwrap();
        spawned.await.unwrap();
    }

    #[tokio::test(flavor = "current_thread")]
    async fn invalid_token() {
        const TOKEN: &str = "\n";
        let auth = test_token(TOKEN.into());
        let (mut service, _handle) =
            mock::spawn_layer::<Request<Body>, Response<Body>, _>(AsyncFilterLayer::new(auth));
        let err = service
            .call(Request::builder().uri("/").body(Body::empty()).unwrap())
            .await
            .unwrap_err();

        assert!(err.is::<AuthError>());
        assert!(matches!(
            *err.downcast::<AuthError>().unwrap(),
            AuthError::InvalidBearerToken(_)
        ));
    }

    fn test_token(token: String) -> RefreshableToken {
        let expiry = Utc::now() + Duration::try_seconds(60 * 60).unwrap();
        let secret_token = SecretString::from(token);
        let info = AuthInfo {
            token: Some(secret_token.clone()),
            ..Default::default()
        };
        RefreshableToken::Exec(Arc::new(Mutex::new((secret_token, expiry, info))))
    }
}