Skip to main content

mz_oidc_mock/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! OIDC mock server for testing.
11//!
12//! This module provides a mock OIDC server that serves JWKS endpoints
13//! for validating JWT tokens in tests.
14
15use std::borrow::Cow;
16use std::future::IntoFuture;
17use std::net::{IpAddr, Ipv4Addr, SocketAddr};
18use std::sync::Arc;
19
20use axum::extract::State;
21use axum::routing::get;
22use axum::{Json, Router};
23use base64::Engine;
24use jsonwebtoken::{EncodingKey, Header, encode};
25use mz_authenticator::OidcClaims;
26use mz_ore::now::NowFn;
27use mz_ore::task::JoinHandle;
28use openssl::pkey::{PKey, Private};
29use openssl::rsa::Rsa;
30use serde::{Deserialize, Serialize};
31use tokio::net::TcpListener;
32
33/// JWKS response structure.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct JwkSet {
36    pub keys: Vec<Jwk>,
37}
38
39/// JSON Web Key structure.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct Jwk {
42    pub kty: String,
43    pub kid: String,
44    #[serde(rename = "use")]
45    pub key_use: String,
46    pub alg: String,
47    pub n: String,
48    pub e: String,
49}
50
51/// Shared context for the OIDC mock server.
52struct OidcMockContext {
53    /// The issuer URL (base URL of this server).
54    issuer: String,
55    /// RSA public key in JWK format.
56    jwk: Jwk,
57}
58
59/// Options for generating JWT tokens.
60#[derive(Debug, Clone, Default)]
61pub struct GenerateJwtOptions<'a> {
62    /// Optional email claim.
63    pub email: Option<&'a str>,
64    /// Custom expiration time. If None, uses server's default expires_in_secs.
65    pub exp: Option<i64>,
66    /// Custom issuer. If None, uses server's issuer.
67    pub issuer: Option<&'a str>,
68    /// Audience claim. If None, uses empty vec.
69    pub aud: Option<Vec<String>>,
70}
71
72/// OIDC mock server for testing.
73pub struct OidcMockServer {
74    /// The issuer URL. Used as the base URL of the server
75    /// and as the issuer for JWT iss claim.
76    pub issuer: String,
77    /// Key ID used in JWT headers.
78    pub kid: String,
79    /// Encoding key for signing JWTs (for generating test tokens).
80    pub encoding_key: EncodingKey,
81    /// Function for getting current time.
82    pub now: NowFn,
83    /// How long tokens should be valid (in seconds).
84    pub expires_in_secs: i64,
85    /// Handle to the server task.
86    pub handle: JoinHandle<Result<(), std::io::Error>>,
87}
88
89impl OidcMockServer {
90    /// Starts an [`OidcMockServer`].
91    ///
92    /// Must be started from within a [`tokio::runtime::Runtime`].
93    ///
94    /// # Arguments
95    ///
96    /// * `addr` - Optional address to bind to. If None, binds to localhost on a random port.
97    /// * `encoding_key` - PEM-encoded RSA private key string for signing JWTs.
98    /// * `kid` - Key ID to use in JWT headers and JWKS.
99    /// * `now` - Function for getting current time.
100    /// * `expires_in_secs` - How long tokens should be valid.
101    pub async fn start(
102        addr: Option<&SocketAddr>,
103        encoding_key: String,
104        kid: String,
105        now: NowFn,
106        expires_in_secs: i64,
107    ) -> Result<OidcMockServer, anyhow::Error> {
108        // Convert PEM string to key.
109        let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?;
110
111        // Parse the private key PEM to extract RSA components for JWKS.
112        let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?;
113        let rsa = pkey.rsa().expect("pkey should be RSA");
114
115        let addr = match addr {
116            Some(addr) => Cow::Borrowed(addr),
117            None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)),
118        };
119
120        let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| {
121            panic!("error binding to {}: {}", addr, e);
122        });
123        let issuer = format!("http://{}", listener.local_addr().unwrap());
124
125        // Extract RSA public key components from the decoding key
126        // We need to serialize the public key to get n and e values
127        let jwk = create_jwk(&kid, &rsa);
128
129        let context = Arc::new(OidcMockContext {
130            issuer: issuer.clone(),
131            jwk,
132        });
133
134        let router = Router::new()
135            .route("/.well-known/jwks.json", get(handle_jwks))
136            .route(
137                "/.well-known/openid-configuration",
138                get(handle_openid_config),
139            )
140            .with_state(context);
141
142        let server = axum::serve(
143            listener,
144            router.into_make_service_with_connect_info::<SocketAddr>(),
145        );
146        println!("oidc-mock listening...");
147        println!(" HTTP address: {}", issuer);
148        let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future());
149
150        Ok(OidcMockServer {
151            issuer,
152            kid,
153            encoding_key: encoding_key_typed,
154            now,
155            expires_in_secs,
156            handle,
157        })
158    }
159
160    /// Generates a JWT token for testing.
161    ///
162    /// # Arguments
163    ///
164    /// * `sub` - Subject (user identifier).
165    /// * `opts` - Optional JWT generation options. Use `Default::default()` for defaults.
166    pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String {
167        let now_ms = (self.now)();
168        let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64");
169
170        let claims = OidcClaims {
171            sub: sub.to_string(),
172            iss: opts.issuer.unwrap_or(&self.issuer).to_string(),
173            exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs),
174            iat: Some(now_secs),
175            email: opts.email.map(|s| s.to_string()),
176            aud: opts.aud.unwrap_or_default(),
177        };
178
179        let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
180        header.kid = Some(self.kid.clone());
181
182        encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT")
183    }
184
185    /// Returns the JWKS URL for this server.
186    pub fn jwks_url(&self) -> String {
187        format!("{}/.well-known/jwks.json", self.issuer)
188    }
189}
190
191/// Handler for JWKS endpoint.
192async fn handle_jwks(State(context): State<Arc<OidcMockContext>>) -> Json<JwkSet> {
193    Json(JwkSet {
194        keys: vec![context.jwk.clone()],
195    })
196}
197
198/// OpenID Configuration response.
199#[derive(Serialize)]
200struct OpenIdConfiguration {
201    issuer: String,
202    jwks_uri: String,
203}
204
205/// Handler for OpenID Configuration endpoint.
206async fn handle_openid_config(
207    State(context): State<Arc<OidcMockContext>>,
208) -> Json<OpenIdConfiguration> {
209    Json(OpenIdConfiguration {
210        issuer: context.issuer.clone(),
211        jwks_uri: format!("{}/.well-known/jwks.json", context.issuer),
212    })
213}
214
215/// Creates a JWK from RSA key components.
216fn create_jwk(kid: &str, rsa: &Rsa<Private>) -> Jwk {
217    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
218    let n = rsa.n().to_vec();
219    let e = rsa.e().to_vec();
220
221    Jwk {
222        kty: "RSA".to_string(),
223        kid: kid.to_string(),
224        key_use: "sig".to_string(),
225        alg: "RS256".to_string(),
226        n: engine.encode(n),
227        e: engine.encode(e),
228    }
229}