mz_frontegg_mock/handlers/
sso.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use crate::models::*;
11use crate::server::Context;
12use axum::{
13    Json,
14    extract::{Path, State},
15    http::StatusCode,
16};
17use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
18use chrono::Utc;
19use std::sync::Arc;
20use uuid::Uuid;
21
22// https://docs.frontegg.com/reference/ssoconfigurationcontrollerv1_getssoconfigurations
23pub async fn handle_list_sso_configs(
24    State(context): State<Arc<Context>>,
25) -> Result<Json<Vec<SSOConfigResponse>>, StatusCode> {
26    let configs = context.sso_configs.lock().unwrap();
27    let config_list: Vec<SSOConfigResponse> = configs
28        .values()
29        .cloned()
30        .map(SSOConfigResponse::from)
31        .collect();
32    Ok(Json(config_list))
33}
34
35// https://docs.frontegg.com/reference/ssoconfigurationcontrollerv1_createssoconfiguration
36pub async fn handle_create_sso_config(
37    State(context): State<Arc<Context>>,
38    Json(new_config): Json<SSOConfigCreateRequest>,
39) -> Result<(StatusCode, Json<SSOConfigResponse>), StatusCode> {
40    let config_storage = SSOConfigStorage {
41        id: Uuid::new_v4().to_string(),
42        enabled: new_config.enabled,
43        sso_endpoint: new_config.sso_endpoint.unwrap_or_default(),
44        public_certificate: new_config
45            .public_certificate
46            .map(|cert| BASE64.encode(cert.as_bytes()))
47            .unwrap_or_default(),
48        sign_request: new_config.sign_request,
49        acs_url: new_config.acs_url.unwrap_or_default(),
50        sp_entity_id: new_config.sp_entity_id.unwrap_or_default(),
51        config_type: new_config.config_type.unwrap_or_else(|| "saml".to_string()),
52        oidc_client_id: new_config.oidc_client_id.unwrap_or_default(),
53        oidc_secret: new_config.oidc_secret.unwrap_or_default(),
54        domains: Vec::new(),
55        groups: Vec::new(),
56        default_roles: DefaultRoles {
57            role_ids: Vec::new(),
58        },
59        generated_verification: Some(Uuid::new_v4().to_string()),
60        created_at: Some(Utc::now()),
61        updated_at: Some(Utc::now()),
62        config_metadata: None,
63        override_active_tenant: Some(true),
64        sub_account_access_limit: Some(0),
65        skip_email_domain_validation: Some(false),
66        role_ids: Vec::new(),
67    };
68
69    let mut configs = context.sso_configs.lock().unwrap();
70    configs.insert(config_storage.id.clone(), config_storage.clone());
71
72    let response = SSOConfigResponse::from(config_storage);
73    Ok((StatusCode::CREATED, Json(response)))
74}
75
76pub async fn handle_get_sso_config(
77    State(context): State<Arc<Context>>,
78    Path(id): Path<String>,
79) -> Result<Json<SSOConfigResponse>, StatusCode> {
80    let configs = context.sso_configs.lock().unwrap();
81    configs
82        .get(&id)
83        .cloned()
84        .map(SSOConfigResponse::from)
85        .map(Json)
86        .ok_or(StatusCode::NOT_FOUND)
87}
88
89pub async fn handle_update_sso_config(
90    State(context): State<Arc<Context>>,
91    Path(id): Path<String>,
92    Json(updated_config): Json<SSOConfigUpdateRequest>,
93) -> Result<Json<SSOConfigResponse>, StatusCode> {
94    let mut configs = context.sso_configs.lock().unwrap();
95    if let Some(config) = configs.get_mut(&id) {
96        if let Some(enabled) = updated_config.enabled {
97            config.enabled = enabled;
98        }
99        if let Some(sso_endpoint) = updated_config.sso_endpoint {
100            config.sso_endpoint = sso_endpoint;
101        }
102        if let Some(public_certificate) = updated_config.public_certificate {
103            config.public_certificate = BASE64.encode(public_certificate.as_bytes());
104        }
105        if let Some(sign_request) = updated_config.sign_request {
106            config.sign_request = sign_request;
107        }
108        if let Some(acs_url) = updated_config.acs_url {
109            config.acs_url = acs_url;
110        }
111        if let Some(sp_entity_id) = updated_config.sp_entity_id {
112            config.sp_entity_id = sp_entity_id;
113        }
114        if let Some(config_type) = updated_config.config_type {
115            config.config_type = config_type;
116        }
117        if let Some(oidc_client_id) = updated_config.oidc_client_id {
118            config.oidc_client_id = oidc_client_id;
119        }
120        if let Some(oidc_secret) = updated_config.oidc_secret {
121            config.oidc_secret = oidc_secret;
122        }
123
124        config.updated_at = Some(Utc::now());
125
126        let response = SSOConfigResponse::from(config.clone());
127        Ok(Json(response))
128    } else {
129        Err(StatusCode::NOT_FOUND)
130    }
131}
132
133// https://docs.frontegg.com/reference/ssoconfigurationcontrollerv1_deletessoconfiguration
134pub async fn handle_delete_sso_config(
135    State(context): State<Arc<Context>>,
136    Path(id): Path<String>,
137) -> StatusCode {
138    let mut configs = context.sso_configs.lock().unwrap();
139    if configs.remove(&id).is_some() {
140        StatusCode::OK
141    } else {
142        StatusCode::NOT_FOUND
143    }
144}
145
146pub async fn handle_list_domains(
147    State(context): State<Arc<Context>>,
148    Path(config_id): Path<String>,
149) -> Result<Json<Vec<DomainResponse>>, StatusCode> {
150    let configs = context.sso_configs.lock().unwrap();
151    if let Some(config) = configs.get(&config_id) {
152        let domains: Vec<DomainResponse> = config
153            .domains
154            .iter()
155            .cloned()
156            .map(DomainResponse::from)
157            .collect();
158        Ok(Json(domains))
159    } else {
160        Err(StatusCode::NOT_FOUND)
161    }
162}
163
164// https://docs.frontegg.com/reference/ssodomaincontrollerv1_createssodomain
165pub async fn handle_create_domain(
166    State(context): State<Arc<Context>>,
167    Path(config_id): Path<String>,
168    Json(mut new_domain): Json<Domain>,
169) -> Result<Json<DomainResponse>, StatusCode> {
170    let mut configs = context.sso_configs.lock().unwrap();
171    if let Some(config) = configs.get_mut(&config_id) {
172        new_domain.id = Uuid::new_v4().to_string();
173        new_domain.sso_config_id = config_id;
174        new_domain.validated = false;
175        config.domains.push(new_domain.clone());
176        Ok(Json(DomainResponse::from(new_domain)))
177    } else {
178        Err(StatusCode::NOT_FOUND)
179    }
180}
181
182// https://docs.frontegg.com/reference/ssorolescontrollerv1_getssodefaultroles
183pub async fn handle_get_default_roles(
184    State(context): State<Arc<Context>>,
185    Path(config_id): Path<String>,
186) -> Result<Json<DefaultRoles>, StatusCode> {
187    let configs = context.sso_configs.lock().unwrap();
188    if let Some(config) = configs.get(&config_id) {
189        Ok(Json(config.default_roles.clone()))
190    } else {
191        Err(StatusCode::NOT_FOUND)
192    }
193}
194
195// https://docs.frontegg.com/reference/ssorolescontrollerv1_setssodefaultroles
196pub async fn handle_set_default_roles(
197    State(context): State<Arc<Context>>,
198    Path(config_id): Path<String>,
199    Json(default_roles): Json<DefaultRoles>,
200) -> Result<(StatusCode, Json<DefaultRoles>), StatusCode> {
201    let mut configs = context.sso_configs.lock().unwrap();
202    if let Some(config) = configs.get_mut(&config_id) {
203        config.default_roles = default_roles.clone();
204        for role_id in &default_roles.role_ids {
205            if !config.role_ids.contains(role_id) {
206                config.role_ids.push(role_id.clone());
207            }
208        }
209        Ok((StatusCode::CREATED, Json(default_roles)))
210    } else {
211        Err(StatusCode::NOT_FOUND)
212    }
213}
214
215pub async fn handle_get_domain(
216    State(context): State<Arc<Context>>,
217    Path((config_id, domain_id)): Path<(String, String)>,
218) -> Result<Json<DomainResponse>, StatusCode> {
219    let configs = context.sso_configs.lock().unwrap();
220    if let Some(config) = configs.get(&config_id) {
221        config
222            .domains
223            .iter()
224            .find(|domain| domain.id == domain_id)
225            .cloned()
226            .map(DomainResponse::from)
227            .map(Json)
228            .ok_or(StatusCode::NOT_FOUND)
229    } else {
230        Err(StatusCode::NOT_FOUND)
231    }
232}
233
234pub async fn handle_update_domain(
235    State(context): State<Arc<Context>>,
236    Path((config_id, domain_id)): Path<(String, String)>,
237    Json(updated_domain): Json<DomainUpdateRequest>,
238) -> Result<Json<DomainResponse>, StatusCode> {
239    let mut configs = context.sso_configs.lock().unwrap();
240    if let Some(config) = configs.get_mut(&config_id) {
241        if let Some(domain) = config.domains.iter_mut().find(|d| d.id == domain_id) {
242            if let Some(new_domain) = updated_domain.domain {
243                domain.domain = new_domain;
244            }
245            if let Some(new_validated) = updated_domain.validated {
246                domain.validated = new_validated;
247            }
248            Ok(Json(DomainResponse::from(domain.clone())))
249        } else {
250            Err(StatusCode::NOT_FOUND)
251        }
252    } else {
253        Err(StatusCode::NOT_FOUND)
254    }
255}
256
257// https://docs.frontegg.com/reference/ssodomaincontrollerv1_deletessodomain
258pub async fn handle_delete_domain(
259    State(context): State<Arc<Context>>,
260    Path((config_id, domain_id)): Path<(String, String)>,
261) -> StatusCode {
262    let mut configs = context.sso_configs.lock().unwrap();
263    if let Some(config) = configs.get_mut(&config_id) {
264        let initial_len = config.domains.len();
265        config.domains.retain(|d| d.id != domain_id);
266        if config.domains.len() < initial_len {
267            StatusCode::OK
268        } else {
269            StatusCode::NOT_FOUND
270        }
271    } else {
272        StatusCode::NOT_FOUND
273    }
274}
275
276// https://docs.frontegg.com/reference/ssogroupscontrollerv1_getssogroup
277pub async fn handle_list_group_mappings(
278    State(context): State<Arc<Context>>,
279    Path(config_id): Path<String>,
280) -> Result<Json<Vec<GroupMappingResponse>>, StatusCode> {
281    let configs = context.sso_configs.lock().unwrap();
282    if let Some(config) = configs.get(&config_id) {
283        let groups: Vec<GroupMappingResponse> = config
284            .groups
285            .iter()
286            .cloned()
287            .map(GroupMappingResponse::from)
288            .collect();
289        Ok(Json(groups))
290    } else {
291        Err(StatusCode::NOT_FOUND)
292    }
293}
294
295// https://docs.frontegg.com/reference/ssogroupscontrollerv1_createssogroup
296pub async fn handle_create_group_mapping(
297    State(context): State<Arc<Context>>,
298    Path(config_id): Path<String>,
299    Json(new_group): Json<GroupMapping>,
300) -> Result<Json<GroupMappingResponse>, StatusCode> {
301    let mut configs = context.sso_configs.lock().unwrap();
302    if let Some(config) = configs.get_mut(&config_id) {
303        let group = GroupMapping {
304            id: Uuid::new_v4().to_string(),
305            group: new_group.group,
306            role_ids: new_group.role_ids,
307            sso_config_id: config_id,
308            enabled: true,
309        };
310        config.groups.push(group.clone());
311        Ok(Json(GroupMappingResponse::from(group)))
312    } else {
313        Err(StatusCode::NOT_FOUND)
314    }
315}
316
317// https://docs.frontegg.com/reference/ssogroupscontrollerv1_getssogroup
318pub async fn handle_get_group_mapping(
319    State(context): State<Arc<Context>>,
320    Path((config_id, group_id)): Path<(String, String)>,
321) -> Result<Json<GroupMappingResponse>, StatusCode> {
322    let configs = context.sso_configs.lock().unwrap();
323    if let Some(config) = configs.get(&config_id) {
324        config
325            .groups
326            .iter()
327            .find(|g| g.id == group_id)
328            .cloned()
329            .map(GroupMappingResponse::from)
330            .map(Json)
331            .ok_or(StatusCode::NOT_FOUND)
332    } else {
333        Err(StatusCode::NOT_FOUND)
334    }
335}
336
337// https://docs.frontegg.com/reference/ssogroupscontrollerv1_updatessogroup
338pub async fn handle_update_group_mapping(
339    State(context): State<Arc<Context>>,
340    Path((config_id, group_id)): Path<(String, String)>,
341    Json(updated_group): Json<GroupMappingUpdateRequest>,
342) -> Result<Json<GroupMappingResponse>, StatusCode> {
343    let mut configs = context.sso_configs.lock().unwrap();
344    if let Some(config) = configs.get_mut(&config_id) {
345        if let Some(group) = config.groups.iter_mut().find(|g| g.id == group_id) {
346            if let Some(new_group) = updated_group.group {
347                group.group = new_group;
348            }
349            if let Some(new_role_ids) = updated_group.role_ids {
350                group.role_ids = new_role_ids;
351            }
352            if let Some(new_enabled) = updated_group.enabled {
353                group.enabled = new_enabled;
354            }
355            Ok(Json(GroupMappingResponse::from(group.clone())))
356        } else {
357            Err(StatusCode::NOT_FOUND)
358        }
359    } else {
360        Err(StatusCode::NOT_FOUND)
361    }
362}
363
364// https://docs.frontegg.com/reference/ssogroupscontrollerv1_deletessogroup
365pub async fn handle_delete_group_mapping(
366    State(context): State<Arc<Context>>,
367    Path((config_id, group_id)): Path<(String, String)>,
368) -> StatusCode {
369    let mut configs = context.sso_configs.lock().unwrap();
370    if let Some(config) = configs.get_mut(&config_id) {
371        let initial_len = config.groups.len();
372        config.groups.retain(|g| g.id != group_id);
373        if config.groups.len() < initial_len {
374            StatusCode::OK
375        } else {
376            StatusCode::NOT_FOUND
377        }
378    } else {
379        StatusCode::NOT_FOUND
380    }
381}