Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::soft_panic_or_log;
24use mz_pgwire_common::{ErrorResponse, Severity};
25use reqwest::Client as HttpClient;
26use serde::{Deserialize, Deserializer, Serialize};
27use tokio_postgres::error::SqlState;
28
29use tracing::{debug, warn};
30use url::Url;
31/// Errors that can occur during OIDC authentication.
32#[derive(Debug)]
33pub enum OidcError {
34    MissingIssuer,
35    /// Failed to parse OIDC configuration URL.
36    InvalidIssuerUrl(String),
37    AudienceParseError,
38    /// Failed to fetch from the identity provider.
39    FetchFromProviderFailed {
40        url: String,
41        error_message: String,
42    },
43    /// The key ID is missing in the token header.
44    MissingKid,
45    /// No matching key found in JWKS.
46    NoMatchingKey {
47        /// Key ID that was found in the JWT header.
48        key_id: String,
49    },
50    /// Configured authentication claim is not found in the JWT.
51    NoMatchingAuthenticationClaim {
52        authentication_claim: String,
53    },
54    /// JWT validation error
55    Jwt,
56    WrongUser,
57    InvalidAudience {
58        expected_audiences: Vec<String>,
59    },
60    InvalidIssuer {
61        expected_issuer: String,
62    },
63    ExpiredSignature,
64    /// The role exists but does not have the LOGIN attribute.
65    NonLogin,
66    LoginCheckError,
67}
68
69impl std::fmt::Display for OidcError {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
73            OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
74            OidcError::AudienceParseError => {
75                write!(f, "failed to parse OIDC_AUDIENCE system variable")
76            }
77            OidcError::FetchFromProviderFailed { .. } => {
78                write!(f, "failed to fetch OIDC provider configuration")
79            }
80            OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
81            OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
82            OidcError::NoMatchingAuthenticationClaim { .. } => {
83                write!(f, "no matching authentication claim found in the JWT")
84            }
85            OidcError::Jwt => write!(f, "failed to validate JWT"),
86            OidcError::WrongUser => write!(f, "wrong user"),
87            OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
88            OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
89            OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
90            OidcError::NonLogin => write!(f, "role is not allowed to login"),
91            OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
92        }
93    }
94}
95
96impl std::error::Error for OidcError {}
97
98impl OidcError {
99    pub fn code(&self) -> SqlState {
100        SqlState::INVALID_AUTHORIZATION_SPECIFICATION
101    }
102
103    pub fn detail(&self) -> Option<String> {
104        match self {
105            OidcError::InvalidIssuerUrl(issuer) => {
106                Some(format!("Could not parse \"{issuer}\" as a URL."))
107            }
108            OidcError::FetchFromProviderFailed { url, error_message } => {
109                Some(format!("Fetching \"{url}\" failed. {error_message}"))
110            }
111            OidcError::NoMatchingKey { key_id } => {
112                Some(format!("JWT key ID \"{key_id}\" was not found."))
113            }
114            OidcError::InvalidAudience { expected_audiences } => Some(format!(
115                "Expected one of audiences {:?} in the JWT.",
116                expected_audiences,
117            )),
118            OidcError::InvalidIssuer { expected_issuer } => {
119                Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
120            }
121            OidcError::NoMatchingAuthenticationClaim {
122                authentication_claim,
123            } => Some(format!(
124                "Expected authentication claim \"{authentication_claim}\" in the JWT.",
125            )),
126            OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
127            _ => None,
128        }
129    }
130
131    pub fn hint(&self) -> Option<String> {
132        match self {
133            OidcError::MissingIssuer => {
134                Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
135            }
136            _ => None,
137        }
138    }
139
140    pub fn into_response(self) -> ErrorResponse {
141        ErrorResponse {
142            severity: Severity::Fatal,
143            code: self.code(),
144            message: self.to_string(),
145            detail: self.detail(),
146            hint: self.hint(),
147            position: None,
148        }
149    }
150}
151
152fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
153where
154    D: Deserializer<'de>,
155{
156    #[derive(Deserialize)]
157    #[serde(untagged)]
158    enum StringOrVec {
159        String(String),
160        Vec(Vec<String>),
161    }
162
163    match StringOrVec::deserialize(deserializer)? {
164        StringOrVec::String(s) => Ok(vec![s]),
165        StringOrVec::Vec(v) => Ok(v),
166    }
167}
168/// Claims extracted from a validated JWT.
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct OidcClaims {
171    /// Issuer.
172    pub iss: String,
173    /// Expiration time (Unix timestamp).
174    pub exp: i64,
175    /// Issued at time (Unix timestamp).
176    #[serde(default)]
177    pub iat: Option<i64>,
178    /// Audience claim (can be single string or array in JWT).
179    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
180    pub aud: Vec<String>,
181    /// Additional claims from the JWT, captured for flexible username extraction.
182    #[serde(flatten)]
183    pub unknown_claims: BTreeMap<String, serde_json::Value>,
184}
185
186impl OidcClaims {
187    /// Extract the username from the OIDC claims.
188    fn user(&self, authentication_claim: &str) -> Option<&str> {
189        self.unknown_claims
190            .get(authentication_claim)
191            .and_then(|value| value.as_str())
192    }
193}
194
195pub struct ValidatedClaims {
196    pub user: String,
197    // Prevent construction outside of `GenericOidcAuthenticator::validate_token`.
198    _private: (),
199}
200
201#[derive(Clone)]
202struct OidcDecodingKey(jsonwebtoken::DecodingKey);
203
204impl std::fmt::Debug for OidcDecodingKey {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("OidcDecodingKey")
207            .field("key", &"<redacted>")
208            .finish()
209    }
210}
211
212/// OIDC Authenticator that validates JWTs using JWKS.
213///
214/// This implementation pre-fetches JWKS at construction time for synchronous
215/// token validation.
216#[derive(Clone, Debug)]
217pub struct GenericOidcAuthenticator {
218    inner: Arc<GenericOidcAuthenticatorInner>,
219}
220
221/// OpenID Connect Discovery document.
222/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
223#[derive(Debug, Deserialize)]
224struct OpenIdConfiguration {
225    /// URL of the JWKS endpoint.
226    jwks_uri: String,
227}
228
229#[derive(Debug)]
230pub struct GenericOidcAuthenticatorInner {
231    adapter_client: AdapterClient,
232    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
233    http_client: HttpClient,
234}
235
236impl GenericOidcAuthenticator {
237    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
238    ///
239    /// The OIDC issuer and audience are fetched from system variables on each
240    /// authentication attempt.
241    pub fn new(adapter_client: AdapterClient) -> Self {
242        let http_client = HttpClient::new();
243
244        Self {
245            inner: Arc::new(GenericOidcAuthenticatorInner {
246                adapter_client,
247                decoding_keys: Mutex::new(BTreeMap::new()),
248                http_client,
249            }),
250        }
251    }
252}
253
254impl GenericOidcAuthenticatorInner {
255    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
256        let openid_config_url = build_openid_config_url(issuer)?;
257
258        let openid_config_url_str = openid_config_url.to_string();
259
260        // Fetch OpenID configuration to get the JWKS URI
261        let response = self
262            .http_client
263            .get(openid_config_url)
264            .timeout(Duration::from_secs(10))
265            .send()
266            .await
267            .map_err(|e| OidcError::FetchFromProviderFailed {
268                url: openid_config_url_str.clone(),
269                error_message: e.to_string(),
270            })?;
271
272        if !response.status().is_success() {
273            return Err(OidcError::FetchFromProviderFailed {
274                url: openid_config_url_str.clone(),
275                error_message: response
276                    .error_for_status()
277                    .err()
278                    .map(|e| e.to_string())
279                    .unwrap_or_else(|| "Unknown error".to_string()),
280            });
281        }
282
283        let openid_config: OpenIdConfiguration =
284            response
285                .json()
286                .await
287                .map_err(|e| OidcError::FetchFromProviderFailed {
288                    url: openid_config_url_str,
289                    error_message: e.to_string(),
290                })?;
291
292        Ok(openid_config.jwks_uri)
293    }
294
295    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
296    async fn fetch_jwks(
297        &self,
298        issuer: &str,
299    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
300        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
301        let response = self
302            .http_client
303            .get(&jwks_uri)
304            .timeout(Duration::from_secs(10))
305            .send()
306            .await
307            .map_err(|e| OidcError::FetchFromProviderFailed {
308                url: jwks_uri.clone(),
309                error_message: e.to_string(),
310            })?;
311
312        if !response.status().is_success() {
313            return Err(OidcError::FetchFromProviderFailed {
314                url: jwks_uri.clone(),
315                error_message: response
316                    .error_for_status()
317                    .err()
318                    .map(|e| e.to_string())
319                    .unwrap_or_else(|| "Unknown error".to_string()),
320            });
321        }
322
323        let jwks: JwkSet =
324            response
325                .json()
326                .await
327                .map_err(|e| OidcError::FetchFromProviderFailed {
328                    url: jwks_uri.clone(),
329                    error_message: e.to_string(),
330                })?;
331
332        let mut keys = BTreeMap::new();
333
334        for jwk in jwks.keys {
335            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
336                Ok(key) => {
337                    if let Some(kid) = jwk.common.key_id {
338                        keys.insert(kid, OidcDecodingKey(key));
339                    }
340                }
341                Err(e) => {
342                    warn!("Failed to parse JWK: {}", e);
343                }
344            }
345        }
346
347        Ok(keys)
348    }
349
350    /// Find a decoding key matching the given key ID.
351    /// If the key is not found, fetch the JWKS and cache the keys.
352    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
353        // Get the cached decoding key.
354        {
355            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
356
357            if let Some(key) = decoding_keys.get(kid) {
358                return Ok(key.clone());
359            }
360        }
361
362        // If not found, fetch the JWKS and cache the keys.
363        let new_decoding_keys = self.fetch_jwks(issuer).await?;
364
365        let decoding_key = new_decoding_keys.get(kid).cloned();
366
367        {
368            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
369            *decoding_keys = new_decoding_keys;
370        }
371
372        if let Some(key) = decoding_key {
373            return Ok(key);
374        }
375
376        {
377            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
378            debug!(
379                "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
380            );
381            Err(OidcError::NoMatchingKey {
382                key_id: kid.to_string(),
383            })
384        }
385    }
386
387    pub async fn validate_token(
388        &self,
389        token: &str,
390        expected_user: Option<&str>,
391    ) -> Result<ValidatedClaims, OidcError> {
392        // Fetch current OIDC configuration from system variables
393        let system_vars = self.adapter_client.get_system_vars().await;
394        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
395            return Err(OidcError::MissingIssuer);
396        };
397
398        let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
399
400        let expected_audiences: Vec<String> = {
401            let audiences: Vec<String> =
402                serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
403                    .map_err(|_| OidcError::AudienceParseError)?;
404
405            if audiences.is_empty() {
406                warn!(
407                    "Audience validation skipped. It is discouraged \
408                    to skip audience validation since it allows anyone \
409                    with a JWT issued by the same issuer to authenticate."
410                );
411            }
412            audiences
413        };
414
415        // Decode header to get key ID (kid) and the
416        // decoding algorithm
417        let header = jsonwebtoken::decode_header(token).map_err(|e| {
418            debug!("Failed to decode JWT header: {:?}", e);
419            OidcError::Jwt
420        })?;
421
422        let kid = header.kid.ok_or(OidcError::MissingKid)?;
423        // Find the matching key from our set of cached keys. If not found,
424        // fetch the JWKS from the provider and cache the keys
425        let decoding_key = self.find_key(&kid, &issuer).await?;
426
427        // Set up audience and issuer validation
428        let mut validation = jsonwebtoken::Validation::new(header.alg);
429        validation.set_issuer(&[&issuer]);
430        if !expected_audiences.is_empty() {
431            validation.set_audience(&expected_audiences);
432        } else {
433            validation.validate_aud = false;
434        }
435
436        // Decode and validate the token
437        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
438            .map_err(|e| match e.kind() {
439                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
440                    if !expected_audiences.is_empty() {
441                        OidcError::InvalidAudience {
442                            expected_audiences
443                        }
444                    } else {
445                        soft_panic_or_log!(
446                            "received an audience validation error when audience validation is disabled"
447                        );
448                        OidcError::Jwt
449                    }
450                }
451                jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
452                    expected_issuer: issuer.clone(),
453                },
454                jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
455                _ => OidcError::Jwt,
456            })?;
457
458        let user = token_data.claims.user(&authentication_claim).ok_or(
459            OidcError::NoMatchingAuthenticationClaim {
460                authentication_claim,
461            },
462        )?;
463
464        // Optionally validate the expected user
465        if let Some(expected) = expected_user {
466            if user != expected {
467                return Err(OidcError::WrongUser);
468            }
469        }
470
471        Ok(ValidatedClaims {
472            user: user.to_string(),
473            _private: (),
474        })
475    }
476
477    /// Checks whether the role has the LOGIN attribute. This is needed otherwise
478    /// a user can authenticate with an OIDC token to a role that isn't recognized
479    /// as a user.
480    async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
481        match self.adapter_client.role_can_login(role_name).await {
482            Ok(()) => Ok(()),
483            Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
484                // Role will be auto-provisioned during startup; allow login.
485                Ok(())
486            }
487            Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
488                Err(OidcError::NonLogin)
489            }
490            Err(e) => {
491                warn!(?e, "unexpected error checking OIDC role login");
492                Err(OidcError::LoginCheckError)
493            }
494        }
495    }
496}
497
498impl GenericOidcAuthenticator {
499    pub async fn authenticate(
500        &self,
501        token: &str,
502        expected_user: Option<&str>,
503    ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
504        let validated_claims = self.inner.validate_token(token, expected_user).await?;
505        self.inner.check_role_login(&validated_claims.user).await?;
506        Ok((validated_claims, Authenticated))
507    }
508}
509
510fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
511    let mut openid_config_url =
512        Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
513    {
514        let mut segments = openid_config_url
515            .path_segments_mut()
516            .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
517        // Remove trailing slash if it exists
518        segments.pop_if_empty();
519        segments.push(".well-known");
520        segments.push("openid-configuration");
521    }
522    Ok(openid_config_url)
523}
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[mz_ore::test]
529    fn test_aud_single_string() {
530        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
531        let claims: OidcClaims = serde_json::from_str(json).unwrap();
532        assert_eq!(claims.aud, vec!["my-app"]);
533    }
534
535    #[mz_ore::test]
536    fn test_aud_array() {
537        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
538        let claims: OidcClaims = serde_json::from_str(json).unwrap();
539        assert_eq!(claims.aud, vec!["app1", "app2"]);
540    }
541
542    #[mz_ore::test]
543    fn test_user() {
544        let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
545        let claims: OidcClaims = serde_json::from_str(json).unwrap();
546        assert_eq!(claims.user("sub"), Some("user-123"));
547        assert_eq!(claims.user("email"), Some("alice@example.com"));
548        assert_eq!(claims.user("missing"), None);
549    }
550
551    #[mz_ore::test]
552    fn test_build_openid_config_url() {
553        let issuer = "https://dev-123456.okta.com/oauth2/default";
554        let url = build_openid_config_url(issuer).unwrap();
555        assert_eq!(
556            url.to_string(),
557            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
558        );
559    }
560
561    #[mz_ore::test]
562    fn test_build_openid_config_url_trailing_slash() {
563        let issuer = "https://dev-123456.okta.com/oauth2/default/";
564        let url = build_openid_config_url(issuer).unwrap();
565        assert_eq!(
566            url.to_string(),
567            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
568        );
569    }
570}