1use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20const ANY_ENVIRONMENT_AUD: &str = "00000000-0000-0000-0000-000000000000";
25const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
28const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
33pub enum ExpirationBehavior {
34 Warn,
35 DisableClusterCreation,
36 Disable,
37}
38
39#[derive(Debug, Clone)]
40pub struct ValidatedLicenseKey {
41 pub id: String,
42 pub organization: String,
43 pub environment_id: String,
44 pub expiration: u64,
45 pub not_before: u64,
46
47 pub max_credit_consumption_rate: f64,
48 pub allow_credit_consumption_override: bool,
49 pub expiration_behavior: ExpirationBehavior,
50 pub expired: bool,
51}
52
53impl ValidatedLicenseKey {
54 pub fn for_tests() -> Self {
55 Self {
56 id: "".to_string(),
57 organization: "".to_string(),
58 environment_id: "".to_string(),
59 expiration: 0,
60 not_before: 0,
61
62 max_credit_consumption_rate: 999999.0,
63 allow_credit_consumption_override: true,
64 expiration_behavior: ExpirationBehavior::Warn,
65 expired: false,
66 }
67 }
68
69 pub fn disabled() -> Self {
70 Self {
71 id: "".to_string(),
72 organization: "".to_string(),
73 environment_id: "".to_string(),
74 expiration: 0,
75 not_before: 0,
76
77 max_credit_consumption_rate: 999999.0,
78 allow_credit_consumption_override: true,
79 expiration_behavior: ExpirationBehavior::Warn,
80 expired: false,
81 }
82 }
83
84 pub fn max_credit_consumption_rate(&self) -> Option<f64> {
85 if self.expired
86 && matches!(
87 self.expiration_behavior,
88 ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
89 )
90 {
91 Some(0.0)
92 } else if self.allow_credit_consumption_override {
93 None
94 } else {
95 Some(self.max_credit_consumption_rate)
96 }
97 }
98}
99
100impl Default for ValidatedLicenseKey {
101 fn default() -> Self {
102 Self {
103 id: "".to_string(),
104 organization: "".to_string(),
105 environment_id: "".to_string(),
106 expiration: 0,
107 not_before: 0,
108
109 max_credit_consumption_rate: 24.0,
110 allow_credit_consumption_override: false,
111 expiration_behavior: ExpirationBehavior::Disable,
112 expired: false,
113 }
114 }
115}
116
117pub fn validate(license_key: &str, environment_id: &str) -> anyhow::Result<ValidatedLicenseKey> {
118 let mut err = None;
119 for pubkey in PUBLIC_KEYS {
120 match validate_with_pubkey(license_key, pubkey, environment_id) {
121 Ok(key) => {
122 return Ok(key);
123 }
124 Err(e) => {
125 err = Some(e);
126 }
127 }
128 }
129
130 if let Some(err) = err {
131 Err(err)
132 } else {
133 Err(anyhow!("no public key found"))
134 }
135}
136
137fn validate_with_pubkey(
138 license_key: &str,
139 pubkey_pem: &str,
140 environment_id: &str,
141) -> anyhow::Result<ValidatedLicenseKey> {
142 let res = validate_with_pubkey_v1(license_key, pubkey_pem, environment_id);
150 let err = match res {
151 Ok(key) => return Ok(key),
152 Err(e) => e,
153 };
154
155 let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
156 ];
160 for validator in previous_versions {
161 if let Ok(key) = validator() {
162 return Ok(key);
163 }
164 }
165
166 Err(err)
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170struct Payload {
171 sub: String,
172 exp: u64,
173 nbf: u64,
174 iss: String,
175 aud: String,
176 iat: u64,
177 jti: String,
178
179 version: u64,
180 max_credit_consumption_rate: f64,
181 #[serde(default, skip_serializing_if = "is_default")]
182 allow_credit_consumption_override: bool,
183 expiration_behavior: ExpirationBehavior,
184}
185
186fn validate_with_pubkey_v1(
187 license_key: &str,
188 pubkey_pem: &str,
189 environment_id: &str,
190) -> anyhow::Result<ValidatedLicenseKey> {
191 let mut validation = Validation::new(Algorithm::PS256);
192 validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
193 validation.set_audience(&[environment_id, ANY_ENVIRONMENT_AUD]);
194 validation.set_issuer(&[ISSUER]);
195 validation.validate_exp = true;
196 validation.validate_nbf = true;
197 validation.validate_aud = true;
198
199 let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
200
201 let (jwt, expired): (TokenData<Payload>, _) =
202 jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
203 |e| {
204 if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
205 validation.validate_exp = false;
206 Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
207 } else {
208 Err::<_, anyhow::Error>(e.into())
209 }
210 },
211 |jwt| Ok((jwt, false)),
212 )?;
213
214 if jwt.header.typ.as_deref() != Some("JWT") {
215 bail!("invalid jwt header type");
216 }
217
218 if jwt.claims.version != 1 {
219 bail!("invalid license key version");
220 }
221
222 if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
223 bail!("invalid issuance time");
224 }
225
226 if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
227 bail!("revoked license key");
228 }
229
230 Ok(ValidatedLicenseKey {
231 id: jwt.claims.jti,
232 organization: jwt.claims.sub,
233 environment_id: jwt.claims.aud,
234 expiration: jwt.claims.exp,
235 not_before: jwt.claims.nbf,
236
237 max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
238 allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
239 expiration_behavior: jwt.claims.expiration_behavior,
240 expired,
241 })
242}
243
244fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
245 *val == T::default()
246}