Skip to main content

mz_oidc_mock/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC mock server for testing.
11//!
12//! This module provides a mock OIDC server that serves JWKS endpoints
13//! for validating JWT tokens in tests.
14
15use std::borrow::Cow;
16use std::collections::BTreeMap;
17use std::future::IntoFuture;
18use std::net::{IpAddr, Ipv4Addr, SocketAddr};
19use std::sync::Arc;
20
21use axum::extract::State;
22use axum::routing::get;
23use axum::{Json, Router};
24use base64::Engine;
25use jsonwebtoken::{EncodingKey, Header, encode};
26use mz_ore::now::NowFn;
27use mz_ore::task::JoinHandle;
28use openssl::pkey::{PKey, Private};
29use openssl::rsa::Rsa;
30use serde::{Deserialize, Serialize};
31use tokio::net::TcpListener;
32
33/// JWKS response structure.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct JwkSet {
36    pub keys: Vec<Jwk>,
37}
38
39/// JSON Web Key structure.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct Jwk {
42    pub kty: String,
43    pub kid: String,
44    #[serde(rename = "use")]
45    pub key_use: String,
46    pub alg: String,
47    pub n: String,
48    pub e: String,
49}
50
51/// Shared context for the OIDC mock server.
52struct OidcMockContext {
53    /// The issuer URL (base URL of this server).
54    issuer: String,
55    /// RSA public key in JWK format.
56    jwk: Jwk,
57}
58
59/// Audience claim value: either a single string or a list of strings.
60///
61/// Serializes as a JSON string when `Single`, and as a JSON array when `Multiple`.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(untagged)]
64pub enum AudClaim {
65    Single(String),
66    Multiple(Vec<String>),
67}
68
69impl Default for AudClaim {
70    fn default() -> Self {
71        AudClaim::Multiple(vec![])
72    }
73}
74
75/// Claims struct used for JWT encoding in the mock server.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77struct MockClaims {
78    iss: String,
79    exp: i64,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    iat: Option<i64>,
82    aud: AudClaim,
83    #[serde(flatten)]
84    unknown_claims: BTreeMap<String, serde_json::Value>,
85}
86
87/// Options for generating JWT tokens.
88#[derive(Debug, Clone, Default)]
89pub struct GenerateJwtOptions<'a> {
90    /// Optional email claim.
91    pub email: Option<&'a str>,
92    /// Custom expiration time. If None, uses server's default expires_in_secs.
93    pub exp: Option<i64>,
94    /// Custom issuer. If None, uses server's issuer.
95    pub issuer: Option<&'a str>,
96    /// Audience claim. If None, uses empty array.
97    /// Use `AudClaim::Single` for a single string or `AudClaim::Multiple` for an array.
98    pub aud: Option<AudClaim>,
99    /// Additional claims as arbitrary JSON values (e.g., arrays for group claims).
100    pub extra_claims: Option<BTreeMap<String, serde_json::Value>>,
101}
102
103/// OIDC mock server for testing.
104pub struct OidcMockServer {
105    /// The issuer URL. Used as the base URL of the server
106    /// and as the issuer for JWT iss claim.
107    pub issuer: String,
108    /// Key ID used in JWT headers.
109    pub kid: String,
110    /// Encoding key for signing JWTs (for generating test tokens).
111    pub encoding_key: EncodingKey,
112    /// Function for getting current time.
113    pub now: NowFn,
114    /// How long tokens should be valid (in seconds).
115    pub expires_in_secs: i64,
116    /// Handle to the server task.
117    pub handle: JoinHandle<Result<(), std::io::Error>>,
118}
119
120impl OidcMockServer {
121    /// Starts an [`OidcMockServer`].
122    ///
123    /// Must be started from within a [`tokio::runtime::Runtime`].
124    ///
125    /// # Arguments
126    ///
127    /// * `addr` - Optional address to bind to. If None, binds to localhost on a random port.
128    /// * `encoding_key` - PEM-encoded RSA private key string for signing JWTs.
129    /// * `kid` - Key ID to use in JWT headers and JWKS.
130    /// * `now` - Function for getting current time.
131    /// * `expires_in_secs` - How long tokens should be valid.
132    pub async fn start(
133        addr: Option<&SocketAddr>,
134        encoding_key: String,
135        kid: String,
136        now: NowFn,
137        expires_in_secs: i64,
138    ) -> Result<OidcMockServer, anyhow::Error> {
139        // Convert PEM string to key.
140        let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?;
141
142        // Parse the private key PEM to extract RSA components for JWKS.
143        let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?;
144        let rsa = pkey.rsa().expect("pkey should be RSA");
145
146        let addr = match addr {
147            Some(addr) => Cow::Borrowed(addr),
148            None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
149        };
150
151        let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
152            panic!("error binding to {}: {}", addr, e);
153        });
154        let issuer = format!("http://{}", listener.local_addr().unwrap());
155
156        // Extract RSA public key components from the decoding key
157        // We need to serialize the public key to get n and e values
158        let jwk = create_jwk(&kid, &rsa);
159
160        let context = Arc::new(OidcMockContext {
161            issuer: issuer.clone(),
162            jwk,
163        });
164
165        let router = Router::new()
166            .route("/.well-known/jwks.json", get(handle_jwks))
167            .route(
168                "/.well-known/openid-configuration",
169                get(handle_openid_config),
170            )
171            .with_state(context);
172
173        let server = axum::serve(
174            listener,
175            router.into_make_service_with_connect_info::<SocketAddr>(),
176        );
177        println!("oidc-mock listening...");
178        println!(" HTTP address: {}", issuer);
179        let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
180
181        Ok(OidcMockServer {
182            issuer,
183            kid,
184            encoding_key: encoding_key_typed,
185            now,
186            expires_in_secs,
187            handle,
188        })
189    }
190
191    /// Generates a JWT token for testing.
192    ///
193    /// # Arguments
194    ///
195    /// * `sub` - Subject (user identifier).
196    /// * `opts` - Optional JWT generation options. Use `Default::default()` for defaults.
197    pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String {
198        let now_ms = (self.now)();
199        let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64");
200        let sub_claim_map = BTreeMap::from([("sub".to_string(), sub.to_string())]);
201
202        let mut unknown_claims: BTreeMap<String, serde_json::Value> = sub_claim_map
203            .into_iter()
204            .map(|(k, v)| (k, serde_json::Value::String(v)))
205            .collect();
206
207        if let Some(extra) = opts.extra_claims {
208            unknown_claims.extend(extra);
209        }
210
211        let claims = MockClaims {
212            iss: opts.issuer.unwrap_or(&self.issuer).to_string(),
213            exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs),
214            iat: Some(now_secs),
215            aud: opts.aud.unwrap_or_default(),
216            unknown_claims,
217        };
218
219        let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
220        header.kid = Some(self.kid.clone());
221
222        encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
223    }
224
225    /// Returns the JWKS URL for this server.
226    pub fn jwks_url(&self) -> String {
227        format!("{}/.well-known/jwks.json", self.issuer)
228    }
229}
230
231/// Handler for JWKS endpoint.
232async fn handle_jwks(State(context): State<Arc<OidcMockContext>>) -> Json<JwkSet> {
233    Json(JwkSet {
234        keys: vec![context.jwk.clone()],
235    })
236}
237
238/// OpenID Configuration response.
239#[derive(Serialize)]
240struct OpenIdConfiguration {
241    issuer: String,
242    jwks_uri: String,
243}
244
245/// Handler for OpenID Configuration endpoint.
246async fn handle_openid_config(
247    State(context): State<Arc<OidcMockContext>>,
248) -> Json<OpenIdConfiguration> {
249    Json(OpenIdConfiguration {
250        issuer: context.issuer.clone(),
251        jwks_uri: format!("{}/.well-known/jwks.json", context.issuer),
252    })
253}
254
255/// Creates a JWK from RSA key components.
256fn create_jwk(kid: &str, rsa: &Rsa<Private>) -> Jwk {
257    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
258    let n = rsa.n().to_vec();
259    let e = rsa.e().to_vec();
260
261    Jwk {
262        kty: "RSA".to_string(),
263        kid: kid.to_string(),
264        key_use: "sig".to_string(),
265        alg: "RS256".to_string(),
266        n: engine.encode(n),
267        e: engine.encode(e),
268    }
269}