azure_identity/token_credentials/
imds_managed_identity_credentials.rs

1use crate::{token_credentials::cache::TokenCache, TokenCredentialOptions};
2use azure_core::{
3    auth::{AccessToken, Secret, TokenCredential},
4    error::{Error, ErrorKind},
5    from_json,
6    headers::HeaderName,
7    HttpClient, Method, Request, StatusCode, Url,
8};
9use serde::{
10    de::{self, Deserializer},
11    Deserialize,
12};
13use std::{str, sync::Arc};
14use time::OffsetDateTime;
15
16#[derive(Debug)]
17pub(crate) enum ImdsId {
18    SystemAssigned,
19    #[allow(dead_code)]
20    ClientId(String),
21    #[allow(dead_code)]
22    ObjectId(String),
23    #[allow(dead_code)]
24    MsiResId(String),
25}
26
27/// Attempts authentication using a managed identity that has been assigned to the deployment environment.
28///
29/// This authentication type works in Azure VMs, App Service and Azure Functions applications, as well as the Azure Cloud Shell
30///
31/// Built up from docs at [https://docs.microsoft.com/azure/app-service/overview-managed-identity#using-the-rest-protocol](https://docs.microsoft.com/azure/app-service/overview-managed-identity#using-the-rest-protocol)
32#[derive(Debug)]
33pub(crate) struct ImdsManagedIdentityCredential {
34    http_client: Arc<dyn HttpClient>,
35    endpoint: Url,
36    api_version: String,
37    secret_header: HeaderName,
38    secret_env: String,
39    id: ImdsId,
40    cache: TokenCache,
41}
42
43impl ImdsManagedIdentityCredential {
44    pub fn new(
45        options: impl Into<TokenCredentialOptions>,
46        endpoint: Url,
47        api_version: &str,
48        secret_header: HeaderName,
49        secret_env: &str,
50        id: ImdsId,
51    ) -> Self {
52        let options = options.into();
53        Self {
54            http_client: options.http_client(),
55            endpoint,
56            api_version: api_version.to_owned(),
57            secret_header: secret_header.to_owned(),
58            secret_env: secret_env.to_owned(),
59            id,
60            cache: TokenCache::new(),
61        }
62    }
63
64    async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
65        let resource = scopes_to_resource(scopes)?;
66
67        let mut query_items = vec![
68            ("api-version", self.api_version.as_str()),
69            ("resource", resource),
70        ];
71
72        match self.id {
73            ImdsId::SystemAssigned => (),
74            ImdsId::ClientId(ref client_id) => query_items.push(("client_id", client_id)),
75            ImdsId::ObjectId(ref object_id) => query_items.push(("object_id", object_id)),
76            ImdsId::MsiResId(ref msi_res_id) => query_items.push(("msi_res_id", msi_res_id)),
77        }
78
79        let mut url = self.endpoint.clone();
80        url.query_pairs_mut().extend_pairs(query_items);
81
82        let mut req = Request::new(url, Method::Get);
83
84        req.insert_header("metadata", "true");
85
86        let msi_secret = std::env::var(&self.secret_env);
87        if let Ok(val) = msi_secret {
88            req.insert_header(self.secret_header.clone(), val);
89        };
90
91        let rsp = self.http_client.execute_request(&req).await?;
92
93        let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
94        let rsp_body = rsp_body.collect().await?;
95
96        if !rsp_status.is_success() {
97            match rsp_status {
98                StatusCode::BadRequest => {
99                    return Err(Error::message(
100                        ErrorKind::Credential,
101                        "the requested identity has not been assigned to this resource",
102                    ))
103                }
104                StatusCode::BadGateway | StatusCode::GatewayTimeout => {
105                    return Err(Error::message(
106                        ErrorKind::Credential,
107                        "the request failed due to a gateway error",
108                    ))
109                }
110                rsp_status => {
111                    return Err(ErrorKind::http_response_from_parts(
112                        rsp_status,
113                        &rsp_headers,
114                        &rsp_body,
115                    )
116                    .into_error())
117                }
118            }
119        }
120
121        let token_response: MsiTokenResponse = from_json(&rsp_body)?;
122        Ok(AccessToken::new(
123            token_response.access_token,
124            token_response.expires_on,
125        ))
126    }
127}
128
129#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
130#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
131impl TokenCredential for ImdsManagedIdentityCredential {
132    async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
133        self.cache.get_token(scopes, self.get_token(scopes)).await
134    }
135
136    async fn clear_cache(&self) -> azure_core::Result<()> {
137        self.cache.clear().await
138    }
139}
140
141fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<OffsetDateTime, D::Error>
142where
143    D: Deserializer<'de>,
144{
145    let v = String::deserialize(deserializer)?;
146    let as_i64 = v.parse::<i64>().map_err(de::Error::custom)?;
147    OffsetDateTime::from_unix_timestamp(as_i64).map_err(de::Error::custom)
148}
149
150/// Convert a `AADv2` scope to an `AADv1` resource
151///
152/// Directly based on the `azure-sdk-for-python` implementation:
153/// ref: <https://github.com/Azure/azure-sdk-for-python/blob/d6aeefef46c94b056419613f1a5cc9eaa3af0d22/sdk/identity/azure-identity/azure/identity/_internal/__init__.py#L22>
154fn scopes_to_resource<'a>(scopes: &'a [&'a str]) -> azure_core::Result<&'a str> {
155    if scopes.len() != 1 {
156        return Err(Error::message(
157            ErrorKind::Credential,
158            "only one scope is supported for IMDS authentication",
159        ));
160    }
161
162    let Some(scope) = scopes.first() else {
163        return Err(Error::message(
164            ErrorKind::Credential,
165            "no scopes were provided",
166        ));
167    };
168
169    Ok(scope.strip_suffix("/.default").unwrap_or(*scope))
170}
171
172// NOTE: expires_on is a String version of unix epoch time, not an integer.
173// https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=dotnet#rest-protocol-examples
174#[derive(Debug, Clone, Deserialize)]
175#[allow(unused)]
176struct MsiTokenResponse {
177    pub access_token: Secret,
178    #[serde(deserialize_with = "expires_on_string")]
179    pub expires_on: OffsetDateTime,
180    pub token_type: String,
181    pub resource: String,
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use time::macros::datetime;
188
189    #[derive(Debug, Deserialize)]
190    struct TestExpires {
191        #[serde(deserialize_with = "expires_on_string")]
192        date: OffsetDateTime,
193    }
194
195    #[test]
196    fn check_expires_on_string() -> azure_core::Result<()> {
197        let as_string = r#"{"date": "1586984735"}"#;
198        let expected = datetime!(2020-4-15 21:5:35 UTC);
199        let parsed: TestExpires = from_json(as_string)?;
200        assert_eq!(expected, parsed.date);
201        Ok(())
202    }
203}