azure_identity/token_credentials/
cache.rs

1use async_lock::RwLock;
2use azure_core::auth::AccessToken;
3use futures::Future;
4use std::{collections::HashMap, time::Duration};
5use time::OffsetDateTime;
6use tracing::trace;
7
8fn is_expired(token: &AccessToken) -> bool {
9    token.expires_on < OffsetDateTime::now_utc() + Duration::from_secs(20)
10}
11
12#[derive(Debug)]
13pub(crate) struct TokenCache(RwLock<HashMap<Vec<String>, AccessToken>>);
14
15impl TokenCache {
16    pub(crate) fn new() -> Self {
17        Self(RwLock::new(HashMap::new()))
18    }
19
20    pub(crate) async fn clear(&self) -> azure_core::Result<()> {
21        let mut token_cache = self.0.write().await;
22        token_cache.clear();
23        Ok(())
24    }
25
26    pub(crate) async fn get_token(
27        &self,
28        scopes: &[&str],
29        callback: impl Future<Output = azure_core::Result<AccessToken>>,
30    ) -> azure_core::Result<AccessToken> {
31        // if the current cached token for this resource is good, return it.
32        let token_cache = self.0.read().await;
33        let scopes = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
34        if let Some(token) = token_cache.get(&scopes) {
35            if !is_expired(token) {
36                trace!("returning cached token");
37                return Ok(token.clone());
38            }
39        }
40
41        // otherwise, drop the read lock and get a write lock to refresh the token
42        drop(token_cache);
43        let mut token_cache = self.0.write().await;
44
45        // check again in case another thread refreshed the token while we were
46        // waiting on the write lock
47        if let Some(token) = token_cache.get(&scopes) {
48            if !is_expired(token) {
49                trace!("returning token that was updated while waiting on write lock");
50                return Ok(token.clone());
51            }
52        }
53
54        trace!("falling back to callback");
55        let token = callback.await?;
56
57        // NOTE: we do not check to see if the token is expired here, as at
58        // least one credential, `AzureCliCredential`, specifies the token is
59        // immediately expired after it is returned, which indicates the token
60        // should always be refreshed upon use.
61        token_cache.insert(scopes, token.clone());
62        Ok(token)
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use azure_core::auth::Secret;
70    use std::sync::Mutex;
71
72    #[derive(Debug)]
73    struct MockCredential {
74        token: AccessToken,
75        get_token_call_count: Mutex<usize>,
76    }
77
78    impl MockCredential {
79        fn new(token: AccessToken) -> Self {
80            Self {
81                token,
82                get_token_call_count: Mutex::new(0),
83            }
84        }
85
86        async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
87            // Include an incrementing counter in the token to track how many times the token has been refreshed
88            let mut call_count = self.get_token_call_count.lock().unwrap();
89            *call_count += 1;
90            Ok(AccessToken {
91                token: Secret::new(format!(
92                    "{}-{}:{}",
93                    scopes.join(" "),
94                    self.token.token.secret(),
95                    *call_count
96                )),
97                expires_on: self.token.expires_on,
98            })
99        }
100    }
101
102    const STORAGE_TOKEN_SCOPE: &str = "https://storage.azure.com/";
103    const IOTHUB_TOKEN_SCOPE: &str = "https://iothubs.azure.net";
104
105    #[tokio::test]
106    async fn test_get_token_different_resources() -> azure_core::Result<()> {
107        let resource1 = &[STORAGE_TOKEN_SCOPE];
108        let resource2 = &[IOTHUB_TOKEN_SCOPE];
109        let secret_string = "test-token";
110        let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(300);
111        let access_token = AccessToken::new(Secret::new(secret_string), expires_on);
112
113        let mock_credential = MockCredential::new(access_token);
114
115        let cache = TokenCache::new();
116
117        // Test that querying a token for the same resource twice returns the same (cached) token on the second call
118        let token1 = cache
119            .get_token(resource1, mock_credential.get_token(resource1))
120            .await?;
121        let token2 = cache
122            .get_token(resource1, mock_credential.get_token(resource1))
123            .await?;
124
125        let expected_token = format!("{}-{}:1", resource1.join(" "), secret_string);
126        assert_eq!(token1.token.secret(), expected_token);
127        assert_eq!(token2.token.secret(), expected_token);
128
129        // Test that querying a token for a second resource returns a different token, as the cache is per-resource.
130        // Also test that the same token is the returned (cached) on a second call.
131        let token3 = cache
132            .get_token(resource2, mock_credential.get_token(resource2))
133            .await?;
134        let token4 = cache
135            .get_token(resource2, mock_credential.get_token(resource2))
136            .await?;
137        let expected_token = format!("{}-{}:2", resource2.join(" "), secret_string);
138        assert_eq!(token3.token.secret(), expected_token);
139        assert_eq!(token4.token.secret(), expected_token);
140
141        Ok(())
142    }
143
144    #[tokio::test]
145    async fn test_refresh_expired_token() -> azure_core::Result<()> {
146        let resource = &[STORAGE_TOKEN_SCOPE];
147        let access_token = "test-token";
148        let expires_on = OffsetDateTime::now_utc();
149        let token_response = AccessToken::new(Secret::new(access_token), expires_on);
150
151        let mock_credential = MockCredential::new(token_response);
152
153        let cache = TokenCache::new();
154
155        // Test that querying an expired token returns a new token
156        for i in 1..5 {
157            let token = cache
158                .get_token(resource, mock_credential.get_token(resource))
159                .await?;
160            assert_eq!(
161                token.token.secret(),
162                format!("{}-{}:{}", resource.join(" "), access_token, i)
163            );
164        }
165
166        Ok(())
167    }
168}