mz_license_keys/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20// this will be used specifically by cloud to avoid needing to issue separate
21// license keys for each environment when it comes up - just being able to
22// share a single license key that allows all environments and never expires
23// will be much simpler to maintain
24const ANY_ENVIRONMENT_AUD: &str = "00000000-0000-0000-0000-000000000000";
25// list of public keys which are allowed to validate license keys. this is a
26// list to allow for key rotation if necessary.
27const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
28// keys which we have issued but need to be revoked before their expiration
29// (due to being accidentally exposed or similar).
30const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
33pub enum ExpirationBehavior {
34    Warn,
35    DisableClusterCreation,
36    Disable,
37}
38
39#[derive(Debug, Clone)]
40pub struct ValidatedLicenseKey {
41    pub id: String,
42    pub organization: String,
43    pub environment_id: String,
44    pub expiration: u64,
45    pub not_before: u64,
46
47    pub max_credit_consumption_rate: f64,
48    pub allow_credit_consumption_override: bool,
49    pub expiration_behavior: ExpirationBehavior,
50    pub expired: bool,
51}
52
53impl ValidatedLicenseKey {
54    pub fn for_tests() -> Self {
55        Self {
56            id: "".to_string(),
57            organization: "".to_string(),
58            environment_id: "".to_string(),
59            expiration: 0,
60            not_before: 0,
61
62            max_credit_consumption_rate: 999999.0,
63            allow_credit_consumption_override: true,
64            expiration_behavior: ExpirationBehavior::Warn,
65            expired: false,
66        }
67    }
68
69    // TODO: temporary until we get the rest of the infrastructure in place
70    pub fn disabled() -> Self {
71        Self {
72            id: "".to_string(),
73            organization: "".to_string(),
74            environment_id: "".to_string(),
75            expiration: 0,
76            not_before: 0,
77
78            max_credit_consumption_rate: 999999.0,
79            allow_credit_consumption_override: true,
80            expiration_behavior: ExpirationBehavior::Warn,
81            expired: false,
82        }
83    }
84
85    pub fn max_credit_consumption_rate(&self) -> Option<f64> {
86        if self.expired
87            && matches!(
88                self.expiration_behavior,
89                ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
90            )
91        {
92            Some(0.0)
93        } else if self.allow_credit_consumption_override {
94            None
95        } else {
96            Some(self.max_credit_consumption_rate)
97        }
98    }
99}
100
101impl Default for ValidatedLicenseKey {
102    fn default() -> Self {
103        // this is used for the emulator if no license key is provided
104        Self {
105            id: "".to_string(),
106            organization: "".to_string(),
107            environment_id: "".to_string(),
108            expiration: 0,
109            not_before: 0,
110
111            max_credit_consumption_rate: 24.0,
112            allow_credit_consumption_override: false,
113            expiration_behavior: ExpirationBehavior::Disable,
114            expired: false,
115        }
116    }
117}
118
119pub fn validate(license_key: &str, environment_id: &str) -> anyhow::Result<ValidatedLicenseKey> {
120    let mut err = None;
121    for pubkey in PUBLIC_KEYS {
122        match validate_with_pubkey(license_key, pubkey, environment_id) {
123            Ok(key) => {
124                return Ok(key);
125            }
126            Err(e) => {
127                err = Some(e);
128            }
129        }
130    }
131
132    if let Some(err) = err {
133        Err(err)
134    } else {
135        Err(anyhow!("no public key found"))
136    }
137}
138
139fn validate_with_pubkey(
140    license_key: &str,
141    pubkey_pem: &str,
142    environment_id: &str,
143) -> anyhow::Result<ValidatedLicenseKey> {
144    // don't just read the version out of the payload before verifying it,
145    // trusting unsigned data to determine how to verify the signature is a
146    // bad idea. instead, just try validating it as each version
147    // independently, and if the signature is valid, only then check to
148    // ensure that the version matches what we validated.
149
150    // try current version first, so we can prefer that for error messages
151    let res = validate_with_pubkey_v1(license_key, pubkey_pem, environment_id);
152    let err = match res {
153        Ok(key) => return Ok(key),
154        Err(e) => e,
155    };
156
157    let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
158        // add to this if/when we add new versions
159        // for example,
160        // Box::new(|| validate_with_pubkey_v1(license_key, pubkey_pem, environment_id)),
161    ];
162    for validator in previous_versions {
163        if let Ok(key) = validator() {
164            return Ok(key);
165        }
166    }
167
168    Err(err)
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172struct Payload {
173    sub: String,
174    exp: u64,
175    nbf: u64,
176    iss: String,
177    aud: String,
178    iat: u64,
179    jti: String,
180
181    version: u64,
182    max_credit_consumption_rate: f64,
183    #[serde(default, skip_serializing_if = "is_default")]
184    allow_credit_consumption_override: bool,
185    expiration_behavior: ExpirationBehavior,
186}
187
188fn validate_with_pubkey_v1(
189    license_key: &str,
190    pubkey_pem: &str,
191    environment_id: &str,
192) -> anyhow::Result<ValidatedLicenseKey> {
193    let mut validation = Validation::new(Algorithm::PS256);
194    validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
195    validation.set_audience(&[environment_id, ANY_ENVIRONMENT_AUD]);
196    validation.set_issuer(&[ISSUER]);
197    validation.validate_exp = true;
198    validation.validate_nbf = true;
199    validation.validate_aud = true;
200
201    let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
202
203    let (jwt, expired): (TokenData<Payload>, _) =
204        jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
205            |e| {
206                if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
207                    validation.validate_exp = false;
208                    Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
209                } else {
210                    Err::<_, anyhow::Error>(e.into())
211                }
212            },
213            |jwt| Ok((jwt, false)),
214        )?;
215
216    if jwt.header.typ.as_deref() != Some("JWT") {
217        bail!("invalid jwt header type");
218    }
219
220    if jwt.claims.version != 1 {
221        bail!("invalid license key version");
222    }
223
224    if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
225        bail!("invalid issuance time");
226    }
227
228    if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
229        bail!("revoked license key");
230    }
231
232    Ok(ValidatedLicenseKey {
233        id: jwt.claims.jti,
234        organization: jwt.claims.sub,
235        environment_id: jwt.claims.aud,
236        expiration: jwt.claims.exp,
237        not_before: jwt.claims.nbf,
238
239        max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
240        allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
241        expiration_behavior: jwt.claims.expiration_behavior,
242        expired,
243    })
244}
245
246fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
247    *val == T::default()
248}