mz_frontegg_auth/client/
tokens.rs
1use std::time::Instant;
11
12use mz_ore::instrument;
13use uuid::Uuid;
14
15use crate::metrics::Metrics;
16use crate::{Client, Error};
17
18const FRONTEGG_TRACE_ID_HEADER: &str = "frontegg-trace-id";
20
21impl Client {
22 #[instrument]
24 pub async fn exchange_client_secret_for_token(
25 &self,
26 request: ApiTokenArgs,
27 admin_api_token_url: &str,
28 metrics: &Metrics,
29 ) -> Result<ApiTokenResponse, Error> {
30 let name = "exchange_secret_for_token";
31 let histogram = metrics.request_duration_seconds.with_label_values(&[name]);
32
33 let start = Instant::now();
34 let response = self
35 .client
36 .post(admin_api_token_url)
37 .json(&request)
38 .send()
39 .await?;
40 let duration = start.elapsed();
41
42 histogram.observe(duration.as_secs_f64());
45
46 let status = response.status().to_string();
47 metrics
48 .http_request_count
49 .with_label_values(&[name, &status])
50 .inc();
51
52 let frontegg_trace_id = response
53 .headers()
54 .get(FRONTEGG_TRACE_ID_HEADER)
55 .and_then(|v| v.to_str().ok())
56 .map(|v| v.to_string());
57
58 match response.error_for_status_ref() {
59 Ok(_) => {
60 tracing::debug!(
61 ?request.client_id,
62 frontegg_trace_id,
63 ?duration,
64 "request success",
65 );
66 Ok(response.json().await?)
67 }
68 Err(e) => {
69 let body = response
70 .text()
71 .await
72 .unwrap_or_else(|_| "failed to deserialize body".to_string());
73 tracing::warn!(frontegg_trace_id, body, "request failed");
74 return Err(e.into());
75 }
76 }
77 }
78}
79
80#[derive(Clone, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct ApiTokenArgs {
83 pub client_id: Uuid,
84 pub secret: Uuid,
85}
86
87#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct ApiTokenResponse {
90 pub expires: String,
91 pub expires_in: i64,
92 pub access_token: String,
93 pub refresh_token: String,
94}
95
96#[cfg(test)]
97mod tests {
98 use axum::http::StatusCode;
99 use axum::{Router, routing::post};
100 use mz_ore::metrics::MetricsRegistry;
101 use mz_ore::{assert_err, assert_ok};
102 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
103 use std::sync::Arc;
104 use std::sync::atomic::{AtomicUsize, Ordering};
105 use tokio::net::TcpListener;
106 use uuid::Uuid;
107
108 use super::ApiTokenResponse;
109 use crate::metrics::Metrics;
110 use crate::{ApiTokenArgs, Client};
111
112 #[mz_ore::test(tokio::test)]
113 #[cfg_attr(miri, ignore)] async fn response_retries() {
115 let count = Arc::new(AtomicUsize::new(0));
116 let count_ = Arc::clone(&count);
117
118 let app = Router::new().route(
120 "/{:status_code}",
121 post(
122 |axum::extract::Path(code): axum::extract::Path<u16>| async move {
123 let cnt = count_.fetch_add(1, Ordering::Relaxed);
124 println!("cnt: {cnt}");
125
126 let resp = ApiTokenResponse {
127 expires: "test".to_string(),
128 expires_in: 0,
129 access_token: "test".to_string(),
130 refresh_token: "test".to_string(),
131 };
132 let resp = serde_json::to_string(&resp).unwrap();
133
134 if cnt >= 2 {
135 Ok(resp.clone())
136 } else {
137 Err(StatusCode::from_u16(code).unwrap())
138 }
139 },
140 ),
141 );
142
143 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
145 let tcp = TcpListener::bind(addr).await.expect("able to bind");
146 let addr = tcp.local_addr().expect("valid addr");
147 mz_ore::task::spawn(|| "test-server", async move {
148 axum::serve(tcp, app.into_make_service()).await.unwrap();
149 });
150
151 let client = Client::default();
152 async fn test_case(
153 client: &Client,
154 addr: &SocketAddr,
155 count: &Arc<AtomicUsize>,
156 code: u16,
157 should_retry: bool,
158 ) -> Result<(), String> {
159 let registry = MetricsRegistry::new();
160 let metrics = Metrics::register_into(®istry);
161
162 let args = ApiTokenArgs {
163 client_id: Uuid::new_v4(),
164 secret: Uuid::new_v4(),
165 };
166 let exchange_result = client
167 .exchange_client_secret_for_token(args, &format!("http://{addr}/{code}"), &metrics)
168 .await
169 .map(|_| ())
170 .map_err(|e| e.to_string());
171 let prev_count = count.swap(0, Ordering::Relaxed);
172 let expected_count = should_retry.then_some(3).unwrap_or(1);
173 assert_eq!(prev_count, expected_count);
174
175 exchange_result
176 }
177
178 assert_ok!(test_case(&client, &addr, &count, 500, true).await);
180 assert_ok!(test_case(&client, &addr, &count, 502, true).await);
181 assert_ok!(test_case(&client, &addr, &count, 429, true).await);
182 assert_ok!(test_case(&client, &addr, &count, 408, true).await);
183
184 assert_err!(test_case(&client, &addr, &count, 404, false).await);
186 assert_err!(test_case(&client, &addr, &count, 400, false).await);
187 }
188}