mz_frontegg_mock/
server.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::utils::RefreshTokenTarget;
11use axum::routing::{delete, get, post, put};
12use axum::{Router, middleware};
13use jsonwebtoken::{DecodingKey, EncodingKey};
14use mz_ore::now::NowFn;
15use mz_ore::retry::Retry;
16use mz_ore::task::JoinHandle;
17use std::borrow::Cow;
18use std::collections::BTreeMap;
19use std::future::IntoFuture;
20use std::net::{IpAddr, Ipv4Addr, SocketAddr};
21use std::sync::atomic::AtomicBool;
22use std::sync::{Arc, Mutex};
23use std::time::Duration;
24use tokio::net::TcpListener;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
26
27use crate::handlers::*;
28use crate::middleware::*;
29use crate::models::*;
30
31const AUTH_API_TOKEN_PATH: &str = "/identity/resources/auth/v1/api-token";
32const AUTH_USER_PATH: &str = "/identity/resources/auth/v1/user";
33const AUTH_API_TOKEN_REFRESH_PATH: &str = "/identity/resources/auth/v1/api-token/token/refresh";
34const GROUPS_PATH: &str = "/frontegg/identity/resources/groups/v1";
35const GROUP_PATH: &str = "/frontegg/identity/resources/groups/v1/{:id}";
36const GROUP_ROLES_PATH: &str = "/frontegg/identity/resources/groups/v1/{:id}/roles";
37const GROUP_USERS_PATH: &str = "/frontegg/identity/resources/groups/v1/{:id}/users";
38const GROUP_PATH_WITH_SLASH: &str = "/frontegg/identity/resources/groups/v1/{:id}/";
39const MEMBERS_PATH: &str = "/frontegg/team/resources/members/v1";
40const USERS_ME_PATH: &str = "/identity/resources/users/v2/me";
41const USERS_API_TOKENS_PATH: &str = "/identity/resources/users/api-tokens/v1";
42const USER_API_TOKENS_PATH: &str = "/identity/resources/users/api-tokens/v1/{:id}";
43const TENANT_API_TOKENS_PATH: &str = "/identity/resources/tenants/api-tokens/v1";
44const TENANT_API_TOKEN_PATH: &str = "/identity/resources/tenants/api-tokens/v1/{:id}";
45const USER_PATH: &str = "/identity/resources/users/v1/{:id}";
46const USER_CREATE_PATH: &str = "/identity/resources/users/v2";
47const USERS_V3_PATH: &str = "/identity/resources/users/v3";
48const ROLES_PATH: &str = "/identity/resources/roles/v2";
49const SCIM_CONFIGURATIONS_PATH: &str = "/frontegg/directory/resources/v1/configurations/scim2";
50const SCIM_CONFIGURATION_PATH: &str = "/frontegg/directory/resources/v1/configurations/scim2/{:id}";
51const SSO_CONFIGS_PATH: &str = "/frontegg/team/resources/sso/v1/configurations";
52const SSO_CONFIG_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/{:id}";
53const SSO_CONFIG_DOMAINS_PATH: &str =
54    "/frontegg/team/resources/sso/v1/configurations/{:id}/domains";
55const SSO_CONFIG_DOMAIN_PATH: &str =
56    "/frontegg/team/resources/sso/v1/configurations/{:id}/domains/{:domain_id}";
57const SSO_CONFIG_GROUPS_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/{:id}/groups";
58const SSO_CONFIG_GROUP_PATH: &str =
59    "/frontegg/team/resources/sso/v1/configurations/{:id}/groups/{:group_id}";
60const SSO_CONFIG_ROLES_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/{:id}/roles";
61
62// Internal endpoints for testing
63const INTERNAL_USER_PASSWORD_PATH: &str = "/api/internal-mock/user-password";
64
65pub struct FronteggMockServer {
66    pub base_url: String,
67    pub refreshes: Arc<Mutex<u64>>,
68    pub enable_auth: Arc<AtomicBool>,
69    pub auth_requests: Arc<Mutex<u64>>,
70    pub role_updates_tx: UnboundedSender<(String, Vec<String>)>,
71    pub handle: JoinHandle<Result<(), std::io::Error>>,
72}
73
74impl FronteggMockServer {
75    /// Starts a [`FronteggMockServer`], must be started from within a [`tokio::runtime::Runtime`].
76    pub async fn start(
77        addr: Option<&SocketAddr>,
78        issuer: String,
79        encoding_key: EncodingKey,
80        decoding_key: DecodingKey,
81        users: BTreeMap<String, UserConfig>,
82        tenant_api_tokens: BTreeMap<ApiToken, TenantApiTokenConfig>,
83        role_permissions: Option<BTreeMap<String, Vec<String>>>,
84        now: NowFn,
85        expires_in_secs: i64,
86        latency: Option<Duration>,
87        roles: Option<Vec<UserRole>>,
88    ) -> Result<FronteggMockServer, anyhow::Error> {
89        let (role_updates_tx, role_updates_rx) = unbounded_channel();
90
91        let enable_auth = Arc::new(AtomicBool::new(true));
92        let refreshes = Arc::new(Mutex::new(0u64));
93        let auth_requests = Arc::new(Mutex::new(0u64));
94
95        let user_api_tokens: BTreeMap<ApiToken, String> = users
96            .iter()
97            .map(|(email, user)| {
98                user.initial_api_tokens
99                    .iter()
100                    .map(|token| (token.clone(), email.clone()))
101            })
102            .flatten()
103            .collect();
104        let role_permissions = role_permissions.unwrap_or_else(|| {
105            BTreeMap::from([
106                (
107                    "MaterializePlatformAdmin".to_owned(),
108                    vec![
109                        "materialize.environment.write".to_owned(),
110                        "materialize.invoice.read".to_owned(),
111                    ],
112                ),
113                (
114                    "MaterializePlatform".to_owned(),
115                    vec!["materialize.environment.read".to_owned()],
116                ),
117            ])
118        });
119
120        // Provide default roles if None is provided
121        let roles = roles.unwrap_or_else(|| {
122            vec![
123                UserRole {
124                    id: uuid::Uuid::new_v4().to_string(),
125                    name: "Organization Admin".to_string(),
126                    key: "MaterializePlatformAdmin".to_string(),
127                },
128                UserRole {
129                    id: uuid::Uuid::new_v4().to_string(),
130                    name: "Organization Member".to_string(),
131                    key: "MaterializePlatform".to_string(),
132                },
133            ]
134        });
135
136        let context = Arc::new(Context {
137            issuer,
138            encoding_key,
139            decoding_key,
140            users: Mutex::new(users),
141            user_api_tokens: Mutex::new(user_api_tokens),
142            tenant_api_tokens: Mutex::new(tenant_api_tokens),
143            role_updates_rx: Mutex::new(role_updates_rx),
144            role_permissions,
145            now,
146            expires_in_secs,
147            latency,
148            refresh_tokens: Mutex::new(BTreeMap::new()),
149            refreshes: Arc::clone(&refreshes),
150            enable_auth: Arc::clone(&enable_auth),
151            auth_requests: Arc::clone(&auth_requests),
152            roles: Arc::new(roles),
153            sso_configs: Mutex::new(BTreeMap::new()),
154            groups: Mutex::new(BTreeMap::new()),
155            scim_configurations: Mutex::new(BTreeMap::new()),
156        });
157
158        let router = Router::new()
159            .route(AUTH_API_TOKEN_PATH, post(handle_post_auth_api_token))
160            .route(AUTH_USER_PATH, post(handle_post_auth_user))
161            .route(AUTH_API_TOKEN_REFRESH_PATH, post(handle_post_token_refresh))
162            .route(USERS_ME_PATH, get(handle_get_user_profile))
163            .route(
164                USERS_API_TOKENS_PATH,
165                get(handle_list_user_api_tokens).post(handle_post_user_api_token),
166            )
167            .route(USER_API_TOKENS_PATH, delete(handle_delete_user_api_token))
168            .route(
169                TENANT_API_TOKENS_PATH,
170                get(handle_list_tenant_api_tokens).post(handle_create_tenant_api_token),
171            )
172            .route(
173                TENANT_API_TOKEN_PATH,
174                delete(handle_delete_tenant_api_token),
175            )
176            .route(USER_PATH, get(handle_get_user).delete(handle_delete_user))
177            .route(USER_CREATE_PATH, post(handle_create_user))
178            .route(USERS_V3_PATH, get(handle_get_users_v3))
179            .route(MEMBERS_PATH, put(handle_update_user_roles))
180            .route(ROLES_PATH, get(handle_roles_request))
181            .route(
182                SSO_CONFIGS_PATH,
183                get(handle_list_sso_configs).post(handle_create_sso_config),
184            )
185            .route(
186                SSO_CONFIG_PATH,
187                get(handle_get_sso_config)
188                    .patch(handle_update_sso_config)
189                    .delete(handle_delete_sso_config),
190            )
191            .route(
192                SSO_CONFIG_DOMAINS_PATH,
193                get(handle_list_domains).post(handle_create_domain),
194            )
195            .route(
196                SSO_CONFIG_DOMAIN_PATH,
197                get(handle_get_domain)
198                    .patch(handle_update_domain)
199                    .delete(handle_delete_domain),
200            )
201            .route(
202                SSO_CONFIG_GROUPS_PATH,
203                get(handle_list_group_mappings).post(handle_create_group_mapping),
204            )
205            .route(
206                SSO_CONFIG_GROUP_PATH,
207                get(handle_get_group_mapping)
208                    .patch(handle_update_group_mapping)
209                    .delete(handle_delete_group_mapping),
210            )
211            .route(
212                SSO_CONFIG_ROLES_PATH,
213                get(handle_get_default_roles).put(handle_set_default_roles),
214            )
215            .route(
216                GROUPS_PATH,
217                get(handle_list_groups).post(handle_create_group),
218            )
219            .route(
220                GROUP_PATH,
221                get(handle_get_group)
222                    .patch(handle_update_group)
223                    .delete(handle_delete_group),
224            )
225            .route(GROUP_PATH_WITH_SLASH, get(handle_get_group))
226            .route(
227                GROUP_ROLES_PATH,
228                post(handle_add_roles_to_group).delete(handle_remove_roles_from_group),
229            )
230            .route(
231                GROUP_USERS_PATH,
232                post(handle_add_users_to_group).delete(handle_remove_users_from_group),
233            )
234            .route(
235                SCIM_CONFIGURATIONS_PATH,
236                get(handle_list_scim_configurations).post(handle_create_scim_configuration),
237            )
238            .route(
239                SCIM_CONFIGURATION_PATH,
240                delete(handle_delete_scim_configuration),
241            )
242            .route(
243                INTERNAL_USER_PASSWORD_PATH,
244                post(internal_handle_get_user_password),
245            )
246            .layer(middleware::from_fn(logging_middleware))
247            .layer(middleware::from_fn_with_state(
248                Arc::clone(&context),
249                latency_middleware,
250            ))
251            .layer(middleware::from_fn_with_state(
252                Arc::clone(&context),
253                role_update_middleware,
254            ))
255            .with_state(context);
256
257        let addr = match addr {
258            Some(addr) => Cow::Borrowed(addr),
259            None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
260        };
261        let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
262            panic!("error binding to {}: {}", addr, e);
263        });
264        let base_url = format!("http://{}", listener.local_addr().unwrap());
265        let server = axum::serve(
266            listener,
267            router.into_make_service_with_connect_info::<SocketAddr>(),
268        );
269        let handle = mz_ore::task::spawn(|| "mzcloud-mock-server", server.into_future());
270
271        Ok(FronteggMockServer {
272            base_url,
273            refreshes,
274            enable_auth,
275            auth_requests,
276            role_updates_tx,
277            handle,
278        })
279    }
280
281    pub fn wait_for_auth(&self, expires_in_secs: u64) {
282        let expected = *self.auth_requests.lock().unwrap() + 1;
283        Retry::default()
284            .factor(1.0)
285            .max_duration(Duration::from_secs(expires_in_secs + 20))
286            .retry(|_| {
287                let refreshes = *self.auth_requests.lock().unwrap();
288                if refreshes >= expected {
289                    Ok(())
290                } else {
291                    Err(format!(
292                        "expected refresh count {}, got {}",
293                        expected, refreshes
294                    ))
295                }
296            })
297            .unwrap();
298    }
299
300    pub fn auth_api_token_url(&self) -> String {
301        format!("{}{}", &self.base_url, AUTH_API_TOKEN_PATH)
302    }
303}
304
305pub struct Context {
306    pub issuer: String,
307    pub encoding_key: EncodingKey,
308    pub decoding_key: DecodingKey,
309    pub users: Mutex<BTreeMap<String, UserConfig>>,
310    pub user_api_tokens: Mutex<BTreeMap<ApiToken, String>>,
311    pub tenant_api_tokens: Mutex<BTreeMap<ApiToken, TenantApiTokenConfig>>,
312    pub role_updates_rx: Mutex<UnboundedReceiver<(String, Vec<String>)>>,
313    pub role_permissions: BTreeMap<String, Vec<String>>,
314    pub now: NowFn,
315    pub expires_in_secs: i64,
316    pub latency: Option<Duration>,
317    pub refresh_tokens: Mutex<BTreeMap<String, RefreshTokenTarget>>,
318    pub refreshes: Arc<Mutex<u64>>,
319    pub enable_auth: Arc<AtomicBool>,
320    pub auth_requests: Arc<Mutex<u64>>,
321    pub roles: Arc<Vec<UserRole>>,
322    pub sso_configs: Mutex<BTreeMap<String, SSOConfigStorage>>,
323    pub groups: Mutex<BTreeMap<String, Group>>,
324    pub scim_configurations: Mutex<BTreeMap<String, SCIM2ConfigurationStorage>>,
325}