Skip to main content

mz_license_keys/
signing.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use anyhow::anyhow;
13use aws_lc_rs::digest;
14use aws_sdk_kms::{
15    primitives::Blob,
16    types::{MessageType, SigningAlgorithmSpec},
17};
18use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
19use jsonwebtoken::{Algorithm, Header};
20use pem::Pem;
21use uuid::Uuid;
22
23use crate::{ExpirationBehavior, ISSUER, Payload};
24
25const VERSION: u64 = 1;
26
27pub async fn get_pubkey_pem(client: &aws_sdk_kms::Client, key_id: &str) -> anyhow::Result<String> {
28    let pubkey = get_pubkey(client, key_id).await?;
29    let pem = Pem::new("PUBLIC KEY", pubkey);
30    Ok(pem.to_string())
31}
32
33pub async fn make_license_key(
34    client: &aws_sdk_kms::Client,
35    key_id: &str,
36    validity: Duration,
37    organization_id: String,
38    environment_id: String,
39    max_credit_consumption_rate: f64,
40    allow_credit_consumption_override: bool,
41    expiration_behavior: ExpirationBehavior,
42) -> anyhow::Result<String> {
43    let mut headers = Header::new(Algorithm::PS256);
44    headers.typ = Some("JWT".to_string());
45    let headers = URL_SAFE_NO_PAD.encode(serde_json::to_string(&headers).unwrap().as_bytes());
46
47    let now = SystemTime::now();
48    let expiration = now + validity;
49    let payload = Payload {
50        sub: organization_id,
51        exp: format_time(&expiration),
52        nbf: format_time(&now),
53        iss: ISSUER.to_string(),
54        aud: environment_id,
55        iat: format_time(&now),
56        jti: Uuid::new_v4().to_string(),
57        version: VERSION,
58        max_credit_consumption_rate,
59        allow_credit_consumption_override,
60        expiration_behavior,
61    };
62    let payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap().as_bytes());
63
64    let signing_string = format!("{}.{}", headers, payload);
65    let signature = URL_SAFE_NO_PAD.encode(sign(client, key_id, signing_string.as_bytes()).await?);
66
67    Ok(format!("{}.{}", signing_string, signature))
68}
69
70async fn get_pubkey(client: &aws_sdk_kms::Client, key_id: &str) -> anyhow::Result<Vec<u8>> {
71    if let Some(pubkey) = client
72        .get_public_key()
73        .key_id(key_id)
74        .send()
75        .await?
76        .public_key
77    {
78        Ok(pubkey.into_inner())
79    } else {
80        Err(anyhow!("failed to get pubkey"))
81    }
82}
83
84async fn sign(
85    client: &aws_sdk_kms::Client,
86    key_id: &str,
87    message: &[u8],
88) -> anyhow::Result<Vec<u8>> {
89    let hash = digest::digest(&digest::SHA256, message);
90    let digest = hash.as_ref().to_vec();
91
92    if let Some(sig) = client
93        .sign()
94        .key_id(key_id)
95        .signing_algorithm(SigningAlgorithmSpec::RsassaPssSha256)
96        .message_type(MessageType::Digest)
97        .message(Blob::new(digest))
98        .send()
99        .await?
100        .signature
101    {
102        Ok(sig.into_inner())
103    } else {
104        Err(anyhow!("failed to get signature"))
105    }
106}
107
108fn format_time(t: &SystemTime) -> u64 {
109    t.duration_since(UNIX_EPOCH).unwrap().as_secs()
110}