1use std::borrow::Cow;
16use std::collections::BTreeMap;
17use std::future::IntoFuture;
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::Arc;
20
21use axum::extract::State;
22use axum::routing::get;
23use axum::{Json, Router};
24use base64::Engine;
25use jsonwebtoken::{EncodingKey, Header, encode};
26use mz_ore::now::NowFn;
27use mz_ore::task::JoinHandle;
28use openssl::pkey::{PKey, Private};
29use openssl::rsa::Rsa;
30use serde::{Deserialize, Serialize};
31use tokio::net::TcpListener;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35struct JwkSet {
36 pub keys: Vec<Jwk>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41struct Jwk {
42 pub kty: String,
43 pub kid: String,
44 #[serde(rename = "use")]
45 pub key_use: String,
46 pub alg: String,
47 pub n: String,
48 pub e: String,
49}
50
51struct OidcMockContext {
53 issuer: String,
55 jwk: Jwk,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(untagged)]
64pub enum AudClaim {
65 Single(String),
66 Multiple(Vec<String>),
67}
68
69impl Default for AudClaim {
70 fn default() -> Self {
71 AudClaim::Multiple(vec![])
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77struct MockClaims {
78 iss: String,
79 exp: i64,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 iat: Option<i64>,
82 aud: AudClaim,
83 #[serde(flatten)]
84 unknown_claims: BTreeMap<String, serde_json::Value>,
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct GenerateJwtOptions<'a> {
90 pub email: Option<&'a str>,
92 pub exp: Option<i64>,
94 pub issuer: Option<&'a str>,
96 pub aud: Option<AudClaim>,
99 pub unknown_claims: Option<BTreeMap<String, String>>,
101}
102
103pub struct OidcMockServer {
105 pub issuer: String,
108 pub kid: String,
110 pub encoding_key: EncodingKey,
112 pub now: NowFn,
114 pub expires_in_secs: i64,
116 pub handle: JoinHandle<Result<(), std::io::Error>>,
118}
119
120impl OidcMockServer {
121 pub async fn start(
133 addr: Option<&SocketAddr>,
134 encoding_key: String,
135 kid: String,
136 now: NowFn,
137 expires_in_secs: i64,
138 ) -> Result<OidcMockServer, anyhow::Error> {
139 let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?;
141
142 let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?;
144 let rsa = pkey.rsa().expect("pkey should be RSA");
145
146 let addr = match addr {
147 Some(addr) => Cow::Borrowed(addr),
148 None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
149 };
150
151 let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
152 panic!("error binding to {}: {}", addr, e);
153 });
154 let issuer = format!("http://{}", listener.local_addr().unwrap());
155
156 let jwk = create_jwk(&kid, &rsa);
159
160 let context = Arc::new(OidcMockContext {
161 issuer: issuer.clone(),
162 jwk,
163 });
164
165 let router = Router::new()
166 .route("/.well-known/jwks.json", get(handle_jwks))
167 .route(
168 "/.well-known/openid-configuration",
169 get(handle_openid_config),
170 )
171 .with_state(context);
172
173 let server = axum::serve(
174 listener,
175 router.into_make_service_with_connect_info::<SocketAddr>(),
176 );
177 println!("oidc-mock listening...");
178 println!(" HTTP address: {}", issuer);
179 let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
180
181 Ok(OidcMockServer {
182 issuer,
183 kid,
184 encoding_key: encoding_key_typed,
185 now,
186 expires_in_secs,
187 handle,
188 })
189 }
190
191 pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String {
198 let now_ms = (self.now)();
199 let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64");
200 let sub_claim_map = BTreeMap::from([("sub".to_string(), sub.to_string())]);
201
202 let unknown_claims = opts
203 .unknown_claims
204 .unwrap_or_default()
205 .into_iter()
206 .chain(sub_claim_map)
207 .map(|(k, v)| (k, serde_json::Value::String(v.to_string())))
208 .collect();
209
210 let claims = MockClaims {
211 iss: opts.issuer.unwrap_or(&self.issuer).to_string(),
212 exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs),
213 iat: Some(now_secs),
214 aud: opts.aud.unwrap_or_default(),
215 unknown_claims,
216 };
217
218 let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
219 header.kid = Some(self.kid.clone());
220
221 encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
222 }
223
224 pub fn jwks_url(&self) -> String {
226 format!("{}/.well-known/jwks.json", self.issuer)
227 }
228}
229
230async fn handle_jwks(State(context): State<Arc<OidcMockContext>>) -> Json<JwkSet> {
232 Json(JwkSet {
233 keys: vec![context.jwk.clone()],
234 })
235}
236
237#[derive(Serialize)]
239struct OpenIdConfiguration {
240 issuer: String,
241 jwks_uri: String,
242}
243
244async fn handle_openid_config(
246 State(context): State<Arc<OidcMockContext>>,
247) -> Json<OpenIdConfiguration> {
248 Json(OpenIdConfiguration {
249 issuer: context.issuer.clone(),
250 jwks_uri: format!("{}/.well-known/jwks.json", context.issuer),
251 })
252}
253
254fn create_jwk(kid: &str, rsa: &Rsa<Private>) -> Jwk {
256 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
257 let n = rsa.n().to_vec();
258 let e = rsa.e().to_vec();
259
260 Jwk {
261 kty: "RSA".to_string(),
262 kid: kid.to_string(),
263 key_use: "sig".to_string(),
264 alg: "RS256".to_string(),
265 n: engine.encode(n),
266 e: engine.encode(e),
267 }
268}