Skip to main content

mz_oidc_mock/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC mock server for testing.
11//!
12//! This module provides a mock OIDC server that serves JWKS endpoints
13//! for validating JWT tokens in tests.
14
15use std::borrow::Cow;
16use std::collections::BTreeMap;
17use std::future::IntoFuture;
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::Arc;
20
21use axum::extract::State;
22use axum::routing::get;
23use axum::{Json, Router};
24use base64::Engine;
25use jsonwebtoken::{EncodingKey, Header, encode};
26use mz_ore::now::NowFn;
27use mz_ore::task::JoinHandle;
28use openssl::pkey::{PKey, Private};
29use openssl::rsa::Rsa;
30use serde::{Deserialize, Serialize};
31use tokio::net::TcpListener;
32
33/// JWKS response structure.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct JwkSet {
36    pub keys: Vec<Jwk>,
37}
38
39/// JSON Web Key structure.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct Jwk {
42    pub kty: String,
43    pub kid: String,
44    #[serde(rename = "use")]
45    pub key_use: String,
46    pub alg: String,
47    pub n: String,
48    pub e: String,
49}
50
51/// Shared context for the OIDC mock server.
52struct OidcMockContext {
53    /// The issuer URL (base URL of this server).
54    issuer: String,
55    /// RSA public key in JWK format.
56    jwk: Jwk,
57}
58
59/// Audience claim value: either a single string or a list of strings.
60///
61/// Serializes as a JSON string when `Single`, and as a JSON array when `Multiple`.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(untagged)]
64pub enum AudClaim {
65    Single(String),
66    Multiple(Vec<String>),
67}
68
69impl Default for AudClaim {
70    fn default() -> Self {
71        AudClaim::Multiple(vec![])
72    }
73}
74
75/// Claims struct used for JWT encoding in the mock server.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77struct MockClaims {
78    iss: String,
79    exp: i64,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    iat: Option<i64>,
82    aud: AudClaim,
83    #[serde(flatten)]
84    unknown_claims: BTreeMap<String, serde_json::Value>,
85}
86
87/// Options for generating JWT tokens.
88#[derive(Debug, Clone, Default)]
89pub struct GenerateJwtOptions<'a> {
90    /// Optional email claim.
91    pub email: Option<&'a str>,
92    /// Custom expiration time. If None, uses server's default expires_in_secs.
93    pub exp: Option<i64>,
94    /// Custom issuer. If None, uses server's issuer.
95    pub issuer: Option<&'a str>,
96    /// Audience claim. If None, uses empty array.
97    /// Use `AudClaim::Single` for a single string or `AudClaim::Multiple` for an array.
98    pub aud: Option<AudClaim>,
99    /// Additional claims.
100    pub unknown_claims: Option<BTreeMap<String, String>>,
101}
102
103/// OIDC mock server for testing.
104pub struct OidcMockServer {
105    /// The issuer URL. Used as the base URL of the server
106    /// and as the issuer for JWT iss claim.
107    pub issuer: String,
108    /// Key ID used in JWT headers.
109    pub kid: String,
110    /// Encoding key for signing JWTs (for generating test tokens).
111    pub encoding_key: EncodingKey,
112    /// Function for getting current time.
113    pub now: NowFn,
114    /// How long tokens should be valid (in seconds).
115    pub expires_in_secs: i64,
116    /// Handle to the server task.
117    pub handle: JoinHandle<Result<(), std::io::Error>>,
118}
119
120impl OidcMockServer {
121    /// Starts an [`OidcMockServer`].
122    ///
123    /// Must be started from within a [`tokio::runtime::Runtime`].
124    ///
125    /// # Arguments
126    ///
127    /// * `addr` - Optional address to bind to. If None, binds to localhost on a random port.
128    /// * `encoding_key` - PEM-encoded RSA private key string for signing JWTs.
129    /// * `kid` - Key ID to use in JWT headers and JWKS.
130    /// * `now` - Function for getting current time.
131    /// * `expires_in_secs` - How long tokens should be valid.
132    pub async fn start(
133        addr: Option<&SocketAddr>,
134        encoding_key: String,
135        kid: String,
136        now: NowFn,
137        expires_in_secs: i64,
138    ) -> Result<OidcMockServer, anyhow::Error> {
139        // Convert PEM string to key.
140        let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?;
141
142        // Parse the private key PEM to extract RSA components for JWKS.
143        let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?;
144        let rsa = pkey.rsa().expect("pkey should be RSA");
145
146        let addr = match addr {
147            Some(addr) => Cow::Borrowed(addr),
148            None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
149        };
150
151        let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
152            panic!("error binding to {}: {}", addr, e);
153        });
154        let issuer = format!("http://{}", listener.local_addr().unwrap());
155
156        // Extract RSA public key components from the decoding key
157        // We need to serialize the public key to get n and e values
158        let jwk = create_jwk(&kid, &rsa);
159
160        let context = Arc::new(OidcMockContext {
161            issuer: issuer.clone(),
162            jwk,
163        });
164
165        let router = Router::new()
166            .route("/.well-known/jwks.json", get(handle_jwks))
167            .route(
168                "/.well-known/openid-configuration",
169                get(handle_openid_config),
170            )
171            .with_state(context);
172
173        let server = axum::serve(
174            listener,
175            router.into_make_service_with_connect_info::<SocketAddr>(),
176        );
177        println!("oidc-mock listening...");
178        println!(" HTTP address: {}", issuer);
179        let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
180
181        Ok(OidcMockServer {
182            issuer,
183            kid,
184            encoding_key: encoding_key_typed,
185            now,
186            expires_in_secs,
187            handle,
188        })
189    }
190
191    /// Generates a JWT token for testing.
192    ///
193    /// # Arguments
194    ///
195    /// * `sub` - Subject (user identifier).
196    /// * `opts` - Optional JWT generation options. Use `Default::default()` for defaults.
197    pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String {
198        let now_ms = (self.now)();
199        let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64");
200        let sub_claim_map = BTreeMap::from([("sub".to_string(), sub.to_string())]);
201
202        let unknown_claims = opts
203            .unknown_claims
204            .unwrap_or_default()
205            .into_iter()
206            .chain(sub_claim_map)
207            .map(|(k, v)| (k, serde_json::Value::String(v.to_string())))
208            .collect();
209
210        let claims = MockClaims {
211            iss: opts.issuer.unwrap_or(&self.issuer).to_string(),
212            exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs),
213            iat: Some(now_secs),
214            aud: opts.aud.unwrap_or_default(),
215            unknown_claims,
216        };
217
218        let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
219        header.kid = Some(self.kid.clone());
220
221        encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
222    }
223
224    /// Returns the JWKS URL for this server.
225    pub fn jwks_url(&self) -> String {
226        format!("{}/.well-known/jwks.json", self.issuer)
227    }
228}
229
230/// Handler for JWKS endpoint.
231async fn handle_jwks(State(context): State<Arc<OidcMockContext>>) -> Json<JwkSet> {
232    Json(JwkSet {
233        keys: vec![context.jwk.clone()],
234    })
235}
236
237/// OpenID Configuration response.
238#[derive(Serialize)]
239struct OpenIdConfiguration {
240    issuer: String,
241    jwks_uri: String,
242}
243
244/// Handler for OpenID Configuration endpoint.
245async fn handle_openid_config(
246    State(context): State<Arc<OidcMockContext>>,
247) -> Json<OpenIdConfiguration> {
248    Json(OpenIdConfiguration {
249        issuer: context.issuer.clone(),
250        jwks_uri: format!("{}/.well-known/jwks.json", context.issuer),
251    })
252}
253
254/// Creates a JWK from RSA key components.
255fn create_jwk(kid: &str, rsa: &Rsa<Private>) -> Jwk {
256    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
257    let n = rsa.n().to_vec();
258    let e = rsa.e().to_vec();
259
260    Jwk {
261        kty: "RSA".to_string(),
262        kid: kid.to_string(),
263        key_use: "sig".to_string(),
264        alg: "RS256".to_string(),
265        n: engine.encode(n),
266        e: engine.encode(e),
267    }
268}