mz_license_keys/
signing.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use anyhow::anyhow;
13use aws_sdk_kms::{
14    primitives::Blob,
15    types::{MessageType, SigningAlgorithmSpec},
16};
17use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
18use jsonwebtoken::{Algorithm, Header};
19use pem::Pem;
20use sha2::{Digest, Sha256};
21use uuid::Uuid;
22
23use crate::{ExpirationBehavior, ISSUER, Payload};
24
25const VERSION: u64 = 1;
26
27pub async fn get_pubkey_pem(client: &aws_sdk_kms::Client, key_id: &str) -> anyhow::Result<String> {
28    let pubkey = get_pubkey(client, key_id).await?;
29    let pem = Pem::new("PUBLIC KEY", pubkey);
30    Ok(pem.to_string())
31}
32
33pub async fn make_license_key(
34    client: &aws_sdk_kms::Client,
35    key_id: &str,
36    validity: Duration,
37    organization_id: String,
38    environment_id: String,
39    max_credit_consumption_rate: f64,
40    allow_credit_consumption_override: bool,
41    expiration_behavior: ExpirationBehavior,
42) -> anyhow::Result<String> {
43    let mut headers = Header::new(Algorithm::PS256);
44    headers.typ = Some("JWT".to_string());
45    let headers = URL_SAFE_NO_PAD.encode(serde_json::to_string(&headers).unwrap().as_bytes());
46
47    let now = SystemTime::now();
48    let expiration = now + validity;
49    let payload = Payload {
50        sub: organization_id,
51        exp: format_time(&expiration),
52        nbf: format_time(&now),
53        iss: ISSUER.to_string(),
54        aud: environment_id,
55        iat: format_time(&now),
56        jti: Uuid::new_v4().to_string(),
57        version: VERSION,
58        max_credit_consumption_rate,
59        allow_credit_consumption_override,
60        expiration_behavior,
61    };
62    let payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap().as_bytes());
63
64    let signing_string = format!("{}.{}", headers, payload);
65    let signature = URL_SAFE_NO_PAD.encode(sign(client, key_id, signing_string.as_bytes()).await?);
66
67    Ok(format!("{}.{}", signing_string, signature))
68}
69
70async fn get_pubkey(client: &aws_sdk_kms::Client, key_id: &str) -> anyhow::Result<Vec<u8>> {
71    if let Some(pubkey) = client
72        .get_public_key()
73        .key_id(key_id)
74        .send()
75        .await?
76        .public_key
77    {
78        Ok(pubkey.into_inner())
79    } else {
80        Err(anyhow!("failed to get pubkey"))
81    }
82}
83
84async fn sign(
85    client: &aws_sdk_kms::Client,
86    key_id: &str,
87    message: &[u8],
88) -> anyhow::Result<Vec<u8>> {
89    let mut hasher = Sha256::new();
90    hasher.update(message);
91    let digest = hasher.finalize().to_vec();
92
93    if let Some(sig) = client
94        .sign()
95        .key_id(key_id)
96        .signing_algorithm(SigningAlgorithmSpec::RsassaPssSha256)
97        .message_type(MessageType::Digest)
98        .message(Blob::new(digest))
99        .send()
100        .await?
101        .signature
102    {
103        Ok(sig.into_inner())
104    } else {
105        Err(anyhow!("failed to get signature"))
106    }
107}
108
109fn format_time(t: &SystemTime) -> u64 {
110    t.duration_since(UNIX_EPOCH).unwrap().as_secs()
111}