1use std::borrow::Cow;
16use std::collections::BTreeMap;
17use std::future::IntoFuture;
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::Arc;
20
21use axum::extract::State;
22use axum::routing::get;
23use axum::{Json, Router};
24use base64::Engine;
25use jsonwebtoken::{EncodingKey, Header, encode};
26use mz_ore::now::NowFn;
27use mz_ore::task::JoinHandle;
28use openssl::pkey::{PKey, Private};
29use openssl::rsa::Rsa;
30use serde::{Deserialize, Serialize};
31use tokio::net::TcpListener;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35struct JwkSet {
36 pub keys: Vec<Jwk>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41struct Jwk {
42 pub kty: String,
43 pub kid: String,
44 #[serde(rename = "use")]
45 pub key_use: String,
46 pub alg: String,
47 pub n: String,
48 pub e: String,
49}
50
51struct OidcMockContext {
53 issuer: String,
55 jwk: Jwk,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(untagged)]
64pub enum AudClaim {
65 Single(String),
66 Multiple(Vec<String>),
67}
68
69impl Default for AudClaim {
70 fn default() -> Self {
71 AudClaim::Multiple(vec![])
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77struct MockClaims {
78 iss: String,
79 exp: i64,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 iat: Option<i64>,
82 aud: AudClaim,
83 #[serde(flatten)]
84 unknown_claims: BTreeMap<String, serde_json::Value>,
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct GenerateJwtOptions<'a> {
90 pub email: Option<&'a str>,
92 pub exp: Option<i64>,
94 pub issuer: Option<&'a str>,
96 pub aud: Option<AudClaim>,
99 pub extra_claims: Option<BTreeMap<String, serde_json::Value>>,
101}
102
103pub struct OidcMockServer {
105 pub issuer: String,
108 pub kid: String,
110 pub encoding_key: EncodingKey,
112 pub now: NowFn,
114 pub expires_in_secs: i64,
116 pub handle: JoinHandle<Result<(), std::io::Error>>,
118}
119
120impl OidcMockServer {
121 pub async fn start(
133 addr: Option<&SocketAddr>,
134 encoding_key: String,
135 kid: String,
136 now: NowFn,
137 expires_in_secs: i64,
138 ) -> Result<OidcMockServer, anyhow::Error> {
139 let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?;
141
142 let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?;
144 let rsa = pkey.rsa().expect("pkey should be RSA");
145
146 let addr = match addr {
147 Some(addr) => Cow::Borrowed(addr),
148 None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
149 };
150
151 let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
152 panic!("error binding to {}: {}", addr, e);
153 });
154 let issuer = format!("http://{}", listener.local_addr().unwrap());
155
156 let jwk = create_jwk(&kid, &rsa);
159
160 let context = Arc::new(OidcMockContext {
161 issuer: issuer.clone(),
162 jwk,
163 });
164
165 let router = Router::new()
166 .route("/.well-known/jwks.json", get(handle_jwks))
167 .route(
168 "/.well-known/openid-configuration",
169 get(handle_openid_config),
170 )
171 .with_state(context);
172
173 let server = axum::serve(
174 listener,
175 router.into_make_service_with_connect_info::<SocketAddr>(),
176 );
177 println!("oidc-mock listening...");
178 println!(" HTTP address: {}", issuer);
179 let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
180
181 Ok(OidcMockServer {
182 issuer,
183 kid,
184 encoding_key: encoding_key_typed,
185 now,
186 expires_in_secs,
187 handle,
188 })
189 }
190
191 pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String {
198 let now_ms = (self.now)();
199 let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64");
200 let sub_claim_map = BTreeMap::from([("sub".to_string(), sub.to_string())]);
201
202 let mut unknown_claims: BTreeMap<String, serde_json::Value> = sub_claim_map
203 .into_iter()
204 .map(|(k, v)| (k, serde_json::Value::String(v)))
205 .collect();
206
207 if let Some(extra) = opts.extra_claims {
208 unknown_claims.extend(extra);
209 }
210
211 let claims = MockClaims {
212 iss: opts.issuer.unwrap_or(&self.issuer).to_string(),
213 exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs),
214 iat: Some(now_secs),
215 aud: opts.aud.unwrap_or_default(),
216 unknown_claims,
217 };
218
219 let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
220 header.kid = Some(self.kid.clone());
221
222 encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
223 }
224
225 pub fn jwks_url(&self) -> String {
227 format!("{}/.well-known/jwks.json", self.issuer)
228 }
229}
230
231async fn handle_jwks(State(context): State<Arc<OidcMockContext>>) -> Json<JwkSet> {
233 Json(JwkSet {
234 keys: vec![context.jwk.clone()],
235 })
236}
237
238#[derive(Serialize)]
240struct OpenIdConfiguration {
241 issuer: String,
242 jwks_uri: String,
243}
244
245async fn handle_openid_config(
247 State(context): State<Arc<OidcMockContext>>,
248) -> Json<OpenIdConfiguration> {
249 Json(OpenIdConfiguration {
250 issuer: context.issuer.clone(),
251 jwks_uri: format!("{}/.well-known/jwks.json", context.issuer),
252 })
253}
254
255fn create_jwk(kid: &str, rsa: &Rsa<Private>) -> Jwk {
257 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
258 let n = rsa.n().to_vec();
259 let e = rsa.e().to_vec();
260
261 Jwk {
262 kty: "RSA".to_string(),
263 kid: kid.to_string(),
264 key_use: "sig".to_string(),
265 alg: "RS256".to_string(),
266 n: engine.encode(n),
267 e: engine.encode(e),
268 }
269}