mz_frontegg_auth/client/
tokens.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::time::Instant;
11
12use mz_ore::instrument;
13use uuid::Uuid;
14
15use crate::metrics::Metrics;
16use crate::{Client, Error};
17
18/// Frontegg includes a trace id in the headers of a response to aid in debugging.
19const FRONTEGG_TRACE_ID_HEADER: &str = "frontegg-trace-id";
20
21impl Client {
22    /// Exchanges a client id and secret for a jwt token.
23    #[instrument]
24    pub async fn exchange_client_secret_for_token(
25        &self,
26        request: ApiTokenArgs,
27        admin_api_token_url: &str,
28        metrics: &Metrics,
29    ) -> Result<ApiTokenResponse, Error> {
30        let name = "exchange_secret_for_token";
31        let histogram = metrics.request_duration_seconds.with_label_values(&[name]);
32
33        let start = Instant::now();
34        let response = self
35            .client
36            .post(admin_api_token_url)
37            .json(&request)
38            .send()
39            .await?;
40        let duration = start.elapsed();
41
42        // Authentication is on the blocking path for connection startup so we
43        // want to make sure it stays fast.
44        histogram.observe(duration.as_secs_f64());
45
46        let status = response.status().to_string();
47        metrics
48            .http_request_count
49            .with_label_values(&[name, &status])
50            .inc();
51
52        let frontegg_trace_id = response
53            .headers()
54            .get(FRONTEGG_TRACE_ID_HEADER)
55            .and_then(|v| v.to_str().ok())
56            .map(|v| v.to_string());
57
58        match response.error_for_status_ref() {
59            Ok(_) => {
60                tracing::debug!(
61                    ?request.client_id,
62                    frontegg_trace_id,
63                    ?duration,
64                    "request success",
65                );
66                Ok(response.json().await?)
67            }
68            Err(e) => {
69                let body = response
70                    .text()
71                    .await
72                    .unwrap_or_else(|_| "failed to deserialize body".to_string());
73                tracing::warn!(frontegg_trace_id, body, "request failed");
74                return Err(e.into());
75            }
76        }
77    }
78}
79
80#[derive(Clone, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct ApiTokenArgs {
83    pub client_id: Uuid,
84    pub secret: Uuid,
85}
86
87#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct ApiTokenResponse {
90    pub expires: String,
91    pub expires_in: i64,
92    pub access_token: String,
93    pub refresh_token: String,
94}
95
96#[cfg(test)]
97mod tests {
98    use axum::http::StatusCode;
99    use axum::{Router, routing::post};
100    use mz_ore::metrics::MetricsRegistry;
101    use mz_ore::{assert_err, assert_ok};
102    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
103    use std::sync::Arc;
104    use std::sync::atomic::{AtomicUsize, Ordering};
105    use tokio::net::TcpListener;
106    use uuid::Uuid;
107
108    use super::ApiTokenResponse;
109    use crate::metrics::Metrics;
110    use crate::{ApiTokenArgs, Client};
111
112    #[mz_ore::test(tokio::test)]
113    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `TLS_method` on OS `linux`
114    async fn response_retries() {
115        let count = Arc::new(AtomicUsize::new(0));
116        let count_ = Arc::clone(&count);
117
118        // Fake server that returns the provided status code a few times before returning success.
119        let app = Router::new().route(
120            "/{:status_code}",
121            post(
122                |axum::extract::Path(code): axum::extract::Path<u16>| async move {
123                    let cnt = count_.fetch_add(1, Ordering::Relaxed);
124                    println!("cnt: {cnt}");
125
126                    let resp = ApiTokenResponse {
127                        expires: "test".to_string(),
128                        expires_in: 0,
129                        access_token: "test".to_string(),
130                        refresh_token: "test".to_string(),
131                    };
132                    let resp = serde_json::to_string(&resp).unwrap();
133
134                    if cnt >= 2 {
135                        Ok(resp.clone())
136                    } else {
137                        Err(StatusCode::from_u16(code).unwrap())
138                    }
139                },
140            ),
141        );
142
143        // Use port 0 to get a dynamically assigned port.
144        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
145        let tcp = TcpListener::bind(addr).await.expect("able to bind");
146        let addr = tcp.local_addr().expect("valid addr");
147        mz_ore::task::spawn(|| "test-server", async move {
148            axum::serve(tcp, app.into_make_service()).await.unwrap();
149        });
150
151        let client = Client::default();
152        async fn test_case(
153            client: &Client,
154            addr: &SocketAddr,
155            count: &Arc<AtomicUsize>,
156            code: u16,
157            should_retry: bool,
158        ) -> Result<(), String> {
159            let registry = MetricsRegistry::new();
160            let metrics = Metrics::register_into(&registry);
161
162            let args = ApiTokenArgs {
163                client_id: Uuid::new_v4(),
164                secret: Uuid::new_v4(),
165            };
166            let exchange_result = client
167                .exchange_client_secret_for_token(args, &format!("http://{addr}/{code}"), &metrics)
168                .await
169                .map(|_| ())
170                .map_err(|e| e.to_string());
171            let prev_count = count.swap(0, Ordering::Relaxed);
172            let expected_count = should_retry.then_some(3).unwrap_or(1);
173            assert_eq!(prev_count, expected_count);
174
175            exchange_result
176        }
177
178        // Should get retried which results in eventual success.
179        assert_ok!(test_case(&client, &addr, &count, 500, true).await);
180        assert_ok!(test_case(&client, &addr, &count, 502, true).await);
181        assert_ok!(test_case(&client, &addr, &count, 429, true).await);
182        assert_ok!(test_case(&client, &addr, &count, 408, true).await);
183
184        // Should not get retried, and thus return an error.
185        assert_err!(test_case(&client, &addr, &count, 404, false).await);
186        assert_err!(test_case(&client, &addr, &count, 400, false).await);
187    }
188}