Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::{BTreeMap, BTreeSet};
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::secure::{Zeroize, ZeroizeOnDrop};
24use mz_ore::soft_panic_or_log;
25use mz_pgwire_common::{ErrorResponse, Severity};
26use reqwest::Client as HttpClient;
27use serde::{Deserialize, Deserializer, Serialize};
28use tokio_postgres::error::SqlState;
29
30use tracing::{debug, warn};
31use url::Url;
32/// Errors that can occur during OIDC authentication.
33#[derive(Debug)]
34pub enum OidcError {
35    MissingIssuer,
36    /// Failed to parse OIDC configuration URL.
37    InvalidIssuerUrl(String),
38    AudienceParseError,
39    /// Failed to fetch from the identity provider.
40    FetchFromProviderFailed {
41        url: String,
42        error_message: String,
43    },
44    /// The key ID is missing in the token header.
45    MissingKid,
46    /// No matching key found in JWKS.
47    NoMatchingKey {
48        /// Key ID that was found in the JWT header.
49        key_id: String,
50    },
51    /// Configured authentication claim is not found in the JWT.
52    NoMatchingAuthenticationClaim {
53        authentication_claim: String,
54    },
55    /// JWT validation error
56    Jwt,
57    WrongUser,
58    InvalidAudience {
59        expected_audiences: Vec<String>,
60    },
61    InvalidIssuer {
62        expected_issuer: String,
63    },
64    ExpiredSignature,
65    /// The role exists but does not have the LOGIN attribute.
66    NonLogin,
67    LoginCheckError,
68}
69
70impl std::fmt::Display for OidcError {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
74            OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
75            OidcError::AudienceParseError => {
76                write!(f, "failed to parse OIDC_AUDIENCE system variable")
77            }
78            OidcError::FetchFromProviderFailed { .. } => {
79                write!(f, "failed to fetch OIDC provider configuration")
80            }
81            OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
82            OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
83            OidcError::NoMatchingAuthenticationClaim { .. } => {
84                write!(f, "no matching authentication claim found in the JWT")
85            }
86            OidcError::Jwt => write!(f, "failed to validate JWT"),
87            OidcError::WrongUser => write!(f, "wrong user"),
88            OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
89            OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
90            OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
91            OidcError::NonLogin => write!(f, "role is not allowed to login"),
92            OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
93        }
94    }
95}
96
97impl std::error::Error for OidcError {}
98
99impl OidcError {
100    pub fn code(&self) -> SqlState {
101        SqlState::INVALID_AUTHORIZATION_SPECIFICATION
102    }
103
104    pub fn detail(&self) -> Option<String> {
105        match self {
106            OidcError::InvalidIssuerUrl(issuer) => {
107                Some(format!("Could not parse \"{issuer}\" as a URL."))
108            }
109            OidcError::FetchFromProviderFailed { url, error_message } => {
110                Some(format!("Fetching \"{url}\" failed. {error_message}"))
111            }
112            OidcError::NoMatchingKey { key_id } => {
113                Some(format!("JWT key ID \"{key_id}\" was not found."))
114            }
115            OidcError::InvalidAudience { expected_audiences } => Some(format!(
116                "Expected one of audiences {:?} in the JWT.",
117                expected_audiences,
118            )),
119            OidcError::InvalidIssuer { expected_issuer } => {
120                Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
121            }
122            OidcError::NoMatchingAuthenticationClaim {
123                authentication_claim,
124            } => Some(format!(
125                "Expected authentication claim \"{authentication_claim}\" in the JWT.",
126            )),
127            OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
128            _ => None,
129        }
130    }
131
132    pub fn hint(&self) -> Option<String> {
133        match self {
134            OidcError::MissingIssuer => {
135                Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
136            }
137            _ => None,
138        }
139    }
140
141    pub fn into_response(self) -> ErrorResponse {
142        ErrorResponse {
143            severity: Severity::Fatal,
144            code: self.code(),
145            message: self.to_string(),
146            detail: self.detail(),
147            hint: self.hint(),
148            position: None,
149        }
150    }
151}
152
153fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
154where
155    D: Deserializer<'de>,
156{
157    #[derive(Deserialize)]
158    #[serde(untagged)]
159    enum StringOrVec {
160        String(String),
161        Vec(Vec<String>),
162    }
163
164    match StringOrVec::deserialize(deserializer)? {
165        StringOrVec::String(s) => Ok(vec![s]),
166        StringOrVec::Vec(v) => Ok(v),
167    }
168}
169/// Claims extracted from a validated JWT.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct OidcClaims {
172    /// Issuer.
173    pub iss: String,
174    /// Expiration time (Unix timestamp).
175    pub exp: i64,
176    /// Issued at time (Unix timestamp).
177    #[serde(default)]
178    pub iat: Option<i64>,
179    /// Audience claim (can be single string or array in JWT).
180    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
181    pub aud: Vec<String>,
182    /// Additional claims from the JWT, captured for flexible username extraction.
183    #[serde(flatten)]
184    pub unknown_claims: BTreeMap<String, serde_json::Value>,
185}
186
187impl Zeroize for OidcClaims {
188    fn zeroize(&mut self) {
189        self.iss.zeroize();
190        self.exp.zeroize();
191        self.iat.zeroize();
192        for s in &mut self.aud {
193            s.zeroize();
194        }
195        self.aud.clear();
196        // serde_json::Value doesn't implement Zeroize; drain entries and
197        // zeroize keys/values before the backing allocations are freed.
198        while let Some((mut k, mut v)) = self.unknown_claims.pop_first() {
199            k.zeroize();
200            zeroize_json_value(&mut v);
201        }
202    }
203}
204
205impl Drop for OidcClaims {
206    fn drop(&mut self) {
207        self.zeroize();
208    }
209}
210
211/// `OidcClaims` implements both `Zeroize` and `Drop` (which calls `zeroize()`),
212/// satisfying the `ZeroizeOnDrop` contract.
213impl ZeroizeOnDrop for OidcClaims {}
214
215fn zeroize_json_value(v: &mut serde_json::Value) {
216    use serde_json::Value;
217    match v {
218        Value::String(s) => s.zeroize(),
219        Value::Array(a) => {
220            for item in a.iter_mut() {
221                zeroize_json_value(item);
222            }
223            a.clear();
224        }
225        Value::Object(map) => {
226            let taken = std::mem::take(map);
227            for (mut k, mut nested) in taken {
228                k.zeroize();
229                zeroize_json_value(&mut nested);
230            }
231        }
232        Value::Number(_) => {
233            *v = Value::Number(serde_json::Number::from(0u8));
234        }
235        Value::Bool(b) => *b = false,
236        Value::Null => {}
237    }
238}
239
240impl OidcClaims {
241    /// Extract the username from the OIDC claims.
242    fn user(&self, authentication_claim: &str) -> Option<&str> {
243        self.unknown_claims
244            .get(authentication_claim)
245            .and_then(|value| value.as_str())
246    }
247
248    /// Extracts group names from the specified JWT claim for group-to-role sync.
249    ///
250    /// Returns `None` if the claim is absent (skip sync, preserve current state),
251    /// `Some(vec![])` if the claim is present but empty (revoke all sync-granted
252    /// roles), or `Some(vec![...])` with normalized (lowercased, deduplicated,
253    /// sorted) group names.
254    ///
255    /// Accepts arrays of strings, single strings, or mixed arrays (non-string
256    /// elements are filtered out). Other JSON types are treated as absent.
257    pub fn groups(&self, claim_name: &str) -> Option<Vec<String>> {
258        let value = self.unknown_claims.get(claim_name)?;
259
260        let raw_groups: Vec<String> = match value {
261            serde_json::Value::Array(arr) => arr
262                .iter()
263                .filter_map(|v| v.as_str().map(String::from))
264                .collect(),
265            serde_json::Value::String(s) => {
266                if s.is_empty() {
267                    vec![]
268                } else {
269                    vec![s.clone()]
270                }
271            }
272            _ => {
273                warn!(
274                    claim_name,
275                    "OIDC group claim has unexpected type; skipping group sync"
276                );
277                return None;
278            }
279        };
280
281        let normalized: Vec<String> = raw_groups
282            .into_iter()
283            .map(|g| g.trim().to_lowercase())
284            .filter(|g| !g.is_empty())
285            .collect::<BTreeSet<_>>()
286            .into_iter()
287            .collect();
288
289        Some(normalized)
290    }
291}
292
293#[derive(Zeroize, ZeroizeOnDrop)]
294pub struct ValidatedClaims {
295    pub user: String,
296    // Prevent construction outside of `GenericOidcAuthenticator::validate_token`.
297    _private: (),
298}
299
300/// Wrapper around `jsonwebtoken::DecodingKey` with a redacted `Debug` impl.
301#[derive(Clone)]
302struct OidcDecodingKey(jsonwebtoken::DecodingKey);
303
304impl std::fmt::Debug for OidcDecodingKey {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        f.debug_struct("OidcDecodingKey")
307            .field("key", &"<redacted>")
308            .finish()
309    }
310}
311
312/// OIDC Authenticator that validates JWTs using JWKS.
313///
314/// This implementation pre-fetches JWKS at construction time for synchronous
315/// token validation.
316#[derive(Clone, Debug)]
317pub struct GenericOidcAuthenticator {
318    inner: Arc<GenericOidcAuthenticatorInner>,
319}
320
321/// OpenID Connect Discovery document.
322/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
323#[derive(Debug, Deserialize)]
324struct OpenIdConfiguration {
325    /// URL of the JWKS endpoint.
326    jwks_uri: String,
327}
328
329#[derive(Debug)]
330pub struct GenericOidcAuthenticatorInner {
331    adapter_client: AdapterClient,
332    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
333    http_client: HttpClient,
334}
335
336impl GenericOidcAuthenticator {
337    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
338    ///
339    /// The OIDC issuer and audience are fetched from system variables on each
340    /// authentication attempt.
341    pub fn new(adapter_client: AdapterClient) -> Self {
342        let http_client = HttpClient::new();
343
344        Self {
345            inner: Arc::new(GenericOidcAuthenticatorInner {
346                adapter_client,
347                decoding_keys: Mutex::new(BTreeMap::new()),
348                http_client,
349            }),
350        }
351    }
352}
353
354impl GenericOidcAuthenticatorInner {
355    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
356        let openid_config_url = build_openid_config_url(issuer)?;
357
358        let openid_config_url_str = openid_config_url.to_string();
359
360        // Fetch OpenID configuration to get the JWKS URI
361        let response = self
362            .http_client
363            .get(openid_config_url)
364            .timeout(Duration::from_secs(10))
365            .send()
366            .await
367            .map_err(|e| OidcError::FetchFromProviderFailed {
368                url: openid_config_url_str.clone(),
369                error_message: e.to_string(),
370            })?;
371
372        if !response.status().is_success() {
373            return Err(OidcError::FetchFromProviderFailed {
374                url: openid_config_url_str.clone(),
375                error_message: response
376                    .error_for_status()
377                    .err()
378                    .map(|e| e.to_string())
379                    .unwrap_or_else(|| "Unknown error".to_string()),
380            });
381        }
382
383        let openid_config: OpenIdConfiguration =
384            response
385                .json()
386                .await
387                .map_err(|e| OidcError::FetchFromProviderFailed {
388                    url: openid_config_url_str,
389                    error_message: e.to_string(),
390                })?;
391
392        Ok(openid_config.jwks_uri)
393    }
394
395    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
396    async fn fetch_jwks(
397        &self,
398        issuer: &str,
399    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
400        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
401        let response = self
402            .http_client
403            .get(&jwks_uri)
404            .timeout(Duration::from_secs(10))
405            .send()
406            .await
407            .map_err(|e| OidcError::FetchFromProviderFailed {
408                url: jwks_uri.clone(),
409                error_message: e.to_string(),
410            })?;
411
412        if !response.status().is_success() {
413            return Err(OidcError::FetchFromProviderFailed {
414                url: jwks_uri.clone(),
415                error_message: response
416                    .error_for_status()
417                    .err()
418                    .map(|e| e.to_string())
419                    .unwrap_or_else(|| "Unknown error".to_string()),
420            });
421        }
422
423        let jwks: JwkSet =
424            response
425                .json()
426                .await
427                .map_err(|e| OidcError::FetchFromProviderFailed {
428                    url: jwks_uri.clone(),
429                    error_message: e.to_string(),
430                })?;
431
432        let mut keys = BTreeMap::new();
433
434        for jwk in jwks.keys {
435            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
436                Ok(key) => {
437                    if let Some(kid) = jwk.common.key_id {
438                        keys.insert(kid, OidcDecodingKey(key));
439                    }
440                }
441                Err(e) => {
442                    warn!("Failed to parse JWK: {}", e);
443                }
444            }
445        }
446
447        Ok(keys)
448    }
449
450    /// Find a decoding key matching the given key ID.
451    /// If the key is not found, fetch the JWKS and cache the keys.
452    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
453        // Get the cached decoding key.
454        {
455            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
456
457            if let Some(key) = decoding_keys.get(kid) {
458                return Ok(key.clone());
459            }
460        }
461
462        // If not found, fetch the JWKS and cache the keys.
463        let new_decoding_keys = self.fetch_jwks(issuer).await?;
464
465        let decoding_key = new_decoding_keys.get(kid).cloned();
466
467        {
468            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
469            *decoding_keys = new_decoding_keys;
470        }
471
472        if let Some(key) = decoding_key {
473            return Ok(key);
474        }
475
476        {
477            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
478            debug!(
479                "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
480            );
481            Err(OidcError::NoMatchingKey {
482                key_id: kid.to_string(),
483            })
484        }
485    }
486
487    pub async fn validate_token(
488        &self,
489        token: &str,
490        expected_user: Option<&str>,
491    ) -> Result<ValidatedClaims, OidcError> {
492        // Fetch current OIDC configuration from system variables
493        let system_vars = self.adapter_client.get_system_vars().await;
494        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
495            return Err(OidcError::MissingIssuer);
496        };
497
498        let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
499
500        let expected_audiences: Vec<String> = {
501            let audiences: Vec<String> =
502                serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
503                    .map_err(|_| OidcError::AudienceParseError)?;
504
505            if audiences.is_empty() {
506                warn!(
507                    "Audience validation skipped. It is discouraged \
508                    to skip audience validation since it allows anyone \
509                    with a JWT issued by the same issuer to authenticate."
510                );
511            }
512            audiences
513        };
514
515        // Decode header to get key ID (kid) and the
516        // decoding algorithm
517        let header = jsonwebtoken::decode_header(token).map_err(|e| {
518            debug!("Failed to decode JWT header: {:?}", e);
519            OidcError::Jwt
520        })?;
521
522        let kid = header.kid.ok_or(OidcError::MissingKid)?;
523        // Find the matching key from our set of cached keys. If not found,
524        // fetch the JWKS from the provider and cache the keys
525        let decoding_key = self.find_key(&kid, &issuer).await?;
526
527        // Set up audience and issuer validation
528        let mut validation = jsonwebtoken::Validation::new(header.alg);
529        validation.set_issuer(&[&issuer]);
530        if !expected_audiences.is_empty() {
531            validation.set_audience(&expected_audiences);
532        } else {
533            validation.validate_aud = false;
534        }
535
536        // Decode and validate the token
537        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
538            .map_err(|e| match e.kind() {
539                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
540                    if !expected_audiences.is_empty() {
541                        OidcError::InvalidAudience {
542                            expected_audiences
543                        }
544                    } else {
545                        soft_panic_or_log!(
546                            "received an audience validation error when audience validation is disabled"
547                        );
548                        OidcError::Jwt
549                    }
550                }
551                jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
552                    expected_issuer: issuer.clone(),
553                },
554                jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
555                _ => OidcError::Jwt,
556            })?;
557
558        let user = token_data.claims.user(&authentication_claim).ok_or(
559            OidcError::NoMatchingAuthenticationClaim {
560                authentication_claim,
561            },
562        )?;
563
564        // Optionally validate the expected user
565        if let Some(expected) = expected_user {
566            if user != expected {
567                return Err(OidcError::WrongUser);
568            }
569        }
570
571        Ok(ValidatedClaims {
572            user: user.to_string(),
573            _private: (),
574        })
575    }
576
577    /// Checks whether the role has the LOGIN attribute. This is needed otherwise
578    /// a user can authenticate with an OIDC token to a role that isn't recognized
579    /// as a user.
580    async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
581        match self.adapter_client.role_can_login(role_name).await {
582            Ok(()) => Ok(()),
583            Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
584                // Role will be auto-provisioned during startup; allow login.
585                Ok(())
586            }
587            Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
588                Err(OidcError::NonLogin)
589            }
590            Err(e) => {
591                warn!(?e, "unexpected error checking OIDC role login");
592                Err(OidcError::LoginCheckError)
593            }
594        }
595    }
596}
597
598impl GenericOidcAuthenticator {
599    pub async fn authenticate(
600        &self,
601        token: &str,
602        expected_user: Option<&str>,
603    ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
604        let validated_claims = self.inner.validate_token(token, expected_user).await?;
605        self.inner.check_role_login(&validated_claims.user).await?;
606        Ok((validated_claims, Authenticated))
607    }
608}
609
610fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
611    let mut openid_config_url =
612        Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
613    {
614        let mut segments = openid_config_url
615            .path_segments_mut()
616            .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
617        // Remove trailing slash if it exists
618        segments.pop_if_empty();
619        segments.push(".well-known");
620        segments.push("openid-configuration");
621    }
622    Ok(openid_config_url)
623}
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[mz_ore::test]
629    fn test_aud_single_string() {
630        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
631        let claims: OidcClaims = serde_json::from_str(json).unwrap();
632        assert_eq!(claims.aud, vec!["my-app"]);
633    }
634
635    #[mz_ore::test]
636    fn test_aud_array() {
637        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
638        let claims: OidcClaims = serde_json::from_str(json).unwrap();
639        assert_eq!(claims.aud, vec!["app1", "app2"]);
640    }
641
642    #[mz_ore::test]
643    fn test_user() {
644        let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
645        let claims: OidcClaims = serde_json::from_str(json).unwrap();
646        assert_eq!(claims.user("sub"), Some("user-123"));
647        assert_eq!(claims.user("email"), Some("alice@example.com"));
648        assert_eq!(claims.user("missing"), None);
649    }
650
651    #[mz_ore::test]
652    fn test_build_openid_config_url() {
653        let issuer = "https://dev-123456.okta.com/oauth2/default";
654        let url = build_openid_config_url(issuer).unwrap();
655        assert_eq!(
656            url.to_string(),
657            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
658        );
659    }
660
661    #[mz_ore::test]
662    fn test_build_openid_config_url_trailing_slash() {
663        let issuer = "https://dev-123456.okta.com/oauth2/default/";
664        let url = build_openid_config_url(issuer).unwrap();
665        assert_eq!(
666            url.to_string(),
667            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
668        );
669    }
670
671    #[mz_ore::test]
672    fn zeroize_clears_validated_claims() {
673        use mz_ore::secure::Zeroize;
674        let mut claims = ValidatedClaims {
675            user: "alice@example.com".to_string(),
676            _private: (),
677        };
678        claims.zeroize();
679        assert!(claims.user.is_empty());
680    }
681
682    #[mz_ore::test]
683    fn oidc_claims_implements_zeroize_on_drop() {
684        fn assert_zod<T: ZeroizeOnDrop>() {}
685        assert_zod::<OidcClaims>();
686        assert_zod::<ValidatedClaims>();
687    }
688
689    #[mz_ore::test]
690    fn zeroize_clears_oidc_claims() {
691        use mz_ore::secure::Zeroize;
692        let mut claims = OidcClaims {
693            iss: "https://issuer.example.com".to_string(),
694            exp: 1234567890,
695            iat: Some(1234567800),
696            aud: vec!["app1".to_string(), "app2".to_string()],
697            unknown_claims: BTreeMap::from([(
698                "email".to_string(),
699                serde_json::Value::String("alice@example.com".to_string()),
700            )]),
701        };
702        claims.zeroize();
703        assert!(claims.iss.is_empty());
704        assert_eq!(claims.exp, 0);
705        assert!(claims.iat.is_none());
706        assert!(claims.aud.is_empty());
707        assert!(claims.unknown_claims.is_empty());
708    }
709
710    #[mz_ore::test]
711    fn test_groups_array() {
712        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["analytics","platform_eng"]}"#;
713        let claims: OidcClaims = serde_json::from_str(json).unwrap();
714        assert_eq!(
715            claims.groups("groups"),
716            Some(vec!["analytics".to_string(), "platform_eng".to_string()])
717        );
718    }
719
720    #[mz_ore::test]
721    fn test_groups_single_string() {
722        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":"analytics"}"#;
723        let claims: OidcClaims = serde_json::from_str(json).unwrap();
724        assert_eq!(claims.groups("groups"), Some(vec!["analytics".to_string()]));
725    }
726
727    #[mz_ore::test]
728    fn test_groups_missing() {
729        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app"}"#;
730        let claims: OidcClaims = serde_json::from_str(json).unwrap();
731        assert_eq!(claims.groups("groups"), None);
732    }
733
734    #[mz_ore::test]
735    fn test_groups_empty_array() {
736        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[]}"#;
737        let claims: OidcClaims = serde_json::from_str(json).unwrap();
738        assert_eq!(claims.groups("groups"), Some(vec![]));
739    }
740
741    #[mz_ore::test]
742    fn test_groups_mixed_case() {
743        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["Analytics","PLATFORM_ENG","analytics"]}"#;
744        let claims: OidcClaims = serde_json::from_str(json).unwrap();
745        assert_eq!(
746            claims.groups("groups"),
747            Some(vec!["analytics".to_string(), "platform_eng".to_string()])
748        );
749    }
750
751    #[mz_ore::test]
752    fn test_groups_custom_claim_name() {
753        let json =
754            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","roles":["admin","viewer"]}"#;
755        let claims: OidcClaims = serde_json::from_str(json).unwrap();
756        assert_eq!(
757            claims.groups("roles"),
758            Some(vec!["admin".to_string(), "viewer".to_string()])
759        );
760        assert_eq!(claims.groups("groups"), None);
761    }
762
763    #[mz_ore::test]
764    fn test_groups_non_string_values_in_array() {
765        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["valid",123,true,"also_valid"]}"#;
766        let claims: OidcClaims = serde_json::from_str(json).unwrap();
767        assert_eq!(
768            claims.groups("groups"),
769            Some(vec!["also_valid".to_string(), "valid".to_string()])
770        );
771    }
772
773    #[mz_ore::test]
774    fn test_groups_non_array_non_string() {
775        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":42}"#;
776        let claims: OidcClaims = serde_json::from_str(json).unwrap();
777        assert_eq!(claims.groups("groups"), None);
778    }
779
780    #[mz_ore::test]
781    fn test_groups_empty_string() {
782        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":""}"#;
783        let claims: OidcClaims = serde_json::from_str(json).unwrap();
784        assert_eq!(claims.groups("groups"), Some(vec![]));
785    }
786
787    #[mz_ore::test]
788    fn test_groups_null_claim() {
789        // Explicit null → treated as absent (not a valid group representation)
790        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":null}"#;
791        let claims: OidcClaims = serde_json::from_str(json).unwrap();
792        assert_eq!(claims.groups("groups"), None);
793    }
794
795    #[mz_ore::test]
796    fn test_groups_boolean_claim() {
797        // Boolean value → not array or string, treated as absent
798        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":true}"#;
799        let claims: OidcClaims = serde_json::from_str(json).unwrap();
800        assert_eq!(claims.groups("groups"), None);
801    }
802
803    #[mz_ore::test]
804    fn test_groups_object_claim() {
805        // JSON object → not array or string, treated as absent
806        let json =
807            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":{"team":"eng"}}"#;
808        let claims: OidcClaims = serde_json::from_str(json).unwrap();
809        assert_eq!(claims.groups("groups"), None);
810    }
811
812    #[mz_ore::test]
813    fn test_groups_array_all_non_strings() {
814        // Array with zero valid string elements → Some(vec![]), not None
815        // (the claim *is* present, it just has no usable group names)
816        let json =
817            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[1,2,true,null]}"#;
818        let claims: OidcClaims = serde_json::from_str(json).unwrap();
819        assert_eq!(claims.groups("groups"), Some(vec![]));
820    }
821
822    #[mz_ore::test]
823    fn test_groups_array_with_nested_arrays() {
824        // Nested arrays are not strings, so they're filtered out
825        let json =
826            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[["nested"],"valid"]}"#;
827        let claims: OidcClaims = serde_json::from_str(json).unwrap();
828        assert_eq!(claims.groups("groups"), Some(vec!["valid".to_string()]));
829    }
830
831    #[mz_ore::test]
832    fn test_groups_array_with_empty_strings() {
833        // Empty strings are not valid role names and are filtered out,
834        // consistent with the single-string case where "" → Some(vec![]).
835        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["","eng",""]}"#;
836        let claims: OidcClaims = serde_json::from_str(json).unwrap();
837        assert_eq!(claims.groups("groups"), Some(vec!["eng".to_string()]));
838    }
839
840    #[mz_ore::test]
841    fn test_groups_whitespace_only_single_string() {
842        // Whitespace-only string trims to empty and is filtered out.
843        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":"  "}"#;
844        let claims: OidcClaims = serde_json::from_str(json).unwrap();
845        assert_eq!(claims.groups("groups"), Some(vec![]));
846    }
847
848    #[mz_ore::test]
849    fn test_groups_whitespace_names() {
850        // Leading/trailing whitespace is trimmed from group names.
851        let json =
852            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["  spaces  ","eng"]}"#;
853        let claims: OidcClaims = serde_json::from_str(json).unwrap();
854        assert_eq!(
855            claims.groups("groups"),
856            Some(vec!["eng".to_string(), "spaces".to_string()])
857        );
858    }
859
860    #[mz_ore::test]
861    fn test_groups_unicode_names() {
862        // Unicode group names should be lowercased correctly
863        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["Développeurs","INGÉNIEURS"]}"#;
864        let claims: OidcClaims = serde_json::from_str(json).unwrap();
865        assert_eq!(
866            claims.groups("groups"),
867            Some(vec!["développeurs".to_string(), "ingénieurs".to_string()])
868        );
869    }
870
871    #[mz_ore::test]
872    fn test_groups_special_characters() {
873        // Group names with special characters (hyphens, underscores, dots)
874        // are common in enterprise IdPs like Azure AD / Okta
875        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["team-platform.eng","org_data-science","role/admin"]}"#;
876        let claims: OidcClaims = serde_json::from_str(json).unwrap();
877        assert_eq!(
878            claims.groups("groups"),
879            Some(vec![
880                "org_data-science".to_string(),
881                "role/admin".to_string(),
882                "team-platform.eng".to_string(),
883            ])
884        );
885    }
886
887    #[mz_ore::test]
888    fn test_groups_case_insensitive_dedup() {
889        // "Eng" and "eng" and "ENG" should all collapse to one "eng"
890        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["Eng","eng","ENG","eNg"]}"#;
891        let claims: OidcClaims = serde_json::from_str(json).unwrap();
892        assert_eq!(claims.groups("groups"), Some(vec!["eng".to_string()]));
893    }
894
895    #[mz_ore::test]
896    fn test_groups_large_array() {
897        // Verify we handle a reasonably large group list without issues
898        let groups: Vec<String> = (0..100).map(|i| format!("\"group_{}\"", i)).collect();
899        let json = format!(
900            r#"{{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[{}]}}"#,
901            groups.join(",")
902        );
903        let claims: OidcClaims = serde_json::from_str(&json).unwrap();
904        let result = claims.groups("groups").unwrap();
905        assert_eq!(result.len(), 100);
906        // BTreeSet ensures sorted order
907        assert_eq!(result[0], "group_0");
908        assert_eq!(result[99], "group_99");
909    }
910
911    #[mz_ore::test]
912    fn test_groups_float_claim() {
913        // Float value → not array or string, treated as absent
914        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":3.14}"#;
915        let claims: OidcClaims = serde_json::from_str(json).unwrap();
916        assert_eq!(claims.groups("groups"), None);
917    }
918
919    #[mz_ore::test]
920    fn test_groups_array_with_null_elements() {
921        // Null elements in array are not strings, filtered out
922        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["eng",null,"ops",null]}"#;
923        let claims: OidcClaims = serde_json::from_str(json).unwrap();
924        assert_eq!(
925            claims.groups("groups"),
926            Some(vec!["eng".to_string(), "ops".to_string()])
927        );
928    }
929
930    #[mz_ore::test]
931    fn test_groups_array_with_object_elements() {
932        // Object elements in array are not strings, filtered out
933        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["eng",{"name":"ops"},"analytics"]}"#;
934        let claims: OidcClaims = serde_json::from_str(json).unwrap();
935        assert_eq!(
936            claims.groups("groups"),
937            Some(vec!["analytics".to_string(), "eng".to_string()])
938        );
939    }
940
941    #[mz_ore::test]
942    fn test_groups_sorted_output() {
943        // Verify output is sorted alphabetically regardless of input order
944        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["zebra","alpha","mango","beta"]}"#;
945        let claims: OidcClaims = serde_json::from_str(json).unwrap();
946        assert_eq!(
947            claims.groups("groups"),
948            Some(vec![
949                "alpha".to_string(),
950                "beta".to_string(),
951                "mango".to_string(),
952                "zebra".to_string(),
953            ])
954        );
955    }
956}