1use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::{AdapterError, AuthenticationError, Client as AdapterClient};
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_AUTHENTICATION_CLAIM, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use mz_ore::soft_panic_or_log;
24use mz_pgwire_common::{ErrorResponse, Severity};
25use reqwest::Client as HttpClient;
26use serde::{Deserialize, Deserializer, Serialize};
27use tokio_postgres::error::SqlState;
28
29use tracing::{debug, warn};
30use url::Url;
31#[derive(Debug)]
33pub enum OidcError {
34 MissingIssuer,
35 InvalidIssuerUrl(String),
37 AudienceParseError,
38 FetchFromProviderFailed {
40 url: String,
41 error_message: String,
42 },
43 MissingKid,
45 NoMatchingKey {
47 key_id: String,
49 },
50 NoMatchingAuthenticationClaim {
52 authentication_claim: String,
53 },
54 Jwt,
56 WrongUser,
57 InvalidAudience {
58 expected_audiences: Vec<String>,
59 },
60 InvalidIssuer {
61 expected_issuer: String,
62 },
63 ExpiredSignature,
64 NonLogin,
66 LoginCheckError,
67}
68
69impl std::fmt::Display for OidcError {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 OidcError::MissingIssuer => write!(f, "OIDC issuer is not configured"),
73 OidcError::InvalidIssuerUrl(_) => write!(f, "invalid OIDC issuer URL"),
74 OidcError::AudienceParseError => {
75 write!(f, "failed to parse OIDC_AUDIENCE system variable")
76 }
77 OidcError::FetchFromProviderFailed { .. } => {
78 write!(f, "failed to fetch OIDC provider configuration")
79 }
80 OidcError::MissingKid => write!(f, "missing key ID in JWT header"),
81 OidcError::NoMatchingKey { .. } => write!(f, "no matching key found in the JWKS"),
82 OidcError::NoMatchingAuthenticationClaim { .. } => {
83 write!(f, "no matching authentication claim found in the JWT")
84 }
85 OidcError::Jwt => write!(f, "failed to validate JWT"),
86 OidcError::WrongUser => write!(f, "wrong user"),
87 OidcError::InvalidAudience { .. } => write!(f, "invalid audience"),
88 OidcError::InvalidIssuer { .. } => write!(f, "invalid issuer"),
89 OidcError::ExpiredSignature => write!(f, "authentication credentials have expired"),
90 OidcError::NonLogin => write!(f, "role is not allowed to login"),
91 OidcError::LoginCheckError => write!(f, "unexpected error checking if role can login"),
92 }
93 }
94}
95
96impl std::error::Error for OidcError {}
97
98impl OidcError {
99 pub fn code(&self) -> SqlState {
100 SqlState::INVALID_AUTHORIZATION_SPECIFICATION
101 }
102
103 pub fn detail(&self) -> Option<String> {
104 match self {
105 OidcError::InvalidIssuerUrl(issuer) => {
106 Some(format!("Could not parse \"{issuer}\" as a URL."))
107 }
108 OidcError::FetchFromProviderFailed { url, error_message } => {
109 Some(format!("Fetching \"{url}\" failed. {error_message}"))
110 }
111 OidcError::NoMatchingKey { key_id } => {
112 Some(format!("JWT key ID \"{key_id}\" was not found."))
113 }
114 OidcError::InvalidAudience { expected_audiences } => Some(format!(
115 "Expected one of audiences {:?} in the JWT.",
116 expected_audiences,
117 )),
118 OidcError::InvalidIssuer { expected_issuer } => {
119 Some(format!("Expected issuer \"{expected_issuer}\" in the JWT.",))
120 }
121 OidcError::NoMatchingAuthenticationClaim {
122 authentication_claim,
123 } => Some(format!(
124 "Expected authentication claim \"{authentication_claim}\" in the JWT.",
125 )),
126 OidcError::NonLogin => Some("The role does not have the LOGIN attribute.".into()),
127 _ => None,
128 }
129 }
130
131 pub fn hint(&self) -> Option<String> {
132 match self {
133 OidcError::MissingIssuer => {
134 Some("Configure the OIDC issuer using the oidc_issuer system variable.".into())
135 }
136 _ => None,
137 }
138 }
139
140 pub fn into_response(self) -> ErrorResponse {
141 ErrorResponse {
142 severity: Severity::Fatal,
143 code: self.code(),
144 message: self.to_string(),
145 detail: self.detail(),
146 hint: self.hint(),
147 position: None,
148 }
149 }
150}
151
152fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
153where
154 D: Deserializer<'de>,
155{
156 #[derive(Deserialize)]
157 #[serde(untagged)]
158 enum StringOrVec {
159 String(String),
160 Vec(Vec<String>),
161 }
162
163 match StringOrVec::deserialize(deserializer)? {
164 StringOrVec::String(s) => Ok(vec![s]),
165 StringOrVec::Vec(v) => Ok(v),
166 }
167}
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct OidcClaims {
171 pub iss: String,
173 pub exp: i64,
175 #[serde(default)]
177 pub iat: Option<i64>,
178 #[serde(default, deserialize_with = "deserialize_string_or_vec")]
180 pub aud: Vec<String>,
181 #[serde(flatten)]
183 pub unknown_claims: BTreeMap<String, serde_json::Value>,
184}
185
186impl OidcClaims {
187 fn user(&self, authentication_claim: &str) -> Option<&str> {
189 self.unknown_claims
190 .get(authentication_claim)
191 .and_then(|value| value.as_str())
192 }
193}
194
195pub struct ValidatedClaims {
196 pub user: String,
197 _private: (),
199}
200
201#[derive(Clone)]
202struct OidcDecodingKey(jsonwebtoken::DecodingKey);
203
204impl std::fmt::Debug for OidcDecodingKey {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("OidcDecodingKey")
207 .field("key", &"<redacted>")
208 .finish()
209 }
210}
211
212#[derive(Clone, Debug)]
217pub struct GenericOidcAuthenticator {
218 inner: Arc<GenericOidcAuthenticatorInner>,
219}
220
221#[derive(Debug, Deserialize)]
224struct OpenIdConfiguration {
225 jwks_uri: String,
227}
228
229#[derive(Debug)]
230pub struct GenericOidcAuthenticatorInner {
231 adapter_client: AdapterClient,
232 decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
233 http_client: HttpClient,
234}
235
236impl GenericOidcAuthenticator {
237 pub fn new(adapter_client: AdapterClient) -> Self {
242 let http_client = HttpClient::new();
243
244 Self {
245 inner: Arc::new(GenericOidcAuthenticatorInner {
246 adapter_client,
247 decoding_keys: Mutex::new(BTreeMap::new()),
248 http_client,
249 }),
250 }
251 }
252}
253
254impl GenericOidcAuthenticatorInner {
255 async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
256 let openid_config_url = build_openid_config_url(issuer)?;
257
258 let openid_config_url_str = openid_config_url.to_string();
259
260 let response = self
262 .http_client
263 .get(openid_config_url)
264 .timeout(Duration::from_secs(10))
265 .send()
266 .await
267 .map_err(|e| OidcError::FetchFromProviderFailed {
268 url: openid_config_url_str.clone(),
269 error_message: e.to_string(),
270 })?;
271
272 if !response.status().is_success() {
273 return Err(OidcError::FetchFromProviderFailed {
274 url: openid_config_url_str.clone(),
275 error_message: response
276 .error_for_status()
277 .err()
278 .map(|e| e.to_string())
279 .unwrap_or_else(|| "Unknown error".to_string()),
280 });
281 }
282
283 let openid_config: OpenIdConfiguration =
284 response
285 .json()
286 .await
287 .map_err(|e| OidcError::FetchFromProviderFailed {
288 url: openid_config_url_str,
289 error_message: e.to_string(),
290 })?;
291
292 Ok(openid_config.jwks_uri)
293 }
294
295 async fn fetch_jwks(
297 &self,
298 issuer: &str,
299 ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
300 let jwks_uri = self.fetch_jwks_uri(issuer).await?;
301 let response = self
302 .http_client
303 .get(&jwks_uri)
304 .timeout(Duration::from_secs(10))
305 .send()
306 .await
307 .map_err(|e| OidcError::FetchFromProviderFailed {
308 url: jwks_uri.clone(),
309 error_message: e.to_string(),
310 })?;
311
312 if !response.status().is_success() {
313 return Err(OidcError::FetchFromProviderFailed {
314 url: jwks_uri.clone(),
315 error_message: response
316 .error_for_status()
317 .err()
318 .map(|e| e.to_string())
319 .unwrap_or_else(|| "Unknown error".to_string()),
320 });
321 }
322
323 let jwks: JwkSet =
324 response
325 .json()
326 .await
327 .map_err(|e| OidcError::FetchFromProviderFailed {
328 url: jwks_uri.clone(),
329 error_message: e.to_string(),
330 })?;
331
332 let mut keys = BTreeMap::new();
333
334 for jwk in jwks.keys {
335 match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
336 Ok(key) => {
337 if let Some(kid) = jwk.common.key_id {
338 keys.insert(kid, OidcDecodingKey(key));
339 }
340 }
341 Err(e) => {
342 warn!("Failed to parse JWK: {}", e);
343 }
344 }
345 }
346
347 Ok(keys)
348 }
349
350 async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
353 {
355 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
356
357 if let Some(key) = decoding_keys.get(kid) {
358 return Ok(key.clone());
359 }
360 }
361
362 let new_decoding_keys = self.fetch_jwks(issuer).await?;
364
365 let decoding_key = new_decoding_keys.get(kid).cloned();
366
367 {
368 let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
369 *decoding_keys = new_decoding_keys;
370 }
371
372 if let Some(key) = decoding_key {
373 return Ok(key);
374 }
375
376 {
377 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
378 debug!(
379 "No matching key found in JWKS for key ID: {kid}. Available keys: {decoding_keys:?}."
380 );
381 Err(OidcError::NoMatchingKey {
382 key_id: kid.to_string(),
383 })
384 }
385 }
386
387 pub async fn validate_token(
388 &self,
389 token: &str,
390 expected_user: Option<&str>,
391 ) -> Result<ValidatedClaims, OidcError> {
392 let system_vars = self.adapter_client.get_system_vars().await;
394 let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
395 return Err(OidcError::MissingIssuer);
396 };
397
398 let authentication_claim = OIDC_AUTHENTICATION_CLAIM.get(system_vars.dyncfgs());
399
400 let expected_audiences: Vec<String> = {
401 let audiences: Vec<String> =
402 serde_json::from_value(OIDC_AUDIENCE.get(system_vars.dyncfgs()))
403 .map_err(|_| OidcError::AudienceParseError)?;
404
405 if audiences.is_empty() {
406 warn!(
407 "Audience validation skipped. It is discouraged \
408 to skip audience validation since it allows anyone \
409 with a JWT issued by the same issuer to authenticate."
410 );
411 }
412 audiences
413 };
414
415 let header = jsonwebtoken::decode_header(token).map_err(|e| {
418 debug!("Failed to decode JWT header: {:?}", e);
419 OidcError::Jwt
420 })?;
421
422 let kid = header.kid.ok_or(OidcError::MissingKid)?;
423 let decoding_key = self.find_key(&kid, &issuer).await?;
426
427 let mut validation = jsonwebtoken::Validation::new(header.alg);
429 validation.set_issuer(&[&issuer]);
430 if !expected_audiences.is_empty() {
431 validation.set_audience(&expected_audiences);
432 } else {
433 validation.validate_aud = false;
434 }
435
436 let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
438 .map_err(|e| match e.kind() {
439 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
440 if !expected_audiences.is_empty() {
441 OidcError::InvalidAudience {
442 expected_audiences
443 }
444 } else {
445 soft_panic_or_log!(
446 "received an audience validation error when audience validation is disabled"
447 );
448 OidcError::Jwt
449 }
450 }
451 jsonwebtoken::errors::ErrorKind::InvalidIssuer => OidcError::InvalidIssuer {
452 expected_issuer: issuer.clone(),
453 },
454 jsonwebtoken::errors::ErrorKind::ExpiredSignature => OidcError::ExpiredSignature,
455 _ => OidcError::Jwt,
456 })?;
457
458 let user = token_data.claims.user(&authentication_claim).ok_or(
459 OidcError::NoMatchingAuthenticationClaim {
460 authentication_claim,
461 },
462 )?;
463
464 if let Some(expected) = expected_user {
466 if user != expected {
467 return Err(OidcError::WrongUser);
468 }
469 }
470
471 Ok(ValidatedClaims {
472 user: user.to_string(),
473 _private: (),
474 })
475 }
476
477 async fn check_role_login(&self, role_name: &str) -> Result<(), OidcError> {
481 match self.adapter_client.role_can_login(role_name).await {
482 Ok(()) => Ok(()),
483 Err(AdapterError::AuthenticationError(AuthenticationError::RoleNotFound)) => {
484 Ok(())
486 }
487 Err(AdapterError::AuthenticationError(AuthenticationError::NonLogin)) => {
488 Err(OidcError::NonLogin)
489 }
490 Err(e) => {
491 warn!(?e, "unexpected error checking OIDC role login");
492 Err(OidcError::LoginCheckError)
493 }
494 }
495 }
496}
497
498impl GenericOidcAuthenticator {
499 pub async fn authenticate(
500 &self,
501 token: &str,
502 expected_user: Option<&str>,
503 ) -> Result<(ValidatedClaims, Authenticated), OidcError> {
504 let validated_claims = self.inner.validate_token(token, expected_user).await?;
505 self.inner.check_role_login(&validated_claims.user).await?;
506 Ok((validated_claims, Authenticated))
507 }
508}
509
510fn build_openid_config_url(issuer: &str) -> Result<Url, OidcError> {
511 let mut openid_config_url =
512 Url::parse(issuer).map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
513 {
514 let mut segments = openid_config_url
515 .path_segments_mut()
516 .map_err(|_| OidcError::InvalidIssuerUrl(issuer.to_string()))?;
517 segments.pop_if_empty();
519 segments.push(".well-known");
520 segments.push("openid-configuration");
521 }
522 Ok(openid_config_url)
523}
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[mz_ore::test]
529 fn test_aud_single_string() {
530 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
531 let claims: OidcClaims = serde_json::from_str(json).unwrap();
532 assert_eq!(claims.aud, vec!["my-app"]);
533 }
534
535 #[mz_ore::test]
536 fn test_aud_array() {
537 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
538 let claims: OidcClaims = serde_json::from_str(json).unwrap();
539 assert_eq!(claims.aud, vec!["app1", "app2"]);
540 }
541
542 #[mz_ore::test]
543 fn test_user() {
544 let json = r#"{"sub":"user-123","iss":"issuer","exp":1234,"aud":["app"],"email":"alice@example.com"}"#;
545 let claims: OidcClaims = serde_json::from_str(json).unwrap();
546 assert_eq!(claims.user("sub"), Some("user-123"));
547 assert_eq!(claims.user("email"), Some("alice@example.com"));
548 assert_eq!(claims.user("missing"), None);
549 }
550
551 #[mz_ore::test]
552 fn test_build_openid_config_url() {
553 let issuer = "https://dev-123456.okta.com/oauth2/default";
554 let url = build_openid_config_url(issuer).unwrap();
555 assert_eq!(
556 url.to_string(),
557 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
558 );
559 }
560
561 #[mz_ore::test]
562 fn test_build_openid_config_url_trailing_slash() {
563 let issuer = "https://dev-123456.okta.com/oauth2/default/";
564 let url = build_openid_config_url(issuer).unwrap();
565 assert_eq!(
566 url.to_string(),
567 "https://dev-123456.okta.com/oauth2/default/.well-known/openid-configuration"
568 );
569 }
570}