Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::Client as AdapterClient;
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::soft_panic_or_log;
24use mz_pgwire_common::{ErrorResponse, Severity};
25use reqwest::Client as HttpClient;
26use serde::{Deserialize, Deserializer, Serialize};
27use tokio_postgres::error::SqlState;
28
29use tracing::{debug, warn};
30use url::Url;
31/// Errors that can occur during OIDC authentication.
32#[derive(Debug)]
33pub enum OidcError {
34    MissingIssuer,
35    /// Failed to parse OIDC configuration URL.
36    InvalidIssuerUrl(String),
37    AudienceParseError,
38    /// Failed to fetch from the identity provider.
39    FetchFromProviderFailed {
40        url: String,
41        error_message: String,
42    },
43    /// The key ID is missing in the token header.
44    MissingKid,
45    /// No matching key found in JWKS.
46    NoMatchingKey {
47        /// Key ID that was found in the JWT header.
48        key_id: String,
49    },
50    /// Configured authentication claim is not found in the JWT.
51    NoMatchingAuthenticationClaim {
52        authentication_claim: String,
53    },
54    /// JWT validation error
55    Jwt,
56    WrongUser,
57    InvalidAudience {
58        expected_audiences: Vec<String>,
59    },
60    InvalidIssuer {
61        expected_issuer: String,
62    },
63    ExpiredSignature,
64}
65
66impl std::fmt::Display for OidcError {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
70            OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
71            OidcError::AudienceParseError => {
72                write!(f, "failed to parse OIDC_AUDIENCE system variable")
73            }
74            OidcError::FetchFromProviderFailed { .. } => {
75                write!(f, "failed to fetch OIDC provider configuration")
76            }
77            OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
78            OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
79            OidcError::NoMatchingAuthenticationClaim { .. } => {
80                write!(f, "no matching authentication claim found in the JWT")
81            }
82            OidcError::Jwt => write!(f, "failed to validate JWT"),
83            OidcError::WrongUser => write!(f, "wrong user"),
84            OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
85            OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
86            OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
87        }
88    }
89}
90
91impl std::error::Error for OidcError {}
92
93impl OidcError {
94    pub fn code(&self) -> SqlState {
95        SqlState::INVALID_AUTHORIZATION_SPECIFICATION
96    }
97
98    pub fn detail(&self) -> Option<String> {
99        match self {
100            OidcError::InvalidIssuerUrl(issuer) => {
101                Some(format!("Could not parse \"{issuer}\" as a URL."))
102            }
103            OidcError::FetchFromProviderFailed { url, error_message } => {
104                Some(format!("Fetching \"{url}\" failed. {error_message}"))
105            }
106            OidcError::NoMatchingKey { key_id } => {
107                Some(format!("JWT key ID \"{key_id}\" was not found."))
108            }
109            OidcError::InvalidAudience { expected_audiences } => Some(format!(
110                "Expected one of audiences {:?} in the JWT.",
111                expected_audiences,
112            )),
113            OidcError::InvalidIssuer { expected_issuer } => {
114                Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
115            }
116            OidcError::NoMatchingAuthenticationClaim {
117                authentication_claim,
118            } => Some(format!(
119                "Expected authentication claim \"{authentication_claim}\" in the JWT.",
120            )),
121            _ => None,
122        }
123    }
124
125    pub fn hint(&self) -> Option<String> {
126        match self {
127            OidcError::MissingIssuer => {
128                Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
129            }
130            _ => None,
131        }
132    }
133
134    pub fn into_response(self) -> ErrorResponse {
135        ErrorResponse {
136            severity: Severity::Fatal,
137            code: self.code(),
138            message: self.to_string(),
139            detail: self.detail(),
140            hint: self.hint(),
141            position: None,
142        }
143    }
144}
145
146fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
147where
148    D: Deserializer<'de>,
149{
150    #[derive(Deserialize)]
151    #[serde(untagged)]
152    enum StringOrVec {
153        String(String),
154        Vec(Vec<String>),
155    }
156
157    match StringOrVec::deserialize(deserializer)? {
158        StringOrVec::String(s) => Ok(vec![s]),
159        StringOrVec::Vec(v) => Ok(v),
160    }
161}
162/// Claims extracted from a validated JWT.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct OidcClaims {
165    /// Issuer.
166    pub iss: String,
167    /// Expiration time (Unix timestamp).
168    pub exp: i64,
169    /// Issued at time (Unix timestamp).
170    #[serde(default)]
171    pub iat: Option<i64>,
172    /// Audience claim (can be single string or array in JWT).
173    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
174    pub aud: Vec<String>,
175    /// Additional claims from the JWT, captured for flexible username extraction.
176    #[serde(flatten)]
177    pub unknown_claims: BTreeMap<String, serde_json::Value>,
178}
179
180impl OidcClaims {
181    /// Extract the username from the OIDC claims.
182    fn user(&self, authentication_claim: &str) -> Option<&str> {
183        self.unknown_claims
184            .get(authentication_claim)
185            .and_then(|value| value.as_str())
186    }
187}
188
189pub struct ValidatedClaims {
190    pub user: String,
191    // Prevent construction outside of `GenericOidcAuthenticator::validate_token`.
192    _private: (),
193}
194
195#[derive(Clone)]
196struct OidcDecodingKey(jsonwebtoken::DecodingKey);
197
198impl std::fmt::Debug for OidcDecodingKey {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("OidcDecodingKey")
201            .field("key", &"<redacted>")
202            .finish()
203    }
204}
205
206/// OIDC Authenticator that validates JWTs using JWKS.
207///
208/// This implementation pre-fetches JWKS at construction time for synchronous
209/// token validation.
210#[derive(Clone, Debug)]
211pub struct GenericOidcAuthenticator {
212    inner: Arc<GenericOidcAuthenticatorInner>,
213}
214
215/// OpenID Connect Discovery document.
216/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
217#[derive(Debug, Deserialize)]
218struct OpenIdConfiguration {
219    /// URL of the JWKS endpoint.
220    jwks_uri: String,
221}
222
223#[derive(Debug)]
224pub struct GenericOidcAuthenticatorInner {
225    adapter_client: AdapterClient,
226    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
227    http_client: HttpClient,
228}
229
230impl GenericOidcAuthenticator {
231    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
232    ///
233    /// The OIDC issuer and audience are fetched from system variables on each
234    /// authentication attempt.
235    pub fn new(adapter_client: AdapterClient) -> Self {
236        let http_client = HttpClient::new();
237
238        Self {
239            inner: Arc::new(GenericOidcAuthenticatorInner {
240                adapter_client,
241                decoding_keys: Mutex::new(BTreeMap::new()),
242                http_client,
243            }),
244        }
245    }
246}
247
248impl GenericOidcAuthenticatorInner {
249    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
250        let openid_config_url = build_openid_config_url(issuer)?;
251
252        let openid_config_url_str = openid_config_url.to_string();
253
254        // Fetch OpenID configuration to get the JWKS URI
255        let response = self
256            .http_client
257            .get(openid_config_url)
258            .timeout(Duration::from_secs(10))
259            .send()
260            .await
261            .map_err(|e| OidcError::FetchFromProviderFailed {
262                url: openid_config_url_str.clone(),
263                error_message: e.to_string(),
264            })?;
265
266        if !response.status().is_success() {
267            return Err(OidcError::FetchFromProviderFailed {
268                url: openid_config_url_str.clone(),
269                error_message: response
270                    .error_for_status()
271                    .err()
272                    .map(|e| e.to_string())
273                    .unwrap_or_else(|| "Unknown error".to_string()),
274            });
275        }
276
277        let openid_config: OpenIdConfiguration =
278            response
279                .json()
280                .await
281                .map_err(|e| OidcError::FetchFromProviderFailed {
282                    url: openid_config_url_str,
283                    error_message: e.to_string(),
284                })?;
285
286        Ok(openid_config.jwks_uri)
287    }
288
289    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
290    async fn fetch_jwks(
291        &self,
292        issuer: &str,
293    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
294        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
295        let response = self
296            .http_client
297            .get(&jwks_uri)
298            .timeout(Duration::from_secs(10))
299            .send()
300            .await
301            .map_err(|e| OidcError::FetchFromProviderFailed {
302                url: jwks_uri.clone(),
303                error_message: e.to_string(),
304            })?;
305
306        if !response.status().is_success() {
307            return Err(OidcError::FetchFromProviderFailed {
308                url: jwks_uri.clone(),
309                error_message: response
310                    .error_for_status()
311                    .err()
312                    .map(|e| e.to_string())
313                    .unwrap_or_else(|| "Unknown error".to_string()),
314            });
315        }
316
317        let jwks: JwkSet =
318            response
319                .json()
320                .await
321                .map_err(|e| OidcError::FetchFromProviderFailed {
322                    url: jwks_uri.clone(),
323                    error_message: e.to_string(),
324                })?;
325
326        let mut keys = BTreeMap::new();
327
328        for jwk in jwks.keys {
329            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
330                Ok(key) => {
331                    if let Some(kid) = jwk.common.key_id {
332                        keys.insert(kid, OidcDecodingKey(key));
333                    }
334                }
335                Err(e) => {
336                    warn!("Failed to parse JWK: {}", e);
337                }
338            }
339        }
340
341        Ok(keys)
342    }
343
344    /// Find a decoding key matching the given key ID.
345    /// If the key is not found, fetch the JWKS and cache the keys.
346    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
347        // Get the cached decoding key.
348        {
349            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
350
351            if let Some(key) = decoding_keys.get(kid) {
352                return Ok(key.clone());
353            }
354        }
355
356        // If not found, fetch the JWKS and cache the keys.
357        let new_decoding_keys = self.fetch_jwks(issuer).await?;
358
359        let decoding_key = new_decoding_keys.get(kid).cloned();
360
361        {
362            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
363            *decoding_keys = new_decoding_keys;
364        }
365
366        if let Some(key) = decoding_key {
367            return Ok(key);
368        }
369
370        {
371            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
372            debug!(
373                "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
374            );
375            Err(OidcError::NoMatchingKey {
376                key_id: kid.to_string(),
377            })
378        }
379    }
380
381    pub async fn validate_token(
382        &self,
383        token: &str,
384        expected_user: Option<&str>,
385    ) -> Result<ValidatedClaims, OidcError> {
386        // Fetch current OIDC configuration from system variables
387        let system_vars = self.adapter_client.get_system_vars().await;
388        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
389            return Err(OidcError::MissingIssuer);
390        };
391
392        let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
393
394        let expected_audiences: Vec<String> = {
395            let audiences: Vec<String> =
396                serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
397                    .map_err(|_| OidcError::AudienceParseError)?;
398
399            if audiences.is_empty() {
400                warn!(
401                    "Audience validation skipped. It is discouraged \
402                    to skip audience validation since it allows anyone \
403                    with a JWT issued by the same issuer to authenticate."
404                );
405            }
406            audiences
407        };
408
409        // Decode header to get key ID (kid) and the
410        // decoding algorithm
411        let header = jsonwebtoken::decode_header(token).map_err(|e| {
412            debug!("Failed to decode JWT header: {:?}", e);
413            OidcError::Jwt
414        })?;
415
416        let kid = header.kid.ok_or(OidcError::MissingKid)?;
417        // Find the matching key from our set of cached keys. If not found,
418        // fetch the JWKS from the provider and cache the keys
419        let decoding_key = self.find_key(&kid, &issuer).await?;
420
421        // Set up audience and issuer validation
422        let mut validation = jsonwebtoken::Validation::new(header.alg);
423        validation.set_issuer(&[&issuer]);
424        if !expected_audiences.is_empty() {
425            validation.set_audience(&expected_audiences);
426        } else {
427            validation.validate_aud = false;
428        }
429
430        // Decode and validate the token
431        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
432            .map_err(|e| match e.kind() {
433                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
434                    if !expected_audiences.is_empty() {
435                        OidcError::InvalidAudience {
436                            expected_audiences
437                        }
438                    } else {
439                        soft_panic_or_log!(
440                            "received an audience validation error when audience validation is disabled"
441                        );
442                        OidcError::Jwt
443                    }
444                }
445                jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
446                    expected_issuer: issuer.clone(),
447                },
448                jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
449                _ => OidcError::Jwt,
450            })?;
451
452        let user = token_data.claims.user(&authentication_claim).ok_or(
453            OidcError::NoMatchingAuthenticationClaim {
454                authentication_claim,
455            },
456        )?;
457
458        // Optionally validate the expected user
459        if let Some(expected) = expected_user {
460            if user != expected {
461                return Err(OidcError::WrongUser);
462            }
463        }
464
465        Ok(ValidatedClaims {
466            user: user.to_string(),
467            _private: (),
468        })
469    }
470}
471
472impl GenericOidcAuthenticator {
473    pub async fn authenticate(
474        &self,
475        token: &str,
476        expected_user: Option<&str>,
477    ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
478        let validated_claims = self.inner.validate_token(token, expected_user).await?;
479        Ok((validated_claims, Authenticated))
480    }
481}
482
483fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
484    let mut openid_config_url =
485        Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
486    {
487        let mut segments = openid_config_url
488            .path_segments_mut()
489            .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
490        // Remove trailing slash if it exists
491        segments.pop_if_empty();
492        segments.push(".well-known");
493        segments.push("openid-configuration");
494    }
495    Ok(openid_config_url)
496}
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[mz_ore::test]
502    fn test_aud_single_string() {
503        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
504        let claims: OidcClaims = serde_json::from_str(json).unwrap();
505        assert_eq!(claims.aud, vec!["my-app"]);
506    }
507
508    #[mz_ore::test]
509    fn test_aud_array() {
510        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
511        let claims: OidcClaims = serde_json::from_str(json).unwrap();
512        assert_eq!(claims.aud, vec!["app1", "app2"]);
513    }
514
515    #[mz_ore::test]
516    fn test_user() {
517        let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
518        let claims: OidcClaims = serde_json::from_str(json).unwrap();
519        assert_eq!(claims.user("sub"), Some("user-123"));
520        assert_eq!(claims.user("email"), Some("alice@example.com"));
521        assert_eq!(claims.user("missing"), None);
522    }
523
524    #[mz_ore::test]
525    fn test_build_openid_config_url() {
526        let issuer = "https://dev-123456.okta.com/oauth2/default";
527        let url = build_openid_config_url(issuer).unwrap();
528        assert_eq!(
529            url.to_string(),
530            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
531        );
532    }
533
534    #[mz_ore::test]
535    fn test_build_openid_config_url_trailing_slash() {
536        let issuer = "https://dev-123456.okta.com/oauth2/default/";
537        let url = build_openid_config_url(issuer).unwrap();
538        assert_eq!(
539            url.to_string(),
540            "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
541        );
542    }
543}