azure_identity/
refresh_token.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//! Refresh token utilities

use azure_core::{
    auth::Secret,
    content_type,
    error::{Error, ErrorKind, ResultExt},
    headers, HttpClient, Request, Url,
};
use azure_core::{from_json, Method};
use serde::Deserialize;
use std::fmt;
use std::sync::Arc;
use url::form_urlencoded;

/// Exchange a refresh token for a new access token and refresh token
pub async fn exchange(
    http_client: Arc<dyn HttpClient>,
    tenant_id: &str,
    client_id: &str,
    client_secret: Option<&str>,
    refresh_token: &Secret,
) -> azure_core::Result<RefreshTokenResponse> {
    let encoded = {
        let mut encoded = &mut form_urlencoded::Serializer::new(String::new());
        encoded = encoded
            .append_pair("grant_type", "refresh_token")
            .append_pair("client_id", client_id)
            .append_pair("refresh_token", refresh_token.secret());
        // optionally add the client secret
        if let Some(client_secret) = client_secret {
            encoded = encoded.append_pair("client_secret", client_secret);
        };
        encoded.finish()
    };

    let url = Url::parse(&format!(
        "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
    ))?;

    let mut req = Request::new(url, Method::Post);
    req.insert_header(
        headers::CONTENT_TYPE,
        content_type::APPLICATION_X_WWW_FORM_URLENCODED,
    );
    req.set_body(encoded);

    let rsp = http_client.execute_request(&req).await?;
    let rsp_status = rsp.status();

    if rsp_status.is_success() {
        rsp.json().await.map_kind(ErrorKind::Credential)
    } else {
        let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
        let rsp_body = rsp_body.collect().await?;
        let token_error: RefreshTokenError = from_json(&rsp_body).map_err(|_| {
            ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body)
        })?;
        Err(Error::new(ErrorKind::Credential, token_error))
    }
}

/// A refresh token
#[derive(Debug, Clone, Deserialize)]
pub struct RefreshTokenResponse {
    token_type: String,
    #[serde(rename = "scope", deserialize_with = "deserialize::split")]
    scopes: Vec<String>,
    expires_in: u64,
    ext_expires_in: u64,
    access_token: Secret,
    refresh_token: Secret,
}

impl RefreshTokenResponse {
    /// Returns the `token_type`. Always `Bearer` for Azure AD.
    pub fn token_type(&self) -> &str {
        &self.token_type
    }
    /// The scopes that the `access_token` is valid for.
    pub fn scopes(&self) -> &[String] {
        &self.scopes
    }
    /// Number of seconds the `access_token` is valid for.
    pub fn expires_in(&self) -> u64 {
        self.expires_in
    }
    /// Issued for the scopes that were requested.
    pub fn access_token(&self) -> &Secret {
        &self.access_token
    }
    /// The new refresh token and should replace old refresh token.
    pub fn refresh_token(&self) -> &Secret {
        &self.refresh_token
    }
    /// Indicates the extended lifetime of an `access_token`.
    pub fn ext_expires_in(&self) -> u64 {
        self.ext_expires_in
    }
}

mod deserialize {
    use serde::Deserializer;
    pub fn split<'de, D>(scope: D) -> Result<Vec<String>, D::Error>
    where
        D: Deserializer<'de>,
    {
        let string: String = serde::Deserialize::deserialize(scope)?;
        Ok(string.split(' ').map(ToOwned::to_owned).collect())
    }
}

/// An error response body when there is an error requesting a token
#[derive(Debug, Clone, Deserialize)]
#[allow(unused)]
pub struct RefreshTokenError {
    error: String,
    error_description: String,
    error_codes: Vec<i64>,
    timestamp: Option<String>,
    trace_id: Option<String>,
    correlation_id: Option<String>,
    suberror: Option<String>,
    claims: Option<String>,
}

impl std::error::Error for RefreshTokenError {}

impl fmt::Display for RefreshTokenError {
    fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> {
        writeln!(f, "error: {}", self.error)?;
        if let Some(suberror) = &self.suberror {
            writeln!(f, "suberror: {suberror}")?;
        }
        writeln!(f, "description: {}", self.error_description)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn require_send<T: Send>(_t: T) {}

    #[test]
    fn ensure_that_exchange_is_send() {
        require_send(exchange(
            azure_core::new_http_client(),
            "UNUSED",
            "UNUSED",
            None,
            &Secret::new("UNUSED"),
        ));
    }
}