azure_identity/token_credentials/
cache.rs
1use async_lock::RwLock;
2use azure_core::auth::AccessToken;
3use futures::Future;
4use std::{collections::HashMap, time::Duration};
5use time::OffsetDateTime;
6use tracing::trace;
7
8fn is_expired(token: &AccessToken) -> bool {
9 token.expires_on < OffsetDateTime::now_utc() + Duration::from_secs(20)
10}
11
12#[derive(Debug)]
13pub(crate) struct TokenCache(RwLock<HashMap<Vec<String>, AccessToken>>);
14
15impl TokenCache {
16 pub(crate) fn new() -> Self {
17 Self(RwLock::new(HashMap::new()))
18 }
19
20 pub(crate) async fn clear(&self) -> azure_core::Result<()> {
21 let mut token_cache = self.0.write().await;
22 token_cache.clear();
23 Ok(())
24 }
25
26 pub(crate) async fn get_token(
27 &self,
28 scopes: &[&str],
29 callback: impl Future<Output = azure_core::Result<AccessToken>>,
30 ) -> azure_core::Result<AccessToken> {
31 let token_cache = self.0.read().await;
33 let scopes = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
34 if let Some(token) = token_cache.get(&scopes) {
35 if !is_expired(token) {
36 trace!("returning cached token");
37 return Ok(token.clone());
38 }
39 }
40
41 drop(token_cache);
43 let mut token_cache = self.0.write().await;
44
45 if let Some(token) = token_cache.get(&scopes) {
48 if !is_expired(token) {
49 trace!("returning token that was updated while waiting on write lock");
50 return Ok(token.clone());
51 }
52 }
53
54 trace!("falling back to callback");
55 let token = callback.await?;
56
57 token_cache.insert(scopes, token.clone());
62 Ok(token)
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use azure_core::auth::Secret;
70 use std::sync::Mutex;
71
72 #[derive(Debug)]
73 struct MockCredential {
74 token: AccessToken,
75 get_token_call_count: Mutex<usize>,
76 }
77
78 impl MockCredential {
79 fn new(token: AccessToken) -> Self {
80 Self {
81 token,
82 get_token_call_count: Mutex::new(0),
83 }
84 }
85
86 async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
87 let mut call_count = self.get_token_call_count.lock().unwrap();
89 *call_count += 1;
90 Ok(AccessToken {
91 token: Secret::new(format!(
92 "{}-{}:{}",
93 scopes.join(" "),
94 self.token.token.secret(),
95 *call_count
96 )),
97 expires_on: self.token.expires_on,
98 })
99 }
100 }
101
102 const STORAGE_TOKEN_SCOPE: &str = "https://storage.azure.com/";
103 const IOTHUB_TOKEN_SCOPE: &str = "https://iothubs.azure.net";
104
105 #[tokio::test]
106 async fn test_get_token_different_resources() -> azure_core::Result<()> {
107 let resource1 = &[STORAGE_TOKEN_SCOPE];
108 let resource2 = &[IOTHUB_TOKEN_SCOPE];
109 let secret_string = "test-token";
110 let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(300);
111 let access_token = AccessToken::new(Secret::new(secret_string), expires_on);
112
113 let mock_credential = MockCredential::new(access_token);
114
115 let cache = TokenCache::new();
116
117 let token1 = cache
119 .get_token(resource1, mock_credential.get_token(resource1))
120 .await?;
121 let token2 = cache
122 .get_token(resource1, mock_credential.get_token(resource1))
123 .await?;
124
125 let expected_token = format!("{}-{}:1", resource1.join(" "), secret_string);
126 assert_eq!(token1.token.secret(), expected_token);
127 assert_eq!(token2.token.secret(), expected_token);
128
129 let token3 = cache
132 .get_token(resource2, mock_credential.get_token(resource2))
133 .await?;
134 let token4 = cache
135 .get_token(resource2, mock_credential.get_token(resource2))
136 .await?;
137 let expected_token = format!("{}-{}:2", resource2.join(" "), secret_string);
138 assert_eq!(token3.token.secret(), expected_token);
139 assert_eq!(token4.token.secret(), expected_token);
140
141 Ok(())
142 }
143
144 #[tokio::test]
145 async fn test_refresh_expired_token() -> azure_core::Result<()> {
146 let resource = &[STORAGE_TOKEN_SCOPE];
147 let access_token = "test-token";
148 let expires_on = OffsetDateTime::now_utc();
149 let token_response = AccessToken::new(Secret::new(access_token), expires_on);
150
151 let mock_credential = MockCredential::new(token_response);
152
153 let cache = TokenCache::new();
154
155 for i in 1..5 {
157 let token = cache
158 .get_token(resource, mock_credential.get_token(resource))
159 .await?;
160 assert_eq!(
161 token.token.secret(),
162 format!("{}-{}:{}", resource.join(" "), access_token, i)
163 );
164 }
165
166 Ok(())
167 }
168}