1use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::Client as AdapterClient;
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::soft_panic_or_log;
24use mz_pgwire_common::{ErrorResponse, Severity};
25use reqwest::Client as HttpClient;
26use serde::{Deserialize, Deserializer, Serialize};
27use tokio_postgres::error::SqlState;
28
29use tracing::{debug, warn};
30use url::Url;
31#[derive(Debug)]
33pub enum OidcError {
34 MissingIssuer,
35 InvalidIssuerUrl(String),
37 AudienceParseError,
38 FetchFromProviderFailed {
40 url: String,
41 error_message: String,
42 },
43 MissingKid,
45 NoMatchingKey {
47 key_id: String,
49 },
50 NoMatchingAuthenticationClaim {
52 authentication_claim: String,
53 },
54 Jwt,
56 WrongUser,
57 InvalidAudience {
58 expected_audiences: Vec<String>,
59 },
60 InvalidIssuer {
61 expected_issuer: String,
62 },
63 ExpiredSignature,
64}
65
66impl std::fmt::Display for OidcError {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
70 OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
71 OidcError::AudienceParseError => {
72 write!(f, "failed to parse OIDC_AUDIENCE system variable")
73 }
74 OidcError::FetchFromProviderFailed { .. } => {
75 write!(f, "failed to fetch OIDC provider configuration")
76 }
77 OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
78 OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
79 OidcError::NoMatchingAuthenticationClaim { .. } => {
80 write!(f, "no matching authentication claim found in the JWT")
81 }
82 OidcError::Jwt => write!(f, "failed to validate JWT"),
83 OidcError::WrongUser => write!(f, "wrong user"),
84 OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
85 OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
86 OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
87 }
88 }
89}
90
91impl std::error::Error for OidcError {}
92
93impl OidcError {
94 pub fn code(&self) -> SqlState {
95 SqlState::INVALID_AUTHORIZATION_SPECIFICATION
96 }
97
98 pub fn detail(&self) -> Option<String> {
99 match self {
100 OidcError::InvalidIssuerUrl(issuer) => {
101 Some(format!("Could not parse \"{issuer}\" as a URL."))
102 }
103 OidcError::FetchFromProviderFailed { url, error_message } => {
104 Some(format!("Fetching \"{url}\" failed. {error_message}"))
105 }
106 OidcError::NoMatchingKey { key_id } => {
107 Some(format!("JWT key ID \"{key_id}\" was not found."))
108 }
109 OidcError::InvalidAudience { expected_audiences } => Some(format!(
110 "Expected one of audiences {:?} in the JWT.",
111 expected_audiences,
112 )),
113 OidcError::InvalidIssuer { expected_issuer } => {
114 Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
115 }
116 OidcError::NoMatchingAuthenticationClaim {
117 authentication_claim,
118 } => Some(format!(
119 "Expected authentication claim \"{authentication_claim}\" in the JWT.",
120 )),
121 _ => None,
122 }
123 }
124
125 pub fn hint(&self) -> Option<String> {
126 match self {
127 OidcError::MissingIssuer => {
128 Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
129 }
130 _ => None,
131 }
132 }
133
134 pub fn into_response(self) -> ErrorResponse {
135 ErrorResponse {
136 severity: Severity::Fatal,
137 code: self.code(),
138 message: self.to_string(),
139 detail: self.detail(),
140 hint: self.hint(),
141 position: None,
142 }
143 }
144}
145
146fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
147where
148 D: Deserializer<'de>,
149{
150 #[derive(Deserialize)]
151 #[serde(untagged)]
152 enum StringOrVec {
153 String(String),
154 Vec(Vec<String>),
155 }
156
157 match StringOrVec::deserialize(deserializer)? {
158 StringOrVec::String(s) => Ok(vec![s]),
159 StringOrVec::Vec(v) => Ok(v),
160 }
161}
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct OidcClaims {
165 pub iss: String,
167 pub exp: i64,
169 #[serde(default)]
171 pub iat: Option<i64>,
172 #[serde(default, deserialize_with = "deserialize_string_or_vec")]
174 pub aud: Vec<String>,
175 #[serde(flatten)]
177 pub unknown_claims: BTreeMap<String, serde_json::Value>,
178}
179
180impl OidcClaims {
181 fn user(&self, authentication_claim: &str) -> Option<&str> {
183 self.unknown_claims
184 .get(authentication_claim)
185 .and_then(|value| value.as_str())
186 }
187}
188
189pub struct ValidatedClaims {
190 pub user: String,
191 _private: (),
193}
194
195#[derive(Clone)]
196struct OidcDecodingKey(jsonwebtoken::DecodingKey);
197
198impl std::fmt::Debug for OidcDecodingKey {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("OidcDecodingKey")
201 .field("key", &"<redacted>")
202 .finish()
203 }
204}
205
206#[derive(Clone, Debug)]
211pub struct GenericOidcAuthenticator {
212 inner: Arc<GenericOidcAuthenticatorInner>,
213}
214
215#[derive(Debug, Deserialize)]
218struct OpenIdConfiguration {
219 jwks_uri: String,
221}
222
223#[derive(Debug)]
224pub struct GenericOidcAuthenticatorInner {
225 adapter_client: AdapterClient,
226 decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
227 http_client: HttpClient,
228}
229
230impl GenericOidcAuthenticator {
231 pub fn new(adapter_client: AdapterClient) -> Self {
236 let http_client = HttpClient::new();
237
238 Self {
239 inner: Arc::new(GenericOidcAuthenticatorInner {
240 adapter_client,
241 decoding_keys: Mutex::new(BTreeMap::new()),
242 http_client,
243 }),
244 }
245 }
246}
247
248impl GenericOidcAuthenticatorInner {
249 async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
250 let openid_config_url = build_openid_config_url(issuer)?;
251
252 let openid_config_url_str = openid_config_url.to_string();
253
254 let response = self
256 .http_client
257 .get(openid_config_url)
258 .timeout(Duration::from_secs(10))
259 .send()
260 .await
261 .map_err(|e| OidcError::FetchFromProviderFailed {
262 url: openid_config_url_str.clone(),
263 error_message: e.to_string(),
264 })?;
265
266 if !response.status().is_success() {
267 return Err(OidcError::FetchFromProviderFailed {
268 url: openid_config_url_str.clone(),
269 error_message: response
270 .error_for_status()
271 .err()
272 .map(|e| e.to_string())
273 .unwrap_or_else(|| "Unknown error".to_string()),
274 });
275 }
276
277 let openid_config: OpenIdConfiguration =
278 response
279 .json()
280 .await
281 .map_err(|e| OidcError::FetchFromProviderFailed {
282 url: openid_config_url_str,
283 error_message: e.to_string(),
284 })?;
285
286 Ok(openid_config.jwks_uri)
287 }
288
289 async fn fetch_jwks(
291 &self,
292 issuer: &str,
293 ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
294 let jwks_uri = self.fetch_jwks_uri(issuer).await?;
295 let response = self
296 .http_client
297 .get(&jwks_uri)
298 .timeout(Duration::from_secs(10))
299 .send()
300 .await
301 .map_err(|e| OidcError::FetchFromProviderFailed {
302 url: jwks_uri.clone(),
303 error_message: e.to_string(),
304 })?;
305
306 if !response.status().is_success() {
307 return Err(OidcError::FetchFromProviderFailed {
308 url: jwks_uri.clone(),
309 error_message: response
310 .error_for_status()
311 .err()
312 .map(|e| e.to_string())
313 .unwrap_or_else(|| "Unknown error".to_string()),
314 });
315 }
316
317 let jwks: JwkSet =
318 response
319 .json()
320 .await
321 .map_err(|e| OidcError::FetchFromProviderFailed {
322 url: jwks_uri.clone(),
323 error_message: e.to_string(),
324 })?;
325
326 let mut keys = BTreeMap::new();
327
328 for jwk in jwks.keys {
329 match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
330 Ok(key) => {
331 if let Some(kid) = jwk.common.key_id {
332 keys.insert(kid, OidcDecodingKey(key));
333 }
334 }
335 Err(e) => {
336 warn!("Failed to parse JWK: {}", e);
337 }
338 }
339 }
340
341 Ok(keys)
342 }
343
344 async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
347 {
349 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
350
351 if let Some(key) = decoding_keys.get(kid) {
352 return Ok(key.clone());
353 }
354 }
355
356 let new_decoding_keys = self.fetch_jwks(issuer).await?;
358
359 let decoding_key = new_decoding_keys.get(kid).cloned();
360
361 {
362 let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
363 *decoding_keys = new_decoding_keys;
364 }
365
366 if let Some(key) = decoding_key {
367 return Ok(key);
368 }
369
370 {
371 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
372 debug!(
373 "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
374 );
375 Err(OidcError::NoMatchingKey {
376 key_id: kid.to_string(),
377 })
378 }
379 }
380
381 pub async fn validate_token(
382 &self,
383 token: &str,
384 expected_user: Option<&str>,
385 ) -> Result<ValidatedClaims, OidcError> {
386 let system_vars = self.adapter_client.get_system_vars().await;
388 let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
389 return Err(OidcError::MissingIssuer);
390 };
391
392 let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
393
394 let expected_audiences: Vec<String> = {
395 let audiences: Vec<String> =
396 serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
397 .map_err(|_| OidcError::AudienceParseError)?;
398
399 if audiences.is_empty() {
400 warn!(
401 "Audience validation skipped. It is discouraged \
402 to skip audience validation since it allows anyone \
403 with a JWT issued by the same issuer to authenticate."
404 );
405 }
406 audiences
407 };
408
409 let header = jsonwebtoken::decode_header(token).map_err(|e| {
412 debug!("Failed to decode JWT header: {:?}", e);
413 OidcError::Jwt
414 })?;
415
416 let kid = header.kid.ok_or(OidcError::MissingKid)?;
417 let decoding_key = self.find_key(&kid, &issuer).await?;
420
421 let mut validation = jsonwebtoken::Validation::new(header.alg);
423 validation.set_issuer(&[&issuer]);
424 if !expected_audiences.is_empty() {
425 validation.set_audience(&expected_audiences);
426 } else {
427 validation.validate_aud = false;
428 }
429
430 let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
432 .map_err(|e| match e.kind() {
433 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
434 if !expected_audiences.is_empty() {
435 OidcError::InvalidAudience {
436 expected_audiences
437 }
438 } else {
439 soft_panic_or_log!(
440 "received an audience validation error when audience validation is disabled"
441 );
442 OidcError::Jwt
443 }
444 }
445 jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
446 expected_issuer: issuer.clone(),
447 },
448 jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
449 _ => OidcError::Jwt,
450 })?;
451
452 let user = token_data.claims.user(&authentication_claim).ok_or(
453 OidcError::NoMatchingAuthenticationClaim {
454 authentication_claim,
455 },
456 )?;
457
458 if let Some(expected) = expected_user {
460 if user != expected {
461 return Err(OidcError::WrongUser);
462 }
463 }
464
465 Ok(ValidatedClaims {
466 user: user.to_string(),
467 _private: (),
468 })
469 }
470}
471
472impl GenericOidcAuthenticator {
473 pub async fn authenticate(
474 &self,
475 token: &str,
476 expected_user: Option<&str>,
477 ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
478 let validated_claims = self.inner.validate_token(token, expected_user).await?;
479 Ok((validated_claims, Authenticated))
480 }
481}
482
483fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
484 let mut openid_config_url =
485 Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
486 {
487 let mut segments = openid_config_url
488 .path_segments_mut()
489 .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
490 segments.pop_if_empty();
492 segments.push(".well-known");
493 segments.push("openid-configuration");
494 }
495 Ok(openid_config_url)
496}
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[mz_ore::test]
502 fn test_aud_single_string() {
503 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
504 let claims: OidcClaims = serde_json::from_str(json).unwrap();
505 assert_eq!(claims.aud, vec!["my-app"]);
506 }
507
508 #[mz_ore::test]
509 fn test_aud_array() {
510 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
511 let claims: OidcClaims = serde_json::from_str(json).unwrap();
512 assert_eq!(claims.aud, vec!["app1", "app2"]);
513 }
514
515 #[mz_ore::test]
516 fn test_user() {
517 let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
518 let claims: OidcClaims = serde_json::from_str(json).unwrap();
519 assert_eq!(claims.user("sub"), Some("user-123"));
520 assert_eq!(claims.user("email"), Some("alice@example.com"));
521 assert_eq!(claims.user("missing"), None);
522 }
523
524 #[mz_ore::test]
525 fn test_build_openid_config_url() {
526 let issuer = "https://dev-123456.okta.com/oauth2/default";
527 let url = build_openid_config_url(issuer).unwrap();
528 assert_eq!(
529 url.to_string(),
530 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
531 );
532 }
533
534 #[mz_ore::test]
535 fn test_build_openid_config_url_trailing_slash() {
536 let issuer = "https://dev-123456.okta.com/oauth2/default/";
537 let url = build_openid_config_url(issuer).unwrap();
538 assert_eq!(
539 url.to_string(),
540 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
541 );
542 }
543}