Skip to main content

mz_license_keys/
signing.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use anyhow::anyhow;
13use aws_lc_rs::digest;
14use aws_sdk_kms::{
15    primitives::Blob,
16    types::{MessageType, SigningAlgorithmSpec},
17};
18use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
19use jsonwebtoken::{Algorithm, Header};
20use pem::Pem;
21use uuid::Uuid;
22
23use crate::{ExpirationBehavior, ISSUER, Payload};
24
25const VERSION: u64 = 1;
26
27pub async fn get_pubkey_pem(client: &aws_sdk_kms::Client, key_id: &str) -> anyhow::Result<String> {
28    let pubkey = get_pubkey(client, key_id).await?;
29    let pem = Pem::new("PUBLIC KEY", pubkey);
30    Ok(pem.to_string())
31}
32
33#[allow(clippy::too_many_arguments)]
34pub async fn make_license_key(
35    client: &aws_sdk_kms::Client,
36    key_id: &str,
37    validity: Duration,
38    organization_id: String,
39    environment_id: String,
40    max_credit_consumption_rate: f64,
41    allow_credit_consumption_override: bool,
42    expiration_behavior: ExpirationBehavior,
43    entitlements: Vec<String>,
44) -> anyhow::Result<String> {
45    let mut headers = Header::new(Algorithm::PS256);
46    headers.typ = Some("JWT".to_string());
47    let headers = URL_SAFE_NO_PAD.encode(serde_json::to_string(&headers).unwrap().as_bytes());
48
49    let now = SystemTime::now();
50    let expiration = now + validity;
51    let payload = Payload {
52        sub: organization_id,
53        exp: format_time(&expiration),
54        nbf: format_time(&now),
55        iss: ISSUER.to_string(),
56        aud: environment_id,
57        iat: format_time(&now),
58        jti: Uuid::new_v4().to_string(),
59        version: VERSION,
60        max_credit_consumption_rate,
61        allow_credit_consumption_override,
62        expiration_behavior,
63        entitlements,
64    };
65    let payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap().as_bytes());
66
67    let signing_string = format!("{}.{}", headers, payload);
68    let signature = URL_SAFE_NO_PAD.encode(sign(client, key_id, signing_string.as_bytes()).await?);
69
70    Ok(format!("{}.{}", signing_string, signature))
71}
72
73async fn get_pubkey(client: &aws_sdk_kms::Client, key_id: &str) -> anyhow::Result<Vec<u8>> {
74    if let Some(pubkey) = client
75        .get_public_key()
76        .key_id(key_id)
77        .send()
78        .await?
79        .public_key
80    {
81        Ok(pubkey.into_inner())
82    } else {
83        Err(anyhow!("failed to get pubkey"))
84    }
85}
86
87async fn sign(
88    client: &aws_sdk_kms::Client,
89    key_id: &str,
90    message: &[u8],
91) -> anyhow::Result<Vec<u8>> {
92    let hash = digest::digest(&digest::SHA256, message);
93    let digest = hash.as_ref().to_vec();
94
95    if let Some(sig) = client
96        .sign()
97        .key_id(key_id)
98        .signing_algorithm(SigningAlgorithmSpec::RsassaPssSha256)
99        .message_type(MessageType::Digest)
100        .message(Blob::new(digest))
101        .send()
102        .await?
103        .signature
104    {
105        Ok(sig.into_inner())
106    } else {
107        Err(anyhow!("failed to get signature"))
108    }
109}
110
111fn format_time(t: &SystemTime) -> u64 {
112    t.duration_since(UNIX_EPOCH).unwrap().as_secs()
113}