1use crate::models::*;
11use crate::server::Context;
12use crate::utils::get_user_roles;
13use crate::utils::{decode_access_token, generate_access_token};
14use axum::{
15 Json,
16 extract::{Path, Query, State},
17 http::StatusCode,
18};
19use axum_extra::TypedHeader;
20use axum_extra::headers::Authorization;
21use axum_extra::headers::authorization::Bearer;
22use chrono::Utc;
23use jsonwebtoken::TokenData;
24use mz_frontegg_auth::{ClaimTokenType, Claims};
25use std::collections::BTreeMap;
26use std::sync::Arc;
27use uuid::Uuid;
28
29pub async fn handle_get_user_profile(
31 State(context): State<Arc<Context>>,
32 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
33) -> Result<Json<UserProfileResponse>, StatusCode> {
34 let claims: Claims = match decode_access_token(&context, authorization.token()) {
35 Ok(TokenData { claims, .. }) => claims,
36 Err(_) => return Err(StatusCode::UNAUTHORIZED),
37 };
38 Ok(Json(UserProfileResponse {
39 tenant_id: claims.tenant_id,
40 }))
41}
42
43pub async fn handle_post_user_api_token(
45 State(context): State<Arc<Context>>,
46 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
47 Json(request): Json<UserApiTokenRequest>,
48) -> Result<(StatusCode, Json<UserApiTokenResponse>), StatusCode> {
49 let claims: Claims = match decode_access_token(&context, authorization.token()) {
50 Ok(TokenData { claims, .. }) => claims,
51 Err(_) => return Err(StatusCode::UNAUTHORIZED),
52 };
53 let mut tokens = context.user_api_tokens.lock().unwrap();
54 let new_token = ApiToken {
55 client_id: Uuid::new_v4(),
56 secret: Uuid::new_v4(),
57 description: request.description.clone(),
58 created_at: Utc::now(),
59 };
60 tokens.insert(new_token.clone(), claims.email.unwrap());
61
62 let response = UserApiTokenResponse {
63 client_id: new_token.client_id.to_string(),
64 description: new_token.description.unwrap(),
65 created_at: new_token.created_at,
66 secret: new_token.secret.to_string(),
67 };
68
69 Ok((StatusCode::CREATED, Json(response)))
70}
71
72pub async fn handle_list_user_api_tokens(
74 State(context): State<Arc<Context>>,
75 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
76) -> Result<Json<Vec<UserApiTokenResponse>>, StatusCode> {
77 let claims = match decode_access_token(&context, authorization.token()) {
78 Ok(TokenData { claims, .. }) => claims,
79 Err(_) => return Err(StatusCode::UNAUTHORIZED),
80 };
81
82 let user_api_tokens = context.user_api_tokens.lock().unwrap();
83 let tokens: Vec<UserApiTokenResponse> = user_api_tokens
84 .iter()
85 .filter(|(_, email)| *email == claims.email.as_ref().unwrap())
86 .map(|(token, _)| UserApiTokenResponse {
87 client_id: token.client_id.to_string(),
88 description: token.description.clone().unwrap_or_default(),
89 created_at: token.created_at,
90 secret: token.secret.to_string(),
91 })
92 .collect();
93
94 Ok(Json(tokens))
95}
96
97pub async fn handle_delete_user_api_token(
99 State(context): State<Arc<Context>>,
100 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
101 Path(token_id): Path<Uuid>,
102) -> StatusCode {
103 let claims = match decode_access_token(&context, authorization.token()) {
104 Ok(TokenData { claims, .. }) => claims,
105 Err(_) => return StatusCode::UNAUTHORIZED,
106 };
107
108 let mut user_api_tokens = context.user_api_tokens.lock().unwrap();
109
110 let removed = user_api_tokens
111 .iter()
112 .find(|(token, email)| {
113 token.client_id == token_id && *email == claims.email.as_ref().unwrap()
114 })
115 .map(|(token, _)| token.clone());
116
117 if let Some(token_to_remove) = removed {
118 user_api_tokens.remove(&token_to_remove);
119 StatusCode::OK
120 } else {
121 StatusCode::NOT_FOUND
122 }
123}
124
125pub async fn handle_list_tenant_api_tokens(
127 State(context): State<Arc<Context>>,
128 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
129) -> Result<Json<Vec<TenantApiTokenResponse>>, StatusCode> {
130 let _claims = match decode_access_token(&context, authorization.token()) {
131 Ok(TokenData { claims, .. }) => claims,
132 Err(_) => return Err(StatusCode::UNAUTHORIZED),
133 };
134
135 let tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
136 let tokens: Vec<TenantApiTokenResponse> = tenant_api_tokens
137 .iter()
138 .map(|(api_token, config)| TenantApiTokenResponse {
139 client_id: api_token.client_id,
140 description: api_token.description.clone().unwrap_or_default(),
141 secret: api_token.secret,
142 created_by_user_id: config.created_by_user_id,
143 metadata: config.metadata.clone(),
144 created_at: config.created_at,
145 role_ids: config.roles.clone(),
146 })
147 .collect();
148
149 Ok(Json(tokens))
150}
151
152pub async fn handle_create_tenant_api_token(
154 State(context): State<Arc<Context>>,
155 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
156 Json(request): Json<CreateTenantApiTokenRequest>,
157) -> Result<(StatusCode, Json<TenantApiTokenResponse>), StatusCode> {
158 let claims = match decode_access_token(&context, authorization.token()) {
159 Ok(TokenData { claims, .. }) => claims,
160 Err(_) => return Err(StatusCode::UNAUTHORIZED),
161 };
162
163 let new_token = ApiToken {
164 client_id: Uuid::new_v4(),
165 secret: Uuid::new_v4(),
166 description: Some(request.description.clone()),
167 created_at: Utc::now(),
168 };
169
170 let config = TenantApiTokenConfig {
171 tenant_id: claims.tenant_id,
172 metadata: request.metadata.clone(),
173 roles: request.role_ids.clone(),
174 description: Some(request.description.clone()),
175 created_by_user_id: claims.sub,
176 created_at: new_token.created_at,
177 };
178
179 let mut tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
180 tenant_api_tokens.insert(new_token.clone(), config.clone());
181
182 let _access_token = generate_access_token(
183 &context,
184 ClaimTokenType::TenantApiToken,
185 new_token.client_id,
186 None,
187 None,
188 config.tenant_id,
189 config.roles.clone(),
190 config.metadata.clone(),
191 );
192
193 let response = TenantApiTokenResponse {
194 client_id: new_token.client_id,
195 description: new_token.description.unwrap(),
196 secret: new_token.secret,
197 created_by_user_id: config.created_by_user_id,
198 metadata: config.metadata,
199 created_at: new_token.created_at,
200 role_ids: config.roles,
201 };
202
203 Ok((StatusCode::CREATED, Json(response)))
204}
205
206pub async fn handle_delete_tenant_api_token(
208 State(context): State<Arc<Context>>,
209 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
210 Path(token_id): Path<Uuid>,
211) -> StatusCode {
212 let _claims = match decode_access_token(&context, authorization.token()) {
213 Ok(TokenData { claims, .. }) => claims,
214 Err(_) => return StatusCode::UNAUTHORIZED,
215 };
216
217 let mut tenant_api_tokens = context.tenant_api_tokens.lock().unwrap();
218
219 let token_to_remove = tenant_api_tokens
220 .keys()
221 .find(|token| token.client_id == token_id)
222 .cloned();
223
224 if let Some(token) = token_to_remove {
225 tenant_api_tokens.remove(&token);
226 StatusCode::OK
227 } else {
228 StatusCode::NOT_FOUND
229 }
230}
231
232pub async fn handle_get_user(
234 State(context): State<Arc<Context>>,
235 Path(user_id): Path<Uuid>,
236) -> Result<Json<UserResponse>, StatusCode> {
237 let users = context.users.lock().unwrap();
238 let role_mapping: BTreeMap<String, UserRole> = context
239 .roles
240 .iter()
241 .map(|role| (role.id.clone(), role.clone()))
242 .collect();
243
244 match users.iter().find(|(_, user)| user.id == user_id) {
245 Some((_, user)) => {
246 let roles = get_user_roles(&user.roles, &role_mapping);
247
248 let user_response = UserResponse {
249 id: user.id,
250 email: user.email.clone(),
251 verified: user.verified.unwrap_or(true),
252 metadata: user.metadata.clone().unwrap_or_default(),
253 provider: user.auth_provider.clone().unwrap_or_default(),
254 roles,
255 };
256
257 Ok(Json(user_response))
258 }
259 None => Err(StatusCode::NOT_FOUND),
260 }
261}
262
263pub async fn handle_create_user(
265 State(context): State<Arc<Context>>,
266 Json(new_user): Json<UserCreate>,
267) -> Result<(StatusCode, Json<UserResponse>), StatusCode> {
268 let mut users = context.users.lock().unwrap();
269 let role_mapping: BTreeMap<String, UserRole> = context
270 .roles
271 .iter()
272 .map(|role| (role.id.clone(), role.clone()))
273 .collect();
274
275 if users.contains_key(&new_user.email) {
276 return Err(StatusCode::CONFLICT);
277 }
278
279 let default_tenant_id = Uuid::new_v4();
280 let user_id = Uuid::new_v4();
281
282 let role_ids = new_user.role_ids.as_deref().unwrap_or(&[]);
283 let mut role_names = Vec::new();
284
285 for role_id in role_ids {
286 match role_mapping.get(role_id) {
287 Some(role) => role_names.push(role.name.clone()),
288 None => return Err(StatusCode::BAD_REQUEST),
289 }
290 }
291
292 let user_config = UserConfig {
293 id: user_id,
294 email: new_user.email.clone(),
295 password: Uuid::new_v4().to_string(),
296 tenant_id: default_tenant_id,
297 initial_api_tokens: vec![],
298 roles: role_names.clone(),
299 auth_provider: None,
300 verified: Some(true),
301 metadata: None,
302 };
303
304 users.insert(new_user.email.clone(), user_config);
305
306 let user_roles = role_ids
307 .iter()
308 .map(|role_id| role_mapping.get(role_id).unwrap().clone())
309 .collect();
310
311 let user_response = UserResponse {
312 id: user_id,
313 email: new_user.email.clone(),
314 verified: true,
315 metadata: String::new(),
316 provider: String::new(),
317 roles: user_roles,
318 };
319
320 Ok((StatusCode::CREATED, Json(user_response)))
321}
322
323pub async fn handle_delete_user(
325 State(context): State<Arc<Context>>,
326 Path(user_id): Path<Uuid>,
327) -> StatusCode {
328 let mut users = context.users.lock().unwrap();
329 let initial_count = users.len();
330 users.retain(|_, user| user.id != user_id);
331
332 if users.len() < initial_count {
333 StatusCode::OK
334 } else {
335 StatusCode::NOT_FOUND
336 }
337}
338
339pub async fn handle_get_users_v3(
341 State(context): State<Arc<Context>>,
342 Query(query): Query<UsersV3Query>,
343) -> Result<Json<UsersV3Response>, (StatusCode, Json<ErrorResponse>)> {
344 let users = context.users.lock().unwrap();
345 let role_mapping: BTreeMap<String, UserRole> = context
346 .roles
347 .iter()
348 .map(|role| (role.id.clone(), role.clone()))
349 .collect();
350
351 let mut filtered_users: Vec<UserResponse> = users
352 .iter()
353 .filter(|(email, user)| {
354 query
355 .email
356 .as_ref()
357 .map_or(true, |q_email| *email == q_email)
358 && query.ids.as_ref().map_or(true, |ids| {
359 ids.split(',').any(|id| id == user.id.to_string())
360 })
361 && query.tenant_id.as_ref().map_or(true, |q_tenant_id| {
362 &user.tenant_id == q_tenant_id || query.include_sub_tenants.unwrap_or(false)
363 })
364 })
365 .map(|(_, user)| UserResponse {
366 id: user.id,
367 email: user.email.clone(),
368 verified: user.verified.unwrap_or(true),
369 metadata: user.metadata.clone().unwrap_or_default(),
370 provider: user.auth_provider.clone().unwrap_or_default(),
371 roles: get_user_roles(&user.roles, &role_mapping),
372 })
373 .collect();
374
375 if let Some(sort_by) = &query.sort_by {
377 let sort_by = SortBy::try_from(sort_by.as_str()).map_err(|_| {
378 (
379 StatusCode::BAD_REQUEST,
380 Json(ErrorResponse {
381 errors: vec!["_sortBy must be a valid enum value".to_string()],
382 }),
383 )
384 })?;
385
386 let order = query
387 .order
388 .as_deref()
389 .map(Order::try_from)
390 .transpose()
391 .map_err(|_| {
392 (
393 StatusCode::BAD_REQUEST,
394 Json(ErrorResponse {
395 errors: vec![
396 "_order must be one of the following values: ASC, DESC".to_string(),
397 ],
398 }),
399 )
400 })?;
401
402 filtered_users.sort_by(|a, b| {
403 let cmp = match sort_by {
404 SortBy::Email => a.email.cmp(&b.email),
405 SortBy::Id => a.id.cmp(&b.id),
406 };
407 if order == Some(Order::DESC) {
408 cmp.reverse()
409 } else {
410 cmp
411 }
412 });
413 }
414
415 let total_items = filtered_users.len();
416
417 let offset = query.offset.unwrap_or(0);
419 let limit = query.limit.unwrap_or(total_items);
420 filtered_users = filtered_users
421 .into_iter()
422 .skip(offset)
423 .take(limit)
424 .collect();
425
426 Ok(Json(UsersV3Response {
427 items: filtered_users,
428 _metadata: UsersV3Metadata { total_items },
429 }))
430}
431
432pub async fn handle_update_user_roles(
433 State(context): State<Arc<Context>>,
434 Json(request): Json<UpdateUserRolesRequest>,
435) -> Result<Json<UserResponse>, StatusCode> {
436 let mut users = context.users.lock().unwrap();
437 let role_mapping: BTreeMap<String, UserRole> = context
438 .roles
439 .iter()
440 .map(|role| (role.id.clone(), role.clone()))
441 .collect();
442
443 if let Some(user) = users.get_mut(&request.email) {
444 user.roles.clone_from(&request.role_ids);
445
446 let updated_roles = get_user_roles(&user.roles, &role_mapping);
447
448 let user_response = UserResponse {
449 id: user.id,
450 email: user.email.clone(),
451 verified: user.verified.unwrap_or(true),
452 metadata: user.metadata.clone().unwrap_or_default(),
453 provider: user.auth_provider.clone().unwrap_or_default(),
454 roles: updated_roles,
455 };
456
457 Ok(Json(user_response))
458 } else {
459 Err(StatusCode::NOT_FOUND)
460 }
461}
462
463pub async fn handle_roles_request(State(context): State<Arc<Context>>) -> Json<UserRolesResponse> {
465 let roles = Arc::<Vec<UserRole>>::clone(&context.roles);
466
467 let response = UserRolesResponse {
468 items: roles.to_vec(),
469 _metadata: UserRolesMetadata {
470 total_items: roles.len(),
471 total_pages: 1,
472 },
473 };
474
475 Json(response)
476}
477
478pub async fn handle_add_users_to_group(
480 State(context): State<Arc<Context>>,
481 Path(group_id): Path<String>,
482 Json(payload): Json<AddUsersToGroupParams>,
483) -> Result<StatusCode, StatusCode> {
484 let mut groups = context.groups.lock().unwrap();
485 if let Some(group) = groups.get_mut(&group_id) {
486 for user_id in payload.user_ids {
487 if !group.users.iter().any(|u| u.id == user_id) {
488 group.users.push(User {
489 id: user_id,
490 name: "".to_string(),
491 email: "".to_string(),
492 });
493 }
494 }
495 Ok(StatusCode::CREATED)
496 } else {
497 Err(StatusCode::NOT_FOUND)
498 }
499}
500
501pub async fn handle_remove_users_from_group(
503 State(context): State<Arc<Context>>,
504 Path(group_id): Path<String>,
505 Json(payload): Json<RemoveUsersFromGroupParams>,
506) -> StatusCode {
507 let mut groups = context.groups.lock().unwrap();
508
509 if let Some(group) = groups.get_mut(&group_id) {
510 group
511 .users
512 .retain(|user| !payload.user_ids.contains(&user.id));
513 StatusCode::OK
514 } else {
515 StatusCode::NOT_FOUND
516 }
517}
518
519pub async fn internal_handle_get_user_password(
520 State(context): State<Arc<Context>>,
521 Json(request): Json<GetUserPasswordRequest>,
522) -> Result<Json<GetUserPasswordResponse>, StatusCode> {
523 let users = context.users.lock().unwrap();
524
525 if let Some(user) = users.get(&request.email) {
526 Ok(Json(GetUserPasswordResponse {
527 email: user.email.clone(),
528 password: user.password.clone(),
529 }))
530 } else {
531 Err(StatusCode::NOT_FOUND)
532 }
533}