mz_frontegg_mock/handlers/
auth.rs1use crate::models::*;
11use crate::server::Context;
12use crate::utils::{
13 RefreshTokenTarget, generate_access_token, generate_refresh_token, get_user_groups,
14};
15use axum::{Json, extract::State, http::StatusCode};
16use mz_frontegg_auth::{ApiTokenResponse, ClaimTokenType};
17use std::sync::Arc;
18use std::sync::atomic::Ordering;
19
20pub async fn handle_post_auth_api_token(
21 State(context): State<Arc<Context>>,
22 Json(request): Json<ApiToken>,
23) -> Result<Json<ApiTokenResponse>, StatusCode> {
24 *context.auth_requests.lock().unwrap() += 1;
25
26 if !context.enable_auth.load(Ordering::Relaxed) {
27 return Err(StatusCode::UNAUTHORIZED);
28 }
29
30 let user_api_tokens = context.user_api_tokens.lock().unwrap();
31 let access_token = match user_api_tokens
32 .iter()
33 .find(|(token, _)| token.client_id == request.client_id && token.secret == request.secret)
34 {
35 Some((_, email)) => {
36 let users = context.users.lock().unwrap();
37 let user = users
38 .get(email)
39 .expect("API tokens are only created by logged in valid users.");
40 let groups = get_user_groups(&context, &user.id);
41 generate_access_token(
42 &context,
43 ClaimTokenType::UserApiToken,
44 request.client_id,
45 Some(email.to_owned()),
46 Some(user.id),
47 user.tenant_id,
48 user.roles.clone(),
49 groups,
50 None,
51 )
52 }
53 None => {
54 let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
55 match tenant_api_tokens.iter().find(|(token, _)| {
56 token.client_id == request.client_id && token.secret == request.secret
57 }) {
58 Some((_, config)) => generate_access_token(
59 &context,
60 ClaimTokenType::TenantApiToken,
61 request.client_id,
62 None,
63 None,
64 config.tenant_id,
65 config.roles.clone(),
66 Vec::new(),
67 config.metadata.clone(),
68 ),
69 None => return Err(StatusCode::UNAUTHORIZED),
70 }
71 }
72 };
73 let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::ApiToken(request));
74 Ok(Json(ApiTokenResponse {
75 expires: "".to_string(),
76 expires_in: context.expires_in_secs,
77 access_token,
78 refresh_token,
79 }))
80}
81
82pub async fn handle_post_auth_user(
83 State(context): State<Arc<Context>>,
84 Json(request): Json<AuthUserRequest>,
85) -> Result<Json<ApiTokenResponse>, StatusCode> {
86 *context.auth_requests.lock().unwrap() += 1;
87
88 if !context.enable_auth.load(Ordering::Relaxed) {
89 return Err(StatusCode::UNAUTHORIZED);
90 }
91
92 let users = context.users.lock().unwrap();
93 let user = match users.get(&request.email) {
94 Some(user) if request.password == user.password => user.to_owned(),
95 _ => return Err(StatusCode::UNAUTHORIZED),
96 };
97 let groups = get_user_groups(&context, &user.id);
98 let access_token = generate_access_token(
99 &context,
100 ClaimTokenType::UserToken,
101 user.id,
102 Some(request.email.clone()),
103 Some(user.id),
104 user.tenant_id,
105 user.roles.clone(),
106 groups,
107 None,
108 );
109 let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::User(request));
110 Ok(Json(ApiTokenResponse {
111 expires: "".to_string(),
112 expires_in: context.expires_in_secs,
113 access_token,
114 refresh_token,
115 }))
116}
117
118pub async fn handle_post_token_refresh(
119 State(context): State<Arc<Context>>,
120 Json(previous_refresh_token): Json<RefreshTokenRequest>,
121) -> Result<Json<ApiTokenResponse>, StatusCode> {
122 *context.refreshes.lock().unwrap() += 1;
124
125 if !context.enable_auth.load(Ordering::Relaxed) {
126 return Err(StatusCode::UNAUTHORIZED);
127 }
128
129 let maybe_target = context
130 .refresh_tokens
131 .lock()
132 .unwrap()
133 .remove(&previous_refresh_token.refresh_token);
134 let Some(target) = maybe_target else {
135 return Err(StatusCode::UNAUTHORIZED);
136 };
137
138 match target {
139 RefreshTokenTarget::User(request) => {
140 handle_post_auth_user(State(context), Json(request)).await
141 }
142 RefreshTokenTarget::ApiToken(request) => {
143 handle_post_auth_api_token(State(context), Json(request)).await
144 }
145 }
146}