1use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20const ANY_ENVIRONMENT_AUD: &str = "00000000-0000-0000-0000-000000000000";
25const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
28const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
33pub enum ExpirationBehavior {
34 Warn,
35 DisableClusterCreation,
36 Disable,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct ValidatedLicenseKey {
41 pub max_credit_consumption_rate: f64,
42 pub allow_credit_consumption_override: bool,
43 pub expiration_behavior: ExpirationBehavior,
44 pub expired: bool,
45}
46
47impl ValidatedLicenseKey {
48 pub fn for_tests() -> Self {
49 Self {
50 max_credit_consumption_rate: 999999.0,
51 allow_credit_consumption_override: true,
52 expiration_behavior: ExpirationBehavior::Warn,
53 expired: false,
54 }
55 }
56
57 pub fn disabled() -> Self {
59 Self {
60 max_credit_consumption_rate: 999999.0,
61 allow_credit_consumption_override: true,
62 expiration_behavior: ExpirationBehavior::Warn,
63 expired: false,
64 }
65 }
66
67 pub fn max_credit_consumption_rate(&self) -> Option<f64> {
68 if self.expired
69 && matches!(
70 self.expiration_behavior,
71 ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
72 )
73 {
74 Some(0.0)
75 } else if self.allow_credit_consumption_override {
76 None
77 } else {
78 Some(self.max_credit_consumption_rate)
79 }
80 }
81}
82
83impl Default for ValidatedLicenseKey {
84 fn default() -> Self {
85 Self {
87 max_credit_consumption_rate: 24.0,
88 allow_credit_consumption_override: false,
89 expiration_behavior: ExpirationBehavior::Disable,
90 expired: false,
91 }
92 }
93}
94
95pub fn validate(license_key: &str, environment_id: &str) -> anyhow::Result<ValidatedLicenseKey> {
96 let mut err = None;
97 for pubkey in PUBLIC_KEYS {
98 match validate_with_pubkey(license_key, pubkey, environment_id) {
99 Ok(key) => {
100 return Ok(key);
101 }
102 Err(e) => {
103 err = Some(e);
104 }
105 }
106 }
107
108 if let Some(err) = err {
109 Err(err)
110 } else {
111 Err(anyhow!("no public key found"))
112 }
113}
114
115fn validate_with_pubkey(
116 license_key: &str,
117 pubkey_pem: &str,
118 environment_id: &str,
119) -> anyhow::Result<ValidatedLicenseKey> {
120 let res = validate_with_pubkey_v1(license_key, pubkey_pem, environment_id);
128 let err = match res {
129 Ok(key) => return Ok(key),
130 Err(e) => e,
131 };
132
133 let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
134 ];
138 for validator in previous_versions {
139 if let Ok(key) = validator() {
140 return Ok(key);
141 }
142 }
143
144 Err(err)
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148struct Payload {
149 sub: String,
150 exp: u64,
151 nbf: u64,
152 iss: String,
153 aud: String,
154 iat: u64,
155 jti: String,
156
157 version: u64,
158 max_credit_consumption_rate: f64,
159 #[serde(default, skip_serializing_if = "is_default")]
160 allow_credit_consumption_override: bool,
161 expiration_behavior: ExpirationBehavior,
162}
163
164fn validate_with_pubkey_v1(
165 license_key: &str,
166 pubkey_pem: &str,
167 environment_id: &str,
168) -> anyhow::Result<ValidatedLicenseKey> {
169 let mut validation = Validation::new(Algorithm::PS256);
170 validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
171 validation.set_audience(&[environment_id, ANY_ENVIRONMENT_AUD]);
172 validation.set_issuer(&[ISSUER]);
173 validation.validate_exp = true;
174 validation.validate_nbf = true;
175 validation.validate_aud = true;
176
177 let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
178
179 let (jwt, expired): (TokenData<Payload>, _) =
180 jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
181 |e| {
182 if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
183 validation.validate_exp = false;
184 Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
185 } else {
186 Err::<_, anyhow::Error>(e.into())
187 }
188 },
189 |jwt| Ok((jwt, false)),
190 )?;
191
192 if jwt.header.typ.as_deref() != Some("JWT") {
193 bail!("invalid jwt header type");
194 }
195
196 if jwt.claims.version != 1 {
197 bail!("invalid license key version");
198 }
199
200 if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
201 bail!("invalid issuance time");
202 }
203
204 if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
205 bail!("revoked license key");
206 }
207
208 Ok(ValidatedLicenseKey {
209 max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
210 allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
211 expiration_behavior: jwt.claims.expiration_behavior,
212 expired,
213 })
214}
215
216fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
217 *val == T::default()
218}