Skip to main content

mz_frontegg_mock/handlers/
auth.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::models::*;
11use crate::server::Context;
12use crate::utils::{
13    RefreshTokenTarget, generate_access_token, generate_refresh_token, get_user_groups,
14};
15use axum::{Json, extract::State, http::StatusCode};
16use mz_frontegg_auth::{ApiTokenResponse, ClaimTokenType};
17use std::sync::Arc;
18use std::sync::atomic::Ordering;
19
20pub async fn handle_post_auth_api_token(
21    State(context): State<Arc<Context>>,
22    Json(request): Json<ApiToken>,
23) -> Result<Json<ApiTokenResponse>, StatusCode> {
24    *context.auth_requests.lock().unwrap() += 1;
25
26    if !context.enable_auth.load(Ordering::Relaxed) {
27        return Err(StatusCode::UNAUTHORIZED);
28    }
29
30    let user_api_tokens = context.user_api_tokens.lock().unwrap();
31    let access_token = match user_api_tokens
32        .iter()
33        .find(|(token, _)| token.client_id == request.client_id && token.secret == request.secret)
34    {
35        Some((_, email)) => {
36            let users = context.users.lock().unwrap();
37            let user = users
38                .get(email)
39                .expect("API tokens are only created by logged in valid users.");
40            let groups = get_user_groups(&context, &user.id);
41            generate_access_token(
42                &context,
43                ClaimTokenType::UserApiToken,
44                request.client_id,
45                Some(email.to_owned()),
46                Some(user.id),
47                user.tenant_id,
48                user.roles.clone(),
49                groups,
50                None,
51            )
52        }
53        None => {
54            let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
55            match tenant_api_tokens.iter().find(|(token, _)| {
56                token.client_id == request.client_id && token.secret == request.secret
57            }) {
58                Some((_, config)) => generate_access_token(
59                    &context,
60                    ClaimTokenType::TenantApiToken,
61                    request.client_id,
62                    None,
63                    None,
64                    config.tenant_id,
65                    config.roles.clone(),
66                    Vec::new(),
67                    config.metadata.clone(),
68                ),
69                None => return Err(StatusCode::UNAUTHORIZED),
70            }
71        }
72    };
73    let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::ApiToken(request));
74    Ok(Json(ApiTokenResponse {
75        expires: "".to_string(),
76        expires_in: context.expires_in_secs,
77        access_token,
78        refresh_token,
79    }))
80}
81
82pub async fn handle_post_auth_user(
83    State(context): State<Arc<Context>>,
84    Json(request): Json<AuthUserRequest>,
85) -> Result<Json<ApiTokenResponse>, StatusCode> {
86    *context.auth_requests.lock().unwrap() += 1;
87
88    if !context.enable_auth.load(Ordering::Relaxed) {
89        return Err(StatusCode::UNAUTHORIZED);
90    }
91
92    let users = context.users.lock().unwrap();
93    let user = match users.get(&request.email) {
94        Some(user) if request.password == user.password => user.to_owned(),
95        _ => return Err(StatusCode::UNAUTHORIZED),
96    };
97    let groups = get_user_groups(&context, &user.id);
98    let access_token = generate_access_token(
99        &context,
100        ClaimTokenType::UserToken,
101        user.id,
102        Some(request.email.clone()),
103        Some(user.id),
104        user.tenant_id,
105        user.roles.clone(),
106        groups,
107        None,
108    );
109    let refresh_token = generate_refresh_token(&context, RefreshTokenTarget::User(request));
110    Ok(Json(ApiTokenResponse {
111        expires: "".to_string(),
112        expires_in: context.expires_in_secs,
113        access_token,
114        refresh_token,
115    }))
116}
117
118pub async fn handle_post_token_refresh(
119    State(context): State<Arc<Context>>,
120    Json(previous_refresh_token): Json<RefreshTokenRequest>,
121) -> Result<Json<ApiTokenResponse>, StatusCode> {
122    // Always count refresh attempts, even if enable_refresh is false.
123    *context.refreshes.lock().unwrap() += 1;
124
125    if !context.enable_auth.load(Ordering::Relaxed) {
126        return Err(StatusCode::UNAUTHORIZED);
127    }
128
129    let maybe_target = context
130        .refresh_tokens
131        .lock()
132        .unwrap()
133        .remove(&previous_refresh_token.refresh_token);
134    let Some(target) = maybe_target else {
135        return Err(StatusCode::UNAUTHORIZED);
136    };
137
138    match target {
139        RefreshTokenTarget::User(request) => {
140            handle_post_auth_user(State(context), Json(request)).await
141        }
142        RefreshTokenTarget::ApiToken(request) => {
143            handle_post_auth_api_token(State(context), Json(request)).await
144        }
145    }
146}