mz_frontegg_mock/handlers/
auth.rs
1use crate::models::*;
11use crate::server::Context;
12use crate::utils::{RefreshTokenTarget, generate_access_token, generate_refresh_token};
13use axum::{Json, extract::State, http::StatusCode};
14use mz_frontegg_auth::{ApiTokenResponse, ClaimTokenType};
15use std::sync::Arc;
16use std::sync::atomic::Ordering;
17
18pub async fn handle_post_auth_api_token(
19 State(context): State<Arc<Context>>,
20 Json(request): Json<ApiToken>,
21) -> Result<Json<ApiTokenResponse>, StatusCode> {
22 *context.auth_requests.lock().unwrap() += 1;
23
24 if !context.enable_auth.load(Ordering::Relaxed) {
25 return Err(StatusCode::UNAUTHORIZED);
26 }
27
28 let user_api_tokens = context.user_api_tokens.lock().unwrap();
29 let access_token = match user_api_tokens
30 .iter()
31 .find(|(token, _)| token.client_id == request.client_id && token.secret == request.secret)
32 {
33 Some((_, email)) => {
34 let users = context.users.lock().unwrap();
35 let user = users
36 .get(email)
37 .expect("API tokens are only created by logged in valid users.");
38 generate_access_token(
39 &context,
40 ClaimTokenType::UserApiToken,
41 request.client_id,
42 Some(email.to_owned()),
43 Some(user.id),
44 user.tenant_id,
45 user.roles.clone(),
46 None,
47 )
48 }
49 None => {
50 let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
51 match tenant_api_tokens.iter().find(|(token, _)| {
52 token.client_id == request.client_id && token.secret == request.secret
53 }) {
54 Some((_, config)) => generate_access_token(
55 &context,
56 ClaimTokenType::TenantApiToken,
57 request.client_id,
58 None,
59 None,
60 config.tenant_id,
61 config.roles.clone(),
62 config.metadata.clone(),
63 ),
64 None => return Err(StatusCode::UNAUTHORIZED),
65 }
66 }
67 };
68 let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::ApiToken(request));
69 Ok(Json(ApiTokenResponse {
70 expires: "".to_string(),
71 expires_in: context.expires_in_secs,
72 access_token,
73 refresh_token,
74 }))
75}
76
77pub async fn handle_post_auth_user(
78 State(context): State<Arc<Context>>,
79 Json(request): Json<AuthUserRequest>,
80) -> Result<Json<ApiTokenResponse>, StatusCode> {
81 *context.auth_requests.lock().unwrap() += 1;
82
83 if !context.enable_auth.load(Ordering::Relaxed) {
84 return Err(StatusCode::UNAUTHORIZED);
85 }
86
87 let users = context.users.lock().unwrap();
88 let user = match users.get(&request.email) {
89 Some(user) if request.password == user.password => user.to_owned(),
90 _ => return Err(StatusCode::UNAUTHORIZED),
91 };
92 let access_token = generate_access_token(
93 &context,
94 ClaimTokenType::UserToken,
95 user.id,
96 Some(request.email.clone()),
97 Some(user.id),
98 user.tenant_id,
99 user.roles.clone(),
100 None,
101 );
102 let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::User(request));
103 Ok(Json(ApiTokenResponse {
104 expires: "".to_string(),
105 expires_in: context.expires_in_secs,
106 access_token,
107 refresh_token,
108 }))
109}
110
111pub async fn handle_post_token_refresh(
112 State(context): State<Arc<Context>>,
113 Json(previous_refresh_token): Json<RefreshTokenRequest>,
114) -> Result<Json<ApiTokenResponse>, StatusCode> {
115 *context.refreshes.lock().unwrap() += 1;
117
118 if !context.enable_auth.load(Ordering::Relaxed) {
119 return Err(StatusCode::UNAUTHORIZED);
120 }
121
122 let maybe_target = context
123 .refresh_tokens
124 .lock()
125 .unwrap()
126 .remove(&previous_refresh_token.refresh_token);
127 let Some(target) = maybe_target else {
128 return Err(StatusCode::UNAUTHORIZED);
129 };
130
131 match target {
132 RefreshTokenTarget::User(request) => {
133 handle_post_auth_user(State(context), Json(request)).await
134 }
135 RefreshTokenTarget::ApiToken(request) => {
136 handle_post_auth_api_token(State(context), Json(request)).await
137 }
138 }
139}