1use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
23const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub enum ExpirationBehavior {
30 Warn,
31 DisableClusterCreation,
32 Disable,
33}
34
35#[derive(Debug, Clone)]
36pub struct ValidatedLicenseKey {
37 pub id: String,
38 pub organization: String,
39 pub environment_id: String,
40 pub expiration: u64,
41 pub not_before: u64,
42
43 pub max_credit_consumption_rate: f64,
44 pub allow_credit_consumption_override: bool,
45 pub expiration_behavior: ExpirationBehavior,
46 pub expired: bool,
47 pub entitlements: Vec<String>,
51}
52
53impl ValidatedLicenseKey {
54 pub fn for_tests() -> Self {
55 Self {
56 id: "".to_string(),
57 organization: "".to_string(),
58 environment_id: "".to_string(),
59 expiration: 0,
60 not_before: 0,
61
62 max_credit_consumption_rate: 999999.0,
63 allow_credit_consumption_override: true,
64 expiration_behavior: ExpirationBehavior::Warn,
65 expired: false,
66 entitlements: Vec::new(),
67 }
68 }
69
70 pub fn disabled() -> Self {
71 Self {
72 id: "".to_string(),
73 organization: "".to_string(),
74 environment_id: "".to_string(),
75 expiration: 0,
76 not_before: 0,
77
78 max_credit_consumption_rate: 999999.0,
79 allow_credit_consumption_override: true,
80 expiration_behavior: ExpirationBehavior::Warn,
81 expired: false,
82 entitlements: Vec::new(),
83 }
84 }
85
86 pub fn has_entitlement(&self, entitlement: &str) -> bool {
88 self.entitlements.iter().any(|e| e == entitlement)
89 }
90
91 pub fn max_credit_consumption_rate(&self) -> Option<f64> {
92 if self.expired
93 && matches!(
94 self.expiration_behavior,
95 ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
96 )
97 {
98 Some(0.0)
99 } else if self.allow_credit_consumption_override {
100 None
101 } else {
102 Some(self.max_credit_consumption_rate)
103 }
104 }
105}
106
107impl Default for ValidatedLicenseKey {
108 fn default() -> Self {
109 Self {
110 id: "".to_string(),
111 organization: "".to_string(),
112 environment_id: "".to_string(),
113 expiration: 0,
114 not_before: 0,
115
116 max_credit_consumption_rate: 24.0,
117 allow_credit_consumption_override: false,
118 expiration_behavior: ExpirationBehavior::Disable,
119 expired: false,
120 entitlements: Vec::new(),
121 }
122 }
123}
124
125pub fn validate(license_key: &str) -> anyhow::Result<ValidatedLicenseKey> {
126 let mut err = None;
127 for pubkey in PUBLIC_KEYS {
128 match validate_with_pubkey(license_key, pubkey) {
129 Ok(key) => {
130 return Ok(key);
131 }
132 Err(e) => {
133 err = Some(e);
134 }
135 }
136 }
137
138 if let Some(err) = err {
139 Err(err)
140 } else {
141 Err(anyhow!("no public key found"))
142 }
143}
144
145fn validate_with_pubkey(
146 license_key: &str,
147 pubkey_pem: &str,
148) -> anyhow::Result<ValidatedLicenseKey> {
149 let res = validate_with_pubkey_v1(license_key, pubkey_pem);
157 let err = match res {
158 Ok(key) => return Ok(key),
159 Err(e) => e,
160 };
161
162 let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
163 ];
167 for validator in previous_versions {
168 if let Ok(key) = validator() {
169 return Ok(key);
170 }
171 }
172
173 Err(err)
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
177struct Payload {
178 sub: String,
179 exp: u64,
180 nbf: u64,
181 iss: String,
182 aud: String,
183 iat: u64,
184 jti: String,
185
186 version: u64,
187 max_credit_consumption_rate: f64,
188 #[serde(default, skip_serializing_if = "is_default")]
189 allow_credit_consumption_override: bool,
190 expiration_behavior: ExpirationBehavior,
191 #[serde(default, skip_serializing_if = "Vec::is_empty")]
195 entitlements: Vec<String>,
196}
197
198fn validate_with_pubkey_v1(
199 license_key: &str,
200 pubkey_pem: &str,
201) -> anyhow::Result<ValidatedLicenseKey> {
202 let mut validation = Validation::new(Algorithm::PS256);
203 validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
204 validation.set_issuer(&[ISSUER]);
205 validation.validate_exp = true;
206 validation.validate_nbf = true;
207 validation.validate_aud = false;
208
209 let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
210
211 let (jwt, expired): (TokenData<Payload>, _) =
212 jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
213 |e| {
214 if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
215 validation.validate_exp = false;
216 Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
217 } else {
218 Err::<_, anyhow::Error>(e.into())
219 }
220 },
221 |jwt| Ok((jwt, false)),
222 )?;
223
224 if jwt.header.typ.as_deref() != Some("JWT") {
225 bail!("invalid jwt header type");
226 }
227
228 if jwt.claims.version != 1 {
229 bail!("invalid license key version");
230 }
231
232 if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
233 bail!("invalid issuance time");
234 }
235
236 if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
237 bail!("revoked license key");
238 }
239
240 Ok(ValidatedLicenseKey {
241 id: jwt.claims.jti,
242 organization: jwt.claims.sub,
243 environment_id: jwt.claims.aud,
244 expiration: jwt.claims.exp,
245 not_before: jwt.claims.nbf,
246
247 max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
248 allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
249 expiration_behavior: jwt.claims.expiration_behavior,
250 expired,
251 entitlements: jwt.claims.entitlements,
252 })
253}
254
255fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
256 *val == T::default()
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 fn sample_payload(entitlements: Vec<String>) -> Payload {
264 Payload {
265 sub: "org-1".to_string(),
266 exp: 200,
267 nbf: 100,
268 iss: ISSUER.to_string(),
269 aud: "env-1".to_string(),
270 iat: 100,
271 jti: "jti-1".to_string(),
272 version: 1,
273 max_credit_consumption_rate: 10.0,
274 allow_credit_consumption_override: false,
275 expiration_behavior: ExpirationBehavior::Warn,
276 entitlements,
277 }
278 }
279
280 #[mz_ore::test]
281 fn entitlements_roundtrip_through_payload() {
282 let payload = sample_payload(vec!["ory".to_string(), "foo".to_string()]);
283 let json = serde_json::to_string(&payload).unwrap();
284 assert!(
285 json.contains(r#""entitlements":["ory","foo"]"#),
286 "payload should serialize entitlements: {json}"
287 );
288
289 let decoded: Payload = serde_json::from_str(&json).unwrap();
290 assert_eq!(decoded.entitlements, vec!["ory", "foo"]);
291 }
292
293 #[mz_ore::test]
294 fn empty_entitlements_omitted_from_payload() {
295 let payload = sample_payload(Vec::new());
296 let json = serde_json::to_string(&payload).unwrap();
297 assert!(
300 !json.contains("entitlements"),
301 "empty entitlements should be skipped: {json}"
302 );
303 }
304
305 #[mz_ore::test]
306 fn legacy_payload_without_entitlements_decodes() {
307 let legacy = serde_json::json!({
310 "sub": "org-1",
311 "exp": 200,
312 "nbf": 100,
313 "iss": ISSUER,
314 "aud": "env-1",
315 "iat": 100,
316 "jti": "jti-1",
317 "version": 1,
318 "max_credit_consumption_rate": 10.0,
319 "expiration_behavior": "Warn",
320 });
321 let decoded: Payload = serde_json::from_value(legacy).unwrap();
322 assert!(decoded.entitlements.is_empty());
323 }
324}