1use std::collections::BTreeMap;
16use std::sync::{Arc, Mutex};
17use std::time::Duration;
18
19use jsonwebtoken::jwk::JwkSet;
20use mz_adapter::Client as AdapterClient;
21use mz_adapter_types::dyncfgs::{OIDC_AUDIENCE, OIDC_ISSUER};
22use mz_auth::Authenticated;
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Deserializer, Serialize};
25
26use tracing::warn;
27use url::Url;
28#[derive(Debug)]
30pub enum OidcError {
31 MissingIssuer,
33 InvalidIssuerUrl(url::ParseError),
35 OpenIdConfigFetchFailed(String),
37 JwksFetchFailed(String),
39 MissingKid,
41 NoMatchingKey,
43 Jwt(jsonwebtoken::errors::Error),
45 WrongUser,
47}
48
49impl std::fmt::Display for OidcError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 OidcError::MissingIssuer => write!(f, "missing OIDC issuer URL"),
53 OidcError::InvalidIssuerUrl(e) => {
54 write!(f, "failed to parse OIDC issuer URL: {}", e)
55 }
56 OidcError::OpenIdConfigFetchFailed(e) => {
57 write!(f, "failed to fetch OpenID configuration: {}", e)
58 }
59 OidcError::JwksFetchFailed(e) => write!(f, "failed to fetch JWKS: {}", e),
60 OidcError::MissingKid => write!(f, "missing key ID in token header"),
61 OidcError::NoMatchingKey => write!(f, "no matching key in JWKS"),
62 OidcError::Jwt(e) => write!(f, "JWT error: {}", e),
63 OidcError::WrongUser => write!(f, "user does not match expected value"),
64 }
65 }
66}
67
68impl std::error::Error for OidcError {}
69
70fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
71where
72 D: Deserializer<'de>,
73{
74 #[derive(Deserialize)]
75 #[serde(untagged)]
76 enum StringOrVec {
77 String(String),
78 Vec(Vec<String>),
79 }
80
81 match StringOrVec::deserialize(deserializer)? {
82 StringOrVec::String(s) => Ok(vec![s]),
83 StringOrVec::Vec(v) => Ok(v),
84 }
85}
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct OidcClaims {
89 pub sub: String,
91 pub iss: String,
93 pub exp: i64,
95 #[serde(default)]
97 pub iat: Option<i64>,
98 #[serde(default)]
100 pub email: Option<String>,
101 #[serde(default, deserialize_with = "deserialize_string_or_vec")]
103 pub aud: Vec<String>,
104}
105
106impl OidcClaims {
107 pub fn username(&self) -> &str {
112 self.email.as_deref().unwrap_or(&self.sub)
113 }
114}
115
116#[derive(Clone)]
117struct OidcDecodingKey(jsonwebtoken::DecodingKey);
118
119impl std::fmt::Debug for OidcDecodingKey {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("OidcDecodingKey")
122 .field("key", &"<redacted>")
123 .finish()
124 }
125}
126
127#[derive(Clone, Debug)]
132pub struct GenericOidcAuthenticator {
133 inner: Arc<GenericOidcAuthenticatorInner>,
134}
135
136#[derive(Debug, Deserialize)]
139struct OpenIdConfiguration {
140 jwks_uri: String,
142}
143
144#[derive(Debug)]
145pub struct GenericOidcAuthenticatorInner {
146 adapter_client: AdapterClient,
147 decoding_keys: Mutex<BTreeMap<String, OidcDecodingKey>>,
148 http_client: HttpClient,
149}
150
151impl GenericOidcAuthenticator {
152 pub fn new(adapter_client: AdapterClient) -> Self {
157 let http_client = HttpClient::new();
158
159 Self {
160 inner: Arc::new(GenericOidcAuthenticatorInner {
161 adapter_client,
162 decoding_keys: Mutex::new(BTreeMap::new()),
163 http_client,
164 }),
165 }
166 }
167}
168
169impl GenericOidcAuthenticatorInner {
170 async fn fetch_jwks_uri(&self, issuer: &str) -> Result<String, OidcError> {
171 let openid_config_url = Url::parse(issuer)
172 .and_then(|url| url.join(".well-known/openid-configuration"))
173 .map_err(OidcError::InvalidIssuerUrl)?;
174
175 let response = self
177 .http_client
178 .get(openid_config_url)
179 .timeout(Duration::from_secs(10))
180 .send()
181 .await
182 .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?;
183
184 if !response.status().is_success() {
185 return Err(OidcError::OpenIdConfigFetchFailed(format!(
186 "HTTP {}",
187 response.status()
188 )));
189 }
190
191 let openid_config: OpenIdConfiguration = response
192 .json()
193 .await
194 .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?;
195
196 Ok(openid_config.jwks_uri)
197 }
198
199 async fn fetch_jwks(
201 &self,
202 issuer: &str,
203 ) -> Result<BTreeMap<String, OidcDecodingKey>, OidcError> {
204 let jwks_uri = self.fetch_jwks_uri(issuer).await?;
205 let response = self
206 .http_client
207 .get(&jwks_uri)
208 .timeout(Duration::from_secs(10))
209 .send()
210 .await
211 .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?;
212
213 if !response.status().is_success() {
214 return Err(OidcError::JwksFetchFailed(format!(
215 "HTTP {}",
216 response.status()
217 )));
218 }
219
220 let jwks: JwkSet = response
221 .json()
222 .await
223 .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?;
224
225 let mut keys = BTreeMap::new();
226 for jwk in jwks.keys {
227 match jsonwebtoken::DecodingKey::from_jwk(&jwk) {
228 Ok(key) => {
229 if let Some(kid) = jwk.common.key_id {
230 keys.insert(kid, OidcDecodingKey(key));
231 }
232 }
233 Err(e) => {
234 warn!("Failed to parse JWK: {}", e);
235 }
236 }
237 }
238
239 if keys.is_empty() {
240 return Err(OidcError::JwksFetchFailed(
241 "no valid keys found in JWKS".to_string(),
242 ));
243 }
244
245 Ok(keys)
246 }
247
248 async fn find_key(&self, kid: &str, issuer: &str) -> Result<OidcDecodingKey, OidcError> {
251 {
252 let decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
253
254 if let Some(key) = decoding_keys.get(kid) {
255 return Ok(key.clone());
256 }
257 }
258
259 let new_decoding_keys = self.fetch_jwks(issuer).await?;
260
261 let decoding_key = new_decoding_keys.get(kid).cloned();
262
263 {
264 let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned");
265 decoding_keys.extend(new_decoding_keys);
266 }
267
268 if let Some(key) = decoding_key {
269 return Ok(key);
270 }
271
272 Err(OidcError::NoMatchingKey)
273 }
274
275 pub async fn validate_token(
276 &self,
277 token: &str,
278 expected_user: Option<&str>,
279 ) -> Result<OidcClaims, OidcError> {
280 let system_vars = self.adapter_client.get_system_vars().await;
282 let Some(issuer) = OIDC_ISSUER.get(system_vars.dyncfgs()) else {
283 return Err(OidcError::MissingIssuer);
284 };
285
286 let audience = {
287 let aud = OIDC_AUDIENCE.get(system_vars.dyncfgs());
288 if aud.is_none() {
289 warn!(
290 "Audience validation skipped. It is discouraged
291 to skip audience validation since it allows
292 anyone with a JWT issued by the same issuer
293 to authenticate."
294 );
295 }
296 aud
297 };
298
299 let header = jsonwebtoken::decode_header(token).map_err(OidcError::Jwt)?;
301
302 let kid = header.kid.ok_or(OidcError::MissingKid)?;
303 let decoding_key = self.find_key(&kid, &issuer).await?;
305
306 let mut validation = jsonwebtoken::Validation::new(header.alg);
309 validation.set_issuer(&[&issuer]);
310 if let Some(ref audience) = audience {
311 validation.set_audience(&[audience]);
312 } else {
313 validation.validate_aud = false;
314 }
315
316 let token_data = jsonwebtoken::decode::<OidcClaims>(token, &(decoding_key.0), &validation)
318 .map_err(OidcError::Jwt)?;
319
320 if let Some(expected) = expected_user {
322 if token_data.claims.username() != expected {
323 return Err(OidcError::WrongUser);
324 }
325 }
326
327 Ok(token_data.claims)
328 }
329}
330
331impl GenericOidcAuthenticator {
332 pub async fn authenticate(
333 &self,
334 token: &str,
335 expected_user: Option<&str>,
336 ) -> Result<(OidcClaims, Authenticated), OidcError> {
337 let claims = self.inner.validate_token(token, expected_user).await?;
338 Ok((claims, Authenticated))
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[mz_ore::test]
347 fn test_aud_single_string() {
348 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#;
349 let claims: OidcClaims = serde_json::from_str(json).unwrap();
350 assert_eq!(claims.aud, vec!["my-app"]);
351 }
352
353 #[mz_ore::test]
354 fn test_aud_array() {
355 let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#;
356 let claims: OidcClaims = serde_json::from_str(json).unwrap();
357 assert_eq!(claims.aud, vec!["app1", "app2"]);
358 }
359}