Skip to main content

mz_frontegg_mock/handlers/
user.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::models::*;
11use crate::server::Context;
12use crate::utils::get_user_roles;
13use crate::utils::{decode_access_token, generate_access_token};
14use axum::{
15    Json,
16    extract::{Path, Query, State},
17    http::StatusCode,
18};
19use axum_extra::TypedHeader;
20use axum_extra::headers::Authorization;
21use axum_extra::headers::authorization::Bearer;
22use chrono::Utc;
23use jsonwebtoken::TokenData;
24use mz_frontegg_auth::{ClaimTokenType, Claims};
25use std::collections::BTreeMap;
26use std::sync::Arc;
27use uuid::Uuid;
28
29// https://docs.frontegg.com/reference/userscontrollerv2_getuserprofile
30pub async fn handle_get_user_profile(
31    State(context): State<Arc<Context>>,
32    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
33) -> Result<Json<UserProfileResponse>, StatusCode> {
34    let claims: Claims = match decode_access_token(&context, authorization.token()) {
35        Ok(TokenData { claims, .. }) => claims,
36        Err(_) => return Err(StatusCode::UNAUTHORIZED),
37    };
38    Ok(Json(UserProfileResponse {
39        tenant_id: claims.tenant_id,
40    }))
41}
42
43// https://docs.frontegg.com/reference/userapitokensv1controller_createtenantapitoken
44pub async fn handle_post_user_api_token(
45    State(context): State<Arc<Context>>,
46    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
47    Json(request): Json<UserApiTokenRequest>,
48) -> Result<(StatusCode, Json<UserApiTokenResponse>), StatusCode> {
49    let claims: Claims = match decode_access_token(&context, authorization.token()) {
50        Ok(TokenData { claims, .. }) => claims,
51        Err(_) => return Err(StatusCode::UNAUTHORIZED),
52    };
53    let mut tokens = context.user_api_tokens.lock().unwrap();
54    let new_token = ApiToken {
55        client_id: Uuid::new_v4(),
56        secret: Uuid::new_v4(),
57        description: request.description.clone(),
58        created_at: Utc::now(),
59    };
60    tokens.insert(new_token.clone(), claims.email.unwrap());
61
62    let response = UserApiTokenResponse {
63        client_id: new_token.client_id.to_string(),
64        description: new_token.description.unwrap(),
65        created_at: new_token.created_at,
66        secret: new_token.secret.to_string(),
67    };
68
69    Ok((StatusCode::CREATED, Json(response)))
70}
71
72// https://docs.frontegg.com/reference/userapitokensv1controller_getapitokens
73pub async fn handle_list_user_api_tokens(
74    State(context): State<Arc<Context>>,
75    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
76) -> Result<Json<Vec<UserApiTokenResponse>>, StatusCode> {
77    let claims = match decode_access_token(&context, authorization.token()) {
78        Ok(TokenData { claims, .. }) => claims,
79        Err(_) => return Err(StatusCode::UNAUTHORIZED),
80    };
81
82    let user_api_tokens = context.user_api_tokens.lock().unwrap();
83    let tokens: Vec<UserApiTokenResponse> = user_api_tokens
84        .iter()
85        .filter(|(_, email)| *email == claims.email.as_ref().unwrap())
86        .map(|(token, _)| UserApiTokenResponse {
87            client_id: token.client_id.to_string(),
88            description: token.description.clone().unwrap_or_default(),
89            created_at: token.created_at,
90            secret: token.secret.to_string(),
91        })
92        .collect();
93
94    Ok(Json(tokens))
95}
96
97// https://docs.frontegg.com/reference/userapitokensv1controller_deleteapitoken
98pub async fn handle_delete_user_api_token(
99    State(context): State<Arc<Context>>,
100    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
101    Path(token_id): Path<Uuid>,
102) -> StatusCode {
103    let claims = match decode_access_token(&context, authorization.token()) {
104        Ok(TokenData { claims, .. }) => claims,
105        Err(_) => return StatusCode::UNAUTHORIZED,
106    };
107
108    let mut user_api_tokens = context.user_api_tokens.lock().unwrap();
109
110    let removed = user_api_tokens
111        .iter()
112        .find(|(token, email)| {
113            token.client_id == token_id && *email == claims.email.as_ref().unwrap()
114        })
115        .map(|(token, _)| token.clone());
116
117    if let Some(token_to_remove) = removed {
118        user_api_tokens.remove(&token_to_remove);
119        StatusCode::OK
120    } else {
121        StatusCode::NOT_FOUND
122    }
123}
124
125// https://docs.frontegg.com/reference/tenantapitokensv1controller_gettenantsapitokens
126pub async fn handle_list_tenant_api_tokens(
127    State(context): State<Arc<Context>>,
128    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
129) -> Result<Json<Vec<TenantApiTokenResponse>>, StatusCode> {
130    let _claims = match decode_access_token(&context, authorization.token()) {
131        Ok(TokenData { claims, .. }) => claims,
132        Err(_) => return Err(StatusCode::UNAUTHORIZED),
133    };
134
135    let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
136    let tokens: Vec<TenantApiTokenResponse> = tenant_api_tokens
137        .iter()
138        .map(|(api_token, config)| TenantApiTokenResponse {
139            client_id: api_token.client_id,
140            description: api_token.description.clone().unwrap_or_default(),
141            secret: api_token.secret,
142            created_by_user_id: config.created_by_user_id,
143            metadata: config.metadata.clone(),
144            created_at: config.created_at,
145            role_ids: config.roles.clone(),
146        })
147        .collect();
148
149    Ok(Json(tokens))
150}
151
152// https://docs.frontegg.com/reference/tenantapitokensv1controller_createtenantapitoken
153pub async fn handle_create_tenant_api_token(
154    State(context): State<Arc<Context>>,
155    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
156    Json(request): Json<CreateTenantApiTokenRequest>,
157) -> Result<(StatusCode, Json<TenantApiTokenResponse>), StatusCode> {
158    let claims = match decode_access_token(&context, authorization.token()) {
159        Ok(TokenData { claims, .. }) => claims,
160        Err(_) => return Err(StatusCode::UNAUTHORIZED),
161    };
162
163    let new_token = ApiToken {
164        client_id: Uuid::new_v4(),
165        secret: Uuid::new_v4(),
166        description: Some(request.description.clone()),
167        created_at: Utc::now(),
168    };
169
170    let config = TenantApiTokenConfig {
171        tenant_id: claims.tenant_id,
172        metadata: request.metadata.clone(),
173        roles: request.role_ids.clone(),
174        description: Some(request.description.clone()),
175        created_by_user_id: claims.sub,
176        created_at: new_token.created_at,
177    };
178
179    let mut tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
180    tenant_api_tokens.insert(new_token.clone(), config.clone());
181
182    let _access_token = generate_access_token(
183        &context,
184        ClaimTokenType::TenantApiToken,
185        new_token.client_id,
186        None,
187        None,
188        config.tenant_id,
189        config.roles.clone(),
190        Vec::new(),
191        config.metadata.clone(),
192    );
193
194    let response = TenantApiTokenResponse {
195        client_id: new_token.client_id,
196        description: new_token.description.unwrap(),
197        secret: new_token.secret,
198        created_by_user_id: config.created_by_user_id,
199        metadata: config.metadata,
200        created_at: new_token.created_at,
201        role_ids: config.roles,
202    };
203
204    Ok((StatusCode::CREATED, Json(response)))
205}
206
207// https://docs.frontegg.com/reference/tenantapitokensv1controller_deletetenantapitoken
208pub async fn handle_delete_tenant_api_token(
209    State(context): State<Arc<Context>>,
210    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
211    Path(token_id): Path<Uuid>,
212) -> StatusCode {
213    let _claims = match decode_access_token(&context, authorization.token()) {
214        Ok(TokenData { claims, .. }) => claims,
215        Err(_) => return StatusCode::UNAUTHORIZED,
216    };
217
218    let mut tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
219
220    let token_to_remove = tenant_api_tokens
221        .keys()
222        .find(|token| token.client_id == token_id)
223        .cloned();
224
225    if let Some(token) = token_to_remove {
226        tenant_api_tokens.remove(&token);
227        StatusCode::OK
228    } else {
229        StatusCode::NOT_FOUND
230    }
231}
232
233// https://docs.frontegg.com/reference/userscontrollerv2_getuserbyid
234pub async fn handle_get_user(
235    State(context): State<Arc<Context>>,
236    Path(user_id): Path<Uuid>,
237) -> Result<Json<UserResponse>, StatusCode> {
238    let users = context.users.lock().unwrap();
239    let role_mapping: BTreeMap<String, UserRole> = context
240        .roles
241        .iter()
242        .map(|role| (role.id.clone(), role.clone()))
243        .collect();
244
245    match users.iter().find(|(_, user)| user.id == user_id) {
246        Some((_, user)) => {
247            let roles = get_user_roles(&user.roles, &role_mapping);
248
249            let user_response = UserResponse {
250                id: user.id,
251                email: user.email.clone(),
252                verified: user.verified.unwrap_or(true),
253                metadata: user.metadata.clone().unwrap_or_default(),
254                provider: user.auth_provider.clone().unwrap_or_default(),
255                roles,
256            };
257
258            Ok(Json(user_response))
259        }
260        None => Err(StatusCode::NOT_FOUND),
261    }
262}
263
264// https://docs.frontegg.com/reference/userscontrollerv2_createuser
265pub async fn handle_create_user(
266    State(context): State<Arc<Context>>,
267    Json(new_user): Json<UserCreate>,
268) -> Result<(StatusCode, Json<UserResponse>), StatusCode> {
269    let mut users = context.users.lock().unwrap();
270    let role_mapping: BTreeMap<String, UserRole> = context
271        .roles
272        .iter()
273        .map(|role| (role.id.clone(), role.clone()))
274        .collect();
275
276    if users.contains_key(&new_user.email) {
277        return Err(StatusCode::CONFLICT);
278    }
279
280    let default_tenant_id = Uuid::new_v4();
281    let user_id = Uuid::new_v4();
282
283    let role_ids = new_user.role_ids.as_deref().unwrap_or(&[]);
284    let mut role_names = Vec::new();
285
286    for role_id in role_ids {
287        match role_mapping.get(role_id) {
288            Some(role) => role_names.push(role.name.clone()),
289            None => return Err(StatusCode::BAD_REQUEST),
290        }
291    }
292
293    let user_config = UserConfig {
294        id: user_id,
295        email: new_user.email.clone(),
296        password: Uuid::new_v4().to_string(),
297        tenant_id: default_tenant_id,
298        initial_api_tokens: vec![],
299        roles: role_names.clone(),
300        auth_provider: None,
301        verified: Some(true),
302        metadata: None,
303    };
304
305    users.insert(new_user.email.clone(), user_config);
306
307    let user_roles = role_ids
308        .iter()
309        .map(|role_id| role_mapping.get(role_id).unwrap().clone())
310        .collect();
311
312    let user_response = UserResponse {
313        id: user_id,
314        email: new_user.email.clone(),
315        verified: true,
316        metadata: String::new(),
317        provider: String::new(),
318        roles: user_roles,
319    };
320
321    Ok((StatusCode::CREATED, Json(user_response)))
322}
323
324// https://docs.frontegg.com/reference/userscontrollerv1_removeuserfromtenant
325pub async fn handle_delete_user(
326    State(context): State<Arc<Context>>,
327    Path(user_id): Path<Uuid>,
328) -> StatusCode {
329    let mut users = context.users.lock().unwrap();
330    let initial_count = users.len();
331    users.retain(|_, user| user.id != user_id);
332
333    if users.len() < initial_count {
334        StatusCode::OK
335    } else {
336        StatusCode::NOT_FOUND
337    }
338}
339
340// https://docs.frontegg.com/reference/userscontrollerv3_getusers
341pub async fn handle_get_users_v3(
342    State(context): State<Arc<Context>>,
343    Query(query): Query<UsersV3Query>,
344) -> Result<Json<UsersV3Response>, (StatusCode, Json<ErrorResponse>)> {
345    let users = context.users.lock().unwrap();
346    let role_mapping: BTreeMap<String, UserRole> = context
347        .roles
348        .iter()
349        .map(|role| (role.id.clone(), role.clone()))
350        .collect();
351
352    let mut filtered_users: Vec<UserResponse> = users
353        .iter()
354        .filter(|(email, user)| {
355            query
356                .email
357                .as_ref()
358                .map_or(true, |q_email| *email == q_email)
359                && query.ids.as_ref().map_or(true, |ids| {
360                    ids.split(',').any(|id| id == user.id.to_string())
361                })
362                && query.tenant_id.as_ref().map_or(true, |q_tenant_id| {
363                    &user.tenant_id == q_tenant_id || query.include_sub_tenants.unwrap_or(false)
364                })
365        })
366        .map(|(_, user)| UserResponse {
367            id: user.id,
368            email: user.email.clone(),
369            verified: user.verified.unwrap_or(true),
370            metadata: user.metadata.clone().unwrap_or_default(),
371            provider: user.auth_provider.clone().unwrap_or_default(),
372            roles: get_user_roles(&user.roles, &role_mapping),
373        })
374        .collect();
375
376    // Sort users if sort_by is provided
377    if let Some(sort_by) = &query.sort_by {
378        let sort_by = SortBy::try_from(sort_by.as_str()).map_err(|_| {
379            (
380                StatusCode::BAD_REQUEST,
381                Json(ErrorResponse {
382                    errors: vec!["_sortBy must be a valid enum value".to_string()],
383                }),
384            )
385        })?;
386
387        let order = query
388            .order
389            .as_deref()
390            .map(Order::try_from)
391            .transpose()
392            .map_err(|_| {
393                (
394                    StatusCode::BAD_REQUEST,
395                    Json(ErrorResponse {
396                        errors: vec![
397                            "_order must be one of the following values: ASC, DESC".to_string(),
398                        ],
399                    }),
400                )
401            })?;
402
403        filtered_users.sort_by(|a, b| {
404            let cmp = match sort_by {
405                SortBy::Email => a.email.cmp(&b.email),
406                SortBy::Id => a.id.cmp(&b.id),
407            };
408            if order == Some(Order::DESC) {
409                cmp.reverse()
410            } else {
411                cmp
412            }
413        });
414    }
415
416    let total_items = filtered_users.len();
417
418    // Apply pagination
419    let offset = query.offset.unwrap_or(0);
420    let limit = query.limit.unwrap_or(total_items);
421    filtered_users = filtered_users
422        .into_iter()
423        .skip(offset)
424        .take(limit)
425        .collect();
426
427    Ok(Json(UsersV3Response {
428        items: filtered_users,
429        _metadata: UsersV3Metadata { total_items },
430    }))
431}
432
433pub async fn handle_update_user_roles(
434    State(context): State<Arc<Context>>,
435    Json(request): Json<UpdateUserRolesRequest>,
436) -> Result<Json<UserResponse>, StatusCode> {
437    let mut users = context.users.lock().unwrap();
438    let role_mapping: BTreeMap<String, UserRole> = context
439        .roles
440        .iter()
441        .map(|role| (role.id.clone(), role.clone()))
442        .collect();
443
444    if let Some(user) = users.get_mut(&request.email) {
445        user.roles.clone_from(&request.role_ids);
446
447        let updated_roles = get_user_roles(&user.roles, &role_mapping);
448
449        let user_response = UserResponse {
450            id: user.id,
451            email: user.email.clone(),
452            verified: user.verified.unwrap_or(true),
453            metadata: user.metadata.clone().unwrap_or_default(),
454            provider: user.auth_provider.clone().unwrap_or_default(),
455            roles: updated_roles,
456        };
457
458        Ok(Json(user_response))
459    } else {
460        Err(StatusCode::NOT_FOUND)
461    }
462}
463
464// https://docs.frontegg.com/reference/permissionscontrollerv2_getallroles
465pub async fn handle_roles_request(State(context): State<Arc<Context>>) -> Json<UserRolesResponse> {
466    let roles = Arc::<Vec<UserRole>>::clone(&context.roles);
467
468    let response = UserRolesResponse {
469        items: roles.to_vec(),
470        _metadata: UserRolesMetadata {
471            total_items: roles.len(),
472            total_pages: 1,
473        },
474    };
475
476    Json(response)
477}
478
479// https://docs.frontegg.com/reference/groupscontrollerv1_adduserstogroup
480pub async fn handle_add_users_to_group(
481    State(context): State<Arc<Context>>,
482    Path(group_id): Path<String>,
483    Json(payload): Json<AddUsersToGroupParams>,
484) -> Result<StatusCode, StatusCode> {
485    let mut groups = context.groups.lock().unwrap();
486    if let Some(group) = groups.get_mut(&group_id) {
487        for user_id in payload.user_ids {
488            if !group.users.iter().any(|u| u.id == user_id) {
489                group.users.push(User {
490                    id: user_id,
491                    name: "".to_string(),
492                    email: "".to_string(),
493                });
494            }
495        }
496        Ok(StatusCode::CREATED)
497    } else {
498        Err(StatusCode::NOT_FOUND)
499    }
500}
501
502// https://docs.frontegg.com/reference/groupscontrollerv1_removeusersfromgroup
503pub async fn handle_remove_users_from_group(
504    State(context): State<Arc<Context>>,
505    Path(group_id): Path<String>,
506    Json(payload): Json<RemoveUsersFromGroupParams>,
507) -> StatusCode {
508    let mut groups = context.groups.lock().unwrap();
509
510    if let Some(group) = groups.get_mut(&group_id) {
511        group
512            .users
513            .retain(|user| !payload.user_ids.contains(&user.id));
514        StatusCode::OK
515    } else {
516        StatusCode::NOT_FOUND
517    }
518}
519
520pub async fn internal_handle_get_user_password(
521    State(context): State<Arc<Context>>,
522    Json(request): Json<GetUserPasswordRequest>,
523) -> Result<Json<GetUserPasswordResponse>, StatusCode> {
524    let users = context.users.lock().unwrap();
525
526    if let Some(user) = users.get(&request.email) {
527        Ok(Json(GetUserPasswordResponse {
528            email: user.email.clone(),
529            password: user.password.clone(),
530        }))
531    } else {
532        Err(StatusCode::NOT_FOUND)
533    }
534}