azure_identity/device_code_flow/
mod.rs
1mod device_code_responses;
7
8use azure_core::{
9 content_type,
10 error::{Error, ErrorKind},
11 from_json, headers, sleep, HttpClient, Method, Request, Response, Url,
12};
13pub use device_code_responses::*;
14use futures::stream::unfold;
15use serde::Deserialize;
16use std::{borrow::Cow, pin::Pin, sync::Arc, time::Duration};
17use url::form_urlencoded;
18
19pub async fn start<'a, 'b, T>(
22 http_client: Arc<dyn HttpClient>,
23 tenant_id: T,
24 client_id: &str,
25 scopes: &'b [&'b str],
26) -> azure_core::Result<DeviceCodePhaseOneResponse<'a>>
27where
28 T: Into<Cow<'a, str>>,
29{
30 let tenant_id = tenant_id.into();
31 let url = &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/devicecode");
32
33 let encoded = form_urlencoded::Serializer::new(String::new())
34 .append_pair("client_id", client_id)
35 .append_pair("scope", &scopes.join(" "))
36 .finish();
37
38 let rsp = post_form(http_client.clone(), url, encoded).await?;
39 let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
40 let rsp_body = rsp_body.collect().await?;
41 if !rsp_status.is_success() {
42 return Err(
43 ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body).into_error(),
44 );
45 }
46 let device_code_response: DeviceCodePhaseOneResponse = from_json(&rsp_body)?;
47
48 Ok(DeviceCodePhaseOneResponse {
51 device_code: device_code_response.device_code,
52 user_code: device_code_response.user_code,
53 verification_uri: device_code_response.verification_uri,
54 expires_in: device_code_response.expires_in,
55 interval: device_code_response.interval,
56 message: device_code_response.message,
57 http_client: Some(http_client),
58 tenant_id,
59 client_id: client_id.to_string(),
60 })
61}
62
63#[derive(Debug, Clone, Deserialize)]
65pub struct DeviceCodePhaseOneResponse<'a> {
66 device_code: String,
67 user_code: String,
68 verification_uri: String,
69 expires_in: u64,
70 interval: u64,
71 message: String,
72 #[serde(skip)]
75 http_client: Option<Arc<dyn HttpClient>>,
76 #[serde(skip)]
77 tenant_id: Cow<'a, str>,
78 #[serde(skip)]
81 client_id: String,
82}
83
84impl<'a> DeviceCodePhaseOneResponse<'a> {
85 pub fn message(&self) -> &str {
87 &self.message
88 }
89
90 pub fn stream(
93 &self,
94 ) -> Pin<Box<impl futures::Stream<Item = azure_core::Result<DeviceCodeAuthorization>> + '_>>
95 {
96 #[derive(Debug, Clone, PartialEq, Eq)]
97 enum NextState {
98 Continue,
99 Finish,
100 }
101
102 Box::pin(unfold(
103 NextState::Continue,
104 move |state: NextState| async move {
105 match state {
106 NextState::Continue => {
107 let url = &format!(
108 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
109 self.tenant_id,
110 );
111
112 sleep(Duration::from_secs(self.interval)).await;
116
117 let encoded = form_urlencoded::Serializer::new(String::new())
118 .append_pair(
119 "grant_type",
120 "urn:ietf:params:oauth:grant-type:device_code",
121 )
122 .append_pair("client_id", self.client_id.as_str())
123 .append_pair("device_code", &self.device_code)
124 .finish();
125
126 let http_client = self.http_client.clone().unwrap();
127
128 match post_form(http_client.clone(), url, encoded).await {
129 Ok(rsp) => {
130 let rsp_status = rsp.status();
131 let rsp_body = match rsp.into_body().collect().await {
132 Ok(b) => b,
133 Err(e) => return Some((Err(e), NextState::Finish)),
134 };
135 if rsp_status.is_success() {
136 match from_json::<_, DeviceCodeAuthorization>(&rsp_body) {
137 Ok(authorization) => {
138 Some((Ok(authorization), NextState::Finish))
139 }
140 Err(error) => Some((Err(error), NextState::Finish)),
141 }
142 } else {
143 match from_json::<_, DeviceCodeErrorResponse>(&rsp_body) {
144 Ok(error_rsp) => {
145 let next_state =
146 if error_rsp.error == "authorization_pending" {
147 NextState::Continue
148 } else {
149 NextState::Finish
150 };
151 Some((
152 Err(Error::new(ErrorKind::Credential, error_rsp)),
153 next_state,
154 ))
155 }
156 Err(error) => Some((Err(error), NextState::Finish)),
157 }
158 }
159 }
160 Err(error) => Some((Err(error), NextState::Finish)),
161 }
162 }
163 NextState::Finish => None,
164 }
165 },
166 ))
167 }
168}
169
170async fn post_form(
171 http_client: Arc<dyn HttpClient>,
172 url: &str,
173 form_body: String,
174) -> azure_core::Result<Response> {
175 let url = Url::parse(url)?;
176 let mut req = Request::new(url, Method::Post);
177 req.insert_header(
178 headers::CONTENT_TYPE,
179 content_type::APPLICATION_X_WWW_FORM_URLENCODED,
180 );
181 req.set_body(form_body);
182 http_client.execute_request(&req).await
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 fn require_send<T: Send>(_t: T) {}
190
191 #[test]
192 fn ensure_that_start_is_send() {
193 require_send(start(
194 azure_core::new_http_client(),
195 "UNUSED",
196 "UNUSED",
197 &[],
198 ));
199 }
200}