1use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
23const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub enum ExpirationBehavior {
30 Warn,
31 DisableClusterCreation,
32 Disable,
33}
34
35#[derive(Debug, Clone)]
36pub struct ValidatedLicenseKey {
37 pub id: String,
38 pub organization: String,
39 pub environment_id: String,
40 pub expiration: u64,
41 pub not_before: u64,
42
43 pub max_credit_consumption_rate: f64,
44 pub allow_credit_consumption_override: bool,
45 pub expiration_behavior: ExpirationBehavior,
46 pub expired: bool,
47}
48
49impl ValidatedLicenseKey {
50 pub fn for_tests() -> Self {
51 Self {
52 id: "".to_string(),
53 organization: "".to_string(),
54 environment_id: "".to_string(),
55 expiration: 0,
56 not_before: 0,
57
58 max_credit_consumption_rate: 999999.0,
59 allow_credit_consumption_override: true,
60 expiration_behavior: ExpirationBehavior::Warn,
61 expired: false,
62 }
63 }
64
65 pub fn disabled() -> Self {
66 Self {
67 id: "".to_string(),
68 organization: "".to_string(),
69 environment_id: "".to_string(),
70 expiration: 0,
71 not_before: 0,
72
73 max_credit_consumption_rate: 999999.0,
74 allow_credit_consumption_override: true,
75 expiration_behavior: ExpirationBehavior::Warn,
76 expired: false,
77 }
78 }
79
80 pub fn max_credit_consumption_rate(&self) -> Option<f64> {
81 if self.expired
82 && matches!(
83 self.expiration_behavior,
84 ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
85 )
86 {
87 Some(0.0)
88 } else if self.allow_credit_consumption_override {
89 None
90 } else {
91 Some(self.max_credit_consumption_rate)
92 }
93 }
94}
95
96impl Default for ValidatedLicenseKey {
97 fn default() -> Self {
98 Self {
99 id: "".to_string(),
100 organization: "".to_string(),
101 environment_id: "".to_string(),
102 expiration: 0,
103 not_before: 0,
104
105 max_credit_consumption_rate: 24.0,
106 allow_credit_consumption_override: false,
107 expiration_behavior: ExpirationBehavior::Disable,
108 expired: false,
109 }
110 }
111}
112
113pub fn validate(license_key: &str) -> anyhow::Result<ValidatedLicenseKey> {
114 let mut err = None;
115 for pubkey in PUBLIC_KEYS {
116 match validate_with_pubkey(license_key, pubkey) {
117 Ok(key) => {
118 return Ok(key);
119 }
120 Err(e) => {
121 err = Some(e);
122 }
123 }
124 }
125
126 if let Some(err) = err {
127 Err(err)
128 } else {
129 Err(anyhow!("no public key found"))
130 }
131}
132
133fn validate_with_pubkey(
134 license_key: &str,
135 pubkey_pem: &str,
136) -> anyhow::Result<ValidatedLicenseKey> {
137 let res = validate_with_pubkey_v1(license_key, pubkey_pem);
145 let err = match res {
146 Ok(key) => return Ok(key),
147 Err(e) => e,
148 };
149
150 let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
151 ];
155 for validator in previous_versions {
156 if let Ok(key) = validator() {
157 return Ok(key);
158 }
159 }
160
161 Err(err)
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
165struct Payload {
166 sub: String,
167 exp: u64,
168 nbf: u64,
169 iss: String,
170 aud: String,
171 iat: u64,
172 jti: String,
173
174 version: u64,
175 max_credit_consumption_rate: f64,
176 #[serde(default, skip_serializing_if = "is_default")]
177 allow_credit_consumption_override: bool,
178 expiration_behavior: ExpirationBehavior,
179}
180
181fn validate_with_pubkey_v1(
182 license_key: &str,
183 pubkey_pem: &str,
184) -> anyhow::Result<ValidatedLicenseKey> {
185 let mut validation = Validation::new(Algorithm::PS256);
186 validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
187 validation.set_issuer(&[ISSUER]);
188 validation.validate_exp = true;
189 validation.validate_nbf = true;
190 validation.validate_aud = false;
191
192 let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
193
194 let (jwt, expired): (TokenData<Payload>, _) =
195 jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
196 |e| {
197 if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
198 validation.validate_exp = false;
199 Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
200 } else {
201 Err::<_, anyhow::Error>(e.into())
202 }
203 },
204 |jwt| Ok((jwt, false)),
205 )?;
206
207 if jwt.header.typ.as_deref() != Some("JWT") {
208 bail!("invalid jwt header type");
209 }
210
211 if jwt.claims.version != 1 {
212 bail!("invalid license key version");
213 }
214
215 if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
216 bail!("invalid issuance time");
217 }
218
219 if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
220 bail!("revoked license key");
221 }
222
223 Ok(ValidatedLicenseKey {
224 id: jwt.claims.jti,
225 organization: jwt.claims.sub,
226 environment_id: jwt.claims.aud,
227 expiration: jwt.claims.exp,
228 not_before: jwt.claims.nbf,
229
230 max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
231 allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
232 expiration_behavior: jwt.claims.expiration_behavior,
233 expired,
234 })
235}
236
237fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
238 *val == T::default()
239}