Skip to main content

mz_auth/
group_claims.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Shared JWT group-claim extraction used by both OIDC and Frontegg
11//! authenticators. Encapsulates the dot-separated claim-path resolution and
12//! the array-vs-string normalization so the two authenticators behave
13//! identically.
14
15use std::collections::{BTreeMap, BTreeSet};
16
17use tracing::warn;
18
19/// Extracts group names from a JWT's unknown-claims map.
20///
21/// `claim_path` may be a bare claim name (e.g. `"groups"`) or a
22/// dot-separated path into nested JSON objects (e.g.
23/// `"customClaims.groups"`). Keys that contain a literal `.` are not
24/// reachable; this is a known limitation matching CockroachDB's
25/// `group_claim` semantics. Empty path segments (leading/trailing/double
26/// dots, or an empty path) yield `None` and emit a `warn!`-level log so
27/// misconfiguration is visible.
28///
29/// Returns `None` if the claim is absent (skip sync, preserve current state),
30/// `Some(vec![])` if the claim is present but empty (revoke all sync-granted
31/// roles), or `Some(vec![...])` with deduplicated, sorted group names
32/// (exact case preserved — matching against catalog role names is
33/// case-sensitive).
34///
35/// Accepts arrays of strings, single strings, or mixed arrays (non-string
36/// elements are filtered out). Other JSON types are treated as absent.
37pub fn extract_groups(
38    claims: &BTreeMap<String, serde_json::Value>,
39    claim_path: &str,
40) -> Option<Vec<String>> {
41    let value = resolve_claim_path(claims, claim_path)?;
42
43    let raw_groups: Vec<String> = match value {
44        serde_json::Value::Array(arr) => arr
45            .iter()
46            .filter_map(|v| v.as_str().map(String::from))
47            .collect(),
48        serde_json::Value::String(s) => {
49            if s.is_empty() {
50                vec![]
51            } else {
52                vec![s.clone()]
53            }
54        }
55        _ => {
56            warn!(
57                claim_path,
58                "JWT group claim has unexpected type; skipping group sync"
59            );
60            return None;
61        }
62    };
63
64    let groups: Vec<String> = raw_groups
65        .into_iter()
66        .filter(|g| !g.is_empty())
67        .collect::<BTreeSet<_>>()
68        .into_iter()
69        .collect();
70
71    Some(groups)
72}
73
74/// Walks a dot-separated claim path into nested JSON objects. Returns
75/// `None` if the path is empty, any segment is empty, an intermediate
76/// segment is missing, or an intermediate segment resolves to a
77/// non-object value.
78fn resolve_claim_path<'a>(
79    claims: &'a BTreeMap<String, serde_json::Value>,
80    claim_path: &str,
81) -> Option<&'a serde_json::Value> {
82    let mut segments = claim_path.split('.');
83    let first = segments
84        .next()
85        .expect("str::split always yields at least one segment");
86    if first.is_empty() {
87        warn!(
88            claim_path,
89            "JWT group claim path has an empty segment; skipping group sync"
90        );
91        return None;
92    }
93    let mut current = claims.get(first)?;
94    for segment in segments {
95        if segment.is_empty() {
96            warn!(
97                claim_path,
98                "JWT group claim path has an empty segment; skipping group sync"
99            );
100            return None;
101        }
102        let obj = match current {
103            serde_json::Value::Object(map) => map,
104            _ => {
105                warn!(
106                    claim_path,
107                    segment,
108                    "JWT group claim intermediate segment is not an object; skipping group sync"
109                );
110                return None;
111            }
112        };
113        current = obj.get(segment)?;
114    }
115    Some(current)
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    fn parse(json: &str) -> BTreeMap<String, serde_json::Value> {
123        serde_json::from_str(json).unwrap()
124    }
125
126    #[mz_ore::test]
127    fn test_groups_array() {
128        let c = parse(r#"{"groups":["analytics","platform_eng"]}"#);
129        assert_eq!(
130            extract_groups(&c, "groups"),
131            Some(vec!["analytics".to_string(), "platform_eng".to_string()])
132        );
133    }
134
135    #[mz_ore::test]
136    fn test_groups_single_string() {
137        let c = parse(r#"{"groups":"analytics"}"#);
138        assert_eq!(
139            extract_groups(&c, "groups"),
140            Some(vec!["analytics".to_string()])
141        );
142    }
143
144    #[mz_ore::test]
145    fn test_groups_missing() {
146        let c = parse(r#"{"other":"x"}"#);
147        assert_eq!(extract_groups(&c, "groups"), None);
148    }
149
150    #[mz_ore::test]
151    fn test_groups_empty_array() {
152        let c = parse(r#"{"groups":[]}"#);
153        assert_eq!(extract_groups(&c, "groups"), Some(vec![]));
154    }
155
156    #[mz_ore::test]
157    fn test_groups_empty_string() {
158        let c = parse(r#"{"groups":""}"#);
159        assert_eq!(extract_groups(&c, "groups"), Some(vec![]));
160    }
161
162    #[mz_ore::test]
163    fn test_groups_mixed_case_preserved() {
164        let c = parse(r#"{"groups":["Analytics","PLATFORM_ENG","analytics"]}"#);
165        assert_eq!(
166            extract_groups(&c, "groups"),
167            Some(vec![
168                "Analytics".to_string(),
169                "PLATFORM_ENG".to_string(),
170                "analytics".to_string()
171            ])
172        );
173    }
174
175    #[mz_ore::test]
176    fn test_groups_non_string_filtered() {
177        let c = parse(r#"{"groups":["valid",123,true,"also_valid"]}"#);
178        assert_eq!(
179            extract_groups(&c, "groups"),
180            Some(vec!["also_valid".to_string(), "valid".to_string()])
181        );
182    }
183
184    #[mz_ore::test]
185    fn test_groups_array_all_non_strings() {
186        let c = parse(r#"{"groups":[1,2,true,null]}"#);
187        assert_eq!(extract_groups(&c, "groups"), Some(vec![]));
188    }
189
190    #[mz_ore::test]
191    fn test_groups_non_array_non_string() {
192        let c = parse(r#"{"groups":42}"#);
193        assert_eq!(extract_groups(&c, "groups"), None);
194    }
195
196    #[mz_ore::test]
197    fn test_groups_null() {
198        let c = parse(r#"{"groups":null}"#);
199        assert_eq!(extract_groups(&c, "groups"), None);
200    }
201
202    #[mz_ore::test]
203    fn test_groups_object_claim() {
204        let c = parse(r#"{"groups":{"team":"eng"}}"#);
205        assert_eq!(extract_groups(&c, "groups"), None);
206    }
207
208    #[mz_ore::test]
209    fn test_groups_dedup_sorted() {
210        let c = parse(r#"{"groups":["zebra","alpha","alpha","beta"]}"#);
211        assert_eq!(
212            extract_groups(&c, "groups"),
213            Some(vec![
214                "alpha".to_string(),
215                "beta".to_string(),
216                "zebra".to_string()
217            ])
218        );
219    }
220
221    #[mz_ore::test]
222    fn test_groups_nested_path_array() {
223        let c = parse(r#"{"customClaims":{"groups":["analytics","platform_eng"]}}"#);
224        assert_eq!(
225            extract_groups(&c, "customClaims.groups"),
226            Some(vec!["analytics".to_string(), "platform_eng".to_string()])
227        );
228    }
229
230    #[mz_ore::test]
231    fn test_groups_nested_path_deep() {
232        let c = parse(r#"{"a":{"b":{"c":["eng"]}}}"#);
233        assert_eq!(extract_groups(&c, "a.b.c"), Some(vec!["eng".to_string()]));
234    }
235
236    #[mz_ore::test]
237    fn test_groups_nested_path_missing() {
238        let c = parse(r#"{"customClaims":{}}"#);
239        assert_eq!(extract_groups(&c, "customClaims.groups"), None);
240    }
241
242    #[mz_ore::test]
243    fn test_groups_nested_intermediate_not_object() {
244        let c = parse(r#"{"customClaims":"not_an_object"}"#);
245        assert_eq!(extract_groups(&c, "customClaims.groups"), None);
246    }
247
248    #[mz_ore::test]
249    fn test_groups_path_leading_dot() {
250        let c = parse(r#"{"groups":["eng"]}"#);
251        assert_eq!(extract_groups(&c, ".groups"), None);
252    }
253
254    #[mz_ore::test]
255    fn test_groups_path_trailing_dot() {
256        let c = parse(r#"{"customClaims":{"groups":["eng"]}}"#);
257        assert_eq!(extract_groups(&c, "customClaims.groups."), None);
258    }
259
260    #[mz_ore::test]
261    fn test_groups_path_double_dot() {
262        let c = parse(r#"{"customClaims":{"groups":["eng"]}}"#);
263        assert_eq!(extract_groups(&c, "customClaims..groups"), None);
264    }
265
266    #[mz_ore::test]
267    fn test_groups_path_empty() {
268        let c = parse(r#"{"groups":["eng"]}"#);
269        assert_eq!(extract_groups(&c, ""), None);
270    }
271
272    #[mz_ore::test]
273    fn test_groups_boolean_claim() {
274        let c = parse(r#"{"groups":true}"#);
275        assert_eq!(extract_groups(&c, "groups"), None);
276    }
277
278    #[mz_ore::test]
279    fn test_groups_float_claim() {
280        let c = parse(r#"{"groups":3.14}"#);
281        assert_eq!(extract_groups(&c, "groups"), None);
282    }
283
284    #[mz_ore::test]
285    fn test_groups_array_with_nested_arrays() {
286        let c = parse(r#"{"groups":[["nested"],"valid"]}"#);
287        assert_eq!(
288            extract_groups(&c, "groups"),
289            Some(vec!["valid".to_string()])
290        );
291    }
292
293    #[mz_ore::test]
294    fn test_groups_array_with_null_elements() {
295        let c = parse(r#"{"groups":["eng",null,"ops",null]}"#);
296        assert_eq!(
297            extract_groups(&c, "groups"),
298            Some(vec!["eng".to_string(), "ops".to_string()])
299        );
300    }
301
302    #[mz_ore::test]
303    fn test_groups_array_with_object_elements() {
304        let c = parse(r#"{"groups":["eng",{"name":"ops"},"analytics"]}"#);
305        assert_eq!(
306            extract_groups(&c, "groups"),
307            Some(vec!["analytics".to_string(), "eng".to_string()])
308        );
309    }
310
311    #[mz_ore::test]
312    fn test_groups_array_with_empty_strings() {
313        let c = parse(r#"{"groups":["","eng",""]}"#);
314        assert_eq!(extract_groups(&c, "groups"), Some(vec!["eng".to_string()]));
315    }
316
317    #[mz_ore::test]
318    fn test_groups_whitespace_only_single_string() {
319        let c = parse(r#"{"groups":"  "}"#);
320        assert_eq!(extract_groups(&c, "groups"), Some(vec!["  ".to_string()]));
321    }
322
323    #[mz_ore::test]
324    fn test_groups_whitespace_names() {
325        let c = parse(r#"{"groups":["  spaces  ","eng"]}"#);
326        assert_eq!(
327            extract_groups(&c, "groups"),
328            Some(vec!["  spaces  ".to_string(), "eng".to_string()])
329        );
330    }
331
332    #[mz_ore::test]
333    fn test_groups_unicode_names() {
334        let c = parse(r#"{"groups":["Développeurs","INGÉNIEURS"]}"#);
335        assert_eq!(
336            extract_groups(&c, "groups"),
337            Some(vec!["Développeurs".to_string(), "INGÉNIEURS".to_string()])
338        );
339    }
340
341    #[mz_ore::test]
342    fn test_groups_special_characters() {
343        let c = parse(r#"{"groups":["team-platform.eng","org_data-science","role/admin"]}"#);
344        assert_eq!(
345            extract_groups(&c, "groups"),
346            Some(vec![
347                "org_data-science".to_string(),
348                "role/admin".to_string(),
349                "team-platform.eng".to_string(),
350            ])
351        );
352    }
353
354    #[mz_ore::test]
355    fn test_groups_no_case_folding() {
356        let c = parse(r#"{"groups":["Eng","eng","ENG","eNg"]}"#);
357        assert_eq!(
358            extract_groups(&c, "groups"),
359            Some(vec![
360                "ENG".to_string(),
361                "Eng".to_string(),
362                "eNg".to_string(),
363                "eng".to_string(),
364            ])
365        );
366    }
367
368    #[mz_ore::test]
369    fn test_groups_large_array() {
370        let items: Vec<String> = (0..100).map(|i| format!("\"group_{}\"", i)).collect();
371        let c = parse(&format!(r#"{{"groups":[{}]}}"#, items.join(",")));
372        let result = extract_groups(&c, "groups").unwrap();
373        assert_eq!(result.len(), 100);
374        assert_eq!(result[0], "group_0");
375        assert_eq!(result[99], "group_99");
376    }
377
378    #[mz_ore::test]
379    fn test_groups_nested_path_single_string() {
380        let c = parse(r#"{"customClaims":{"groups":"analytics"}}"#);
381        assert_eq!(
382            extract_groups(&c, "customClaims.groups"),
383            Some(vec!["analytics".to_string()])
384        );
385    }
386
387    #[mz_ore::test]
388    fn test_groups_nested_path_terminal_not_array_or_string() {
389        let c = parse(r#"{"customClaims":{"groups":42}}"#);
390        assert_eq!(extract_groups(&c, "customClaims.groups"), None);
391    }
392}