1use anyhow::{anyhow, bail};
11use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "signing")]
15mod signing;
16#[cfg(feature = "signing")]
17pub use signing::{get_pubkey_pem, make_license_key};
18
19const ISSUER: &str = "Materialize, Inc.";
20const PUBLIC_KEYS: &[&str] = &[include_str!("license_keys/production.pub")];
23const REVOKED_KEYS: &[&str] = &["eddaf004-dc1e-48cf-9cc1-41d1543d940a"];
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
28pub enum ExpirationBehavior {
29 Warn,
30 DisableClusterCreation,
31 Disable,
32}
33
34#[derive(Debug, Clone)]
35pub struct ValidatedLicenseKey {
36 pub id: String,
37 pub organization: String,
38 pub environment_id: String,
39 pub expiration: u64,
40 pub not_before: u64,
41
42 pub max_credit_consumption_rate: f64,
43 pub allow_credit_consumption_override: bool,
44 pub expiration_behavior: ExpirationBehavior,
45 pub expired: bool,
46}
47
48impl ValidatedLicenseKey {
49 pub fn for_tests() -> Self {
50 Self {
51 id: "".to_string(),
52 organization: "".to_string(),
53 environment_id: "".to_string(),
54 expiration: 0,
55 not_before: 0,
56
57 max_credit_consumption_rate: 999999.0,
58 allow_credit_consumption_override: true,
59 expiration_behavior: ExpirationBehavior::Warn,
60 expired: false,
61 }
62 }
63
64 pub fn disabled() -> Self {
65 Self {
66 id: "".to_string(),
67 organization: "".to_string(),
68 environment_id: "".to_string(),
69 expiration: 0,
70 not_before: 0,
71
72 max_credit_consumption_rate: 999999.0,
73 allow_credit_consumption_override: true,
74 expiration_behavior: ExpirationBehavior::Warn,
75 expired: false,
76 }
77 }
78
79 pub fn max_credit_consumption_rate(&self) -> Option<f64> {
80 if self.expired
81 && matches!(
82 self.expiration_behavior,
83 ExpirationBehavior::DisableClusterCreation | ExpirationBehavior::Disable
84 )
85 {
86 Some(0.0)
87 } else if self.allow_credit_consumption_override {
88 None
89 } else {
90 Some(self.max_credit_consumption_rate)
91 }
92 }
93}
94
95impl Default for ValidatedLicenseKey {
96 fn default() -> Self {
97 Self {
98 id: "".to_string(),
99 organization: "".to_string(),
100 environment_id: "".to_string(),
101 expiration: 0,
102 not_before: 0,
103
104 max_credit_consumption_rate: 24.0,
105 allow_credit_consumption_override: false,
106 expiration_behavior: ExpirationBehavior::Disable,
107 expired: false,
108 }
109 }
110}
111
112pub fn validate(license_key: &str) -> anyhow::Result<ValidatedLicenseKey> {
113 let mut err = None;
114 for pubkey in PUBLIC_KEYS {
115 match validate_with_pubkey(license_key, pubkey) {
116 Ok(key) => {
117 return Ok(key);
118 }
119 Err(e) => {
120 err = Some(e);
121 }
122 }
123 }
124
125 if let Some(err) = err {
126 Err(err)
127 } else {
128 Err(anyhow!("no public key found"))
129 }
130}
131
132fn validate_with_pubkey(
133 license_key: &str,
134 pubkey_pem: &str,
135) -> anyhow::Result<ValidatedLicenseKey> {
136 let res = validate_with_pubkey_v1(license_key, pubkey_pem);
144 let err = match res {
145 Ok(key) => return Ok(key),
146 Err(e) => e,
147 };
148
149 let previous_versions: Vec<Box<dyn Fn() -> anyhow::Result<ValidatedLicenseKey>>> = vec![
150 ];
154 for validator in previous_versions {
155 if let Ok(key) = validator() {
156 return Ok(key);
157 }
158 }
159
160 Err(err)
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164struct Payload {
165 sub: String,
166 exp: u64,
167 nbf: u64,
168 iss: String,
169 aud: String,
170 iat: u64,
171 jti: String,
172
173 version: u64,
174 max_credit_consumption_rate: f64,
175 #[serde(default, skip_serializing_if = "is_default")]
176 allow_credit_consumption_override: bool,
177 expiration_behavior: ExpirationBehavior,
178}
179
180fn validate_with_pubkey_v1(
181 license_key: &str,
182 pubkey_pem: &str,
183) -> anyhow::Result<ValidatedLicenseKey> {
184 let mut validation = Validation::new(Algorithm::PS256);
185 validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
186 validation.set_issuer(&[ISSUER]);
187 validation.validate_exp = true;
188 validation.validate_nbf = true;
189 validation.validate_aud = false;
190
191 let key = DecodingKey::from_rsa_pem(pubkey_pem.as_bytes())?;
192
193 let (jwt, expired): (TokenData<Payload>, _) =
194 jsonwebtoken::decode(license_key, &key, &validation).map_or_else(
195 |e| {
196 if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
197 validation.validate_exp = false;
198 Ok((jsonwebtoken::decode(license_key, &key, &validation)?, true))
199 } else {
200 Err::<_, anyhow::Error>(e.into())
201 }
202 },
203 |jwt| Ok((jwt, false)),
204 )?;
205
206 if jwt.header.typ.as_deref() != Some("JWT") {
207 bail!("invalid jwt header type");
208 }
209
210 if jwt.claims.version != 1 {
211 bail!("invalid license key version");
212 }
213
214 if !(jwt.claims.nbf..=jwt.claims.exp).contains(&jwt.claims.iat) {
215 bail!("invalid issuance time");
216 }
217
218 if REVOKED_KEYS.contains(&jwt.claims.jti.as_str()) {
219 bail!("revoked license key");
220 }
221
222 Ok(ValidatedLicenseKey {
223 id: jwt.claims.jti,
224 organization: jwt.claims.sub,
225 environment_id: jwt.claims.aud,
226 expiration: jwt.claims.exp,
227 not_before: jwt.claims.nbf,
228
229 max_credit_consumption_rate: jwt.claims.max_credit_consumption_rate,
230 allow_credit_consumption_override: jwt.claims.allow_credit_consumption_override,
231 expiration_behavior: jwt.claims.expiration_behavior,
232 expired,
233 })
234}
235
236fn is_default<T: PartialEq + Eq + Default>(val: &T) -> bool {
237 *val == T::default()
238}