Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::secure::{Zeroize, ZeroizeOnDrop};
24use mz_ore::soft_panic_or_log;
25use mz_pgwire_common::{ErrorResponse, Severity};
26use reqwest::Client as HttpClient;
27use serde::{Deserialize, Deserializer, Serialize};
28use tokio_postgres::error::SqlState;
29
30use tracing::{debug, warn};
31use url::Url;
32/// Errors that can occur during OIDC authentication.
33#[derive(Debug)]
34pub enum OidcError {
35    MissingIssuer,
36    /// Failed to parse OIDC configuration URL.
37    InvalidIssuerUrl(String),
38    AudienceParseError,
39    /// Failed to fetch from the identity provider.
40    FetchFromProviderFailed {
41        url: String,
42        error_message: String,
43    },
44    /// The key ID is missing in the token header.
45    MissingKid,
46    /// No matching key found in JWKS.
47    NoMatchingKey {
48        /// Key ID that was found in the JWT header.
49        key_id: String,
50    },
51    /// Configured authentication claim is not found in the JWT.
52    NoMatchingAuthenticationClaim {
53        authentication_claim: String,
54    },
55    /// JWT validation error
56    Jwt,
57    WrongUser,
58    InvalidAudience {
59        expected_audiences: Vec<String>,
60    },
61    InvalidIssuer {
62        expected_issuer: String,
63    },
64    ExpiredSignature,
65    /// The role exists but does not have the LOGIN attribute.
66    NonLogin,
67    LoginCheckError,
68}
69
70impl std::fmt::Display for OidcError {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
74            OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
75            OidcError::AudienceParseError => {
76                write!(f, "failed to parse OIDC_AUDIENCE system variable")
77            }
78            OidcError::FetchFromProviderFailed { .. } => {
79                write!(f, "failed to fetch OIDC provider configuration")
80            }
81            OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
82            OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
83            OidcError::NoMatchingAuthenticationClaim { .. } => {
84                write!(f, "no matching authentication claim found in the JWT")
85            }
86            OidcError::Jwt => write!(f, "failed to validate JWT"),
87            OidcError::WrongUser => write!(f, "wrong user"),
88            OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
89            OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
90            OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
91            OidcError::NonLogin => write!(f, "role is not allowed to login"),
92            OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
93        }
94    }
95}
96
97impl std::error::Error for OidcError {}
98
99impl OidcError {
100    pub fn code(&self) -> SqlState {
101        SqlState::INVALID_AUTHORIZATION_SPECIFICATION
102    }
103
104    pub fn detail(&self) -> Option<String> {
105        match self {
106            OidcError::InvalidIssuerUrl(issuer) => {
107                Some(format!("Could not parse \"{issuer}\" as a URL."))
108            }
109            OidcError::FetchFromProviderFailed { url, error_message } => {
110                Some(format!("Fetching \"{url}\" failed. {error_message}"))
111            }
112            OidcError::NoMatchingKey { key_id } => {
113                Some(format!("JWT key ID \"{key_id}\" was not found."))
114            }
115            OidcError::InvalidAudience { expected_audiences } => Some(format!(
116                "Expected one of audiences {:?} in the JWT.",
117                expected_audiences,
118            )),
119            OidcError::InvalidIssuer { expected_issuer } => {
120                Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
121            }
122            OidcError::NoMatchingAuthenticationClaim {
123                authentication_claim,
124            } => Some(format!(
125                "Expected authentication claim \"{authentication_claim}\" in the JWT.",
126            )),
127            OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
128            _ => None,
129        }
130    }
131
132    pub fn hint(&self) -> Option<String> {
133        match self {
134            OidcError::MissingIssuer => {
135                Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
136            }
137            _ => None,
138        }
139    }
140
141    pub fn into_response(self) -> ErrorResponse {
142        ErrorResponse {
143            severity: Severity::Fatal,
144            code: self.code(),
145            message: self.to_string(),
146            detail: self.detail(),
147            hint: self.hint(),
148            position: None,
149        }
150    }
151}
152
153fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
154where
155    D: Deserializer<'de>,
156{
157    #[derive(Deserialize)]
158    #[serde(untagged)]
159    enum StringOrVec {
160        String(String),
161        Vec(Vec<String>),
162    }
163
164    match StringOrVec::deserialize(deserializer)? {
165        StringOrVec::String(s) => Ok(vec![s]),
166        StringOrVec::Vec(v) => Ok(v),
167    }
168}
169/// Claims extracted from a validated JWT.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct OidcClaims {
172    /// Issuer.
173    pub iss: String,
174    /// Expiration time (Unix timestamp).
175    pub exp: i64,
176    /// Issued at time (Unix timestamp).
177    #[serde(default)]
178    pub iat: Option<i64>,
179    /// Audience claim (can be single string or array in JWT).
180    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
181    pub aud: Vec<String>,
182    /// Additional claims from the JWT, captured for flexible username extraction.
183    #[serde(flatten)]
184    pub unknown_claims: BTreeMap<String, serde_json::Value>,
185}
186
187impl Zeroize for OidcClaims {
188    fn zeroize(&mut self) {
189        self.iss.zeroize();
190        self.exp.zeroize();
191        self.iat.zeroize();
192        for s in &mut self.aud {
193            s.zeroize();
194        }
195        self.aud.clear();
196        // serde_json::Value doesn't implement Zeroize; drain entries and
197        // zeroize keys/values before the backing allocations are freed.
198        while let Some((mut k, mut v)) = self.unknown_claims.pop_first() {
199            k.zeroize();
200            zeroize_json_value(&mut v);
201        }
202    }
203}
204
205impl Drop for OidcClaims {
206    fn drop(&mut self) {
207        self.zeroize();
208    }
209}
210
211/// `OidcClaims` implements both `Zeroize` and `Drop` (which calls `zeroize()`),
212/// satisfying the `ZeroizeOnDrop` contract.
213impl ZeroizeOnDrop for OidcClaims {}
214
215fn zeroize_json_value(v: &mut serde_json::Value) {
216    use serde_json::Value;
217    match v {
218        Value::String(s) => s.zeroize(),
219        Value::Array(a) => {
220            for item in a.iter_mut() {
221                zeroize_json_value(item);
222            }
223            a.clear();
224        }
225        Value::Object(map) => {
226            let taken = std::mem::take(map);
227            for (mut k, mut nested) in taken {
228                k.zeroize();
229                zeroize_json_value(&mut nested);
230            }
231        }
232        Value::Number(_) => {
233            *v = Value::Number(serde_json::Number::from(0u8));
234        }
235        Value::Bool(b) => *b = false,
236        Value::Null => {}
237    }
238}
239
240impl OidcClaims {
241    /// Extract the username from the OIDC claims.
242    fn user(&self, authentication_claim: &str) -> Option<&str> {
243        self.unknown_claims
244            .get(authentication_claim)
245            .and_then(|value| value.as_str())
246    }
247}
248
249#[derive(Zeroize, ZeroizeOnDrop)]
250pub struct ValidatedClaims {
251    pub user: String,
252    // Prevent construction outside of `GenericOidcAuthenticator::validate_token`.
253    _private: (),
254}
255
256/// Wrapper around `jsonwebtoken::DecodingKey` with a redacted `Debug` impl.
257#[derive(Clone)]
258struct OidcDecodingKey(jsonwebtoken::DecodingKey);
259
260impl std::fmt::Debug for OidcDecodingKey {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.debug_struct("OidcDecodingKey")
263            .field("key", &"<redacted>")
264            .finish()
265    }
266}
267
268/// OIDC Authenticator that validates JWTs using JWKS.
269///
270/// This implementation pre-fetches JWKS at construction time for synchronous
271/// token validation.
272#[derive(Clone, Debug)]
273pub struct GenericOidcAuthenticator {
274    inner: Arc<GenericOidcAuthenticatorInner>,
275}
276
277/// OpenID Connect Discovery document.
278/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
279#[derive(Debug, Deserialize)]
280struct OpenIdConfiguration {
281    /// URL of the JWKS endpoint.
282    jwks_uri: String,
283}
284
285#[derive(Debug)]
286pub struct GenericOidcAuthenticatorInner {
287    adapter_client: AdapterClient,
288    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
289    http_client: HttpClient,
290}
291
292impl GenericOidcAuthenticator {
293    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
294    ///
295    /// The OIDC issuer and audience are fetched from system variables on each
296    /// authentication attempt.
297    pub fn new(adapter_client: AdapterClient) -> Self {
298        let http_client = HttpClient::new();
299
300        Self {
301            inner: Arc::new(GenericOidcAuthenticatorInner {
302                adapter_client,
303                decoding_keys: Mutex::new(BTreeMap::new()),
304                http_client,
305            }),
306        }
307    }
308}
309
310impl GenericOidcAuthenticatorInner {
311    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
312        let openid_config_url = build_openid_config_url(issuer)?;
313
314        let openid_config_url_str = openid_config_url.to_string();
315
316        // Fetch OpenID configuration to get the JWKS URI
317        let response = self
318            .http_client
319            .get(openid_config_url)
320            .timeout(Duration::from_secs(10))
321            .send()
322            .await
323            .map_err(|e| OidcError::FetchFromProviderFailed {
324                url: openid_config_url_str.clone(),
325                error_message: e.to_string(),
326            })?;
327
328        if !response.status().is_success() {
329            return Err(OidcError::FetchFromProviderFailed {
330                url: openid_config_url_str.clone(),
331                error_message: response
332                    .error_for_status()
333                    .err()
334                    .map(|e| e.to_string())
335                    .unwrap_or_else(|| "Unknown error".to_string()),
336            });
337        }
338
339        let openid_config: OpenIdConfiguration =
340            response
341                .json()
342                .await
343                .map_err(|e| OidcError::FetchFromProviderFailed {
344                    url: openid_config_url_str,
345                    error_message: e.to_string(),
346                })?;
347
348        Ok(openid_config.jwks_uri)
349    }
350
351    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
352    async fn fetch_jwks(
353        &self,
354        issuer: &str,
355    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
356        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
357        let response = self
358            .http_client
359            .get(&jwks_uri)
360            .timeout(Duration::from_secs(10))
361            .send()
362            .await
363            .map_err(|e| OidcError::FetchFromProviderFailed {
364                url: jwks_uri.clone(),
365                error_message: e.to_string(),
366            })?;
367
368        if !response.status().is_success() {
369            return Err(OidcError::FetchFromProviderFailed {
370                url: jwks_uri.clone(),
371                error_message: response
372                    .error_for_status()
373                    .err()
374                    .map(|e| e.to_string())
375                    .unwrap_or_else(|| "Unknown error".to_string()),
376            });
377        }
378
379        let jwks: JwkSet =
380            response
381                .json()
382                .await
383                .map_err(|e| OidcError::FetchFromProviderFailed {
384                    url: jwks_uri.clone(),
385                    error_message: e.to_string(),
386                })?;
387
388        let mut keys = BTreeMap::new();
389
390        for jwk in jwks.keys {
391            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
392                Ok(key) => {
393                    if let Some(kid) = jwk.common.key_id {
394                        keys.insert(kid, OidcDecodingKey(key));
395                    }
396                }
397                Err(e) => {
398                    warn!("Failed to parse JWK: {}", e);
399                }
400            }
401        }
402
403        Ok(keys)
404    }
405
406    /// Find a decoding key matching the given key ID.
407    /// If the key is not found, fetch the JWKS and cache the keys.
408    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
409        // Get the cached decoding key.
410        {
411            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
412
413            if let Some(key) = decoding_keys.get(kid) {
414                return Ok(key.clone());
415            }
416        }
417
418        // If not found, fetch the JWKS and cache the keys.
419        let new_decoding_keys = self.fetch_jwks(issuer).await?;
420
421        let decoding_key = new_decoding_keys.get(kid).cloned();
422
423        {
424            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
425            *decoding_keys = new_decoding_keys;
426        }
427
428        if let Some(key) = decoding_key {
429            return Ok(key);
430        }
431
432        {
433            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
434            debug!(
435                "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
436            );
437            Err(OidcError::NoMatchingKey {
438                key_id: kid.to_string(),
439            })
440        }
441    }
442
443    pub async fn validate_token(
444        &self,
445        token: &str,
446        expected_user: Option<&str>,
447    ) -> Result<ValidatedClaims, OidcError> {
448        // Fetch current OIDC configuration from system variables
449        let system_vars = self.adapter_client.get_system_vars().await;
450        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
451            return Err(OidcError::MissingIssuer);
452        };
453
454        let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
455
456        let expected_audiences: Vec<String> = {
457            let audiences: Vec<String> =
458                serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
459                    .map_err(|_| OidcError::AudienceParseError)?;
460
461            if audiences.is_empty() {
462                warn!(
463                    "Audience validation skipped. It is discouraged \
464                    to skip audience validation since it allows anyone \
465                    with a JWT issued by the same issuer to authenticate."
466                );
467            }
468            audiences
469        };
470
471        // Decode header to get key ID (kid) and the
472        // decoding algorithm
473        let header = jsonwebtoken::decode_header(token).map_err(|e| {
474            debug!("Failed to decode JWT header: {:?}", e);
475            OidcError::Jwt
476        })?;
477
478        let kid = header.kid.ok_or(OidcError::MissingKid)?;
479        // Find the matching key from our set of cached keys. If not found,
480        // fetch the JWKS from the provider and cache the keys
481        let decoding_key = self.find_key(&kid, &issuer).await?;
482
483        // Set up audience and issuer validation
484        let mut validation = jsonwebtoken::Validation::new(header.alg);
485        validation.set_issuer(&[&issuer]);
486        if !expected_audiences.is_empty() {
487            validation.set_audience(&expected_audiences);
488        } else {
489            validation.validate_aud = false;
490        }
491
492        // Decode and validate the token
493        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
494            .map_err(|e| match e.kind() {
495                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
496                    if !expected_audiences.is_empty() {
497                        OidcError::InvalidAudience {
498                            expected_audiences
499                        }
500                    } else {
501                        soft_panic_or_log!(
502                            "received an audience validation error when audience validation is disabled"
503                        );
504                        OidcError::Jwt
505                    }
506                }
507                jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
508                    expected_issuer: issuer.clone(),
509                },
510                jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
511                _ => OidcError::Jwt,
512            })?;
513
514        let user = token_data.claims.user(&authentication_claim).ok_or(
515            OidcError::NoMatchingAuthenticationClaim {
516                authentication_claim,
517            },
518        )?;
519
520        // Optionally validate the expected user
521        if let Some(expected) = expected_user {
522            if user != expected {
523                return Err(OidcError::WrongUser);
524            }
525        }
526
527        Ok(ValidatedClaims {
528            user: user.to_string(),
529            _private: (),
530        })
531    }
532
533    /// Checks whether the role has the LOGIN attribute. This is needed otherwise
534    /// a user can authenticate with an OIDC token to a role that isn't recognized
535    /// as a user.
536    async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
537        match self.adapter_client.role_can_login(role_name).await {
538            Ok(()) => Ok(()),
539            Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
540                // Role will be auto-provisioned during startup; allow login.
541                Ok(())
542            }
543            Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
544                Err(OidcError::NonLogin)
545            }
546            Err(e) => {
547                warn!(?e, "unexpected error checking OIDC role login");
548                Err(OidcError::LoginCheckError)
549            }
550        }
551    }
552}
553
554impl GenericOidcAuthenticator {
555    pub async fn authenticate(
556        &self,
557        token: &str,
558        expected_user: Option<&str>,
559    ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
560        let validated_claims = self.inner.validate_token(token, expected_user).await?;
561        self.inner.check_role_login(&validated_claims.user).await?;
562        Ok((validated_claims, Authenticated))
563    }
564}
565
566fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
567    let mut openid_config_url =
568        Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
569    {
570        let mut segments = openid_config_url
571            .path_segments_mut()
572            .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
573        // Remove trailing slash if it exists
574        segments.pop_if_empty();
575        segments.push(".well-known");
576        segments.push("openid-configuration");
577    }
578    Ok(openid_config_url)
579}
580#[cfg(test)]
581mod tests {
582    use super::*;
583
584    #[mz_ore::test]
585    fn test_aud_single_string() {
586        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
587        let claims: OidcClaims = serde_json::from_str(json).unwrap();
588        assert_eq!(claims.aud, vec!["my-app"]);
589    }
590
591    #[mz_ore::test]
592    fn test_aud_array() {
593        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
594        let claims: OidcClaims = serde_json::from_str(json).unwrap();
595        assert_eq!(claims.aud, vec!["app1", "app2"]);
596    }
597
598    #[mz_ore::test]
599    fn test_user() {
600        let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
601        let claims: OidcClaims = serde_json::from_str(json).unwrap();
602        assert_eq!(claims.user("sub"), Some("user-123"));
603        assert_eq!(claims.user("email"), Some("alice@example.com"));
604        assert_eq!(claims.user("missing"), None);
605    }
606
607    #[mz_ore::test]
608    fn test_build_openid_config_url() {
609        let issuer = "https://dev-123456.okta.com/oauth2/default";
610        let url = build_openid_config_url(issuer).unwrap();
611        assert_eq!(
612            url.to_string(),
613            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
614        );
615    }
616
617    #[mz_ore::test]
618    fn test_build_openid_config_url_trailing_slash() {
619        let issuer = "https://dev-123456.okta.com/oauth2/default/";
620        let url = build_openid_config_url(issuer).unwrap();
621        assert_eq!(
622            url.to_string(),
623            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
624        );
625    }
626
627    #[mz_ore::test]
628    fn zeroize_clears_validated_claims() {
629        use mz_ore::secure::Zeroize;
630        let mut claims = ValidatedClaims {
631            user: "alice@example.com".to_string(),
632            _private: (),
633        };
634        claims.zeroize();
635        assert!(claims.user.is_empty());
636    }
637
638    #[mz_ore::test]
639    fn oidc_claims_implements_zeroize_on_drop() {
640        fn assert_zod<T: ZeroizeOnDrop>() {}
641        assert_zod::<OidcClaims>();
642        assert_zod::<ValidatedClaims>();
643    }
644
645    #[mz_ore::test]
646    fn zeroize_clears_oidc_claims() {
647        use mz_ore::secure::Zeroize;
648        let mut claims = OidcClaims {
649            iss: "https://issuer.example.com".to_string(),
650            exp: 1234567890,
651            iat: Some(1234567800),
652            aud: vec!["app1".to_string(), "app2".to_string()],
653            unknown_claims: BTreeMap::from([(
654                "email".to_string(),
655                serde_json::Value::String("alice@example.com".to_string()),
656            )]),
657        };
658        claims.zeroize();
659        assert!(claims.iss.is_empty());
660        assert_eq!(claims.exp, 0);
661        assert!(claims.iat.is_none());
662        assert!(claims.aud.is_empty());
663        assert!(claims.unknown_claims.is_empty());
664    }
665}