azure_identity/device_code_flow/
mod.rs

1//! Authorize using the device authorization grant flow
2//!
3//! This flow allows users to sign in to input-constrained devices such as a smart TV, `IoT` device, or printer.
4//!
5//! You can learn more about this authorization flow [here](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-device-code).
6mod device_code_responses;
7
8use azure_core::{
9    content_type,
10    error::{Error, ErrorKind},
11    from_json, headers, sleep, HttpClient, Method, Request, Response, Url,
12};
13pub use device_code_responses::*;
14use futures::stream::unfold;
15use serde::Deserialize;
16use std::{borrow::Cow, pin::Pin, sync::Arc, time::Duration};
17use url::form_urlencoded;
18
19/// Start the device authorization grant flow.
20/// The user has only 15 minutes to sign in (the usual value for `expires_in`).
21pub async fn start<'a, 'b, T>(
22    http_client: Arc<dyn HttpClient>,
23    tenant_id: T,
24    client_id: &str,
25    scopes: &'b [&'b str],
26) -> azure_core::Result<DeviceCodePhaseOneResponse<'a>>
27where
28    T: Into<Cow<'a, str>>,
29{
30    let tenant_id = tenant_id.into();
31    let url = &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/devicecode");
32
33    let encoded = form_urlencoded::Serializer::new(String::new())
34        .append_pair("client_id", client_id)
35        .append_pair("scope", &scopes.join(" "))
36        .finish();
37
38    let rsp = post_form(http_client.clone(), url, encoded).await?;
39    let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
40    let rsp_body = rsp_body.collect().await?;
41    if !rsp_status.is_success() {
42        return Err(
43            ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body).into_error(),
44        );
45    }
46    let device_code_response: DeviceCodePhaseOneResponse = from_json(&rsp_body)?;
47
48    // we need to capture some variables that will be useful in
49    // the second phase (the client, the tenant_id and the client_id)
50    Ok(DeviceCodePhaseOneResponse {
51        device_code: device_code_response.device_code,
52        user_code: device_code_response.user_code,
53        verification_uri: device_code_response.verification_uri,
54        expires_in: device_code_response.expires_in,
55        interval: device_code_response.interval,
56        message: device_code_response.message,
57        http_client: Some(http_client),
58        tenant_id,
59        client_id: client_id.to_string(),
60    })
61}
62
63/// Contains the required information to allow a user to sign in.
64#[derive(Debug, Clone, Deserialize)]
65pub struct DeviceCodePhaseOneResponse<'a> {
66    device_code: String,
67    user_code: String,
68    verification_uri: String,
69    expires_in: u64,
70    interval: u64,
71    message: String,
72    // The skipped fields below do not come from the Azure answer.
73    // They will be added manually after deserialization
74    #[serde(skip)]
75    http_client: Option<Arc<dyn HttpClient>>,
76    #[serde(skip)]
77    tenant_id: Cow<'a, str>,
78    // We store the ClientId as string instead of the original type, because it
79    // does not implement Default, and it's in another crate
80    #[serde(skip)]
81    client_id: String,
82}
83
84impl<'a> DeviceCodePhaseOneResponse<'a> {
85    /// The message containing human readable instructions for the user.
86    pub fn message(&self) -> &str {
87        &self.message
88    }
89
90    /// Polls the token endpoint while the user signs in.
91    /// This will continue until either success or error is returned.
92    pub fn stream(
93        &self,
94    ) -> Pin<Box<impl futures::Stream<Item = azure_core::Result<DeviceCodeAuthorization>> + '_>>
95    {
96        #[derive(Debug, Clone, PartialEq, Eq)]
97        enum NextState {
98            Continue,
99            Finish,
100        }
101
102        Box::pin(unfold(
103            NextState::Continue,
104            move |state: NextState| async move {
105                match state {
106                    NextState::Continue => {
107                        let url = &format!(
108                            "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
109                            self.tenant_id,
110                        );
111
112                        // Throttle down as specified by Azure. This could be
113                        // smarter: we could calculate the elapsed time since the
114                        // last poll and wait only the delta.
115                        sleep(Duration::from_secs(self.interval)).await;
116
117                        let encoded = form_urlencoded::Serializer::new(String::new())
118                            .append_pair(
119                                "grant_type",
120                                "urn:ietf:params:oauth:grant-type:device_code",
121                            )
122                            .append_pair("client_id", self.client_id.as_str())
123                            .append_pair("device_code", &self.device_code)
124                            .finish();
125
126                        let http_client = self.http_client.clone().unwrap();
127
128                        match post_form(http_client.clone(), url, encoded).await {
129                            Ok(rsp) => {
130                                let rsp_status = rsp.status();
131                                let rsp_body = match rsp.into_body().collect().await {
132                                    Ok(b) => b,
133                                    Err(e) => return Some((Err(e), NextState::Finish)),
134                                };
135                                if rsp_status.is_success() {
136                                    match from_json::<_, DeviceCodeAuthorization>(&rsp_body) {
137                                        Ok(authorization) => {
138                                            Some((Ok(authorization), NextState::Finish))
139                                        }
140                                        Err(error) => Some((Err(error), NextState::Finish)),
141                                    }
142                                } else {
143                                    match from_json::<_, DeviceCodeErrorResponse>(&rsp_body) {
144                                        Ok(error_rsp) => {
145                                            let next_state =
146                                                if error_rsp.error == "authorization_pending" {
147                                                    NextState::Continue
148                                                } else {
149                                                    NextState::Finish
150                                                };
151                                            Some((
152                                                Err(Error::new(ErrorKind::Credential, error_rsp)),
153                                                next_state,
154                                            ))
155                                        }
156                                        Err(error) => Some((Err(error), NextState::Finish)),
157                                    }
158                                }
159                            }
160                            Err(error) => Some((Err(error), NextState::Finish)),
161                        }
162                    }
163                    NextState::Finish => None,
164                }
165            },
166        ))
167    }
168}
169
170async fn post_form(
171    http_client: Arc<dyn HttpClient>,
172    url: &str,
173    form_body: String,
174) -> azure_core::Result<Response> {
175    let url = Url::parse(url)?;
176    let mut req = Request::new(url, Method::Post);
177    req.insert_header(
178        headers::CONTENT_TYPE,
179        content_type::APPLICATION_X_WWW_FORM_URLENCODED,
180    );
181    req.set_body(form_body);
182    http_client.execute_request(&req).await
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    fn require_send<T: Send>(_t: T) {}
190
191    #[test]
192    fn ensure_that_start_is_send() {
193        require_send(start(
194            azure_core::new_http_client(),
195            "UNUSED",
196            "UNUSED",
197            &[],
198        ));
199    }
200}