mz_frontegg_mock/
server.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::utils::RefreshTokenTarget;
11use axum::routing::{delete, get, post, put};
12use axum::{Router, middleware};
13use jsonwebtoken::{DecodingKey, EncodingKey};
14use mz_ore::now::NowFn;
15use mz_ore::retry::Retry;
16use mz_ore::task::JoinHandle;
17use std::borrow::Cow;
18use std::collections::BTreeMap;
19use std::future::IntoFuture;
20use std::net::{IpAddr, Ipv4Addr, SocketAddr};
21use std::sync::atomic::AtomicBool;
22use std::sync::{Arc, Mutex};
23use std::time::Duration;
24use tokio::net::TcpListener;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
26
27use crate::handlers::*;
28use crate::middleware::*;
29use crate::models::*;
30
31const AUTH_API_TOKEN_PATH: &str = "/identity/resources/auth/v1/api-token";
32const AUTH_USER_PATH: &str = "/identity/resources/auth/v1/user";
33const AUTH_API_TOKEN_REFRESH_PATH: &str = "/identity/resources/auth/v1/api-token/token/refresh";
34const GROUPS_PATH: &str = "/frontegg/identity/resources/groups/v1";
35const GROUP_PATH: &str = "/frontegg/identity/resources/groups/v1/:id";
36const GROUP_ROLES_PATH: &str = "/frontegg/identity/resources/groups/v1/:id/roles";
37const GROUP_USERS_PATH: &str = "/frontegg/identity/resources/groups/v1/:id/users";
38const GROUP_PATH_WITH_SLASH: &str = "/frontegg/identity/resources/groups/v1/:id/";
39const MEMBERS_PATH: &str = "/frontegg/team/resources/members/v1";
40const USERS_ME_PATH: &str = "/identity/resources/users/v2/me";
41const USERS_API_TOKENS_PATH: &str = "/identity/resources/users/api-tokens/v1";
42const USER_API_TOKENS_PATH: &str = "/identity/resources/users/api-tokens/v1/:id";
43const TENANT_API_TOKENS_PATH: &str = "/identity/resources/tenants/api-tokens/v1";
44const TENANT_API_TOKEN_PATH: &str = "/identity/resources/tenants/api-tokens/v1/:id";
45const USER_PATH: &str = "/identity/resources/users/v1/:id";
46const USER_CREATE_PATH: &str = "/identity/resources/users/v2";
47const USERS_V3_PATH: &str = "/identity/resources/users/v3";
48const ROLES_PATH: &str = "/identity/resources/roles/v2";
49const SCIM_CONFIGURATIONS_PATH: &str = "/frontegg/directory/resources/v1/configurations/scim2";
50const SCIM_CONFIGURATION_PATH: &str = "/frontegg/directory/resources/v1/configurations/scim2/:id";
51const SSO_CONFIGS_PATH: &str = "/frontegg/team/resources/sso/v1/configurations";
52const SSO_CONFIG_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/:id";
53const SSO_CONFIG_DOMAINS_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/:id/domains";
54const SSO_CONFIG_DOMAIN_PATH: &str =
55    "/frontegg/team/resources/sso/v1/configurations/:id/domains/:domain_id";
56const SSO_CONFIG_GROUPS_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/:id/groups";
57const SSO_CONFIG_GROUP_PATH: &str =
58    "/frontegg/team/resources/sso/v1/configurations/:id/groups/:group_id";
59const SSO_CONFIG_ROLES_PATH: &str = "/frontegg/team/resources/sso/v1/configurations/:id/roles";
60
61// Internal endpoints for testing
62const INTERNAL_USER_PASSWORD_PATH: &str = "/api/internal-mock/user-password";
63
64pub struct FronteggMockServer {
65    pub base_url: String,
66    pub refreshes: Arc<Mutex<u64>>,
67    pub enable_auth: Arc<AtomicBool>,
68    pub auth_requests: Arc<Mutex<u64>>,
69    pub role_updates_tx: UnboundedSender<(String, Vec<String>)>,
70    pub handle: JoinHandle<Result<(), std::io::Error>>,
71}
72
73impl FronteggMockServer {
74    /// Starts a [`FronteggMockServer`], must be started from within a [`tokio::runtime::Runtime`].
75    pub async fn start(
76        addr: Option<&SocketAddr>,
77        issuer: String,
78        encoding_key: EncodingKey,
79        decoding_key: DecodingKey,
80        users: BTreeMap<String, UserConfig>,
81        tenant_api_tokens: BTreeMap<ApiToken, TenantApiTokenConfig>,
82        role_permissions: Option<BTreeMap<String, Vec<String>>>,
83        now: NowFn,
84        expires_in_secs: i64,
85        latency: Option<Duration>,
86        roles: Option<Vec<UserRole>>,
87    ) -> Result<FronteggMockServer, anyhow::Error> {
88        let (role_updates_tx, role_updates_rx) = unbounded_channel();
89
90        let enable_auth = Arc::new(AtomicBool::new(true));
91        let refreshes = Arc::new(Mutex::new(0u64));
92        let auth_requests = Arc::new(Mutex::new(0u64));
93
94        let user_api_tokens: BTreeMap<ApiToken, String> = users
95            .iter()
96            .map(|(email, user)| {
97                user.initial_api_tokens
98                    .iter()
99                    .map(|token| (token.clone(), email.clone()))
100            })
101            .flatten()
102            .collect();
103        let role_permissions = role_permissions.unwrap_or_else(|| {
104            BTreeMap::from([
105                (
106                    "MaterializePlatformAdmin".to_owned(),
107                    vec![
108                        "materialize.environment.write".to_owned(),
109                        "materialize.invoice.read".to_owned(),
110                    ],
111                ),
112                (
113                    "MaterializePlatform".to_owned(),
114                    vec!["materialize.environment.read".to_owned()],
115                ),
116            ])
117        });
118
119        // Provide default roles if None is provided
120        let roles = roles.unwrap_or_else(|| {
121            vec![
122                UserRole {
123                    id: uuid::Uuid::new_v4().to_string(),
124                    name: "Organization Admin".to_string(),
125                    key: "MaterializePlatformAdmin".to_string(),
126                },
127                UserRole {
128                    id: uuid::Uuid::new_v4().to_string(),
129                    name: "Organization Member".to_string(),
130                    key: "MaterializePlatform".to_string(),
131                },
132            ]
133        });
134
135        let context = Arc::new(Context {
136            issuer,
137            encoding_key,
138            decoding_key,
139            users: Mutex::new(users),
140            user_api_tokens: Mutex::new(user_api_tokens),
141            tenant_api_tokens: Mutex::new(tenant_api_tokens),
142            role_updates_rx: Mutex::new(role_updates_rx),
143            role_permissions,
144            now,
145            expires_in_secs,
146            latency,
147            refresh_tokens: Mutex::new(BTreeMap::new()),
148            refreshes: Arc::clone(&refreshes),
149            enable_auth: Arc::clone(&enable_auth),
150            auth_requests: Arc::clone(&auth_requests),
151            roles: Arc::new(roles),
152            sso_configs: Mutex::new(BTreeMap::new()),
153            groups: Mutex::new(BTreeMap::new()),
154            scim_configurations: Mutex::new(BTreeMap::new()),
155        });
156
157        let router = Router::new()
158            .route(AUTH_API_TOKEN_PATH, post(handle_post_auth_api_token))
159            .route(AUTH_USER_PATH, post(handle_post_auth_user))
160            .route(AUTH_API_TOKEN_REFRESH_PATH, post(handle_post_token_refresh))
161            .route(USERS_ME_PATH, get(handle_get_user_profile))
162            .route(
163                USERS_API_TOKENS_PATH,
164                get(handle_list_user_api_tokens).post(handle_post_user_api_token),
165            )
166            .route(USER_API_TOKENS_PATH, delete(handle_delete_user_api_token))
167            .route(
168                TENANT_API_TOKENS_PATH,
169                get(handle_list_tenant_api_tokens).post(handle_create_tenant_api_token),
170            )
171            .route(
172                TENANT_API_TOKEN_PATH,
173                delete(handle_delete_tenant_api_token),
174            )
175            .route(USER_PATH, get(handle_get_user).delete(handle_delete_user))
176            .route(USER_CREATE_PATH, post(handle_create_user))
177            .route(USERS_V3_PATH, get(handle_get_users_v3))
178            .route(MEMBERS_PATH, put(handle_update_user_roles))
179            .route(ROLES_PATH, get(handle_roles_request))
180            .route(
181                SSO_CONFIGS_PATH,
182                get(handle_list_sso_configs).post(handle_create_sso_config),
183            )
184            .route(
185                SSO_CONFIG_PATH,
186                get(handle_get_sso_config)
187                    .patch(handle_update_sso_config)
188                    .delete(handle_delete_sso_config),
189            )
190            .route(
191                SSO_CONFIG_DOMAINS_PATH,
192                get(handle_list_domains).post(handle_create_domain),
193            )
194            .route(
195                SSO_CONFIG_DOMAIN_PATH,
196                get(handle_get_domain)
197                    .patch(handle_update_domain)
198                    .delete(handle_delete_domain),
199            )
200            .route(
201                SSO_CONFIG_GROUPS_PATH,
202                get(handle_list_group_mappings).post(handle_create_group_mapping),
203            )
204            .route(
205                SSO_CONFIG_GROUP_PATH,
206                get(handle_get_group_mapping)
207                    .patch(handle_update_group_mapping)
208                    .delete(handle_delete_group_mapping),
209            )
210            .route(
211                SSO_CONFIG_ROLES_PATH,
212                get(handle_get_default_roles).put(handle_set_default_roles),
213            )
214            .route(
215                GROUPS_PATH,
216                get(handle_list_groups).post(handle_create_group),
217            )
218            .route(
219                GROUP_PATH,
220                get(handle_get_group)
221                    .patch(handle_update_group)
222                    .delete(handle_delete_group),
223            )
224            .route(GROUP_PATH_WITH_SLASH, get(handle_get_group))
225            .route(
226                GROUP_ROLES_PATH,
227                post(handle_add_roles_to_group).delete(handle_remove_roles_from_group),
228            )
229            .route(
230                GROUP_USERS_PATH,
231                post(handle_add_users_to_group).delete(handle_remove_users_from_group),
232            )
233            .route(
234                SCIM_CONFIGURATIONS_PATH,
235                get(handle_list_scim_configurations).post(handle_create_scim_configuration),
236            )
237            .route(
238                SCIM_CONFIGURATION_PATH,
239                delete(handle_delete_scim_configuration),
240            )
241            .route(
242                INTERNAL_USER_PASSWORD_PATH,
243                post(internal_handle_get_user_password),
244            )
245            .layer(middleware::from_fn(logging_middleware))
246            .layer(middleware::from_fn_with_state(
247                Arc::clone(&context),
248                latency_middleware,
249            ))
250            .layer(middleware::from_fn_with_state(
251                Arc::clone(&context),
252                role_update_middleware,
253            ))
254            .with_state(context);
255
256        let addr = match addr {
257            Some(addr) => Cow::Borrowed(addr),
258            None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
259        };
260        let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
261            panic!("error binding to {}: {}", addr, e);
262        });
263        let base_url = format!("http://{}", listener.local_addr().unwrap());
264        let server = axum::serve(
265            listener,
266            router.into_make_service_with_connect_info::<SocketAddr>(),
267        );
268        let handle = mz_ore::task::spawn(|| "mzcloud-mock-server", server.into_future());
269
270        Ok(FronteggMockServer {
271            base_url,
272            refreshes,
273            enable_auth,
274            auth_requests,
275            role_updates_tx,
276            handle,
277        })
278    }
279
280    pub fn wait_for_auth(&self, expires_in_secs: u64) {
281        let expected = *self.auth_requests.lock().unwrap() + 1;
282        Retry::default()
283            .factor(1.0)
284            .max_duration(Duration::from_secs(expires_in_secs + 20))
285            .retry(|_| {
286                let refreshes = *self.auth_requests.lock().unwrap();
287                if refreshes >= expected {
288                    Ok(())
289                } else {
290                    Err(format!(
291                        "expected refresh count {}, got {}",
292                        expected, refreshes
293                    ))
294                }
295            })
296            .unwrap();
297    }
298
299    pub fn auth_api_token_url(&self) -> String {
300        format!("{}{}", &self.base_url, AUTH_API_TOKEN_PATH)
301    }
302}
303
304pub struct Context {
305    pub issuer: String,
306    pub encoding_key: EncodingKey,
307    pub decoding_key: DecodingKey,
308    pub users: Mutex<BTreeMap<String, UserConfig>>,
309    pub user_api_tokens: Mutex<BTreeMap<ApiToken, String>>,
310    pub tenant_api_tokens: Mutex<BTreeMap<ApiToken, TenantApiTokenConfig>>,
311    pub role_updates_rx: Mutex<UnboundedReceiver<(String, Vec<String>)>>,
312    pub role_permissions: BTreeMap<String, Vec<String>>,
313    pub now: NowFn,
314    pub expires_in_secs: i64,
315    pub latency: Option<Duration>,
316    pub refresh_tokens: Mutex<BTreeMap<String, RefreshTokenTarget>>,
317    pub refreshes: Arc<Mutex<u64>>,
318    pub enable_auth: Arc<AtomicBool>,
319    pub auth_requests: Arc<Mutex<u64>>,
320    pub roles: Arc<Vec<UserRole>>,
321    pub sso_configs: Mutex<BTreeMap<String, SSOConfigStorage>>,
322    pub groups: Mutex<BTreeMap<String, Group>>,
323    pub scim_configurations: Mutex<BTreeMap<String, SCIM2ConfigurationStorage>>,
324}