Skip to main content

mz_license_keys/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20// list of public keys which are allowed to validate license keys. this is a
21// list to allow for key rotation if necessary.
22const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
23// keys which we have issued but need to be revoked before their expiration
24// (due to being accidentally exposed or similar).
25const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub enum ExpirationBehavior {
30    Warn,
31    DisableClusterCreation,
32    Disable,
33}
34
35#[derive(Debug, Clone)]
36pub struct ValidatedLicenseKey {
37    pub id: String,
38    pub organization: String,
39    pub environment_id: String,
40    pub expiration: u64,
41    pub not_before: u64,
42
43    pub max_credit_consumption_rate: f64,
44    pub allow_credit_consumption_override: bool,
45    pub expiration_behavior: ExpirationBehavior,
46    pub expired: bool,
47    /// Optional feature flags / third-party integrations enabled for this key
48    /// (e.g. `"ory"` to permit pulling images through the OCI registry proxy).
49    /// Empty for keys that predate this field.
50    pub entitlements: Vec<String>,
51}
52
53impl ValidatedLicenseKey {
54    pub fn for_tests() -> Self {
55        Self {
56            id: "".to_string(),
57            organization: "".to_string(),
58            environment_id: "".to_string(),
59            expiration: 0,
60            not_before: 0,
61
62            max_credit_consumption_rate: 999999.0,
63            allow_credit_consumption_override: true,
64            expiration_behavior: ExpirationBehavior::Warn,
65            expired: false,
66            entitlements: Vec::new(),
67        }
68    }
69
70    pub fn disabled() -> Self {
71        Self {
72            id: "".to_string(),
73            organization: "".to_string(),
74            environment_id: "".to_string(),
75            expiration: 0,
76            not_before: 0,
77
78            max_credit_consumption_rate: 999999.0,
79            allow_credit_consumption_override: true,
80            expiration_behavior: ExpirationBehavior::Warn,
81            expired: false,
82            entitlements: Vec::new(),
83        }
84    }
85
86    /// Returns true if `entitlement` is present in this key.
87    pub fn has_entitlement(&self, entitlement: &str) -> bool {
88        self.entitlements.iter().any(|e| e == entitlement)
89    }
90
91    pub fn max_credit_consumption_rate(&self) -> Option<f64> {
92        if self.expired
93            && matches!(
94                self.expiration_behavior,
95                ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
96            )
97        {
98            Some(0.0)
99        } else if self.allow_credit_consumption_override {
100            None
101        } else {
102            Some(self.max_credit_consumption_rate)
103        }
104    }
105}
106
107impl Default for ValidatedLicenseKey {
108    fn default() -> Self {
109        Self {
110            id: "".to_string(),
111            organization: "".to_string(),
112            environment_id: "".to_string(),
113            expiration: 0,
114            not_before: 0,
115
116            max_credit_consumption_rate: 24.0,
117            allow_credit_consumption_override: false,
118            expiration_behavior: ExpirationBehavior::Disable,
119            expired: false,
120            entitlements: Vec::new(),
121        }
122    }
123}
124
125pub fn validate(license_key: &str) -> anyhow::Result<ValidatedLicenseKey> {
126    let mut err = None;
127    for pubkey in PUBLIC_KEYS {
128        match validate_with_pubkey(license_key, pubkey) {
129            Ok(key) => {
130                return Ok(key);
131            }
132            Err(e) => {
133                err = Some(e);
134            }
135        }
136    }
137
138    if let Some(err) = err {
139        Err(err)
140    } else {
141        Err(anyhow!("no public key found"))
142    }
143}
144
145fn validate_with_pubkey(
146    license_key: &str,
147    pubkey_pem: &str,
148) -> anyhow::Result<ValidatedLicenseKey> {
149    // don't just read the version out of the payload before verifying it,
150    // trusting unsigned data to determine how to verify the signature is a
151    // bad idea. instead, just try validating it as each version
152    // independently, and if the signature is valid, only then check to
153    // ensure that the version matches what we validated.
154
155    // try current version first, so we can prefer that for error messages
156    let res = validate_with_pubkey_v1(license_key, pubkey_pem);
157    let err = match res {
158        Ok(key) => return Ok(key),
159        Err(e) => e,
160    };
161
162    let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
163        // add to this if/when we add new versions
164        // for example,
165        // Box::new(|| validate_with_pubkey_v1(license_key, pubkey_pem, environment_id)),
166    ];
167    for validator in previous_versions {
168        if let Ok(key) = validator() {
169            return Ok(key);
170        }
171    }
172
173    Err(err)
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
177struct Payload {
178    sub: String,
179    exp: u64,
180    nbf: u64,
181    iss: String,
182    aud: String,
183    iat: u64,
184    jti: String,
185
186    version: u64,
187    max_credit_consumption_rate: f64,
188    #[serde(default, skip_serializing_if = "is_default")]
189    allow_credit_consumption_override: bool,
190    expiration_behavior: ExpirationBehavior,
191    // Defaulted + skipped-when-empty so keys issued before entitlements
192    // existed continue to validate and we don't bloat keys that don't need
193    // any entitlements.
194    #[serde(default, skip_serializing_if = "Vec::is_empty")]
195    entitlements: Vec<String>,
196}
197
198fn validate_with_pubkey_v1(
199    license_key: &str,
200    pubkey_pem: &str,
201) -> anyhow::Result<ValidatedLicenseKey> {
202    let mut validation = Validation::new(Algorithm::PS256);
203    validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
204    validation.set_issuer(&[ISSUER]);
205    validation.validate_exp = true;
206    validation.validate_nbf = true;
207    validation.validate_aud = false;
208
209    let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
210
211    let (jwt, expired): (TokenData<Payload>, _) =
212        jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
213            |e| {
214                if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
215                    validation.validate_exp = false;
216                    Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
217                } else {
218                    Err::<_, anyhow::Error>(e.into())
219                }
220            },
221            |jwt| Ok((jwt, false)),
222        )?;
223
224    if jwt.header.typ.as_deref() != Some("JWT") {
225        bail!("invalid jwt header type");
226    }
227
228    if jwt.claims.version != 1 {
229        bail!("invalid license key version");
230    }
231
232    if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
233        bail!("invalid issuance time");
234    }
235
236    if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
237        bail!("revoked license key");
238    }
239
240    Ok(ValidatedLicenseKey {
241        id: jwt.claims.jti,
242        organization: jwt.claims.sub,
243        environment_id: jwt.claims.aud,
244        expiration: jwt.claims.exp,
245        not_before: jwt.claims.nbf,
246
247        max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
248        allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
249        expiration_behavior: jwt.claims.expiration_behavior,
250        expired,
251        entitlements: jwt.claims.entitlements,
252    })
253}
254
255fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
256    *val == T::default()
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn sample_payload(entitlements: Vec<String>) -> Payload {
264        Payload {
265            sub: "org-1".to_string(),
266            exp: 200,
267            nbf: 100,
268            iss: ISSUER.to_string(),
269            aud: "env-1".to_string(),
270            iat: 100,
271            jti: "jti-1".to_string(),
272            version: 1,
273            max_credit_consumption_rate: 10.0,
274            allow_credit_consumption_override: false,
275            expiration_behavior: ExpirationBehavior::Warn,
276            entitlements,
277        }
278    }
279
280    #[mz_ore::test]
281    fn entitlements_roundtrip_through_payload() {
282        let payload = sample_payload(vec!["ory".to_string(), "foo".to_string()]);
283        let json = serde_json::to_string(&payload).unwrap();
284        assert!(
285            json.contains(r#""entitlements":["ory","foo"]"#),
286            "payload should serialize entitlements: {json}"
287        );
288
289        let decoded: Payload = serde_json::from_str(&json).unwrap();
290        assert_eq!(decoded.entitlements, vec!["ory", "foo"]);
291    }
292
293    #[mz_ore::test]
294    fn empty_entitlements_omitted_from_payload() {
295        let payload = sample_payload(Vec::new());
296        let json = serde_json::to_string(&payload).unwrap();
297        // Skipping the field on empty keeps issued JWTs the same shape they
298        // were before entitlements existed, so old + new validators agree.
299        assert!(
300            !json.contains("entitlements"),
301            "empty entitlements should be skipped: {json}"
302        );
303    }
304
305    #[mz_ore::test]
306    fn legacy_payload_without_entitlements_decodes() {
307        // Pre-DEP-130 keys have no `entitlements` field. They must still
308        // deserialize, with entitlements defaulting to empty.
309        let legacy = serde_json::json!({
310            "sub": "org-1",
311            "exp": 200,
312            "nbf": 100,
313            "iss": ISSUER,
314            "aud": "env-1",
315            "iat": 100,
316            "jti": "jti-1",
317            "version": 1,
318            "max_credit_consumption_rate": 10.0,
319            "expiration_behavior": "Warn",
320        });
321        let decoded: Payload = serde_json::from_value(legacy).unwrap();
322        assert!(decoded.entitlements.is_empty());
323    }
324}