1use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::secure::{Zeroize, ZeroizeOnDrop};
24use mz_ore::soft_panic_or_log;
25use mz_pgwire_common::{ErrorResponse, Severity};
26use reqwest::Client as HttpClient;
27use serde::{Deserialize, Deserializer, Serialize};
28use tokio_postgres::error::SqlState;
29
30use tracing::{debug, warn};
31use url::Url;
32#[derive(Debug)]
34pub enum OidcError {
35 MissingIssuer,
36 InvalidIssuerUrl(String),
38 AudienceParseError,
39 FetchFromProviderFailed {
41 url: String,
42 error_message: String,
43 },
44 MissingKid,
46 NoMatchingKey {
48 key_id: String,
50 },
51 NoMatchingAuthenticationClaim {
53 authentication_claim: String,
54 },
55 Jwt,
57 WrongUser,
58 InvalidAudience {
59 expected_audiences: Vec<String>,
60 },
61 InvalidIssuer {
62 expected_issuer: String,
63 },
64 ExpiredSignature,
65 NonLogin,
67 LoginCheckError,
68}
69
70impl std::fmt::Display for OidcError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
74 OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
75 OidcError::AudienceParseError => {
76 write!(f, "failed to parse OIDC_AUDIENCE system variable")
77 }
78 OidcError::FetchFromProviderFailed { .. } => {
79 write!(f, "failed to fetch OIDC provider configuration")
80 }
81 OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
82 OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
83 OidcError::NoMatchingAuthenticationClaim { .. } => {
84 write!(f, "no matching authentication claim found in the JWT")
85 }
86 OidcError::Jwt => write!(f, "failed to validate JWT"),
87 OidcError::WrongUser => write!(f, "wrong user"),
88 OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
89 OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
90 OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
91 OidcError::NonLogin => write!(f, "role is not allowed to login"),
92 OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
93 }
94 }
95}
96
97impl std::error::Error for OidcError {}
98
99impl OidcError {
100 pub fn code(&self) -> SqlState {
101 SqlState::INVALID_AUTHORIZATION_SPECIFICATION
102 }
103
104 pub fn detail(&self) -> Option<String> {
105 match self {
106 OidcError::InvalidIssuerUrl(issuer) => {
107 Some(format!("Could not parse \"{issuer}\" as a URL."))
108 }
109 OidcError::FetchFromProviderFailed { url, error_message } => {
110 Some(format!("Fetching \"{url}\" failed. {error_message}"))
111 }
112 OidcError::NoMatchingKey { key_id } => {
113 Some(format!("JWT key ID \"{key_id}\" was not found."))
114 }
115 OidcError::InvalidAudience { expected_audiences } => Some(format!(
116 "Expected one of audiences {:?} in the JWT.",
117 expected_audiences,
118 )),
119 OidcError::InvalidIssuer { expected_issuer } => {
120 Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
121 }
122 OidcError::NoMatchingAuthenticationClaim {
123 authentication_claim,
124 } => Some(format!(
125 "Expected authentication claim \"{authentication_claim}\" in the JWT.",
126 )),
127 OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
128 _ => None,
129 }
130 }
131
132 pub fn hint(&self) -> Option<String> {
133 match self {
134 OidcError::MissingIssuer => {
135 Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
136 }
137 _ => None,
138 }
139 }
140
141 pub fn into_response(self) -> ErrorResponse {
142 ErrorResponse {
143 severity: Severity::Fatal,
144 code: self.code(),
145 message: self.to_string(),
146 detail: self.detail(),
147 hint: self.hint(),
148 position: None,
149 }
150 }
151}
152
153fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
154where
155 D: Deserializer<'de>,
156{
157 #[derive(Deserialize)]
158 #[serde(untagged)]
159 enum StringOrVec {
160 String(String),
161 Vec(Vec<String>),
162 }
163
164 match StringOrVec::deserialize(deserializer)? {
165 StringOrVec::String(s) => Ok(vec![s]),
166 StringOrVec::Vec(v) => Ok(v),
167 }
168}
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct OidcClaims {
172 pub iss: String,
174 pub exp: i64,
176 #[serde(default)]
178 pub iat: Option<i64>,
179 #[serde(default, deserialize_with = "deserialize_string_or_vec")]
181 pub aud: Vec<String>,
182 #[serde(flatten)]
184 pub unknown_claims: BTreeMap<String, serde_json::Value>,
185}
186
187impl Zeroize for OidcClaims {
188 fn zeroize(&mut self) {
189 self.iss.zeroize();
190 self.exp.zeroize();
191 self.iat.zeroize();
192 for s in &mut self.aud {
193 s.zeroize();
194 }
195 self.aud.clear();
196 while let Some((mut k, mut v)) = self.unknown_claims.pop_first() {
199 k.zeroize();
200 zeroize_json_value(&mut v);
201 }
202 }
203}
204
205impl Drop for OidcClaims {
206 fn drop(&mut self) {
207 self.zeroize();
208 }
209}
210
211impl ZeroizeOnDrop for OidcClaims {}
214
215fn zeroize_json_value(v: &mut serde_json::Value) {
216 use serde_json::Value;
217 match v {
218 Value::String(s) => s.zeroize(),
219 Value::Array(a) => {
220 for item in a.iter_mut() {
221 zeroize_json_value(item);
222 }
223 a.clear();
224 }
225 Value::Object(map) => {
226 let taken = std::mem::take(map);
227 for (mut k, mut nested) in taken {
228 k.zeroize();
229 zeroize_json_value(&mut nested);
230 }
231 }
232 Value::Number(_) => {
233 *v = Value::Number(serde_json::Number::from(0u8));
234 }
235 Value::Bool(b) => *b = false,
236 Value::Null => {}
237 }
238}
239
240impl OidcClaims {
241 fn user(&self, authentication_claim: &str) -> Option<&str> {
243 self.unknown_claims
244 .get(authentication_claim)
245 .and_then(|value| value.as_str())
246 }
247}
248
249#[derive(Zeroize, ZeroizeOnDrop)]
250pub struct ValidatedClaims {
251 pub user: String,
252 _private: (),
254}
255
256#[derive(Clone)]
258struct OidcDecodingKey(jsonwebtoken::DecodingKey);
259
260impl std::fmt::Debug for OidcDecodingKey {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("OidcDecodingKey")
263 .field("key", &"<redacted>")
264 .finish()
265 }
266}
267
268#[derive(Clone, Debug)]
273pub struct GenericOidcAuthenticator {
274 inner: Arc<GenericOidcAuthenticatorInner>,
275}
276
277#[derive(Debug, Deserialize)]
280struct OpenIdConfiguration {
281 jwks_uri: String,
283}
284
285#[derive(Debug)]
286pub struct GenericOidcAuthenticatorInner {
287 adapter_client: AdapterClient,
288 decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
289 http_client: HttpClient,
290}
291
292impl GenericOidcAuthenticator {
293 pub fn new(adapter_client: AdapterClient) -> Self {
298 let http_client = HttpClient::new();
299
300 Self {
301 inner: Arc::new(GenericOidcAuthenticatorInner {
302 adapter_client,
303 decoding_keys: Mutex::new(BTreeMap::new()),
304 http_client,
305 }),
306 }
307 }
308}
309
310impl GenericOidcAuthenticatorInner {
311 async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
312 let openid_config_url = build_openid_config_url(issuer)?;
313
314 let openid_config_url_str = openid_config_url.to_string();
315
316 let response = self
318 .http_client
319 .get(openid_config_url)
320 .timeout(Duration::from_secs(10))
321 .send()
322 .await
323 .map_err(|e| OidcError::FetchFromProviderFailed {
324 url: openid_config_url_str.clone(),
325 error_message: e.to_string(),
326 })?;
327
328 if !response.status().is_success() {
329 return Err(OidcError::FetchFromProviderFailed {
330 url: openid_config_url_str.clone(),
331 error_message: response
332 .error_for_status()
333 .err()
334 .map(|e| e.to_string())
335 .unwrap_or_else(|| "Unknown error".to_string()),
336 });
337 }
338
339 let openid_config: OpenIdConfiguration =
340 response
341 .json()
342 .await
343 .map_err(|e| OidcError::FetchFromProviderFailed {
344 url: openid_config_url_str,
345 error_message: e.to_string(),
346 })?;
347
348 Ok(openid_config.jwks_uri)
349 }
350
351 async fn fetch_jwks(
353 &self,
354 issuer: &str,
355 ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
356 let jwks_uri = self.fetch_jwks_uri(issuer).await?;
357 let response = self
358 .http_client
359 .get(&jwks_uri)
360 .timeout(Duration::from_secs(10))
361 .send()
362 .await
363 .map_err(|e| OidcError::FetchFromProviderFailed {
364 url: jwks_uri.clone(),
365 error_message: e.to_string(),
366 })?;
367
368 if !response.status().is_success() {
369 return Err(OidcError::FetchFromProviderFailed {
370 url: jwks_uri.clone(),
371 error_message: response
372 .error_for_status()
373 .err()
374 .map(|e| e.to_string())
375 .unwrap_or_else(|| "Unknown error".to_string()),
376 });
377 }
378
379 let jwks: JwkSet =
380 response
381 .json()
382 .await
383 .map_err(|e| OidcError::FetchFromProviderFailed {
384 url: jwks_uri.clone(),
385 error_message: e.to_string(),
386 })?;
387
388 let mut keys = BTreeMap::new();
389
390 for jwk in jwks.keys {
391 match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
392 Ok(key) => {
393 if let Some(kid) = jwk.common.key_id {
394 keys.insert(kid, OidcDecodingKey(key));
395 }
396 }
397 Err(e) => {
398 warn!("Failed to parse JWK: {}", e);
399 }
400 }
401 }
402
403 Ok(keys)
404 }
405
406 async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
409 {
411 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
412
413 if let Some(key) = decoding_keys.get(kid) {
414 return Ok(key.clone());
415 }
416 }
417
418 let new_decoding_keys = self.fetch_jwks(issuer).await?;
420
421 let decoding_key = new_decoding_keys.get(kid).cloned();
422
423 {
424 let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
425 *decoding_keys = new_decoding_keys;
426 }
427
428 if let Some(key) = decoding_key {
429 return Ok(key);
430 }
431
432 {
433 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
434 debug!(
435 "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
436 );
437 Err(OidcError::NoMatchingKey {
438 key_id: kid.to_string(),
439 })
440 }
441 }
442
443 pub async fn validate_token(
444 &self,
445 token: &str,
446 expected_user: Option<&str>,
447 ) -> Result<ValidatedClaims, OidcError> {
448 let system_vars = self.adapter_client.get_system_vars().await;
450 let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
451 return Err(OidcError::MissingIssuer);
452 };
453
454 let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
455
456 let expected_audiences: Vec<String> = {
457 let audiences: Vec<String> =
458 serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
459 .map_err(|_| OidcError::AudienceParseError)?;
460
461 if audiences.is_empty() {
462 warn!(
463 "Audience validation skipped. It is discouraged \
464 to skip audience validation since it allows anyone \
465 with a JWT issued by the same issuer to authenticate."
466 );
467 }
468 audiences
469 };
470
471 let header = jsonwebtoken::decode_header(token).map_err(|e| {
474 debug!("Failed to decode JWT header: {:?}", e);
475 OidcError::Jwt
476 })?;
477
478 let kid = header.kid.ok_or(OidcError::MissingKid)?;
479 let decoding_key = self.find_key(&kid, &issuer).await?;
482
483 let mut validation = jsonwebtoken::Validation::new(header.alg);
485 validation.set_issuer(&[&issuer]);
486 if !expected_audiences.is_empty() {
487 validation.set_audience(&expected_audiences);
488 } else {
489 validation.validate_aud = false;
490 }
491
492 let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
494 .map_err(|e| match e.kind() {
495 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
496 if !expected_audiences.is_empty() {
497 OidcError::InvalidAudience {
498 expected_audiences
499 }
500 } else {
501 soft_panic_or_log!(
502 "received an audience validation error when audience validation is disabled"
503 );
504 OidcError::Jwt
505 }
506 }
507 jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
508 expected_issuer: issuer.clone(),
509 },
510 jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
511 _ => OidcError::Jwt,
512 })?;
513
514 let user = token_data.claims.user(&authentication_claim).ok_or(
515 OidcError::NoMatchingAuthenticationClaim {
516 authentication_claim,
517 },
518 )?;
519
520 if let Some(expected) = expected_user {
522 if user != expected {
523 return Err(OidcError::WrongUser);
524 }
525 }
526
527 Ok(ValidatedClaims {
528 user: user.to_string(),
529 _private: (),
530 })
531 }
532
533 async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
537 match self.adapter_client.role_can_login(role_name).await {
538 Ok(()) => Ok(()),
539 Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
540 Ok(())
542 }
543 Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
544 Err(OidcError::NonLogin)
545 }
546 Err(e) => {
547 warn!(?e, "unexpected error checking OIDC role login");
548 Err(OidcError::LoginCheckError)
549 }
550 }
551 }
552}
553
554impl GenericOidcAuthenticator {
555 pub async fn authenticate(
556 &self,
557 token: &str,
558 expected_user: Option<&str>,
559 ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
560 let validated_claims = self.inner.validate_token(token, expected_user).await?;
561 self.inner.check_role_login(&validated_claims.user).await?;
562 Ok((validated_claims, Authenticated))
563 }
564}
565
566fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
567 let mut openid_config_url =
568 Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
569 {
570 let mut segments = openid_config_url
571 .path_segments_mut()
572 .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
573 segments.pop_if_empty();
575 segments.push(".well-known");
576 segments.push("openid-configuration");
577 }
578 Ok(openid_config_url)
579}
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[mz_ore::test]
585 fn test_aud_single_string() {
586 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
587 let claims: OidcClaims = serde_json::from_str(json).unwrap();
588 assert_eq!(claims.aud, vec!["my-app"]);
589 }
590
591 #[mz_ore::test]
592 fn test_aud_array() {
593 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
594 let claims: OidcClaims = serde_json::from_str(json).unwrap();
595 assert_eq!(claims.aud, vec!["app1", "app2"]);
596 }
597
598 #[mz_ore::test]
599 fn test_user() {
600 let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
601 let claims: OidcClaims = serde_json::from_str(json).unwrap();
602 assert_eq!(claims.user("sub"), Some("user-123"));
603 assert_eq!(claims.user("email"), Some("alice@example.com"));
604 assert_eq!(claims.user("missing"), None);
605 }
606
607 #[mz_ore::test]
608 fn test_build_openid_config_url() {
609 let issuer = "https://dev-123456.okta.com/oauth2/default";
610 let url = build_openid_config_url(issuer).unwrap();
611 assert_eq!(
612 url.to_string(),
613 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
614 );
615 }
616
617 #[mz_ore::test]
618 fn test_build_openid_config_url_trailing_slash() {
619 let issuer = "https://dev-123456.okta.com/oauth2/default/";
620 let url = build_openid_config_url(issuer).unwrap();
621 assert_eq!(
622 url.to_string(),
623 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
624 );
625 }
626
627 #[mz_ore::test]
628 fn zeroize_clears_validated_claims() {
629 use mz_ore::secure::Zeroize;
630 let mut claims = ValidatedClaims {
631 user: "alice@example.com".to_string(),
632 _private: (),
633 };
634 claims.zeroize();
635 assert!(claims.user.is_empty());
636 }
637
638 #[mz_ore::test]
639 fn oidc_claims_implements_zeroize_on_drop() {
640 fn assert_zod<T: ZeroizeOnDrop>() {}
641 assert_zod::<OidcClaims>();
642 assert_zod::<ValidatedClaims>();
643 }
644
645 #[mz_ore::test]
646 fn zeroize_clears_oidc_claims() {
647 use mz_ore::secure::Zeroize;
648 let mut claims = OidcClaims {
649 iss: "https://issuer.example.com".to_string(),
650 exp: 1234567890,
651 iat: Some(1234567800),
652 aud: vec!["app1".to_string(), "app2".to_string()],
653 unknown_claims: BTreeMap::from([(
654 "email".to_string(),
655 serde_json::Value::String("alice@example.com".to_string()),
656 )]),
657 };
658 claims.zeroize();
659 assert!(claims.iss.is_empty());
660 assert_eq!(claims.exp, 0);
661 assert!(claims.iat.is_none());
662 assert!(claims.aud.is_empty());
663 assert!(claims.unknown_claims.is_empty());
664 }
665}