Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::{BTreeMap, BTreeSet};
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{
22    OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_GROUP_CLAIM, OIDC_ISSUER,
23};
24use mz_auth::Authenticated;
25use mz_ore::secure::{Zeroize, ZeroizeOnDrop};
26use mz_ore::soft_panic_or_log;
27use mz_pgwire_common::{ErrorResponse, Severity};
28use reqwest::Client as HttpClient;
29use serde::{Deserialize, Deserializer, Serialize};
30use tokio_postgres::error::SqlState;
31
32use tracing::{debug, warn};
33use url::Url;
34/// Errors that can occur during OIDC authentication.
35#[derive(Debug)]
36pub enum OidcError {
37    MissingIssuer,
38    /// Failed to parse OIDC configuration URL.
39    InvalidIssuerUrl(String),
40    AudienceParseError,
41    /// Failed to fetch from the identity provider.
42    FetchFromProviderFailed {
43        url: String,
44        error_message: String,
45    },
46    /// The key ID is missing in the token header.
47    MissingKid,
48    /// No matching key found in JWKS.
49    NoMatchingKey {
50        /// Key ID that was found in the JWT header.
51        key_id: String,
52    },
53    /// Configured authentication claim is not found in the JWT.
54    NoMatchingAuthenticationClaim {
55        authentication_claim: String,
56    },
57    /// JWT validation error
58    Jwt,
59    WrongUser,
60    InvalidAudience {
61        expected_audiences: Vec<String>,
62    },
63    InvalidIssuer {
64        expected_issuer: String,
65    },
66    ExpiredSignature,
67    /// The role exists but does not have the LOGIN attribute.
68    NonLogin,
69    LoginCheckError,
70}
71
72impl std::fmt::Display for OidcError {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        match self {
75            OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
76            OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
77            OidcError::AudienceParseError => {
78                write!(f, "failed to parse OIDC_AUDIENCE system variable")
79            }
80            OidcError::FetchFromProviderFailed { .. } => {
81                write!(f, "failed to fetch OIDC provider configuration")
82            }
83            OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
84            OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
85            OidcError::NoMatchingAuthenticationClaim { .. } => {
86                write!(f, "no matching authentication claim found in the JWT")
87            }
88            OidcError::Jwt => write!(f, "failed to validate JWT"),
89            OidcError::WrongUser => write!(f, "wrong user"),
90            OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
91            OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
92            OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
93            OidcError::NonLogin => write!(f, "role is not allowed to login"),
94            OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
95        }
96    }
97}
98
99impl std::error::Error for OidcError {}
100
101impl OidcError {
102    pub fn code(&self) -> SqlState {
103        SqlState::INVALID_AUTHORIZATION_SPECIFICATION
104    }
105
106    pub fn detail(&self) -> Option<String> {
107        match self {
108            OidcError::InvalidIssuerUrl(issuer) => {
109                Some(format!("Could not parse \"{issuer}\" as a URL."))
110            }
111            OidcError::FetchFromProviderFailed { url, error_message } => {
112                Some(format!("Fetching \"{url}\" failed. {error_message}"))
113            }
114            OidcError::NoMatchingKey { key_id } => {
115                Some(format!("JWT key ID \"{key_id}\" was not found."))
116            }
117            OidcError::InvalidAudience { expected_audiences } => Some(format!(
118                "Expected one of audiences {:?} in the JWT.",
119                expected_audiences,
120            )),
121            OidcError::InvalidIssuer { expected_issuer } => {
122                Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
123            }
124            OidcError::NoMatchingAuthenticationClaim {
125                authentication_claim,
126            } => Some(format!(
127                "Expected authentication claim \"{authentication_claim}\" in the JWT.",
128            )),
129            OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
130            _ => None,
131        }
132    }
133
134    pub fn hint(&self) -> Option<String> {
135        match self {
136            OidcError::MissingIssuer => {
137                Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
138            }
139            _ => None,
140        }
141    }
142
143    pub fn into_response(self) -> ErrorResponse {
144        ErrorResponse {
145            severity: Severity::Fatal,
146            code: self.code(),
147            message: self.to_string(),
148            detail: self.detail(),
149            hint: self.hint(),
150            position: None,
151        }
152    }
153}
154
155fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
156where
157    D: Deserializer<'de>,
158{
159    #[derive(Deserialize)]
160    #[serde(untagged)]
161    enum StringOrVec {
162        String(String),
163        Vec(Vec<String>),
164    }
165
166    match StringOrVec::deserialize(deserializer)? {
167        StringOrVec::String(s) => Ok(vec![s]),
168        StringOrVec::Vec(v) => Ok(v),
169    }
170}
171/// Claims extracted from a validated JWT.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct OidcClaims {
174    /// Issuer.
175    pub iss: String,
176    /// Expiration time (Unix timestamp).
177    pub exp: i64,
178    /// Issued at time (Unix timestamp).
179    #[serde(default)]
180    pub iat: Option<i64>,
181    /// Audience claim (can be single string or array in JWT).
182    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
183    pub aud: Vec<String>,
184    /// Additional claims from the JWT, captured for flexible username extraction.
185    #[serde(flatten)]
186    pub unknown_claims: BTreeMap<String, serde_json::Value>,
187}
188
189impl Zeroize for OidcClaims {
190    fn zeroize(&mut self) {
191        self.iss.zeroize();
192        self.exp.zeroize();
193        self.iat.zeroize();
194        for s in &mut self.aud {
195            s.zeroize();
196        }
197        self.aud.clear();
198        // serde_json::Value doesn't implement Zeroize; drain entries and
199        // zeroize keys/values before the backing allocations are freed.
200        while let Some((mut k, mut v)) = self.unknown_claims.pop_first() {
201            k.zeroize();
202            zeroize_json_value(&mut v);
203        }
204    }
205}
206
207impl Drop for OidcClaims {
208    fn drop(&mut self) {
209        self.zeroize();
210    }
211}
212
213/// `OidcClaims` implements both `Zeroize` and `Drop` (which calls `zeroize()`),
214/// satisfying the `ZeroizeOnDrop` contract.
215impl ZeroizeOnDrop for OidcClaims {}
216
217fn zeroize_json_value(v: &mut serde_json::Value) {
218    use serde_json::Value;
219    match v {
220        Value::String(s) => s.zeroize(),
221        Value::Array(a) => {
222            for item in a.iter_mut() {
223                zeroize_json_value(item);
224            }
225            a.clear();
226        }
227        Value::Object(map) => {
228            let taken = std::mem::take(map);
229            for (mut k, mut nested) in taken {
230                k.zeroize();
231                zeroize_json_value(&mut nested);
232            }
233        }
234        Value::Number(_) => {
235            *v = Value::Number(serde_json::Number::from(0u8));
236        }
237        Value::Bool(b) => *b = false,
238        Value::Null => {}
239    }
240}
241
242impl OidcClaims {
243    /// Extract the username from the OIDC claims.
244    fn user(&self, authentication_claim: &str) -> Option<&str> {
245        self.unknown_claims
246            .get(authentication_claim)
247            .and_then(|value| value.as_str())
248    }
249
250    /// Extracts group names from the specified JWT claim for group-to-role sync.
251    ///
252    /// `claim_path` may be a bare claim name (e.g. `"groups"`) or a
253    /// dot-separated path into nested JSON objects (e.g.
254    /// `"customClaims.groups"`). Keys that contain a literal `.` are not
255    /// reachable; this is a known limitation matching CockroachDB's
256    /// `group_claim` semantics. Empty path segments (leading/trailing/double
257    /// dots, or an empty path) yield `None` and emit a `warn!`-level log so
258    /// misconfiguration is visible.
259    ///
260    /// Returns `None` if the claim is absent (skip sync, preserve current state),
261    /// `Some(vec![])` if the claim is present but empty (revoke all sync-granted
262    /// roles), or `Some(vec![...])` with deduplicated, sorted group names
263    /// (exact case preserved — matching against catalog role names is
264    /// case-sensitive).
265    ///
266    /// Accepts arrays of strings, single strings, or mixed arrays (non-string
267    /// elements are filtered out). Other JSON types are treated as absent.
268    pub fn groups(&self, claim_path: &str) -> Option<Vec<String>> {
269        let value = self.resolve_claim_path(claim_path)?;
270
271        let raw_groups: Vec<String> = match value {
272            serde_json::Value::Array(arr) => arr
273                .iter()
274                .filter_map(|v| v.as_str().map(String::from))
275                .collect(),
276            serde_json::Value::String(s) => {
277                if s.is_empty() {
278                    vec![]
279                } else {
280                    vec![s.clone()]
281                }
282            }
283            _ => {
284                warn!(
285                    claim_path,
286                    "OIDC group claim has unexpected type; skipping group sync"
287                );
288                return None;
289            }
290        };
291
292        let groups: Vec<String> = raw_groups
293            .into_iter()
294            .filter(|g| !g.is_empty())
295            .collect::<BTreeSet<_>>()
296            .into_iter()
297            .collect();
298
299        Some(groups)
300    }
301
302    /// Walks a dot-separated claim path into nested JSON objects. Returns
303    /// `None` if the path is empty, any segment is empty, an intermediate
304    /// segment is missing, or an intermediate segment resolves to a
305    /// non-object value.
306    fn resolve_claim_path(&self, claim_path: &str) -> Option<&serde_json::Value> {
307        let mut segments = claim_path.split('.');
308        let first = segments
309            .next()
310            .expect("str::split always yields at least one segment");
311        if first.is_empty() {
312            warn!(
313                claim_path,
314                "OIDC group claim path has an empty segment; skipping group sync"
315            );
316            return None;
317        }
318        let mut current = self.unknown_claims.get(first)?;
319        for segment in segments {
320            if segment.is_empty() {
321                warn!(
322                    claim_path,
323                    "OIDC group claim path has an empty segment; skipping group sync"
324                );
325                return None;
326            }
327            let obj = match current {
328                serde_json::Value::Object(map) => map,
329                _ => {
330                    warn!(
331                        claim_path,
332                        segment,
333                        "OIDC group claim intermediate segment is not an object; skipping group sync"
334                    );
335                    return None;
336                }
337            };
338            current = obj.get(segment)?;
339        }
340        Some(current)
341    }
342}
343
344#[derive(Zeroize, ZeroizeOnDrop)]
345pub struct ValidatedClaims {
346    pub user: String,
347    /// Groups extracted from the JWT group claim. None if claim absent.
348    pub groups: Option<Vec<String>>,
349    // Prevent construction outside of `GenericOidcAuthenticator::validate_token`.
350    _private: (),
351}
352
353/// Wrapper around `jsonwebtoken::DecodingKey` with a redacted `Debug` impl.
354#[derive(Clone)]
355struct OidcDecodingKey(jsonwebtoken::DecodingKey);
356
357impl std::fmt::Debug for OidcDecodingKey {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        f.debug_struct("OidcDecodingKey")
360            .field("key", &"<redacted>")
361            .finish()
362    }
363}
364
365/// OIDC Authenticator that validates JWTs using JWKS.
366///
367/// This implementation pre-fetches JWKS at construction time for synchronous
368/// token validation.
369#[derive(Clone, Debug)]
370pub struct GenericOidcAuthenticator {
371    inner: Arc<GenericOidcAuthenticatorInner>,
372}
373
374/// OpenID Connect Discovery document.
375/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
376#[derive(Debug, Deserialize)]
377struct OpenIdConfiguration {
378    /// URL of the JWKS endpoint.
379    jwks_uri: String,
380}
381
382#[derive(Debug)]
383pub struct GenericOidcAuthenticatorInner {
384    adapter_client: AdapterClient,
385    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
386    http_client: HttpClient,
387}
388
389impl GenericOidcAuthenticator {
390    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
391    ///
392    /// The OIDC issuer and audience are fetched from system variables on each
393    /// authentication attempt.
394    pub fn new(adapter_client: AdapterClient) -> Self {
395        let http_client = HttpClient::new();
396
397        Self {
398            inner: Arc::new(GenericOidcAuthenticatorInner {
399                adapter_client,
400                decoding_keys: Mutex::new(BTreeMap::new()),
401                http_client,
402            }),
403        }
404    }
405}
406
407impl GenericOidcAuthenticatorInner {
408    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
409        let openid_config_url = build_openid_config_url(issuer)?;
410
411        let openid_config_url_str = openid_config_url.to_string();
412
413        // Fetch OpenID configuration to get the JWKS URI
414        let response = self
415            .http_client
416            .get(openid_config_url)
417            .timeout(Duration::from_secs(10))
418            .send()
419            .await
420            .map_err(|e| OidcError::FetchFromProviderFailed {
421                url: openid_config_url_str.clone(),
422                error_message: e.to_string(),
423            })?;
424
425        if !response.status().is_success() {
426            return Err(OidcError::FetchFromProviderFailed {
427                url: openid_config_url_str.clone(),
428                error_message: response
429                    .error_for_status()
430                    .err()
431                    .map(|e| e.to_string())
432                    .unwrap_or_else(|| "Unknown error".to_string()),
433            });
434        }
435
436        let openid_config: OpenIdConfiguration =
437            response
438                .json()
439                .await
440                .map_err(|e| OidcError::FetchFromProviderFailed {
441                    url: openid_config_url_str,
442                    error_message: e.to_string(),
443                })?;
444
445        Ok(openid_config.jwks_uri)
446    }
447
448    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
449    async fn fetch_jwks(
450        &self,
451        issuer: &str,
452    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
453        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
454        let response = self
455            .http_client
456            .get(&jwks_uri)
457            .timeout(Duration::from_secs(10))
458            .send()
459            .await
460            .map_err(|e| OidcError::FetchFromProviderFailed {
461                url: jwks_uri.clone(),
462                error_message: e.to_string(),
463            })?;
464
465        if !response.status().is_success() {
466            return Err(OidcError::FetchFromProviderFailed {
467                url: jwks_uri.clone(),
468                error_message: response
469                    .error_for_status()
470                    .err()
471                    .map(|e| e.to_string())
472                    .unwrap_or_else(|| "Unknown error".to_string()),
473            });
474        }
475
476        let jwks: JwkSet =
477            response
478                .json()
479                .await
480                .map_err(|e| OidcError::FetchFromProviderFailed {
481                    url: jwks_uri.clone(),
482                    error_message: e.to_string(),
483                })?;
484
485        let mut keys = BTreeMap::new();
486
487        for jwk in jwks.keys {
488            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
489                Ok(key) => {
490                    if let Some(kid) = jwk.common.key_id {
491                        keys.insert(kid, OidcDecodingKey(key));
492                    }
493                }
494                Err(e) => {
495                    warn!("Failed to parse JWK: {}", e);
496                }
497            }
498        }
499
500        Ok(keys)
501    }
502
503    /// Find a decoding key matching the given key ID.
504    /// If the key is not found, fetch the JWKS and cache the keys.
505    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
506        // Get the cached decoding key.
507        {
508            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
509
510            if let Some(key) = decoding_keys.get(kid) {
511                return Ok(key.clone());
512            }
513        }
514
515        // If not found, fetch the JWKS and cache the keys.
516        let new_decoding_keys = self.fetch_jwks(issuer).await?;
517
518        let decoding_key = new_decoding_keys.get(kid).cloned();
519
520        {
521            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
522            *decoding_keys = new_decoding_keys;
523        }
524
525        if let Some(key) = decoding_key {
526            return Ok(key);
527        }
528
529        {
530            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
531            debug!(
532                "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
533            );
534            Err(OidcError::NoMatchingKey {
535                key_id: kid.to_string(),
536            })
537        }
538    }
539
540    pub async fn validate_token(
541        &self,
542        token: &str,
543        expected_user: Option<&str>,
544    ) -> Result<ValidatedClaims, OidcError> {
545        // Fetch current OIDC configuration from system variables
546        let system_vars = self.adapter_client.get_system_vars().await;
547        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
548            return Err(OidcError::MissingIssuer);
549        };
550
551        let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
552
553        let expected_audiences: Vec<String> = {
554            let audiences: Vec<String> =
555                serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
556                    .map_err(|_| OidcError::AudienceParseError)?;
557
558            if audiences.is_empty() {
559                warn!(
560                    "Audience validation skipped. It is discouraged \
561                    to skip audience validation since it allows anyone \
562                    with a JWT issued by the same issuer to authenticate."
563                );
564            }
565            audiences
566        };
567
568        // Decode header to get key ID (kid) and the
569        // decoding algorithm
570        let header = jsonwebtoken::decode_header(token).map_err(|e| {
571            debug!("Failed to decode JWT header: {:?}", e);
572            OidcError::Jwt
573        })?;
574
575        let kid = header.kid.ok_or(OidcError::MissingKid)?;
576        // Find the matching key from our set of cached keys. If not found,
577        // fetch the JWKS from the provider and cache the keys
578        let decoding_key = self.find_key(&kid, &issuer).await?;
579
580        // Set up audience and issuer validation
581        let mut validation = jsonwebtoken::Validation::new(header.alg);
582        validation.set_issuer(&[&issuer]);
583        if !expected_audiences.is_empty() {
584            validation.set_audience(&expected_audiences);
585        } else {
586            validation.validate_aud = false;
587        }
588
589        // Decode and validate the token
590        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
591            .map_err(|e| match e.kind() {
592                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
593                    if !expected_audiences.is_empty() {
594                        OidcError::InvalidAudience {
595                            expected_audiences
596                        }
597                    } else {
598                        soft_panic_or_log!(
599                            "received an audience validation error when audience validation is disabled"
600                        );
601                        OidcError::Jwt
602                    }
603                }
604                jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
605                    expected_issuer: issuer.clone(),
606                },
607                jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
608                _ => OidcError::Jwt,
609            })?;
610
611        let user = token_data.claims.user(&authentication_claim).ok_or(
612            OidcError::NoMatchingAuthenticationClaim {
613                authentication_claim,
614            },
615        )?;
616
617        // Optionally validate the expected user
618        if let Some(expected) = expected_user {
619            if user != expected {
620                return Err(OidcError::WrongUser);
621            }
622        }
623
624        // Extract groups from the configured claim name for group-to-role sync.
625        let group_claim = OIDC_GROUP_CLAIM.get(system_vars.dyncfgs());
626        let groups = token_data.claims.groups(&group_claim);
627
628        Ok(ValidatedClaims {
629            user: user.to_string(),
630            groups,
631            _private: (),
632        })
633    }
634
635    /// Checks whether the role has the LOGIN attribute. This is needed otherwise
636    /// a user can authenticate with an OIDC token to a role that isn't recognized
637    /// as a user.
638    async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
639        match self.adapter_client.role_can_login(role_name).await {
640            Ok(()) => Ok(()),
641            Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
642                // Role will be auto-provisioned during startup; allow login.
643                Ok(())
644            }
645            Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
646                Err(OidcError::NonLogin)
647            }
648            Err(e) => {
649                warn!(?e, "unexpected error checking OIDC role login");
650                Err(OidcError::LoginCheckError)
651            }
652        }
653    }
654}
655
656impl GenericOidcAuthenticator {
657    pub async fn authenticate(
658        &self,
659        token: &str,
660        expected_user: Option<&str>,
661    ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
662        let validated_claims = self.inner.validate_token(token, expected_user).await?;
663        self.inner.check_role_login(&validated_claims.user).await?;
664        Ok((validated_claims, Authenticated))
665    }
666}
667
668fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
669    let mut openid_config_url =
670        Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
671    {
672        let mut segments = openid_config_url
673            .path_segments_mut()
674            .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
675        // Remove trailing slash if it exists
676        segments.pop_if_empty();
677        segments.push(".well-known");
678        segments.push("openid-configuration");
679    }
680    Ok(openid_config_url)
681}
682#[cfg(test)]
683mod tests {
684    use super::*;
685
686    #[mz_ore::test]
687    fn test_aud_single_string() {
688        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
689        let claims: OidcClaims = serde_json::from_str(json).unwrap();
690        assert_eq!(claims.aud, vec!["my-app"]);
691    }
692
693    #[mz_ore::test]
694    fn test_aud_array() {
695        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
696        let claims: OidcClaims = serde_json::from_str(json).unwrap();
697        assert_eq!(claims.aud, vec!["app1", "app2"]);
698    }
699
700    #[mz_ore::test]
701    fn test_user() {
702        let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
703        let claims: OidcClaims = serde_json::from_str(json).unwrap();
704        assert_eq!(claims.user("sub"), Some("user-123"));
705        assert_eq!(claims.user("email"), Some("alice@example.com"));
706        assert_eq!(claims.user("missing"), None);
707    }
708
709    #[mz_ore::test]
710    fn test_build_openid_config_url() {
711        let issuer = "https://dev-123456.okta.com/oauth2/default";
712        let url = build_openid_config_url(issuer).unwrap();
713        assert_eq!(
714            url.to_string(),
715            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
716        );
717    }
718
719    #[mz_ore::test]
720    fn test_build_openid_config_url_trailing_slash() {
721        let issuer = "https://dev-123456.okta.com/oauth2/default/";
722        let url = build_openid_config_url(issuer).unwrap();
723        assert_eq!(
724            url.to_string(),
725            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
726        );
727    }
728
729    #[mz_ore::test]
730    fn zeroize_clears_validated_claims() {
731        use mz_ore::secure::Zeroize;
732        let mut claims = ValidatedClaims {
733            user: "alice@example.com".to_string(),
734            groups: Some(vec!["eng".to_string()]),
735            _private: (),
736        };
737        claims.zeroize();
738        assert!(claims.user.is_empty());
739    }
740
741    #[mz_ore::test]
742    fn oidc_claims_implements_zeroize_on_drop() {
743        fn assert_zod<T: ZeroizeOnDrop>() {}
744        assert_zod::<OidcClaims>();
745        assert_zod::<ValidatedClaims>();
746    }
747
748    #[mz_ore::test]
749    fn zeroize_clears_oidc_claims() {
750        use mz_ore::secure::Zeroize;
751        let mut claims = OidcClaims {
752            iss: "https://issuer.example.com".to_string(),
753            exp: 1234567890,
754            iat: Some(1234567800),
755            aud: vec!["app1".to_string(), "app2".to_string()],
756            unknown_claims: BTreeMap::from([(
757                "email".to_string(),
758                serde_json::Value::String("alice@example.com".to_string()),
759            )]),
760        };
761        claims.zeroize();
762        assert!(claims.iss.is_empty());
763        assert_eq!(claims.exp, 0);
764        assert!(claims.iat.is_none());
765        assert!(claims.aud.is_empty());
766        assert!(claims.unknown_claims.is_empty());
767    }
768
769    #[mz_ore::test]
770    fn test_groups_array() {
771        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["analytics","platform_eng"]}"#;
772        let claims: OidcClaims = serde_json::from_str(json).unwrap();
773        assert_eq!(
774            claims.groups("groups"),
775            Some(vec!["analytics".to_string(), "platform_eng".to_string()])
776        );
777    }
778
779    #[mz_ore::test]
780    fn test_groups_single_string() {
781        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":"analytics"}"#;
782        let claims: OidcClaims = serde_json::from_str(json).unwrap();
783        assert_eq!(claims.groups("groups"), Some(vec!["analytics".to_string()]));
784    }
785
786    #[mz_ore::test]
787    fn test_groups_missing() {
788        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app"}"#;
789        let claims: OidcClaims = serde_json::from_str(json).unwrap();
790        assert_eq!(claims.groups("groups"), None);
791    }
792
793    #[mz_ore::test]
794    fn test_groups_empty_array() {
795        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[]}"#;
796        let claims: OidcClaims = serde_json::from_str(json).unwrap();
797        assert_eq!(claims.groups("groups"), Some(vec![]));
798    }
799
800    #[mz_ore::test]
801    fn test_groups_mixed_case() {
802        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["Analytics","PLATFORM_ENG","analytics"]}"#;
803        let claims: OidcClaims = serde_json::from_str(json).unwrap();
804        // Case is preserved; "Analytics" and "analytics" are distinct groups.
805        assert_eq!(
806            claims.groups("groups"),
807            Some(vec![
808                "Analytics".to_string(),
809                "PLATFORM_ENG".to_string(),
810                "analytics".to_string(),
811            ])
812        );
813    }
814
815    #[mz_ore::test]
816    fn test_groups_custom_claim_name() {
817        let json =
818            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","roles":["admin","viewer"]}"#;
819        let claims: OidcClaims = serde_json::from_str(json).unwrap();
820        assert_eq!(
821            claims.groups("roles"),
822            Some(vec!["admin".to_string(), "viewer".to_string()])
823        );
824        assert_eq!(claims.groups("groups"), None);
825    }
826
827    #[mz_ore::test]
828    fn test_groups_non_string_values_in_array() {
829        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["valid",123,true,"also_valid"]}"#;
830        let claims: OidcClaims = serde_json::from_str(json).unwrap();
831        assert_eq!(
832            claims.groups("groups"),
833            Some(vec!["also_valid".to_string(), "valid".to_string()])
834        );
835    }
836
837    #[mz_ore::test]
838    fn test_groups_non_array_non_string() {
839        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":42}"#;
840        let claims: OidcClaims = serde_json::from_str(json).unwrap();
841        assert_eq!(claims.groups("groups"), None);
842    }
843
844    #[mz_ore::test]
845    fn test_groups_empty_string() {
846        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":""}"#;
847        let claims: OidcClaims = serde_json::from_str(json).unwrap();
848        assert_eq!(claims.groups("groups"), Some(vec![]));
849    }
850
851    #[mz_ore::test]
852    fn test_groups_null_claim() {
853        // Explicit null → treated as absent (not a valid group representation)
854        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":null}"#;
855        let claims: OidcClaims = serde_json::from_str(json).unwrap();
856        assert_eq!(claims.groups("groups"), None);
857    }
858
859    #[mz_ore::test]
860    fn test_groups_boolean_claim() {
861        // Boolean value → not array or string, treated as absent
862        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":true}"#;
863        let claims: OidcClaims = serde_json::from_str(json).unwrap();
864        assert_eq!(claims.groups("groups"), None);
865    }
866
867    #[mz_ore::test]
868    fn test_groups_object_claim() {
869        // JSON object → not array or string, treated as absent
870        let json =
871            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":{"team":"eng"}}"#;
872        let claims: OidcClaims = serde_json::from_str(json).unwrap();
873        assert_eq!(claims.groups("groups"), None);
874    }
875
876    #[mz_ore::test]
877    fn test_groups_array_all_non_strings() {
878        // Array with zero valid string elements → Some(vec![]), not None
879        // (the claim *is* present, it just has no usable group names)
880        let json =
881            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[1,2,true,null]}"#;
882        let claims: OidcClaims = serde_json::from_str(json).unwrap();
883        assert_eq!(claims.groups("groups"), Some(vec![]));
884    }
885
886    #[mz_ore::test]
887    fn test_groups_array_with_nested_arrays() {
888        // Nested arrays are not strings, so they're filtered out
889        let json =
890            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[["nested"],"valid"]}"#;
891        let claims: OidcClaims = serde_json::from_str(json).unwrap();
892        assert_eq!(claims.groups("groups"), Some(vec!["valid".to_string()]));
893    }
894
895    #[mz_ore::test]
896    fn test_groups_array_with_empty_strings() {
897        // Empty strings are not valid role names and are filtered out,
898        // consistent with the single-string case where "" → Some(vec![]).
899        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["","eng",""]}"#;
900        let claims: OidcClaims = serde_json::from_str(json).unwrap();
901        assert_eq!(claims.groups("groups"), Some(vec!["eng".to_string()]));
902    }
903
904    #[mz_ore::test]
905    fn test_groups_whitespace_only_single_string() {
906        // Whitespace-only string is preserved as-is (exact matching, no trim).
907        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":"  "}"#;
908        let claims: OidcClaims = serde_json::from_str(json).unwrap();
909        assert_eq!(claims.groups("groups"), Some(vec!["  ".to_string()]));
910    }
911
912    #[mz_ore::test]
913    fn test_groups_whitespace_names() {
914        // Leading/trailing whitespace is preserved (exact matching, no trim).
915        let json =
916            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["  spaces  ","eng"]}"#;
917        let claims: OidcClaims = serde_json::from_str(json).unwrap();
918        assert_eq!(
919            claims.groups("groups"),
920            Some(vec!["  spaces  ".to_string(), "eng".to_string()])
921        );
922    }
923
924    #[mz_ore::test]
925    fn test_groups_unicode_names() {
926        // Unicode group names are preserved as-is (no case folding).
927        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["Développeurs","INGÉNIEURS"]}"#;
928        let claims: OidcClaims = serde_json::from_str(json).unwrap();
929        assert_eq!(
930            claims.groups("groups"),
931            Some(vec!["Développeurs".to_string(), "INGÉNIEURS".to_string()])
932        );
933    }
934
935    #[mz_ore::test]
936    fn test_groups_special_characters() {
937        // Group names with special characters (hyphens, underscores, dots)
938        // are common in enterprise IdPs like Azure AD / Okta
939        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["team-platform.eng","org_data-science","role/admin"]}"#;
940        let claims: OidcClaims = serde_json::from_str(json).unwrap();
941        assert_eq!(
942            claims.groups("groups"),
943            Some(vec![
944                "org_data-science".to_string(),
945                "role/admin".to_string(),
946                "team-platform.eng".to_string(),
947            ])
948        );
949    }
950
951    #[mz_ore::test]
952    fn test_groups_no_case_folding() {
953        // Case is preserved; "Eng", "eng", "ENG", "eNg" are four distinct groups.
954        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["Eng","eng","ENG","eNg"]}"#;
955        let claims: OidcClaims = serde_json::from_str(json).unwrap();
956        assert_eq!(
957            claims.groups("groups"),
958            Some(vec![
959                "ENG".to_string(),
960                "Eng".to_string(),
961                "eNg".to_string(),
962                "eng".to_string(),
963            ])
964        );
965    }
966
967    #[mz_ore::test]
968    fn test_groups_large_array() {
969        // Verify we handle a reasonably large group list without issues
970        let groups: Vec<String> = (0..100).map(|i| format!("\"group_{}\"", i)).collect();
971        let json = format!(
972            r#"{{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":[{}]}}"#,
973            groups.join(",")
974        );
975        let claims: OidcClaims = serde_json::from_str(&json).unwrap();
976        let result = claims.groups("groups").unwrap();
977        assert_eq!(result.len(), 100);
978        // BTreeSet ensures sorted order
979        assert_eq!(result[0], "group_0");
980        assert_eq!(result[99], "group_99");
981    }
982
983    #[mz_ore::test]
984    fn test_groups_float_claim() {
985        // Float value → not array or string, treated as absent
986        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":3.14}"#;
987        let claims: OidcClaims = serde_json::from_str(json).unwrap();
988        assert_eq!(claims.groups("groups"), None);
989    }
990
991    #[mz_ore::test]
992    fn test_groups_array_with_null_elements() {
993        // Null elements in array are not strings, filtered out
994        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["eng",null,"ops",null]}"#;
995        let claims: OidcClaims = serde_json::from_str(json).unwrap();
996        assert_eq!(
997            claims.groups("groups"),
998            Some(vec!["eng".to_string(), "ops".to_string()])
999        );
1000    }
1001
1002    #[mz_ore::test]
1003    fn test_groups_array_with_object_elements() {
1004        // Object elements in array are not strings, filtered out
1005        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["eng",{"name":"ops"},"analytics"]}"#;
1006        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1007        assert_eq!(
1008            claims.groups("groups"),
1009            Some(vec!["analytics".to_string(), "eng".to_string()])
1010        );
1011    }
1012
1013    #[mz_ore::test]
1014    fn test_groups_sorted_output() {
1015        // Verify output is sorted alphabetically regardless of input order
1016        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["zebra","alpha","mango","beta"]}"#;
1017        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1018        assert_eq!(
1019            claims.groups("groups"),
1020            Some(vec![
1021                "alpha".to_string(),
1022                "beta".to_string(),
1023                "mango".to_string(),
1024                "zebra".to_string(),
1025            ])
1026        );
1027    }
1028
1029    #[mz_ore::test]
1030    fn test_groups_nested_path_array() {
1031        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":{"groups":["analytics","platform_eng"]}}"#;
1032        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1033        assert_eq!(
1034            claims.groups("customClaims.groups"),
1035            Some(vec!["analytics".to_string(), "platform_eng".to_string()])
1036        );
1037    }
1038
1039    #[mz_ore::test]
1040    fn test_groups_nested_path_single_string() {
1041        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":{"groups":"analytics"}}"#;
1042        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1043        assert_eq!(
1044            claims.groups("customClaims.groups"),
1045            Some(vec!["analytics".to_string()])
1046        );
1047    }
1048
1049    #[mz_ore::test]
1050    fn test_groups_nested_path_deeply_nested() {
1051        let json =
1052            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","a":{"b":{"c":["eng"]}}}"#;
1053        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1054        assert_eq!(claims.groups("a.b.c"), Some(vec!["eng".to_string()]));
1055    }
1056
1057    #[mz_ore::test]
1058    fn test_groups_nested_path_missing_intermediate() {
1059        // The top-level key exists but the nested key doesn't.
1060        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":{"other":["eng"]}}"#;
1061        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1062        assert_eq!(claims.groups("customClaims.groups"), None);
1063    }
1064
1065    #[mz_ore::test]
1066    fn test_groups_nested_path_missing_root() {
1067        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app"}"#;
1068        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1069        assert_eq!(claims.groups("customClaims.groups"), None);
1070    }
1071
1072    #[mz_ore::test]
1073    fn test_groups_nested_path_first_segment_not_object() {
1074        // `customClaims` is an array, can't be descended into.
1075        let json =
1076            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":["nope"]}"#;
1077        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1078        assert_eq!(claims.groups("customClaims.groups"), None);
1079    }
1080
1081    #[mz_ore::test]
1082    fn test_groups_nested_path_terminal_not_array_or_string() {
1083        // The terminal value is a number, treated as absent (matches the
1084        // flat-path behavior for unexpected types).
1085        let json =
1086            r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":{"groups":42}}"#;
1087        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1088        assert_eq!(claims.groups("customClaims.groups"), None);
1089    }
1090
1091    #[mz_ore::test]
1092    fn test_groups_path_leading_dot() {
1093        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["eng"]}"#;
1094        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1095        assert_eq!(claims.groups(".groups"), None);
1096    }
1097
1098    #[mz_ore::test]
1099    fn test_groups_path_trailing_dot() {
1100        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":{"groups":["eng"]}}"#;
1101        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1102        assert_eq!(claims.groups("customClaims.groups."), None);
1103    }
1104
1105    #[mz_ore::test]
1106    fn test_groups_path_double_dot() {
1107        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","customClaims":{"groups":["eng"]}}"#;
1108        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1109        assert_eq!(claims.groups("customClaims..groups"), None);
1110    }
1111
1112    #[mz_ore::test]
1113    fn test_groups_path_empty() {
1114        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"app","groups":["eng"]}"#;
1115        let claims: OidcClaims = serde_json::from_str(json).unwrap();
1116        assert_eq!(claims.groups(""), None);
1117    }
1118}