mz_frontegg_mock/handlers/
auth.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::models::*;
11use crate::server::Context;
12use crate::utils::{RefreshTokenTarget, generate_access_token, generate_refresh_token};
13use axum::{Json, extract::State, http::StatusCode};
14use mz_frontegg_auth::{ApiTokenResponse, ClaimTokenType};
15use std::sync::Arc;
16use std::sync::atomic::Ordering;
17
18pub async fn handle_post_auth_api_token(
19    State(context): State<Arc<Context>>,
20    Json(request): Json<ApiToken>,
21) -> Result<Json<ApiTokenResponse>, StatusCode> {
22    *context.auth_requests.lock().unwrap() += 1;
23
24    if !context.enable_auth.load(Ordering::Relaxed) {
25        return Err(StatusCode::UNAUTHORIZED);
26    }
27
28    let user_api_tokens = context.user_api_tokens.lock().unwrap();
29    let access_token = match user_api_tokens
30        .iter()
31        .find(|(token, _)| token.client_id == request.client_id && token.secret == request.secret)
32    {
33        Some((_, email)) => {
34            let users = context.users.lock().unwrap();
35            let user = users
36                .get(email)
37                .expect("API tokens are only created by logged in valid users.");
38            generate_access_token(
39                &context,
40                ClaimTokenType::UserApiToken,
41                request.client_id,
42                Some(email.to_owned()),
43                Some(user.id),
44                user.tenant_id,
45                user.roles.clone(),
46                None,
47            )
48        }
49        None => {
50            let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
51            match tenant_api_tokens.iter().find(|(token, _)| {
52                token.client_id == request.client_id && token.secret == request.secret
53            }) {
54                Some((_, config)) => generate_access_token(
55                    &context,
56                    ClaimTokenType::TenantApiToken,
57                    request.client_id,
58                    None,
59                    None,
60                    config.tenant_id,
61                    config.roles.clone(),
62                    config.metadata.clone(),
63                ),
64                None => return Err(StatusCode::UNAUTHORIZED),
65            }
66        }
67    };
68    let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::ApiToken(request));
69    Ok(Json(ApiTokenResponse {
70        expires: "".to_string(),
71        expires_in: context.expires_in_secs,
72        access_token,
73        refresh_token,
74    }))
75}
76
77pub async fn handle_post_auth_user(
78    State(context): State<Arc<Context>>,
79    Json(request): Json<AuthUserRequest>,
80) -> Result<Json<ApiTokenResponse>, StatusCode> {
81    *context.auth_requests.lock().unwrap() += 1;
82
83    if !context.enable_auth.load(Ordering::Relaxed) {
84        return Err(StatusCode::UNAUTHORIZED);
85    }
86
87    let users = context.users.lock().unwrap();
88    let user = match users.get(&request.email) {
89        Some(user) if request.password == user.password => user.to_owned(),
90        _ => return Err(StatusCode::UNAUTHORIZED),
91    };
92    let access_token = generate_access_token(
93        &context,
94        ClaimTokenType::UserToken,
95        user.id,
96        Some(request.email.clone()),
97        Some(user.id),
98        user.tenant_id,
99        user.roles.clone(),
100        None,
101    );
102    let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::User(request));
103    Ok(Json(ApiTokenResponse {
104        expires: "".to_string(),
105        expires_in: context.expires_in_secs,
106        access_token,
107        refresh_token,
108    }))
109}
110
111pub async fn handle_post_token_refresh(
112    State(context): State<Arc<Context>>,
113    Json(previous_refresh_token): Json<RefreshTokenRequest>,
114) -> Result<Json<ApiTokenResponse>, StatusCode> {
115    // Always count refresh attempts, even if enable_refresh is false.
116    *context.refreshes.lock().unwrap() += 1;
117
118    if !context.enable_auth.load(Ordering::Relaxed) {
119        return Err(StatusCode::UNAUTHORIZED);
120    }
121
122    let maybe_target = context
123        .refresh_tokens
124        .lock()
125        .unwrap()
126        .remove(&previous_refresh_token.refresh_token);
127    let Some(target) = maybe_target else {
128        return Err(StatusCode::UNAUTHORIZED);
129    };
130
131    match target {
132        RefreshTokenTarget::User(request) => {
133            handle_post_auth_user(State(context), Json(request)).await
134        }
135        RefreshTokenTarget::ApiToken(request) => {
136            handle_post_auth_api_token(State(context), Json(request)).await
137        }
138    }
139}