Skip to main content

mz_authenticator/
oidc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC Authentication for pgwire connections.
11//!
12//! This module provides JWT-based authentication using OpenID Connect (OIDC).
13//! JWTs are validated locally using JWKS fetched from the configured provider.
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::Client as AdapterClient;
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Deserializer, Serialize};
25
26use tracing::warn;
27use url::Url;
28/// Errors that can occur during OIDC authentication.
29#[derive(Debug)]
30pub enum OidcError {
31    /// The issuer is missing.
32    MissingIssuer,
33    /// Failed to parse OIDC configuration URL.
34    InvalidIssuerUrl(url::ParseError),
35    /// Failed to fetch OpenID configuration from provider.
36    OpenIdConfigFetchFailed(String),
37    /// Failed to fetch JWKS from provider.
38    JwksFetchFailed(String),
39    /// The key ID is missing in the token header.
40    MissingKid,
41    /// No matching key found in JWKS.
42    NoMatchingKey,
43    /// JWT validation error from jsonwebtoken.
44    Jwt(jsonwebtoken::errors::Error),
45    /// User does not match expected value.
46    WrongUser,
47}
48
49impl std::fmt::Display for OidcError {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            OidcError::MissingIssuer => write!(f, "missing OIDC issuer URL"),
53            OidcError::InvalidIssuerUrl(e) => {
54                write!(f, "failed to parse OIDC issuer URL: {}", e)
55            }
56            OidcError::OpenIdConfigFetchFailed(e) => {
57                write!(f, "failed to fetch OpenID configuration: {}", e)
58            }
59            OidcError::JwksFetchFailed(e) => write!(f, "failed to fetch JWKS: {}", e),
60            OidcError::MissingKid => write!(f, "missing key ID in token header"),
61            OidcError::NoMatchingKey => write!(f, "no matching key in JWKS"),
62            OidcError::Jwt(e) => write!(f, "JWT error: {}", e),
63            OidcError::WrongUser => write!(f, "user does not match expected value"),
64        }
65    }
66}
67
68impl std::error::Error for OidcError {}
69
70fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
71where
72    D: Deserializer<'de>,
73{
74    #[derive(Deserialize)]
75    #[serde(untagged)]
76    enum StringOrVec {
77        String(String),
78        Vec(Vec<String>),
79    }
80
81    match StringOrVec::deserialize(deserializer)? {
82        StringOrVec::String(s) => Ok(vec![s]),
83        StringOrVec::Vec(v) => Ok(v),
84    }
85}
86/// Claims extracted from a validated JWT.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct OidcClaims {
89    /// Subject (user identifier).
90    pub sub: String,
91    /// Issuer.
92    pub iss: String,
93    /// Expiration time (Unix timestamp).
94    pub exp: i64,
95    /// Issued at time (Unix timestamp).
96    #[serde(default)]
97    pub iat: Option<i64>,
98    /// Email claim (commonly used for username).
99    #[serde(default)]
100    pub email: Option<String>,
101    /// Audience claim (can be single string or array in JWT).
102    #[serde(default, deserialize_with = "deserialize_string_or_vec")]
103    pub aud: Vec<String>,
104}
105
106impl OidcClaims {
107    /// Extract the username to use for the session.
108    ///
109    /// Priority: email > sub
110    // TODO (Oidc): Add a configuration variable to use a different username field.
111    pub fn username(&self) -> &str {
112        self.email.as_deref().unwrap_or(&self.sub)
113    }
114}
115
116#[derive(Clone)]
117struct OidcDecodingKey(jsonwebtoken::DecodingKey);
118
119impl std::fmt::Debug for OidcDecodingKey {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("OidcDecodingKey")
122            .field("key", &"<redacted>")
123            .finish()
124    }
125}
126
127/// OIDC Authenticator that validates JWTs using JWKS.
128///
129/// This implementation pre-fetches JWKS at construction time for synchronous
130/// token validation.
131#[derive(Clone, Debug)]
132pub struct GenericOidcAuthenticator {
133    inner: Arc<GenericOidcAuthenticatorInner>,
134}
135
136/// OpenID Connect Discovery document.
137/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html>
138#[derive(Debug, Deserialize)]
139struct OpenIdConfiguration {
140    /// URL of the JWKS endpoint.
141    jwks_uri: String,
142}
143
144#[derive(Debug)]
145pub struct GenericOidcAuthenticatorInner {
146    adapter_client: AdapterClient,
147    decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
148    http_client: HttpClient,
149}
150
151impl GenericOidcAuthenticator {
152    /// Create a new [`GenericOidcAuthenticator`] with an [`AdapterClient`].
153    ///
154    /// The OIDC issuer and audience are fetched from system variables on each
155    /// authentication attempt.
156    pub fn new(adapter_client: AdapterClient) -> Self {
157        let http_client = HttpClient::new();
158
159        Self {
160            inner: Arc::new(GenericOidcAuthenticatorInner {
161                adapter_client,
162                decoding_keys: Mutex::new(BTreeMap::new()),
163                http_client,
164            }),
165        }
166    }
167}
168
169impl GenericOidcAuthenticatorInner {
170    async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
171        let openid_config_url = Url::parse(issuer)
172            .and_then(|url| url.join(".well-known/openid-configuration"))
173            .map_err(OidcError::InvalidIssuerUrl)?;
174
175        // Fetch OpenID configuration to get the JWKS URI
176        let response = self
177            .http_client
178            .get(openid_config_url)
179            .timeout(Duration::from_secs(10))
180            .send()
181            .await
182            .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?;
183
184        if !response.status().is_success() {
185            return Err(OidcError::OpenIdConfigFetchFailed(format!(
186                "HTTP {}",
187                response.status()
188            )));
189        }
190
191        let openid_config: OpenIdConfiguration = response
192            .json()
193            .await
194            .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?;
195
196        Ok(openid_config.jwks_uri)
197    }
198
199    /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys.
200    async fn fetch_jwks(
201        &self,
202        issuer: &str,
203    ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
204        let jwks_uri = self.fetch_jwks_uri(issuer).await?;
205        let response = self
206            .http_client
207            .get(&jwks_uri)
208            .timeout(Duration::from_secs(10))
209            .send()
210            .await
211            .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?;
212
213        if !response.status().is_success() {
214            return Err(OidcError::JwksFetchFailed(format!(
215                "HTTP {}",
216                response.status()
217            )));
218        }
219
220        let jwks: JwkSet = response
221            .json()
222            .await
223            .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?;
224
225        let mut keys = BTreeMap::new();
226        for jwk in jwks.keys {
227            match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
228                Ok(key) => {
229                    if let Some(kid) = jwk.common.key_id {
230                        keys.insert(kid, OidcDecodingKey(key));
231                    }
232                }
233                Err(e) => {
234                    warn!("Failed to parse JWK: {}", e);
235                }
236            }
237        }
238
239        if keys.is_empty() {
240            return Err(OidcError::JwksFetchFailed(
241                "no valid keys found in JWKS".to_string(),
242            ));
243        }
244
245        Ok(keys)
246    }
247
248    /// Find a decoding key matching the given key ID.
249    /// If the key is not found, fetch the JWKS and cache the keys.
250    async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
251        {
252            let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
253
254            if let Some(key) = decoding_keys.get(kid) {
255                return Ok(key.clone());
256            }
257        }
258
259        let new_decoding_keys = self.fetch_jwks(issuer).await?;
260
261        let decoding_key = new_decoding_keys.get(kid).cloned();
262
263        {
264            let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
265            decoding_keys.extend(new_decoding_keys);
266        }
267
268        if let Some(key) = decoding_key {
269            return Ok(key);
270        }
271
272        Err(OidcError::NoMatchingKey)
273    }
274
275    pub async fn validate_token(
276        &self,
277        token: &str,
278        expected_user: Option<&str>,
279    ) -> Result<OidcClaims, OidcError> {
280        // Fetch current OIDC configuration from system variables
281        let system_vars = self.adapter_client.get_system_vars().await;
282        let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
283            return Err(OidcError::MissingIssuer);
284        };
285
286        let audience = {
287            let aud = OIDC_AUDIENCE.get(system_vars.dyncfgs());
288            if aud.is_none() {
289                warn!(
290                    "Audience validation skipped. It is discouraged
291                    to skip audience validation since it allows
292                    anyone with a JWT issued by the same issuer
293                    to authenticate."
294                );
295            }
296            aud
297        };
298
299        // Decode header to get key ID (kid) and algorithm
300        let header = jsonwebtoken::decode_header(token).map_err(OidcError::Jwt)?;
301
302        let kid = header.kid.ok_or(OidcError::MissingKid)?;
303        // Find matching key from cached keys
304        let decoding_key = self.find_key(&kid, &issuer).await?;
305
306        // Set up validation
307        // TODO (Oidc): Make JWT expiration configurable.
308        let mut validation = jsonwebtoken::Validation::new(header.alg);
309        validation.set_issuer(&[&issuer]);
310        if let Some(ref audience) = audience {
311            validation.set_audience(&[audience]);
312        } else {
313            validation.validate_aud = false;
314        }
315
316        // Decode and validate the token
317        let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
318            .map_err(OidcError::Jwt)?;
319
320        // Optionally validate expected user
321        if let Some(expected) = expected_user {
322            if token_data.claims.username() != expected {
323                return Err(OidcError::WrongUser);
324            }
325        }
326
327        Ok(token_data.claims)
328    }
329}
330
331impl GenericOidcAuthenticator {
332    pub async fn authenticate(
333        &self,
334        token: &str,
335        expected_user: Option<&str>,
336    ) -> Result<(OidcClaims, Authenticated), OidcError> {
337        let claims = self.inner.validate_token(token, expected_user).await?;
338        Ok((claims, Authenticated))
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[mz_ore::test]
347    fn test_aud_single_string() {
348        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
349        let claims: OidcClaims = serde_json::from_str(json).unwrap();
350        assert_eq!(claims.aud, vec!["my-app"]);
351    }
352
353    #[mz_ore::test]
354    fn test_aud_array() {
355        let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
356        let claims: OidcClaims = serde_json::from_str(json).unwrap();
357        assert_eq!(claims.aud, vec!["app1", "app2"]);
358    }
359}