Skip to main content

mz_frontegg_auth/client/
tokens.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::time::Instant;
11
12use mz_ore::instrument;
13use uuid::Uuid;
14
15use crate::metrics::Metrics;
16use crate::{Client, Error};
17
18/// Frontegg includes a trace id in the headers of a response to aid in debugging.
19const FRONTEGG_TRACE_ID_HEADER: &str = "frontegg-trace-id";
20
21impl Client {
22    /// Exchanges a client id and secret for a jwt token.
23    #[instrument]
24    pub async fn exchange_client_secret_for_token(
25        &self,
26        request: ApiTokenArgs,
27        admin_api_token_url: &str,
28        metrics: &Metrics,
29    ) -> Result<ApiTokenResponse, Error> {
30        let name = "exchange_secret_for_token";
31        let histogram = metrics.request_duration_seconds.with_label_values(&[name]);
32
33        let start = Instant::now();
34        let response = self
35            .client
36            .post(admin_api_token_url)
37            .json(&request)
38            .send()
39            .await?;
40        let duration = start.elapsed();
41
42        // Authentication is on the blocking path for connection startup so we
43        // want to make sure it stays fast.
44        histogram.observe(duration.as_secs_f64());
45
46        let status = response.status().to_string();
47        metrics
48            .http_request_count
49            .with_label_values(&[name, &status])
50            .inc();
51
52        let frontegg_trace_id = response
53            .headers()
54            .get(FRONTEGG_TRACE_ID_HEADER)
55            .and_then(|v| v.to_str().ok())
56            .map(|v| v.to_string());
57
58        match response.error_for_status_ref() {
59            Ok(_) => {
60                tracing::debug!(
61                    ?request.client_id,
62                    frontegg_trace_id,
63                    ?duration,
64                    "request success",
65                );
66                Ok(response.json().await?)
67            }
68            Err(e) => {
69                let body = response
70                    .text()
71                    .await
72                    .unwrap_or_else(|_| "failed to deserialize body".to_string());
73                tracing::warn!(frontegg_trace_id, body, "request failed");
74                return Err(e.into());
75            }
76        }
77    }
78}
79
80#[derive(
81    Clone,
82    Debug,
83    Hash,
84    PartialEq,
85    Eq,
86    serde::Serialize,
87    serde::Deserialize
88)]
89#[serde(rename_all = "camelCase")]
90pub struct ApiTokenArgs {
91    pub client_id: Uuid,
92    pub secret: Uuid,
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct ApiTokenResponse {
98    pub expires: String,
99    pub expires_in: i64,
100    pub access_token: String,
101    pub refresh_token: String,
102}
103
104#[cfg(test)]
105mod tests {
106    use axum::http::StatusCode;
107    use axum::{Router, routing::post};
108    use mz_ore::metrics::MetricsRegistry;
109    use mz_ore::{assert_err, assert_ok};
110    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
111    use std::sync::Arc;
112    use std::sync::atomic::{AtomicUsize, Ordering};
113    use tokio::net::TcpListener;
114    use uuid::Uuid;
115
116    use super::ApiTokenResponse;
117    use crate::metrics::Metrics;
118    use crate::{ApiTokenArgs, Client};
119
120    #[mz_ore::test(tokio::test)]
121    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `TLS_method` on OS `linux`
122    async fn response_retries() {
123        let count = Arc::new(AtomicUsize::new(0));
124        let count_ = Arc::clone(&count);
125
126        // Fake server that returns the provided status code a few times before returning success.
127        let app = Router::new().route(
128            "/{:status_code}",
129            post(
130                |axum::extract::Path(code): axum::extract::Path<u16>| async move {
131                    let cnt = count_.fetch_add(1, Ordering::Relaxed);
132                    println!("cnt: {cnt}");
133
134                    let resp = ApiTokenResponse {
135                        expires: "test".to_string(),
136                        expires_in: 0,
137                        access_token: "test".to_string(),
138                        refresh_token: "test".to_string(),
139                    };
140                    let resp = serde_json::to_string(&resp).unwrap();
141
142                    if cnt >= 2 {
143                        Ok(resp.clone())
144                    } else {
145                        Err(StatusCode::from_u16(code).unwrap())
146                    }
147                },
148            ),
149        );
150
151        // Use port 0 to get a dynamically assigned port.
152        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
153        let tcp = TcpListener::bind(addr).await.expect("able to bind");
154        let addr = tcp.local_addr().expect("valid addr");
155        mz_ore::task::spawn(|| "test-server", async move {
156            axum::serve(tcp, app.into_make_service()).await.unwrap();
157        });
158
159        let client = Client::default();
160        async fn test_case(
161            client: &Client,
162            addr: &SocketAddr,
163            count: &Arc<AtomicUsize>,
164            code: u16,
165            should_retry: bool,
166        ) -> Result<(), String> {
167            let registry = MetricsRegistry::new();
168            let metrics = Metrics::register_into(&registry);
169
170            let args = ApiTokenArgs {
171                client_id: Uuid::new_v4(),
172                secret: Uuid::new_v4(),
173            };
174            let exchange_result = client
175                .exchange_client_secret_for_token(args, &format!("http://{addr}/{code}"), &metrics)
176                .await
177                .map(|_| ())
178                .map_err(|e| e.to_string());
179            let prev_count = count.swap(0, Ordering::Relaxed);
180            let expected_count = should_retry.then_some(3).unwrap_or(1);
181            assert_eq!(prev_count, expected_count);
182
183            exchange_result
184        }
185
186        // Should get retried which results in eventual success.
187        assert_ok!(test_case(&client, &addr, &count, 500, true).await);
188        assert_ok!(test_case(&client, &addr, &count, 502, true).await);
189        assert_ok!(test_case(&client, &addr, &count, 429, true).await);
190        assert_ok!(test_case(&client, &addr, &count, 408, true).await);
191
192        // Should not get retried, and thus return an error.
193        assert_err!(test_case(&client, &addr, &count, 404, false).await);
194        assert_err!(test_case(&client, &addr, &count, 400, false).await);
195    }
196}