mz_frontegg_mock/handlers/
user.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::models::*;
11use crate::server::Context;
12use crate::utils::get_user_roles;
13use crate::utils::{decode_access_token, generate_access_token};
14use axum::{
15    Json,
16    extract::{Path, Query, State},
17    http::StatusCode,
18};
19use axum_extra::TypedHeader;
20use axum_extra::headers::Authorization;
21use axum_extra::headers::authorization::Bearer;
22use chrono::Utc;
23use jsonwebtoken::TokenData;
24use mz_frontegg_auth::{ClaimTokenType, Claims};
25use std::collections::BTreeMap;
26use std::sync::Arc;
27use uuid::Uuid;
28
29// https://docs.frontegg.com/reference/userscontrollerv2_getuserprofile
30pub async fn handle_get_user_profile(
31    State(context): State<Arc<Context>>,
32    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
33) -> Result<Json<UserProfileResponse>, StatusCode> {
34    let claims: Claims = match decode_access_token(&context, authorization.token()) {
35        Ok(TokenData { claims, .. }) => claims,
36        Err(_) => return Err(StatusCode::UNAUTHORIZED),
37    };
38    Ok(Json(UserProfileResponse {
39        tenant_id: claims.tenant_id,
40    }))
41}
42
43// https://docs.frontegg.com/reference/userapitokensv1controller_createtenantapitoken
44pub async fn handle_post_user_api_token(
45    State(context): State<Arc<Context>>,
46    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
47    Json(request): Json<UserApiTokenRequest>,
48) -> Result<(StatusCode, Json<UserApiTokenResponse>), StatusCode> {
49    let claims: Claims = match decode_access_token(&context, authorization.token()) {
50        Ok(TokenData { claims, .. }) => claims,
51        Err(_) => return Err(StatusCode::UNAUTHORIZED),
52    };
53    let mut tokens = context.user_api_tokens.lock().unwrap();
54    let new_token = ApiToken {
55        client_id: Uuid::new_v4(),
56        secret: Uuid::new_v4(),
57        description: request.description.clone(),
58        created_at: Utc::now(),
59    };
60    tokens.insert(new_token.clone(), claims.email.unwrap());
61
62    let response = UserApiTokenResponse {
63        client_id: new_token.client_id.to_string(),
64        description: new_token.description.unwrap(),
65        created_at: new_token.created_at,
66        secret: new_token.secret.to_string(),
67    };
68
69    Ok((StatusCode::CREATED, Json(response)))
70}
71
72// https://docs.frontegg.com/reference/userapitokensv1controller_getapitokens
73pub async fn handle_list_user_api_tokens(
74    State(context): State<Arc<Context>>,
75    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
76) -> Result<Json<Vec<UserApiTokenResponse>>, StatusCode> {
77    let claims = match decode_access_token(&context, authorization.token()) {
78        Ok(TokenData { claims, .. }) => claims,
79        Err(_) => return Err(StatusCode::UNAUTHORIZED),
80    };
81
82    let user_api_tokens = context.user_api_tokens.lock().unwrap();
83    let tokens: Vec<UserApiTokenResponse> = user_api_tokens
84        .iter()
85        .filter(|(_, email)| *email == claims.email.as_ref().unwrap())
86        .map(|(token, _)| UserApiTokenResponse {
87            client_id: token.client_id.to_string(),
88            description: token.description.clone().unwrap_or_default(),
89            created_at: token.created_at,
90            secret: token.secret.to_string(),
91        })
92        .collect();
93
94    Ok(Json(tokens))
95}
96
97// https://docs.frontegg.com/reference/userapitokensv1controller_deleteapitoken
98pub async fn handle_delete_user_api_token(
99    State(context): State<Arc<Context>>,
100    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
101    Path(token_id): Path<Uuid>,
102) -> StatusCode {
103    let claims = match decode_access_token(&context, authorization.token()) {
104        Ok(TokenData { claims, .. }) => claims,
105        Err(_) => return StatusCode::UNAUTHORIZED,
106    };
107
108    let mut user_api_tokens = context.user_api_tokens.lock().unwrap();
109
110    let removed = user_api_tokens
111        .iter()
112        .find(|(token, email)| {
113            token.client_id == token_id && *email == claims.email.as_ref().unwrap()
114        })
115        .map(|(token, _)| token.clone());
116
117    if let Some(token_to_remove) = removed {
118        user_api_tokens.remove(&token_to_remove);
119        StatusCode::OK
120    } else {
121        StatusCode::NOT_FOUND
122    }
123}
124
125// https://docs.frontegg.com/reference/tenantapitokensv1controller_gettenantsapitokens
126pub async fn handle_list_tenant_api_tokens(
127    State(context): State<Arc<Context>>,
128    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
129) -> Result<Json<Vec<TenantApiTokenResponse>>, StatusCode> {
130    let _claims = match decode_access_token(&context, authorization.token()) {
131        Ok(TokenData { claims, .. }) => claims,
132        Err(_) => return Err(StatusCode::UNAUTHORIZED),
133    };
134
135    let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
136    let tokens: Vec<TenantApiTokenResponse> = tenant_api_tokens
137        .iter()
138        .map(|(api_token, config)| TenantApiTokenResponse {
139            client_id: api_token.client_id,
140            description: api_token.description.clone().unwrap_or_default(),
141            secret: api_token.secret,
142            created_by_user_id: config.created_by_user_id,
143            metadata: config.metadata.clone(),
144            created_at: config.created_at,
145            role_ids: config.roles.clone(),
146        })
147        .collect();
148
149    Ok(Json(tokens))
150}
151
152// https://docs.frontegg.com/reference/tenantapitokensv1controller_createtenantapitoken
153pub async fn handle_create_tenant_api_token(
154    State(context): State<Arc<Context>>,
155    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
156    Json(request): Json<CreateTenantApiTokenRequest>,
157) -> Result<(StatusCode, Json<TenantApiTokenResponse>), StatusCode> {
158    let claims = match decode_access_token(&context, authorization.token()) {
159        Ok(TokenData { claims, .. }) => claims,
160        Err(_) => return Err(StatusCode::UNAUTHORIZED),
161    };
162
163    let new_token = ApiToken {
164        client_id: Uuid::new_v4(),
165        secret: Uuid::new_v4(),
166        description: Some(request.description.clone()),
167        created_at: Utc::now(),
168    };
169
170    let config = TenantApiTokenConfig {
171        tenant_id: claims.tenant_id,
172        metadata: request.metadata.clone(),
173        roles: request.role_ids.clone(),
174        description: Some(request.description.clone()),
175        created_by_user_id: claims.sub,
176        created_at: new_token.created_at,
177    };
178
179    let mut tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
180    tenant_api_tokens.insert(new_token.clone(), config.clone());
181
182    let _access_token = generate_access_token(
183        &context,
184        ClaimTokenType::TenantApiToken,
185        new_token.client_id,
186        None,
187        None,
188        config.tenant_id,
189        config.roles.clone(),
190        config.metadata.clone(),
191    );
192
193    let response = TenantApiTokenResponse {
194        client_id: new_token.client_id,
195        description: new_token.description.unwrap(),
196        secret: new_token.secret,
197        created_by_user_id: config.created_by_user_id,
198        metadata: config.metadata,
199        created_at: new_token.created_at,
200        role_ids: config.roles,
201    };
202
203    Ok((StatusCode::CREATED, Json(response)))
204}
205
206// https://docs.frontegg.com/reference/tenantapitokensv1controller_deletetenantapitoken
207pub async fn handle_delete_tenant_api_token(
208    State(context): State<Arc<Context>>,
209    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
210    Path(token_id): Path<Uuid>,
211) -> StatusCode {
212    let _claims = match decode_access_token(&context, authorization.token()) {
213        Ok(TokenData { claims, .. }) => claims,
214        Err(_) => return StatusCode::UNAUTHORIZED,
215    };
216
217    let mut tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
218
219    let token_to_remove = tenant_api_tokens
220        .keys()
221        .find(|token| token.client_id == token_id)
222        .cloned();
223
224    if let Some(token) = token_to_remove {
225        tenant_api_tokens.remove(&token);
226        StatusCode::OK
227    } else {
228        StatusCode::NOT_FOUND
229    }
230}
231
232// https://docs.frontegg.com/reference/userscontrollerv2_getuserbyid
233pub async fn handle_get_user(
234    State(context): State<Arc<Context>>,
235    Path(user_id): Path<Uuid>,
236) -> Result<Json<UserResponse>, StatusCode> {
237    let users = context.users.lock().unwrap();
238    let role_mapping: BTreeMap<String, UserRole> = context
239        .roles
240        .iter()
241        .map(|role| (role.id.clone(), role.clone()))
242        .collect();
243
244    match users.iter().find(|(_, user)| user.id == user_id) {
245        Some((_, user)) => {
246            let roles = get_user_roles(&user.roles, &role_mapping);
247
248            let user_response = UserResponse {
249                id: user.id,
250                email: user.email.clone(),
251                verified: user.verified.unwrap_or(true),
252                metadata: user.metadata.clone().unwrap_or_default(),
253                provider: user.auth_provider.clone().unwrap_or_default(),
254                roles,
255            };
256
257            Ok(Json(user_response))
258        }
259        None => Err(StatusCode::NOT_FOUND),
260    }
261}
262
263// https://docs.frontegg.com/reference/userscontrollerv2_createuser
264pub async fn handle_create_user(
265    State(context): State<Arc<Context>>,
266    Json(new_user): Json<UserCreate>,
267) -> Result<(StatusCode, Json<UserResponse>), StatusCode> {
268    let mut users = context.users.lock().unwrap();
269    let role_mapping: BTreeMap<String, UserRole> = context
270        .roles
271        .iter()
272        .map(|role| (role.id.clone(), role.clone()))
273        .collect();
274
275    if users.contains_key(&new_user.email) {
276        return Err(StatusCode::CONFLICT);
277    }
278
279    let default_tenant_id = Uuid::new_v4();
280    let user_id = Uuid::new_v4();
281
282    let role_ids = new_user.role_ids.as_deref().unwrap_or(&[]);
283    let mut role_names = Vec::new();
284
285    for role_id in role_ids {
286        match role_mapping.get(role_id) {
287            Some(role) => role_names.push(role.name.clone()),
288            None => return Err(StatusCode::BAD_REQUEST),
289        }
290    }
291
292    let user_config = UserConfig {
293        id: user_id,
294        email: new_user.email.clone(),
295        password: Uuid::new_v4().to_string(),
296        tenant_id: default_tenant_id,
297        initial_api_tokens: vec![],
298        roles: role_names.clone(),
299        auth_provider: None,
300        verified: Some(true),
301        metadata: None,
302    };
303
304    users.insert(new_user.email.clone(), user_config);
305
306    let user_roles = role_ids
307        .iter()
308        .map(|role_id| role_mapping.get(role_id).unwrap().clone())
309        .collect();
310
311    let user_response = UserResponse {
312        id: user_id,
313        email: new_user.email.clone(),
314        verified: true,
315        metadata: String::new(),
316        provider: String::new(),
317        roles: user_roles,
318    };
319
320    Ok((StatusCode::CREATED, Json(user_response)))
321}
322
323// https://docs.frontegg.com/reference/userscontrollerv1_removeuserfromtenant
324pub async fn handle_delete_user(
325    State(context): State<Arc<Context>>,
326    Path(user_id): Path<Uuid>,
327) -> StatusCode {
328    let mut users = context.users.lock().unwrap();
329    let initial_count = users.len();
330    users.retain(|_, user| user.id != user_id);
331
332    if users.len() < initial_count {
333        StatusCode::OK
334    } else {
335        StatusCode::NOT_FOUND
336    }
337}
338
339// https://docs.frontegg.com/reference/userscontrollerv3_getusers
340pub async fn handle_get_users_v3(
341    State(context): State<Arc<Context>>,
342    Query(query): Query<UsersV3Query>,
343) -> Result<Json<UsersV3Response>, (StatusCode, Json<ErrorResponse>)> {
344    let users = context.users.lock().unwrap();
345    let role_mapping: BTreeMap<String, UserRole> = context
346        .roles
347        .iter()
348        .map(|role| (role.id.clone(), role.clone()))
349        .collect();
350
351    let mut filtered_users: Vec<UserResponse> = users
352        .iter()
353        .filter(|(email, user)| {
354            query
355                .email
356                .as_ref()
357                .map_or(true, |q_email| *email == q_email)
358                && query.ids.as_ref().map_or(true, |ids| {
359                    ids.split(',').any(|id| id == user.id.to_string())
360                })
361                && query.tenant_id.as_ref().map_or(true, |q_tenant_id| {
362                    &user.tenant_id == q_tenant_id || query.include_sub_tenants.unwrap_or(false)
363                })
364        })
365        .map(|(_, user)| UserResponse {
366            id: user.id,
367            email: user.email.clone(),
368            verified: user.verified.unwrap_or(true),
369            metadata: user.metadata.clone().unwrap_or_default(),
370            provider: user.auth_provider.clone().unwrap_or_default(),
371            roles: get_user_roles(&user.roles, &role_mapping),
372        })
373        .collect();
374
375    // Sort users if sort_by is provided
376    if let Some(sort_by) = &query.sort_by {
377        let sort_by = SortBy::try_from(sort_by.as_str()).map_err(|_| {
378            (
379                StatusCode::BAD_REQUEST,
380                Json(ErrorResponse {
381                    errors: vec!["_sortBy must be a valid enum value".to_string()],
382                }),
383            )
384        })?;
385
386        let order = query
387            .order
388            .as_deref()
389            .map(Order::try_from)
390            .transpose()
391            .map_err(|_| {
392                (
393                    StatusCode::BAD_REQUEST,
394                    Json(ErrorResponse {
395                        errors: vec![
396                            "_order must be one of the following values: ASC, DESC".to_string(),
397                        ],
398                    }),
399                )
400            })?;
401
402        filtered_users.sort_by(|a, b| {
403            let cmp = match sort_by {
404                SortBy::Email => a.email.cmp(&b.email),
405                SortBy::Id => a.id.cmp(&b.id),
406            };
407            if order == Some(Order::DESC) {
408                cmp.reverse()
409            } else {
410                cmp
411            }
412        });
413    }
414
415    let total_items = filtered_users.len();
416
417    // Apply pagination
418    let offset = query.offset.unwrap_or(0);
419    let limit = query.limit.unwrap_or(total_items);
420    filtered_users = filtered_users
421        .into_iter()
422        .skip(offset)
423        .take(limit)
424        .collect();
425
426    Ok(Json(UsersV3Response {
427        items: filtered_users,
428        _metadata: UsersV3Metadata { total_items },
429    }))
430}
431
432pub async fn handle_update_user_roles(
433    State(context): State<Arc<Context>>,
434    Json(request): Json<UpdateUserRolesRequest>,
435) -> Result<Json<UserResponse>, StatusCode> {
436    let mut users = context.users.lock().unwrap();
437    let role_mapping: BTreeMap<String, UserRole> = context
438        .roles
439        .iter()
440        .map(|role| (role.id.clone(), role.clone()))
441        .collect();
442
443    if let Some(user) = users.get_mut(&request.email) {
444        user.roles.clone_from(&request.role_ids);
445
446        let updated_roles = get_user_roles(&user.roles, &role_mapping);
447
448        let user_response = UserResponse {
449            id: user.id,
450            email: user.email.clone(),
451            verified: user.verified.unwrap_or(true),
452            metadata: user.metadata.clone().unwrap_or_default(),
453            provider: user.auth_provider.clone().unwrap_or_default(),
454            roles: updated_roles,
455        };
456
457        Ok(Json(user_response))
458    } else {
459        Err(StatusCode::NOT_FOUND)
460    }
461}
462
463// https://docs.frontegg.com/reference/permissionscontrollerv2_getallroles
464pub async fn handle_roles_request(State(context): State<Arc<Context>>) -> Json<UserRolesResponse> {
465    let roles = Arc::<Vec<UserRole>>::clone(&context.roles);
466
467    let response = UserRolesResponse {
468        items: roles.to_vec(),
469        _metadata: UserRolesMetadata {
470            total_items: roles.len(),
471            total_pages: 1,
472        },
473    };
474
475    Json(response)
476}
477
478// https://docs.frontegg.com/reference/groupscontrollerv1_adduserstogroup
479pub async fn handle_add_users_to_group(
480    State(context): State<Arc<Context>>,
481    Path(group_id): Path<String>,
482    Json(payload): Json<AddUsersToGroupParams>,
483) -> Result<StatusCode, StatusCode> {
484    let mut groups = context.groups.lock().unwrap();
485    if let Some(group) = groups.get_mut(&group_id) {
486        for user_id in payload.user_ids {
487            if !group.users.iter().any(|u| u.id == user_id) {
488                group.users.push(User {
489                    id: user_id,
490                    name: "".to_string(),
491                    email: "".to_string(),
492                });
493            }
494        }
495        Ok(StatusCode::CREATED)
496    } else {
497        Err(StatusCode::NOT_FOUND)
498    }
499}
500
501// https://docs.frontegg.com/reference/groupscontrollerv1_removeusersfromgroup
502pub async fn handle_remove_users_from_group(
503    State(context): State<Arc<Context>>,
504    Path(group_id): Path<String>,
505    Json(payload): Json<RemoveUsersFromGroupParams>,
506) -> StatusCode {
507    let mut groups = context.groups.lock().unwrap();
508
509    if let Some(group) = groups.get_mut(&group_id) {
510        group
511            .users
512            .retain(|user| !payload.user_ids.contains(&user.id));
513        StatusCode::OK
514    } else {
515        StatusCode::NOT_FOUND
516    }
517}
518
519pub async fn internal_handle_get_user_password(
520    State(context): State<Arc<Context>>,
521    Json(request): Json<GetUserPasswordRequest>,
522) -> Result<Json<GetUserPasswordResponse>, StatusCode> {
523    let users = context.users.lock().unwrap();
524
525    if let Some(user) = users.get(&request.email) {
526        Ok(Json(GetUserPasswordResponse {
527            email: user.email.clone(),
528            password: user.password.clone(),
529        }))
530    } else {
531        Err(StatusCode::NOT_FOUND)
532    }
533}