Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{
22    OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_GROUP_CLAIM, OIDC_ISSUER,
23};
24use mz_auth::Authenticated;
25use mz_ore::secure::{Zeroize, ZeroizeOnDrop};
26use mz_ore::soft_panic_or_log;
27use mz_pgwire_common::{ErrorResponse, Severity};
28use reqwest::Client as HttpClient;
29use serde::{Deserialize, Deserializer, Serialize};
30use tokio_postgres::error::SqlState;
31
32use tracing::{debug, warn};
33use url::Url;
34/// Errors that can occur during OIDC authentication.
35#[derive(Debug)]
36pub enum OidcError {
37    MissingIssuer,
38    /// Failed to parse OIDC configuration URL.
39    InvalidIssuerUrl(String),
40    AudienceParseError,
41    /// Failed to fetch from the identity provider.
42    FetchFromProviderFailed {
43        url: String,
44        error_message: String,
45    },
46    /// The key ID is missing in the token header.
47    MissingKid,
48    /// No matching key found in JWKS.
49    NoMatchingKey {
50        /// Key ID that was found in the JWT header.
51        key_id: String,
52    },
53    /// Configured authentication claim is not found in the JWT.
54    NoMatchingAuthenticationClaim {
55        authentication_claim: String,
56    },
57    /// JWT validation error
58    Jwt,
59    WrongUser,
60    InvalidAudience {
61        expected_audiences: Vec<String>,
62    },
63    InvalidIssuer {
64        expected_issuer: String,
65    },
66    ExpiredSignature,
67    /// The role exists but does not have the LOGIN attribute.
68    NonLogin,
69    LoginCheckError,
70}
71
72impl std::fmt::Display for OidcError {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        match self {
75            OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
76            OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
77            OidcError::AudienceParseError => {
78                write!(f, "failed to parse OIDC_AUDIENCE system variable")
79            }
80            OidcError::FetchFromProviderFailed { .. } => {
81                write!(f, "failed to fetch OIDC provider configuration")
82            }
83            OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
84            OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
85            OidcError::NoMatchingAuthenticationClaim { .. } => {
86                write!(f, "no matching authentication claim found in the JWT")
87            }
88            OidcError::Jwt => write!(f, "failed to validate JWT"),
89            OidcError::WrongUser => write!(f, "wrong user"),
90            OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
91            OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
92            OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
93            OidcError::NonLogin => write!(f, "role is not allowed to login"),
94            OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
95        }
96    }
97}
98
99impl std::error::Error for OidcError {}
100
101impl OidcError {
102    pub fn code(&self) -> SqlState {
103        SqlState::INVALID_AUTHORIZATION_SPECIFICATION
104    }
105
106    pub fn detail(&self) -> Option<String> {
107        match self {
108            OidcError::InvalidIssuerUrl(issuer) => {
109                Some(format!("Could not parse \"{issuer}\" as a URL."))
110            }
111            OidcError::FetchFromProviderFailed { url, error_message } => {
112                Some(format!("Fetching \"{url}\" failed. {error_message}"))
113            }
114            OidcError::NoMatchingKey { key_id } => {
115                Some(format!("JWT key ID \"{key_id}\" was not found."))
116            }
117            OidcError::InvalidAudience { expected_audiences } => Some(format!(
118                "Expected one of audiences {:?} in the JWT.",
119                expected_audiences,
120            )),
121            OidcError::InvalidIssuer { expected_issuer } => {
122                Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
123            }
124            OidcError::NoMatchingAuthenticationClaim {
125                authentication_claim,
126            } => Some(format!(
127                "Expected authentication claim \"{authentication_claim}\" in the JWT.",
128            )),
129            OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
130            _ => None,
131        }
132    }
133
134    pub fn hint(&self) -> Option<String> {
135        match self {
136            OidcError::MissingIssuer => {
137                Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
138            }
139            _ => None,
140        }
141    }
142
143    pub fn into_response(self) -> ErrorResponse {
144        ErrorResponse {
145            severity: Severity::Fatal,
146            code: self.code(),
147            message: self.to_string(),
148            detail: self.detail(),
149            hint: self.hint(),
150            position: None,
151        }
152    }
153}
154
155fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
156where
157    D: Deserializer<'de>,
158{
159    #[derive(Deserialize)]
160    #[serde(untagged)]
161    enum StringOrVec {
162        String(String),
163        Vec(Vec<String>),
164    }
165
166    match StringOrVec::deserialize(deserializer)? {
167        StringOrVec::String(s) => Ok(vec![s]),
168        StringOrVec::Vec(v) => Ok(v),
169    }
170}
171/// Claims extracted from a validated JWT.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct OidcClaims {
174    /// Issuer.
175    pub iss: String,
176    /// Expiration time (Unix timestamp).
177    pub exp: i64,
178    /// Issued at time (Unix timestamp).
179    #[serde(default)]
180    pub iat: Option<i64>,
181    /// Audience claim (can be single string or array in JWT).
182    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
183    pub aud: Vec<String>,
184    /// Additional claims from the JWT, captured for flexible username extraction.
185    #[serde(flatten)]
186    pub unknown_claims: BTreeMap<String, serde_json::Value>,
187}
188
189impl Zeroize for OidcClaims {
190    fn zeroize(&mut self) {
191        self.iss.zeroize();
192        self.exp.zeroize();
193        self.iat.zeroize();
194        for s in &mut self.aud {
195            s.zeroize();
196        }
197        self.aud.clear();
198        // serde_json::Value doesn't implement Zeroize; drain entries and
199        // zeroize keys/values before the backing allocations are freed.
200        while let Some((mut k, mut v)) = self.unknown_claims.pop_first() {
201            k.zeroize();
202            zeroize_json_value(&mut v);
203        }
204    }
205}
206
207impl Drop for OidcClaims {
208    fn drop(&mut self) {
209        self.zeroize();
210    }
211}
212
213/// `OidcClaims` implements both `Zeroize` and `Drop` (which calls `zeroize()`),
214/// satisfying the `ZeroizeOnDrop` contract.
215impl ZeroizeOnDrop for OidcClaims {}
216
217fn zeroize_json_value(v: &mut serde_json::Value) {
218    use serde_json::Value;
219    match v {
220        Value::String(s) => s.zeroize(),
221        Value::Array(a) => {
222            for item in a.iter_mut() {
223                zeroize_json_value(item);
224            }
225            a.clear();
226        }
227        Value::Object(map) => {
228            let taken = std::mem::take(map);
229            for (mut k, mut nested) in taken {
230                k.zeroize();
231                zeroize_json_value(&mut nested);
232            }
233        }
234        Value::Number(_) => {
235            *v = Value::Number(serde_json::Number::from(0u8));
236        }
237        Value::Bool(b) => *b = false,
238        Value::Null => {}
239    }
240}
241
242impl OidcClaims {
243    /// Extract the username from the OIDC claims.
244    fn user(&self, authentication_claim: &str) -> Option<&str> {
245        self.unknown_claims
246            .get(authentication_claim)
247            .and_then(|value| value.as_str())
248    }
249
250    /// Extracts group names from the configured JWT claim for group-to-role
251    /// sync. See [`mz_auth::group_claims::extract_groups`] for semantics.
252    pub fn groups(&self, claim_path: &str) -> Option<Vec<String>> {
253        mz_auth::group_claims::extract_groups(&self.unknown_claims, claim_path)
254    }
255}
256
257#[derive(Zeroize, ZeroizeOnDrop)]
258pub struct ValidatedClaims {
259    pub user: String,
260    /// Groups extracted from the JWT group claim. None if claim absent.
261    pub groups: Option<Vec<String>>,
262    // Prevent construction outside of `GenericOidcAuthenticator::validate_token`.
263    _private: (),
264}
265
266/// Wrapper around `jsonwebtoken::DecodingKey` with a redacted `Debug` impl.
267#[derive(Clone)]
268struct OidcDecodingKey(jsonwebtoken::DecodingKey);
269
270impl std::fmt::Debug for OidcDecodingKey {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        f.debug_struct("OidcDecodingKey")
273            .field("key", &"<redacted>")
274            .finish()
275    }
276}
277
278/// OIDC Authenticator that validates JWTs using JWKS.
279///
280/// This implementation pre-fetches JWKS at construction time for synchronous
281/// token validation.
282#[derive(Clone, Debug)]
283pub struct GenericOidcAuthenticator {
284    inner: Arc<GenericOidcAuthenticatorInner>,
285}
286
287/// OpenID Connect Discovery document.
288/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
289#[derive(Debug, Deserialize)]
290struct OpenIdConfiguration {
291    /// URL of the JWKS endpoint.
292    jwks_uri: String,
293}
294
295#[derive(Debug)]
296pub struct GenericOidcAuthenticatorInner {
297    adapter_client: AdapterClient,
298    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
299    http_client: HttpClient,
300}
301
302impl GenericOidcAuthenticator {
303    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
304    ///
305    /// The OIDC issuer and audience are fetched from system variables on each
306    /// authentication attempt.
307    pub fn new(adapter_client: AdapterClient) -> Self {
308        let http_client = HttpClient::new();
309
310        Self {
311            inner: Arc::new(GenericOidcAuthenticatorInner {
312                adapter_client,
313                decoding_keys: Mutex::new(BTreeMap::new()),
314                http_client,
315            }),
316        }
317    }
318}
319
320impl GenericOidcAuthenticatorInner {
321    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
322        let openid_config_url = build_openid_config_url(issuer)?;
323
324        let openid_config_url_str = openid_config_url.to_string();
325
326        // Fetch OpenID configuration to get the JWKS URI
327        let response = self
328            .http_client
329            .get(openid_config_url)
330            .timeout(Duration::from_secs(10))
331            .send()
332            .await
333            .map_err(|e| OidcError::FetchFromProviderFailed {
334                url: openid_config_url_str.clone(),
335                error_message: e.to_string(),
336            })?;
337
338        if !response.status().is_success() {
339            return Err(OidcError::FetchFromProviderFailed {
340                url: openid_config_url_str.clone(),
341                error_message: response
342                    .error_for_status()
343                    .err()
344                    .map(|e| e.to_string())
345                    .unwrap_or_else(|| "Unknown error".to_string()),
346            });
347        }
348
349        let openid_config: OpenIdConfiguration =
350            response
351                .json()
352                .await
353                .map_err(|e| OidcError::FetchFromProviderFailed {
354                    url: openid_config_url_str,
355                    error_message: e.to_string(),
356                })?;
357
358        Ok(openid_config.jwks_uri)
359    }
360
361    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
362    async fn fetch_jwks(
363        &self,
364        issuer: &str,
365    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
366        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
367        let response = self
368            .http_client
369            .get(&jwks_uri)
370            .timeout(Duration::from_secs(10))
371            .send()
372            .await
373            .map_err(|e| OidcError::FetchFromProviderFailed {
374                url: jwks_uri.clone(),
375                error_message: e.to_string(),
376            })?;
377
378        if !response.status().is_success() {
379            return Err(OidcError::FetchFromProviderFailed {
380                url: jwks_uri.clone(),
381                error_message: response
382                    .error_for_status()
383                    .err()
384                    .map(|e| e.to_string())
385                    .unwrap_or_else(|| "Unknown error".to_string()),
386            });
387        }
388
389        let jwks: JwkSet =
390            response
391                .json()
392                .await
393                .map_err(|e| OidcError::FetchFromProviderFailed {
394                    url: jwks_uri.clone(),
395                    error_message: e.to_string(),
396                })?;
397
398        let mut keys = BTreeMap::new();
399
400        for jwk in jwks.keys {
401            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
402                Ok(key) => {
403                    if let Some(kid) = jwk.common.key_id {
404                        keys.insert(kid, OidcDecodingKey(key));
405                    }
406                }
407                Err(e) => {
408                    warn!("Failed to parse JWK: {}", e);
409                }
410            }
411        }
412
413        Ok(keys)
414    }
415
416    /// Find a decoding key matching the given key ID.
417    /// If the key is not found, fetch the JWKS and cache the keys.
418    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
419        // Get the cached decoding key.
420        {
421            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
422
423            if let Some(key) = decoding_keys.get(kid) {
424                return Ok(key.clone());
425            }
426        }
427
428        // If not found, fetch the JWKS and cache the keys.
429        let new_decoding_keys = self.fetch_jwks(issuer).await?;
430
431        let decoding_key = new_decoding_keys.get(kid).cloned();
432
433        {
434            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
435            *decoding_keys = new_decoding_keys;
436        }
437
438        if let Some(key) = decoding_key {
439            return Ok(key);
440        }
441
442        {
443            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
444            debug!(
445                "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
446            );
447            Err(OidcError::NoMatchingKey {
448                key_id: kid.to_string(),
449            })
450        }
451    }
452
453    pub async fn validate_token(
454        &self,
455        token: &str,
456        expected_user: Option<&str>,
457    ) -> Result<ValidatedClaims, OidcError> {
458        // Fetch current OIDC configuration from system variables
459        let system_vars = self.adapter_client.get_system_vars().await;
460        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
461            return Err(OidcError::MissingIssuer);
462        };
463
464        let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
465
466        let expected_audiences: Vec<String> = {
467            let audiences: Vec<String> =
468                serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
469                    .map_err(|_| OidcError::AudienceParseError)?;
470
471            if audiences.is_empty() {
472                warn!(
473                    "Audience validation skipped. It is discouraged \
474                    to skip audience validation since it allows anyone \
475                    with a JWT issued by the same issuer to authenticate."
476                );
477            }
478            audiences
479        };
480
481        // Decode header to get key ID (kid) and the
482        // decoding algorithm
483        let header = jsonwebtoken::decode_header(token).map_err(|e| {
484            debug!("Failed to decode JWT header: {:?}", e);
485            OidcError::Jwt
486        })?;
487
488        let kid = header.kid.ok_or(OidcError::MissingKid)?;
489        // Find the matching key from our set of cached keys. If not found,
490        // fetch the JWKS from the provider and cache the keys
491        let decoding_key = self.find_key(&kid, &issuer).await?;
492
493        // Set up audience and issuer validation
494        let mut validation = jsonwebtoken::Validation::new(header.alg);
495        validation.set_issuer(&[&issuer]);
496        if !expected_audiences.is_empty() {
497            validation.set_audience(&expected_audiences);
498        } else {
499            validation.validate_aud = false;
500        }
501
502        // Decode and validate the token
503        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
504            .map_err(|e| match e.kind() {
505                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
506                    if !expected_audiences.is_empty() {
507                        OidcError::InvalidAudience {
508                            expected_audiences
509                        }
510                    } else {
511                        soft_panic_or_log!(
512                            "received an audience validation error when audience validation is disabled"
513                        );
514                        OidcError::Jwt
515                    }
516                }
517                jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
518                    expected_issuer: issuer.clone(),
519                },
520                jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
521                _ => OidcError::Jwt,
522            })?;
523
524        let user = token_data.claims.user(&authentication_claim).ok_or(
525            OidcError::NoMatchingAuthenticationClaim {
526                authentication_claim,
527            },
528        )?;
529
530        // Optionally validate the expected user
531        if let Some(expected) = expected_user {
532            if user != expected {
533                return Err(OidcError::WrongUser);
534            }
535        }
536
537        // Extract groups from the configured claim name for group-to-role sync.
538        let group_claim = OIDC_GROUP_CLAIM.get(system_vars.dyncfgs());
539        let groups = token_data.claims.groups(&group_claim);
540
541        Ok(ValidatedClaims {
542            user: user.to_string(),
543            groups,
544            _private: (),
545        })
546    }
547
548    /// Checks whether the role has the LOGIN attribute. This is needed otherwise
549    /// a user can authenticate with an OIDC token to a role that isn't recognized
550    /// as a user.
551    async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
552        match self.adapter_client.role_can_login(role_name).await {
553            Ok(()) => Ok(()),
554            Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
555                // Role will be auto-provisioned during startup; allow login.
556                Ok(())
557            }
558            Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
559                Err(OidcError::NonLogin)
560            }
561            Err(e) => {
562                warn!(?e, "unexpected error checking OIDC role login");
563                Err(OidcError::LoginCheckError)
564            }
565        }
566    }
567}
568
569impl GenericOidcAuthenticator {
570    pub async fn authenticate(
571        &self,
572        token: &str,
573        expected_user: Option<&str>,
574    ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
575        let validated_claims = self.inner.validate_token(token, expected_user).await?;
576        self.inner.check_role_login(&validated_claims.user).await?;
577        Ok((validated_claims, Authenticated))
578    }
579}
580
581fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
582    let mut openid_config_url =
583        Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
584    {
585        let mut segments = openid_config_url
586            .path_segments_mut()
587            .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
588        // Remove trailing slash if it exists
589        segments.pop_if_empty();
590        segments.push(".well-known");
591        segments.push("openid-configuration");
592    }
593    Ok(openid_config_url)
594}
595#[cfg(test)]
596mod tests {
597    use super::*;
598
599    #[mz_ore::test]
600    fn test_aud_single_string() {
601        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
602        let claims: OidcClaims = serde_json::from_str(json).unwrap();
603        assert_eq!(claims.aud, vec!["my-app"]);
604    }
605
606    #[mz_ore::test]
607    fn test_aud_array() {
608        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
609        let claims: OidcClaims = serde_json::from_str(json).unwrap();
610        assert_eq!(claims.aud, vec!["app1", "app2"]);
611    }
612
613    #[mz_ore::test]
614    fn test_user() {
615        let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
616        let claims: OidcClaims = serde_json::from_str(json).unwrap();
617        assert_eq!(claims.user("sub"), Some("user-123"));
618        assert_eq!(claims.user("email"), Some("alice@example.com"));
619        assert_eq!(claims.user("missing"), None);
620    }
621
622    #[mz_ore::test]
623    fn test_build_openid_config_url() {
624        let issuer = "https://dev-123456.okta.com/oauth2/default";
625        let url = build_openid_config_url(issuer).unwrap();
626        assert_eq!(
627            url.to_string(),
628            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
629        );
630    }
631
632    #[mz_ore::test]
633    fn test_build_openid_config_url_trailing_slash() {
634        let issuer = "https://dev-123456.okta.com/oauth2/default/";
635        let url = build_openid_config_url(issuer).unwrap();
636        assert_eq!(
637            url.to_string(),
638            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
639        );
640    }
641
642    #[mz_ore::test]
643    fn zeroize_clears_validated_claims() {
644        use mz_ore::secure::Zeroize;
645        let mut claims = ValidatedClaims {
646            user: "alice@example.com".to_string(),
647            groups: Some(vec!["eng".to_string()]),
648            _private: (),
649        };
650        claims.zeroize();
651        assert!(claims.user.is_empty());
652    }
653
654    #[mz_ore::test]
655    fn oidc_claims_implements_zeroize_on_drop() {
656        fn assert_zod<T: ZeroizeOnDrop>() {}
657        assert_zod::<OidcClaims>();
658        assert_zod::<ValidatedClaims>();
659    }
660
661    #[mz_ore::test]
662    fn zeroize_clears_oidc_claims() {
663        use mz_ore::secure::Zeroize;
664        let mut claims = OidcClaims {
665            iss: "https://issuer.example.com".to_string(),
666            exp: 1234567890,
667            iat: Some(1234567800),
668            aud: vec!["app1".to_string(), "app2".to_string()],
669            unknown_claims: BTreeMap::from([(
670                "email".to_string(),
671                serde_json::Value::String("alice@example.com".to_string()),
672            )]),
673        };
674        claims.zeroize();
675        assert!(claims.iss.is_empty());
676        assert_eq!(claims.exp, 0);
677        assert!(claims.iat.is_none());
678        assert!(claims.aud.is_empty());
679        assert!(claims.unknown_claims.is_empty());
680    }
681}