1use std::time::{Duration, SystemTime};
23
24use jsonwebtoken::jwk::JwkSet;
25use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
26use mz_frontegg_auth::{AppPassword, Claims};
27use reqwest::{Method, RequestBuilder};
28use serde::de::DeserializeOwned;
29use serde::{Deserialize, Serialize};
30use tokio::sync::Mutex;
31use url::Url;
32
33use crate::config::{ClientBuilder, ClientConfig};
34use crate::error::{ApiError, Error};
35
36pub mod app_password;
37pub mod role;
38pub mod user;
39
40const CREDENTIALS_AUTH_PATH: [&str; 5] = ["identity", "resources", "auth", "v1", "user"];
41const APP_PASSWORD_AUTH_PATH: [&str; 5] = ["identity", "resources", "auth", "v1", "api-token"];
42
43const REFRESH_AUTH_PATH: [&str; 7] = [
44 "identity",
45 "resources",
46 "auth",
47 "v1",
48 "api-token",
49 "token",
50 "refresh",
51];
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(rename_all = "camelCase")]
55struct AuthenticationResponse {
56 access_token: String,
57 expires: String,
58 expires_in: i64,
59 refresh_token: String,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64struct CredentialsAuthenticationRequest<'a> {
65 email: &'a str,
66 password: &'a str,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71struct AppPasswordAuthenticationRequest<'a> {
72 client_id: &'a str,
73 secret: &'a str,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
77#[serde(rename_all = "camelCase")]
78struct RefreshRequest<'a> {
79 refresh_token: &'a str,
80}
81
82#[derive(Debug, Clone)]
83pub(crate) struct Auth {
84 token: String,
85 refresh_at: SystemTime,
87 refresh_token: String,
88}
89
90pub enum Authentication {
97 Credentials(Credentials),
100 AppPassword(AppPassword),
103}
104
105pub struct Credentials {
109 pub email: String,
111 pub password: String,
113}
114
115pub struct Client {
123 pub(crate) inner: reqwest::Client,
124 pub(crate) authentication: Authentication,
125 pub(crate) endpoint: Url,
126 pub(crate) auth: Mutex<Option<Auth>>,
127}
128
129impl Client {
130 pub fn new(config: ClientConfig) -> Client {
132 ClientBuilder::default().build(config)
133 }
134
135 pub fn builder() -> ClientBuilder {
138 ClientBuilder::default()
139 }
140
141 fn build_request<P>(&self, method: Method, path: P) -> RequestBuilder
143 where
144 P: IntoIterator,
145 P::Item: AsRef<str>,
146 {
147 let mut url = self.endpoint.clone();
148 url.path_segments_mut()
149 .expect("builder validated URL can be a base")
150 .clear()
151 .extend(path);
152 self.inner.request(method, url)
153 }
154
155 async fn send_request<T>(&self, req: RequestBuilder) -> Result<T, Error>
157 where
158 T: DeserializeOwned,
159 {
160 let token = self.auth().await?;
161 let req = req.bearer_auth(token);
162 self.send_unauthenticated_request(req).await
163 }
164
165 async fn send_unauthenticated_request<T>(&self, req: RequestBuilder) -> Result<T, Error>
166 where
167 T: DeserializeOwned,
168 {
169 #[derive(Deserialize)]
170 #[serde(rename_all = "camelCase")]
171 struct ErrorResponse {
172 #[serde(default)]
173 message: Option<String>,
174 #[serde(default)]
175 errors: Vec<String>,
176 }
177
178 let res = req.send().await?;
179 let status_code = res.status();
180 if status_code.is_success() {
181 Ok(res.json().await?)
182 } else {
183 match res.json::<ErrorResponse>().await {
184 Ok(e) => {
185 let mut messages = e.errors;
186 messages.extend(e.message);
187 Err(Error::Api(ApiError {
188 status_code,
189 messages,
190 }))
191 }
192 Err(_) => Err(Error::Api(ApiError {
193 status_code,
194 messages: vec!["unable to decode error details".into()],
195 })),
196 }
197 }
198 }
199
200 pub async fn auth(&self) -> Result<String, Error> {
203 let mut auth = self.auth.lock().await;
204 let mut req;
205
206 match &*auth {
207 Some(auth) => {
208 if SystemTime::now() < auth.refresh_at {
209 return Ok(auth.token.clone());
210 } else {
211 req = self.build_request(Method::POST, REFRESH_AUTH_PATH);
213 let refresh_request = RefreshRequest {
214 refresh_token: auth.refresh_token.as_str(),
215 };
216 req = req.json(&refresh_request);
217 }
218 }
219 None => {
220 match &self.authentication {
221 Authentication::Credentials(credentials) => {
222 req = self.build_request(Method::POST, CREDENTIALS_AUTH_PATH);
224
225 let authentication_request = CredentialsAuthenticationRequest {
226 email: &credentials.email,
227 password: &credentials.password,
228 };
229 req = req.json(&authentication_request);
230 }
231 Authentication::AppPassword(app_password) => {
232 req = self.build_request(Method::POST, APP_PASSWORD_AUTH_PATH);
233
234 let authentication_request = AppPasswordAuthenticationRequest {
235 client_id: &app_password.client_id.to_string(),
236 secret: &app_password.secret_key.to_string(),
237 };
238 req = req.json(&authentication_request);
239 }
240 }
241 }
242 }
243
244 let res: AuthenticationResponse = self.send_unauthenticated_request(req).await?;
246
247 *auth = Some(Auth {
248 token: res.access_token.clone(),
249 refresh_at: SystemTime::now()
251 + (Duration::from_secs(res.expires_in.try_into().unwrap()) / 2),
252 refresh_token: res.refresh_token,
253 });
254 Ok(res.access_token)
255 }
256
257 async fn get_jwks(&self) -> Result<JwkSet, Error> {
259 let well_known = vec![".well-known", "jwks.json"];
260 let req = self.build_request(Method::GET, well_known);
261 let jwks: JwkSet = self.send_request(req).await?;
262 Ok(jwks)
263 }
264
265 pub async fn claims(&self) -> Result<Claims, Error> {
268 let jwks = self.get_jwks().await.map_err(|_| Error::FetchingJwks)?;
269 let jwk = jwks.keys.first().ok_or_else(|| Error::EmptyJwks)?;
270 let token = self.auth().await?;
271
272 let mut validation = Validation::new(Algorithm::RS256);
273
274 validation.validate_aud = false;
288
289 let token_data = decode::<Claims>(
290 &token,
291 &DecodingKey::from_jwk(jwk).map_err(|_| Error::ConvertingJwks)?,
292 &validation,
293 )
294 .map_err(Error::DecodingClaims)?;
295
296 Ok(token_data.claims)
297 }
298}