Skip to main content

mz_frontegg_mock/
utils.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::models::{ApiToken, AuthUserRequest, UserRole};
11use crate::server::Context;
12use jsonwebtoken::TokenData;
13use mz_frontegg_auth::{ClaimMetadata, ClaimTokenType, Claims};
14use std::collections::BTreeMap;
15use uuid::Uuid;
16
17pub fn decode_access_token(
18    context: &Context,
19    token: &str,
20) -> Result<TokenData<Claims>, jsonwebtoken::errors::Error> {
21    jsonwebtoken::decode(
22        token,
23        &context.decoding_key,
24        &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256),
25    )
26}
27
28pub fn generate_access_token(
29    context: &Context,
30    token_type: ClaimTokenType,
31    sub: Uuid,
32    email: Option<String>,
33    user_id: Option<Uuid>,
34    tenant_id: Uuid,
35    roles: Vec<String>,
36    groups: Vec<String>,
37    metadata: Option<ClaimMetadata>,
38) -> String {
39    let mut permissions = Vec::new();
40    roles.iter().for_each(|role| {
41        if let Some(role_permissions) = context.role_permissions.get(role.as_str()) {
42            permissions.extend_from_slice(role_permissions);
43        }
44    });
45    permissions.sort();
46    permissions.dedup();
47    // Stamp groups under the `groups` JWT claim (the default `GROUP_CLAIM`
48    // dyncfg value) via the flattened `unknown_claims` bag so the mock can
49    // exercise the dyncfg-driven group extraction path. Always emit the claim
50    // (even as `[]`) so revocation-from-all-groups produces a present-but-empty
51    // claim rather than an omitted claim, matching the semantics group sync
52    // relies on (None = no sync; Some([]) = revoke).
53    let mut unknown_claims = BTreeMap::new();
54    unknown_claims.insert(
55        "groups".to_string(),
56        serde_json::Value::Array(groups.into_iter().map(serde_json::Value::String).collect()),
57    );
58    jsonwebtoken::encode(
59        &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
60        &Claims {
61            token_type,
62            exp: context.now.as_secs() + context.expires_in_secs,
63            email,
64            iss: context.issuer.clone(),
65            sub,
66            user_id,
67            tenant_id,
68            roles,
69            permissions,
70            metadata,
71            unknown_claims,
72        },
73        &context.encoding_key,
74    )
75    .unwrap()
76}
77
78pub fn generate_refresh_token(context: &Context, target: RefreshTokenTarget) -> String {
79    let refresh_token = Uuid::new_v4().to_string();
80    context
81        .refresh_tokens
82        .lock()
83        .unwrap()
84        .insert(refresh_token.clone(), target);
85    refresh_token
86}
87
88pub fn get_user_groups(context: &Context, user_id: &Uuid) -> Vec<String> {
89    let user_id_str = user_id.to_string();
90    let mut groups: Vec<String> = context
91        .groups
92        .lock()
93        .unwrap()
94        .values()
95        .filter(|g| g.users.iter().any(|u| u.id == user_id_str))
96        .map(|g| g.name.clone())
97        .collect();
98    groups.sort();
99    groups.dedup();
100    groups
101}
102
103pub fn get_user_roles(
104    role_ids_or_names: &[String],
105    role_mapping: &BTreeMap<String, UserRole>,
106) -> Vec<UserRole> {
107    role_ids_or_names
108        .iter()
109        .map(|id_or_name| {
110            role_mapping
111                .get(id_or_name)
112                .cloned()
113                .unwrap_or_else(|| UserRole {
114                    id: id_or_name.clone(),
115                    name: id_or_name.clone(),
116                    key: id_or_name.clone(),
117                })
118        })
119        .collect()
120}
121
122pub enum RefreshTokenTarget {
123    User(AuthUserRequest),
124    ApiToken(ApiToken),
125}