1use std::borrow::Cow;
16use std::future::IntoFuture;
17use std::net::{IpAddr, Ipv4Addr, SocketAddr};
18use std::sync::Arc;
19
20use axum::extract::State;
21use axum::routing::get;
22use axum::{Json, Router};
23use base64::Engine;
24use jsonwebtoken::{EncodingKey, Header, encode};
25use mz_authenticator::OidcClaims;
26use mz_ore::now::NowFn;
27use mz_ore::task::JoinHandle;
28use openssl::pkey::{PKey, Private};
29use openssl::rsa::Rsa;
30use serde::{Deserialize, Serialize};
31use tokio::net::TcpListener;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35struct JwkSet {
36 pub keys: Vec<Jwk>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41struct Jwk {
42 pub kty: String,
43 pub kid: String,
44 #[serde(rename = "use")]
45 pub key_use: String,
46 pub alg: String,
47 pub n: String,
48 pub e: String,
49}
50
51struct OidcMockContext {
53 issuer: String,
55 jwk: Jwk,
57}
58
59#[derive(Debug, Clone, Default)]
61pub struct GenerateJwtOptions<'a> {
62 pub email: Option<&'a str>,
64 pub exp: Option<i64>,
66 pub issuer: Option<&'a str>,
68 pub aud: Option<Vec<String>>,
70}
71
72pub struct OidcMockServer {
74 pub issuer: String,
77 pub kid: String,
79 pub encoding_key: EncodingKey,
81 pub now: NowFn,
83 pub expires_in_secs: i64,
85 pub handle: JoinHandle<Result<(), std::io::Error>>,
87}
88
89impl OidcMockServer {
90 pub async fn start(
102 addr: Option<&SocketAddr>,
103 encoding_key: String,
104 kid: String,
105 now: NowFn,
106 expires_in_secs: i64,
107 ) -> Result<OidcMockServer, anyhow::Error> {
108 let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?;
110
111 let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?;
113 let rsa = pkey.rsa().expect("pkey should be RSA");
114
115 let addr = match addr {
116 Some(addr) => Cow::Borrowed(addr),
117 None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
118 };
119
120 let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
121 panic!("error binding to {}: {}", addr, e);
122 });
123 let issuer = format!("http://{}", listener.local_addr().unwrap());
124
125 let jwk = create_jwk(&kid, &rsa);
128
129 let context = Arc::new(OidcMockContext {
130 issuer: issuer.clone(),
131 jwk,
132 });
133
134 let router = Router::new()
135 .route("/.well-known/jwks.json", get(handle_jwks))
136 .route(
137 "/.well-known/openid-configuration",
138 get(handle_openid_config),
139 )
140 .with_state(context);
141
142 let server = axum::serve(
143 listener,
144 router.into_make_service_with_connect_info::<SocketAddr>(),
145 );
146 println!("oidc-mock listening...");
147 println!(" HTTP address: {}", issuer);
148 let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
149
150 Ok(OidcMockServer {
151 issuer,
152 kid,
153 encoding_key: encoding_key_typed,
154 now,
155 expires_in_secs,
156 handle,
157 })
158 }
159
160 pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String {
167 let now_ms = (self.now)();
168 let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64");
169
170 let claims = OidcClaims {
171 sub: sub.to_string(),
172 iss: opts.issuer.unwrap_or(&self.issuer).to_string(),
173 exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs),
174 iat: Some(now_secs),
175 email: opts.email.map(|s| s.to_string()),
176 aud: opts.aud.unwrap_or_default(),
177 };
178
179 let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
180 header.kid = Some(self.kid.clone());
181
182 encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
183 }
184
185 pub fn jwks_url(&self) -> String {
187 format!("{}/.well-known/jwks.json", self.issuer)
188 }
189}
190
191async fn handle_jwks(State(context): State<Arc<OidcMockContext>>) -> Json<JwkSet> {
193 Json(JwkSet {
194 keys: vec![context.jwk.clone()],
195 })
196}
197
198#[derive(Serialize)]
200struct OpenIdConfiguration {
201 issuer: String,
202 jwks_uri: String,
203}
204
205async fn handle_openid_config(
207 State(context): State<Arc<OidcMockContext>>,
208) -> Json<OpenIdConfiguration> {
209 Json(OpenIdConfiguration {
210 issuer: context.issuer.clone(),
211 jwks_uri: format!("{}/.well-known/jwks.json", context.issuer),
212 })
213}
214
215fn create_jwk(kid: &str, rsa: &Rsa<Private>) -> Jwk {
217 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
218 let n = rsa.n().to_vec();
219 let e = rsa.e().to_vec();
220
221 Jwk {
222 kty: "RSA".to_string(),
223 kid: kid.to_string(),
224 key_use: "sig".to_string(),
225 alg: "RS256".to_string(),
226 n: engine.encode(n),
227 e: engine.encode(e),
228 }
229}