mz_frontegg_auth/client/
tokens.rs1use std::time::Instant;
11
12use mz_ore::instrument;
13use uuid::Uuid;
14
15use crate::metrics::Metrics;
16use crate::{Client, Error};
17
18const FRONTEGG_TRACE_ID_HEADER: &str = "frontegg-trace-id";
20
21impl Client {
22 #[instrument]
24 pub async fn exchange_client_secret_for_token(
25 &self,
26 request: ApiTokenArgs,
27 admin_api_token_url: &str,
28 metrics: &Metrics,
29 ) -> Result<ApiTokenResponse, Error> {
30 let name = "exchange_secret_for_token";
31 let histogram = metrics.request_duration_seconds.with_label_values(&[name]);
32
33 let start = Instant::now();
34 let response = self
35 .client
36 .post(admin_api_token_url)
37 .json(&request)
38 .send()
39 .await?;
40 let duration = start.elapsed();
41
42 histogram.observe(duration.as_secs_f64());
45
46 let status = response.status().to_string();
47 metrics
48 .http_request_count
49 .with_label_values(&[name, &status])
50 .inc();
51
52 let frontegg_trace_id = response
53 .headers()
54 .get(FRONTEGG_TRACE_ID_HEADER)
55 .and_then(|v| v.to_str().ok())
56 .map(|v| v.to_string());
57
58 match response.error_for_status_ref() {
59 Ok(_) => {
60 tracing::debug!(
61 ?request.client_id,
62 frontegg_trace_id,
63 ?duration,
64 "request success",
65 );
66 Ok(response.json().await?)
67 }
68 Err(e) => {
69 let body = response
70 .text()
71 .await
72 .unwrap_or_else(|_| "failed to deserialize body".to_string());
73 tracing::warn!(frontegg_trace_id, body, "request failed");
74 return Err(e.into());
75 }
76 }
77 }
78}
79
80#[derive(
81 Clone,
82 Debug,
83 Hash,
84 PartialEq,
85 Eq,
86 serde::Serialize,
87 serde::Deserialize
88)]
89#[serde(rename_all = "camelCase")]
90pub struct ApiTokenArgs {
91 pub client_id: Uuid,
92 pub secret: Uuid,
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct ApiTokenResponse {
98 pub expires: String,
99 pub expires_in: i64,
100 pub access_token: String,
101 pub refresh_token: String,
102}
103
104#[cfg(test)]
105mod tests {
106 use axum::http::StatusCode;
107 use axum::{Router, routing::post};
108 use mz_ore::metrics::MetricsRegistry;
109 use mz_ore::{assert_err, assert_ok};
110 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
111 use std::sync::Arc;
112 use std::sync::atomic::{AtomicUsize, Ordering};
113 use tokio::net::TcpListener;
114 use uuid::Uuid;
115
116 use super::ApiTokenResponse;
117 use crate::metrics::Metrics;
118 use crate::{ApiTokenArgs, Client};
119
120 #[mz_ore::test(tokio::test)]
121 #[cfg_attr(miri, ignore)] async fn response_retries() {
123 let count = Arc::new(AtomicUsize::new(0));
124 let count_ = Arc::clone(&count);
125
126 let app = Router::new().route(
128 "/{:status_code}",
129 post(
130 |axum::extract::Path(code): axum::extract::Path<u16>| async move {
131 let cnt = count_.fetch_add(1, Ordering::Relaxed);
132 println!("cnt: {cnt}");
133
134 let resp = ApiTokenResponse {
135 expires: "test".to_string(),
136 expires_in: 0,
137 access_token: "test".to_string(),
138 refresh_token: "test".to_string(),
139 };
140 let resp = serde_json::to_string(&resp).unwrap();
141
142 if cnt >= 2 {
143 Ok(resp.clone())
144 } else {
145 Err(StatusCode::from_u16(code).unwrap())
146 }
147 },
148 ),
149 );
150
151 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
153 let tcp = TcpListener::bind(addr).await.expect("able to bind");
154 let addr = tcp.local_addr().expect("valid addr");
155 mz_ore::task::spawn(|| "test-server", async move {
156 axum::serve(tcp, app.into_make_service()).await.unwrap();
157 });
158
159 let client = Client::default();
160 async fn test_case(
161 client: &Client,
162 addr: &SocketAddr,
163 count: &Arc<AtomicUsize>,
164 code: u16,
165 should_retry: bool,
166 ) -> Result<(), String> {
167 let registry = MetricsRegistry::new();
168 let metrics = Metrics::register_into(®istry);
169
170 let args = ApiTokenArgs {
171 client_id: Uuid::new_v4(),
172 secret: Uuid::new_v4(),
173 };
174 let exchange_result = client
175 .exchange_client_secret_for_token(args, &format!("http://{addr}/{code}"), &metrics)
176 .await
177 .map(|_| ())
178 .map_err(|e| e.to_string());
179 let prev_count = count.swap(0, Ordering::Relaxed);
180 let expected_count = should_retry.then_some(3).unwrap_or(1);
181 assert_eq!(prev_count, expected_count);
182
183 exchange_result
184 }
185
186 assert_ok!(test_case(&client, &addr, &count, 500, true).await);
188 assert_ok!(test_case(&client, &addr, &count, 502, true).await);
189 assert_ok!(test_case(&client, &addr, &count, 429, true).await);
190 assert_ok!(test_case(&client, &addr, &count, 408, true).await);
191
192 assert_err!(test_case(&client, &addr, &count, 404, false).await);
194 assert_err!(test_case(&client, &addr, &count, 400, false).await);
195 }
196}