azure_storage/authorization/
authorization_policy.rs
use crate::{clients::ServiceType, StorageCredentials, StorageCredentialsInner};
use azure_core::{
auth::Secret,
error::{ErrorKind, ResultExt},
headers::*,
hmac::hmac_sha256,
Context, Method, Policy, PolicyResult, Request, Url,
};
use std::{borrow::Cow, ops::Deref, sync::Arc};
use tracing::trace;
const STORAGE_TOKEN_SCOPE: &str = "https://storage.azure.com/.default";
#[derive(Debug, Clone)]
pub struct AuthorizationPolicy {
credentials: StorageCredentials,
}
impl AuthorizationPolicy {
pub(crate) fn new(credentials: StorageCredentials) -> Self {
Self { credentials }
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl Policy for AuthorizationPolicy {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
trace!("called AuthorizationPolicy::send. self == {:#?}", self);
assert!(
!next.is_empty(),
"Authorization policies cannot be the last policy of a pipeline"
);
{
let creds = self.credentials.0.read().await;
match creds.deref() {
StorageCredentialsInner::Key(account, key) => {
if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
let auth = generate_authorization(
request.headers(),
request.url(),
*request.method(),
account,
key,
*ctx.get()
.expect("ServiceType must be in the Context at this point"),
)?;
request.insert_header(AUTHORIZATION, auth);
}
}
StorageCredentialsInner::SASToken(query_pairs) => {
if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
request
.url_mut()
.query_pairs_mut()
.extend_pairs(query_pairs);
}
}
StorageCredentialsInner::BearerToken(token) => {
request.insert_header(AUTHORIZATION, format!("Bearer {}", token.secret()));
}
StorageCredentialsInner::TokenCredential(token_credential) => {
let bearer_token = token_credential
.get_token(&[STORAGE_TOKEN_SCOPE])
.await
.context(ErrorKind::Credential, "failed to get bearer token")?;
request.insert_header(
AUTHORIZATION,
format!("Bearer {}", bearer_token.token.secret()),
);
}
StorageCredentialsInner::Anonymous => {}
}
};
next[0].send(ctx, request, &next[1..]).await
}
}
fn generate_authorization(
h: &Headers,
u: &Url,
method: Method,
account: &str,
key: &Secret,
service_type: ServiceType,
) -> azure_core::Result<String> {
let str_to_sign = string_to_sign(h, u, method, account, service_type);
let auth = hmac_sha256(&str_to_sign, key).context(
azure_core::error::ErrorKind::Credential,
"failed to sign the hmac",
)?;
Ok(format!("SharedKey {account}:{auth}"))
}
fn add_if_exists<'a>(h: &'a Headers, key: &HeaderName) -> &'a str {
h.get_optional_str(key).unwrap_or_default()
}
#[allow(unknown_lints)]
fn string_to_sign(
h: &Headers,
u: &Url,
method: Method,
account: &str,
service_type: ServiceType,
) -> String {
if matches!(service_type, ServiceType::Table) {
format!(
"{}\n{}\n{}\n{}\n{}",
method.as_ref(),
add_if_exists(h, &CONTENT_MD5),
add_if_exists(h, &CONTENT_TYPE),
add_if_exists(h, &MS_DATE),
canonicalized_resource_table(account, u)
)
} else {
let content_length = h
.get_optional_str(&CONTENT_LENGTH)
.filter(|&v| v != "0")
.unwrap_or_default();
format!(
"{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
method.as_ref(),
add_if_exists(h, &CONTENT_ENCODING),
add_if_exists(h, &CONTENT_LANGUAGE),
content_length,
add_if_exists(h, &CONTENT_MD5),
add_if_exists(h, &CONTENT_TYPE),
add_if_exists(h, &DATE),
add_if_exists(h, &IF_MODIFIED_SINCE),
add_if_exists(h, &IF_MATCH),
add_if_exists(h, &IF_NONE_MATCH),
add_if_exists(h, &IF_UNMODIFIED_SINCE),
add_if_exists(h, &RANGE),
canonicalize_header(h),
canonicalized_resource(account, u)
)
}
}
fn canonicalize_header(headers: &Headers) -> String {
let mut names = headers
.iter()
.filter_map(|(k, _)| (k.as_str().starts_with("x-ms")).then_some(k))
.collect::<Vec<_>>();
names.sort_unstable();
let mut result = String::new();
for header_name in names {
let value = headers.get_optional_str(header_name).unwrap();
let name = header_name.as_str();
result = format!("{result}{name}:{value}\n");
}
result
}
fn canonicalized_resource_table(account: &str, u: &Url) -> String {
format!("/{}{}", account, u.path())
}
fn canonicalized_resource(account: &str, uri: &Url) -> String {
let mut can_res: String = String::new();
can_res += "/";
can_res += account;
for p in uri.path_segments().into_iter().flatten() {
can_res.push('/');
can_res.push_str(p);
}
can_res += "\n";
let query_pairs = uri.query_pairs();
{
let mut qps: Vec<String> = Vec::new();
for (q, _) in query_pairs {
if !(qps.iter().any(|x| x == &*q)) {
qps.push(q.into_owned());
}
}
qps.sort();
for qparam in qps {
let ret = lexy_sort(query_pairs, &qparam);
can_res = can_res + &qparam.to_lowercase() + ":";
for (i, item) in ret.iter().enumerate() {
if i > 0 {
can_res += ",";
}
can_res += item;
}
can_res += "\n";
}
};
can_res[0..can_res.len() - 1].to_owned()
}
fn lexy_sort<'a>(
vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
query_param: &str,
) -> Vec<Cow<'a, str>> {
let mut values = vec
.filter(|(k, _)| *k == query_param)
.map(|(_, v)| v)
.collect::<Vec<_>>();
values.sort_unstable();
values
}
#[cfg(test)]
mod tests {
use super::*;
use azure_core::{BytesStream, Response};
#[derive(Debug, Clone)]
struct AssertSigHeaderUniqueMockPolicy;
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl Policy for AssertSigHeaderUniqueMockPolicy {
async fn send(
&self,
_ctx: &Context,
request: &mut Request,
_next: &[Arc<dyn Policy>],
) -> PolicyResult {
let sig_header_count = request
.url()
.query_pairs()
.filter(|param| param.0 == "sig")
.count();
assert_eq!(sig_header_count, 1);
Ok(Response::new(
azure_core::StatusCode::Accepted,
Headers::new(),
Box::pin(BytesStream::new(vec![])),
))
}
}
const SAMPLE_SAS_TOKEN: &str = "sp=r&st=1970-01-01T00:00:00Z&se=1970-01-01T00:00:00Z&spr=https&sv=1970-01-01&sr=c&sig=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
#[tokio::test]
async fn authorization_policy_applies_sas_token() {
let ctx = Context::default();
let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap();
let auth_policy = AuthorizationPolicy::new(storage_credentials);
let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get);
let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy);
auth_policy
.send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy])
.await
.unwrap();
}
#[tokio::test]
async fn authorization_policy_with_sas_token_does_not_apply_twice() {
let ctx = Context::default();
let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap();
let auth_policy = AuthorizationPolicy::new(storage_credentials);
let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get);
let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy);
auth_policy
.send(
&ctx,
&mut request,
&[assert_sig_header_unique_mock_policy.clone()],
)
.await
.unwrap();
auth_policy
.send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy])
.await
.unwrap();
}
}