1use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{
22 OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_GROUP_CLAIM, OIDC_ISSUER,
23};
24use mz_auth::Authenticated;
25use mz_ore::secure::{Zeroize, ZeroizeOnDrop};
26use mz_ore::soft_panic_or_log;
27use mz_pgwire_common::{ErrorResponse, Severity};
28use reqwest::Client as HttpClient;
29use serde::{Deserialize, Deserializer, Serialize};
30use tokio_postgres::error::SqlState;
31
32use tracing::{debug, warn};
33use url::Url;
34#[derive(Debug)]
36pub enum OidcError {
37 MissingIssuer,
38 InvalidIssuerUrl(String),
40 AudienceParseError,
41 FetchFromProviderFailed {
43 url: String,
44 error_message: String,
45 },
46 MissingKid,
48 NoMatchingKey {
50 key_id: String,
52 },
53 NoMatchingAuthenticationClaim {
55 authentication_claim: String,
56 },
57 Jwt,
59 WrongUser,
60 InvalidAudience {
61 expected_audiences: Vec<String>,
62 },
63 InvalidIssuer {
64 expected_issuer: String,
65 },
66 ExpiredSignature,
67 NonLogin,
69 LoginCheckError,
70}
71
72impl std::fmt::Display for OidcError {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
76 OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
77 OidcError::AudienceParseError => {
78 write!(f, "failed to parse OIDC_AUDIENCE system variable")
79 }
80 OidcError::FetchFromProviderFailed { .. } => {
81 write!(f, "failed to fetch OIDC provider configuration")
82 }
83 OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
84 OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
85 OidcError::NoMatchingAuthenticationClaim { .. } => {
86 write!(f, "no matching authentication claim found in the JWT")
87 }
88 OidcError::Jwt => write!(f, "failed to validate JWT"),
89 OidcError::WrongUser => write!(f, "wrong user"),
90 OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
91 OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
92 OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
93 OidcError::NonLogin => write!(f, "role is not allowed to login"),
94 OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
95 }
96 }
97}
98
99impl std::error::Error for OidcError {}
100
101impl OidcError {
102 pub fn code(&self) -> SqlState {
103 SqlState::INVALID_AUTHORIZATION_SPECIFICATION
104 }
105
106 pub fn detail(&self) -> Option<String> {
107 match self {
108 OidcError::InvalidIssuerUrl(issuer) => {
109 Some(format!("Could not parse \"{issuer}\" as a URL."))
110 }
111 OidcError::FetchFromProviderFailed { url, error_message } => {
112 Some(format!("Fetching \"{url}\" failed. {error_message}"))
113 }
114 OidcError::NoMatchingKey { key_id } => {
115 Some(format!("JWT key ID \"{key_id}\" was not found."))
116 }
117 OidcError::InvalidAudience { expected_audiences } => Some(format!(
118 "Expected one of audiences {:?} in the JWT.",
119 expected_audiences,
120 )),
121 OidcError::InvalidIssuer { expected_issuer } => {
122 Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
123 }
124 OidcError::NoMatchingAuthenticationClaim {
125 authentication_claim,
126 } => Some(format!(
127 "Expected authentication claim \"{authentication_claim}\" in the JWT.",
128 )),
129 OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
130 _ => None,
131 }
132 }
133
134 pub fn hint(&self) -> Option<String> {
135 match self {
136 OidcError::MissingIssuer => {
137 Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
138 }
139 _ => None,
140 }
141 }
142
143 pub fn into_response(self) -> ErrorResponse {
144 ErrorResponse {
145 severity: Severity::Fatal,
146 code: self.code(),
147 message: self.to_string(),
148 detail: self.detail(),
149 hint: self.hint(),
150 position: None,
151 }
152 }
153}
154
155fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
156where
157 D: Deserializer<'de>,
158{
159 #[derive(Deserialize)]
160 #[serde(untagged)]
161 enum StringOrVec {
162 String(String),
163 Vec(Vec<String>),
164 }
165
166 match StringOrVec::deserialize(deserializer)? {
167 StringOrVec::String(s) => Ok(vec![s]),
168 StringOrVec::Vec(v) => Ok(v),
169 }
170}
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct OidcClaims {
174 pub iss: String,
176 pub exp: i64,
178 #[serde(default)]
180 pub iat: Option<i64>,
181 #[serde(default, deserialize_with = "deserialize_string_or_vec")]
183 pub aud: Vec<String>,
184 #[serde(flatten)]
186 pub unknown_claims: BTreeMap<String, serde_json::Value>,
187}
188
189impl Zeroize for OidcClaims {
190 fn zeroize(&mut self) {
191 self.iss.zeroize();
192 self.exp.zeroize();
193 self.iat.zeroize();
194 for s in &mut self.aud {
195 s.zeroize();
196 }
197 self.aud.clear();
198 while let Some((mut k, mut v)) = self.unknown_claims.pop_first() {
201 k.zeroize();
202 zeroize_json_value(&mut v);
203 }
204 }
205}
206
207impl Drop for OidcClaims {
208 fn drop(&mut self) {
209 self.zeroize();
210 }
211}
212
213impl ZeroizeOnDrop for OidcClaims {}
216
217fn zeroize_json_value(v: &mut serde_json::Value) {
218 use serde_json::Value;
219 match v {
220 Value::String(s) => s.zeroize(),
221 Value::Array(a) => {
222 for item in a.iter_mut() {
223 zeroize_json_value(item);
224 }
225 a.clear();
226 }
227 Value::Object(map) => {
228 let taken = std::mem::take(map);
229 for (mut k, mut nested) in taken {
230 k.zeroize();
231 zeroize_json_value(&mut nested);
232 }
233 }
234 Value::Number(_) => {
235 *v = Value::Number(serde_json::Number::from(0u8));
236 }
237 Value::Bool(b) => *b = false,
238 Value::Null => {}
239 }
240}
241
242impl OidcClaims {
243 fn user(&self, authentication_claim: &str) -> Option<&str> {
245 self.unknown_claims
246 .get(authentication_claim)
247 .and_then(|value| value.as_str())
248 }
249
250 pub fn groups(&self, claim_path: &str) -> Option<Vec<String>> {
253 mz_auth::group_claims::extract_groups(&self.unknown_claims, claim_path)
254 }
255}
256
257#[derive(Zeroize, ZeroizeOnDrop)]
258pub struct ValidatedClaims {
259 pub user: String,
260 pub groups: Option<Vec<String>>,
262 _private: (),
264}
265
266#[derive(Clone)]
268struct OidcDecodingKey(jsonwebtoken::DecodingKey);
269
270impl std::fmt::Debug for OidcDecodingKey {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 f.debug_struct("OidcDecodingKey")
273 .field("key", &"<redacted>")
274 .finish()
275 }
276}
277
278#[derive(Clone, Debug)]
283pub struct GenericOidcAuthenticator {
284 inner: Arc<GenericOidcAuthenticatorInner>,
285}
286
287#[derive(Debug, Deserialize)]
290struct OpenIdConfiguration {
291 jwks_uri: String,
293}
294
295#[derive(Debug)]
296pub struct GenericOidcAuthenticatorInner {
297 adapter_client: AdapterClient,
298 decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
299 http_client: HttpClient,
300}
301
302impl GenericOidcAuthenticator {
303 pub fn new(adapter_client: AdapterClient) -> Self {
308 let http_client = HttpClient::new();
309
310 Self {
311 inner: Arc::new(GenericOidcAuthenticatorInner {
312 adapter_client,
313 decoding_keys: Mutex::new(BTreeMap::new()),
314 http_client,
315 }),
316 }
317 }
318}
319
320impl GenericOidcAuthenticatorInner {
321 async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
322 let openid_config_url = build_openid_config_url(issuer)?;
323
324 let openid_config_url_str = openid_config_url.to_string();
325
326 let response = self
328 .http_client
329 .get(openid_config_url)
330 .timeout(Duration::from_secs(10))
331 .send()
332 .await
333 .map_err(|e| OidcError::FetchFromProviderFailed {
334 url: openid_config_url_str.clone(),
335 error_message: e.to_string(),
336 })?;
337
338 if !response.status().is_success() {
339 return Err(OidcError::FetchFromProviderFailed {
340 url: openid_config_url_str.clone(),
341 error_message: response
342 .error_for_status()
343 .err()
344 .map(|e| e.to_string())
345 .unwrap_or_else(|| "Unknown error".to_string()),
346 });
347 }
348
349 let openid_config: OpenIdConfiguration =
350 response
351 .json()
352 .await
353 .map_err(|e| OidcError::FetchFromProviderFailed {
354 url: openid_config_url_str,
355 error_message: e.to_string(),
356 })?;
357
358 Ok(openid_config.jwks_uri)
359 }
360
361 async fn fetch_jwks(
363 &self,
364 issuer: &str,
365 ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
366 let jwks_uri = self.fetch_jwks_uri(issuer).await?;
367 let response = self
368 .http_client
369 .get(&jwks_uri)
370 .timeout(Duration::from_secs(10))
371 .send()
372 .await
373 .map_err(|e| OidcError::FetchFromProviderFailed {
374 url: jwks_uri.clone(),
375 error_message: e.to_string(),
376 })?;
377
378 if !response.status().is_success() {
379 return Err(OidcError::FetchFromProviderFailed {
380 url: jwks_uri.clone(),
381 error_message: response
382 .error_for_status()
383 .err()
384 .map(|e| e.to_string())
385 .unwrap_or_else(|| "Unknown error".to_string()),
386 });
387 }
388
389 let jwks: JwkSet =
390 response
391 .json()
392 .await
393 .map_err(|e| OidcError::FetchFromProviderFailed {
394 url: jwks_uri.clone(),
395 error_message: e.to_string(),
396 })?;
397
398 let mut keys = BTreeMap::new();
399
400 for jwk in jwks.keys {
401 match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
402 Ok(key) => {
403 if let Some(kid) = jwk.common.key_id {
404 keys.insert(kid, OidcDecodingKey(key));
405 }
406 }
407 Err(e) => {
408 warn!("Failed to parse JWK: {}", e);
409 }
410 }
411 }
412
413 Ok(keys)
414 }
415
416 async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
419 {
421 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
422
423 if let Some(key) = decoding_keys.get(kid) {
424 return Ok(key.clone());
425 }
426 }
427
428 let new_decoding_keys = self.fetch_jwks(issuer).await?;
430
431 let decoding_key = new_decoding_keys.get(kid).cloned();
432
433 {
434 let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
435 *decoding_keys = new_decoding_keys;
436 }
437
438 if let Some(key) = decoding_key {
439 return Ok(key);
440 }
441
442 {
443 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
444 debug!(
445 "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
446 );
447 Err(OidcError::NoMatchingKey {
448 key_id: kid.to_string(),
449 })
450 }
451 }
452
453 pub async fn validate_token(
454 &self,
455 token: &str,
456 expected_user: Option<&str>,
457 ) -> Result<ValidatedClaims, OidcError> {
458 let system_vars = self.adapter_client.get_system_vars().await;
460 let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
461 return Err(OidcError::MissingIssuer);
462 };
463
464 let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
465
466 let expected_audiences: Vec<String> = {
467 let audiences: Vec<String> =
468 serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
469 .map_err(|_| OidcError::AudienceParseError)?;
470
471 if audiences.is_empty() {
472 warn!(
473 "Audience validation skipped. It is discouraged \
474 to skip audience validation since it allows anyone \
475 with a JWT issued by the same issuer to authenticate."
476 );
477 }
478 audiences
479 };
480
481 let header = jsonwebtoken::decode_header(token).map_err(|e| {
484 debug!("Failed to decode JWT header: {:?}", e);
485 OidcError::Jwt
486 })?;
487
488 let kid = header.kid.ok_or(OidcError::MissingKid)?;
489 let decoding_key = self.find_key(&kid, &issuer).await?;
492
493 let mut validation = jsonwebtoken::Validation::new(header.alg);
495 validation.set_issuer(&[&issuer]);
496 if !expected_audiences.is_empty() {
497 validation.set_audience(&expected_audiences);
498 } else {
499 validation.validate_aud = false;
500 }
501
502 let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
504 .map_err(|e| match e.kind() {
505 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
506 if !expected_audiences.is_empty() {
507 OidcError::InvalidAudience {
508 expected_audiences
509 }
510 } else {
511 soft_panic_or_log!(
512 "received an audience validation error when audience validation is disabled"
513 );
514 OidcError::Jwt
515 }
516 }
517 jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
518 expected_issuer: issuer.clone(),
519 },
520 jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
521 _ => OidcError::Jwt,
522 })?;
523
524 let user = token_data.claims.user(&authentication_claim).ok_or(
525 OidcError::NoMatchingAuthenticationClaim {
526 authentication_claim,
527 },
528 )?;
529
530 if let Some(expected) = expected_user {
532 if user != expected {
533 return Err(OidcError::WrongUser);
534 }
535 }
536
537 let group_claim = OIDC_GROUP_CLAIM.get(system_vars.dyncfgs());
539 let groups = token_data.claims.groups(&group_claim);
540
541 Ok(ValidatedClaims {
542 user: user.to_string(),
543 groups,
544 _private: (),
545 })
546 }
547
548 async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
552 match self.adapter_client.role_can_login(role_name).await {
553 Ok(()) => Ok(()),
554 Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
555 Ok(())
557 }
558 Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
559 Err(OidcError::NonLogin)
560 }
561 Err(e) => {
562 warn!(?e, "unexpected error checking OIDC role login");
563 Err(OidcError::LoginCheckError)
564 }
565 }
566 }
567}
568
569impl GenericOidcAuthenticator {
570 pub async fn authenticate(
571 &self,
572 token: &str,
573 expected_user: Option<&str>,
574 ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
575 let validated_claims = self.inner.validate_token(token, expected_user).await?;
576 self.inner.check_role_login(&validated_claims.user).await?;
577 Ok((validated_claims, Authenticated))
578 }
579}
580
581fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
582 let mut openid_config_url =
583 Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
584 {
585 let mut segments = openid_config_url
586 .path_segments_mut()
587 .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
588 segments.pop_if_empty();
590 segments.push(".well-known");
591 segments.push("openid-configuration");
592 }
593 Ok(openid_config_url)
594}
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[mz_ore::test]
600 fn test_aud_single_string() {
601 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
602 let claims: OidcClaims = serde_json::from_str(json).unwrap();
603 assert_eq!(claims.aud, vec!["my-app"]);
604 }
605
606 #[mz_ore::test]
607 fn test_aud_array() {
608 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
609 let claims: OidcClaims = serde_json::from_str(json).unwrap();
610 assert_eq!(claims.aud, vec!["app1", "app2"]);
611 }
612
613 #[mz_ore::test]
614 fn test_user() {
615 let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
616 let claims: OidcClaims = serde_json::from_str(json).unwrap();
617 assert_eq!(claims.user("sub"), Some("user-123"));
618 assert_eq!(claims.user("email"), Some("alice@example.com"));
619 assert_eq!(claims.user("missing"), None);
620 }
621
622 #[mz_ore::test]
623 fn test_build_openid_config_url() {
624 let issuer = "https://dev-123456.okta.com/oauth2/default";
625 let url = build_openid_config_url(issuer).unwrap();
626 assert_eq!(
627 url.to_string(),
628 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
629 );
630 }
631
632 #[mz_ore::test]
633 fn test_build_openid_config_url_trailing_slash() {
634 let issuer = "https://dev-123456.okta.com/oauth2/default/";
635 let url = build_openid_config_url(issuer).unwrap();
636 assert_eq!(
637 url.to_string(),
638 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
639 );
640 }
641
642 #[mz_ore::test]
643 fn zeroize_clears_validated_claims() {
644 use mz_ore::secure::Zeroize;
645 let mut claims = ValidatedClaims {
646 user: "alice@example.com".to_string(),
647 groups: Some(vec!["eng".to_string()]),
648 _private: (),
649 };
650 claims.zeroize();
651 assert!(claims.user.is_empty());
652 }
653
654 #[mz_ore::test]
655 fn oidc_claims_implements_zeroize_on_drop() {
656 fn assert_zod<T: ZeroizeOnDrop>() {}
657 assert_zod::<OidcClaims>();
658 assert_zod::<ValidatedClaims>();
659 }
660
661 #[mz_ore::test]
662 fn zeroize_clears_oidc_claims() {
663 use mz_ore::secure::Zeroize;
664 let mut claims = OidcClaims {
665 iss: "https://issuer.example.com".to_string(),
666 exp: 1234567890,
667 iat: Some(1234567800),
668 aud: vec!["app1".to_string(), "app2".to_string()],
669 unknown_claims: BTreeMap::from([(
670 "email".to_string(),
671 serde_json::Value::String("alice@example.com".to_string()),
672 )]),
673 };
674 claims.zeroize();
675 assert!(claims.iss.is_empty());
676 assert_eq!(claims.exp, 0);
677 assert!(claims.iat.is_none());
678 assert!(claims.aud.is_empty());
679 assert!(claims.unknown_claims.is_empty());
680 }
681}