use crate::http::extensions::Extensions;
use crate::http::Headers;
use crate::http::HttpError;
use aws_smithy_types::body::SdkBody;
use http as http0;
use http0::uri::PathAndQuery;
use http0::Method;
use std::borrow::Cow;
#[non_exhaustive]
pub struct RequestParts<B = SdkBody> {
pub uri: Uri,
pub headers: Headers,
pub body: B,
}
#[derive(Debug)]
pub struct Request<B = SdkBody> {
body: B,
uri: Uri,
method: Method,
extensions: Extensions,
headers: Headers,
}
#[derive(Debug, Clone)]
pub struct Uri {
as_string: String,
parsed: ParsedUri,
}
#[derive(Debug, Clone)]
enum ParsedUri {
H0(http0::Uri),
H1(http1::Uri),
}
impl ParsedUri {
fn path_and_query(&self) -> &str {
match &self {
ParsedUri::H0(u) => u.path_and_query().map(|pq| pq.as_str()).unwrap_or(""),
ParsedUri::H1(u) => u.path_and_query().map(|pq| pq.as_str()).unwrap_or(""),
}
}
fn path(&self) -> &str {
match &self {
ParsedUri::H0(u) => u.path(),
ParsedUri::H1(u) => u.path(),
}
}
fn query(&self) -> Option<&str> {
match &self {
ParsedUri::H0(u) => u.query(),
ParsedUri::H1(u) => u.query(),
}
}
}
impl Uri {
pub fn set_endpoint(&mut self, endpoint: &str) -> Result<(), HttpError> {
let endpoint: http0::Uri = endpoint.parse().map_err(HttpError::invalid_uri)?;
let endpoint = endpoint.into_parts();
let authority = endpoint
.authority
.ok_or_else(|| HttpError::new("endpoint must contain authority"))?;
let scheme = endpoint
.scheme
.ok_or_else(|| HttpError::new("endpoint must have scheme"))?;
let new_uri = http0::Uri::builder()
.authority(authority)
.scheme(scheme)
.path_and_query(merge_paths(endpoint.path_and_query, &self.parsed).as_ref())
.build()
.map_err(HttpError::new)?;
self.as_string = new_uri.to_string();
self.parsed = ParsedUri::H0(new_uri);
Ok(())
}
pub fn path(&self) -> &str {
self.parsed.path()
}
pub fn query(&self) -> Option<&str> {
self.parsed.query()
}
fn from_http0x_uri(uri: http0::Uri) -> Self {
Self {
as_string: uri.to_string(),
parsed: ParsedUri::H0(uri),
}
}
#[allow(dead_code)]
fn from_http1x_uri(uri: http1::Uri) -> Self {
Self {
as_string: uri.to_string(),
parsed: ParsedUri::H1(uri),
}
}
#[allow(dead_code)]
fn into_h0(self) -> http0::Uri {
match self.parsed {
ParsedUri::H0(uri) => uri,
ParsedUri::H1(_uri) => self.as_string.parse().unwrap(),
}
}
}
fn merge_paths(endpoint_path: Option<PathAndQuery>, uri: &ParsedUri) -> Cow<'_, str> {
let uri_path_and_query = uri.path_and_query();
let endpoint_path = match endpoint_path {
None => return Cow::Borrowed(uri_path_and_query),
Some(path) => path,
};
if let Some(query) = endpoint_path.query() {
tracing::warn!(query = %query, "query specified in endpoint will be ignored during endpoint resolution");
}
let endpoint_path = endpoint_path.path();
if endpoint_path.is_empty() {
Cow::Borrowed(uri_path_and_query)
} else {
let ep_no_slash = endpoint_path.strip_suffix('/').unwrap_or(endpoint_path);
let uri_path_no_slash = uri_path_and_query
.strip_prefix('/')
.unwrap_or(uri_path_and_query);
Cow::Owned(format!("{}/{}", ep_no_slash, uri_path_no_slash))
}
}
impl TryFrom<String> for Uri {
type Error = HttpError;
fn try_from(value: String) -> Result<Self, Self::Error> {
let parsed = ParsedUri::H0(value.parse().map_err(HttpError::invalid_uri)?);
Ok(Uri {
as_string: value,
parsed,
})
}
}
impl<'a> TryFrom<&'a str> for Uri {
type Error = HttpError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::try_from(value.to_string())
}
}
#[cfg(feature = "http-02x")]
impl From<http0::Uri> for Uri {
fn from(value: http::Uri) -> Self {
Uri::from_http0x_uri(value)
}
}
#[cfg(feature = "http-02x")]
impl<B> TryInto<http0::Request<B>> for Request<B> {
type Error = HttpError;
fn try_into(self) -> Result<http::Request<B>, Self::Error> {
self.try_into_http02x()
}
}
#[cfg(feature = "http-1x")]
impl<B> TryInto<http1::Request<B>> for Request<B> {
type Error = HttpError;
fn try_into(self) -> Result<http1::Request<B>, Self::Error> {
self.try_into_http1x()
}
}
impl<B> Request<B> {
#[cfg(feature = "http-02x")]
pub fn try_into_http02x(self) -> Result<http0::Request<B>, HttpError> {
let mut req = http::Request::builder()
.uri(self.uri.into_h0())
.method(self.method)
.body(self.body)
.expect("known valid");
*req.headers_mut() = self.headers.http0_headermap();
*req.extensions_mut() = self.extensions.try_into()?;
Ok(req)
}
#[cfg(feature = "http-1x")]
pub fn try_into_http1x(self) -> Result<http1::Request<B>, HttpError> {
let mut req = http1::Request::builder()
.uri(self.uri.as_string)
.method(self.method.as_str())
.body(self.body)
.expect("known valid");
*req.headers_mut() = self.headers.http1_headermap();
*req.extensions_mut() = self.extensions.try_into()?;
Ok(req)
}
pub fn map<U>(self, f: impl Fn(B) -> U) -> Request<U> {
Request {
body: f(self.body),
uri: self.uri,
method: self.method,
extensions: self.extensions,
headers: self.headers,
}
}
pub fn new(body: B) -> Self {
Self {
body,
uri: Uri::from_http0x_uri(http0::Uri::from_static("/")),
method: Method::GET,
extensions: Default::default(),
headers: Default::default(),
}
}
pub fn into_parts(self) -> RequestParts<B> {
RequestParts {
uri: self.uri,
headers: self.headers,
body: self.body,
}
}
pub fn headers(&self) -> &Headers {
&self.headers
}
pub fn headers_mut(&mut self) -> &mut Headers {
&mut self.headers
}
pub fn body(&self) -> &B {
&self.body
}
pub fn body_mut(&mut self) -> &mut B {
&mut self.body
}
pub fn into_body(self) -> B {
self.body
}
pub fn method(&self) -> &str {
self.method.as_str()
}
pub fn uri(&self) -> &str {
&self.uri.as_string
}
pub fn uri_mut(&mut self) -> &mut Uri {
&mut self.uri
}
pub fn set_uri<U>(&mut self, uri: U) -> Result<(), U::Error>
where
U: TryInto<Uri>,
{
let uri = uri.try_into()?;
self.uri = uri;
Ok(())
}
pub fn add_extension<T: Send + Sync + Clone + 'static>(&mut self, extension: T) {
self.extensions.insert(extension.clone());
}
}
impl Request<SdkBody> {
pub fn try_clone(&self) -> Option<Self> {
let body = self.body().try_clone()?;
Some(Self {
body,
uri: self.uri.clone(),
method: self.method.clone(),
extensions: Extensions::new(),
headers: self.headers.clone(),
})
}
pub fn take_body(&mut self) -> SdkBody {
std::mem::replace(self.body_mut(), SdkBody::taken())
}
pub fn empty() -> Self {
Self::new(SdkBody::empty())
}
pub fn get(uri: impl AsRef<str>) -> Result<Self, HttpError> {
let mut req = Self::new(SdkBody::empty());
req.set_uri(uri.as_ref())?;
Ok(req)
}
}
#[cfg(feature = "http-02x")]
impl<B> TryFrom<http0::Request<B>> for Request<B> {
type Error = HttpError;
fn try_from(value: http::Request<B>) -> Result<Self, Self::Error> {
let (parts, body) = value.into_parts();
let headers = Headers::try_from(parts.headers)?;
Ok(Self {
body,
uri: parts.uri.into(),
method: parts.method,
extensions: parts.extensions.into(),
headers,
})
}
}
#[cfg(feature = "http-1x")]
impl<B> TryFrom<http1::Request<B>> for Request<B> {
type Error = HttpError;
fn try_from(value: http1::Request<B>) -> Result<Self, Self::Error> {
let (parts, body) = value.into_parts();
let headers = Headers::try_from(parts.headers)?;
Ok(Self {
body,
uri: Uri::from_http1x_uri(parts.uri),
method: Method::from_bytes(parts.method.as_str().as_bytes()).expect("valid"),
extensions: parts.extensions.into(),
headers,
})
}
}
#[cfg(all(test, feature = "http-02x", feature = "http-1x"))]
mod test {
use super::*;
use aws_smithy_types::body::SdkBody;
use http::header::{AUTHORIZATION, CONTENT_LENGTH};
use http::Uri;
#[test]
fn non_ascii_requests() {
let request = http::Request::builder()
.header("k", "😹")
.body(SdkBody::empty())
.unwrap();
let request: Request = request
.try_into()
.expect("failed to convert a non-string header");
assert_eq!(request.headers().get("k"), Some("😹"))
}
#[test]
fn request_can_be_created() {
let req = http::Request::builder()
.uri("http://foo.com")
.body(SdkBody::from("hello"))
.unwrap();
let mut req = super::Request::try_from(req).unwrap();
req.headers_mut().insert("a", "b");
assert_eq!(req.headers().get("a").unwrap(), "b");
req.headers_mut().append("a", "c");
assert_eq!(req.headers().get("a").unwrap(), "b");
let http0 = req.try_into_http02x().unwrap();
assert_eq!(http0.uri(), "http://foo.com");
}
#[test]
fn uri_mutations() {
let req = http::Request::builder()
.uri("http://foo.com")
.body(SdkBody::from("hello"))
.unwrap();
let mut req = super::Request::try_from(req).unwrap();
assert_eq!(req.uri(), "http://foo.com/");
req.set_uri("http://bar.com").unwrap();
assert_eq!(req.uri(), "http://bar.com");
let http0 = req.try_into_http02x().unwrap();
assert_eq!(http0.uri(), "http://bar.com");
}
#[test]
#[should_panic]
fn header_panics() {
let req = http::Request::builder()
.uri("http://foo.com")
.body(SdkBody::from("hello"))
.unwrap();
let mut req = super::Request::try_from(req).unwrap();
let _ = req
.headers_mut()
.try_insert("a\nb", "a\nb")
.expect_err("invalid header");
let _ = req.headers_mut().insert("a\nb", "a\nb");
}
#[test]
fn try_clone_clones_all_data() {
let request = ::http::Request::builder()
.uri(Uri::from_static("https://www.amazon.com"))
.method("POST")
.header(CONTENT_LENGTH, 456)
.header(AUTHORIZATION, "Token: hello")
.body(SdkBody::from("hello world!"))
.expect("valid request");
let request: super::Request = request.try_into().unwrap();
let cloned = request.try_clone().expect("request is cloneable");
assert_eq!("https://www.amazon.com/", cloned.uri());
assert_eq!("POST", cloned.method());
assert_eq!(2, cloned.headers().len());
assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),);
assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap());
assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap());
}
#[test]
fn valid_round_trips() {
let request = || {
http::Request::builder()
.uri(Uri::from_static("https://www.amazon.com"))
.method("POST")
.header(CONTENT_LENGTH, 456)
.header(AUTHORIZATION, "Token: hello")
.header("multi", "v1")
.header("multi", "v2")
.body(SdkBody::from("hello world!"))
.expect("valid request")
};
check_roundtrip(request);
}
macro_rules! req_eq {
($a: expr, $b: expr) => {{
assert_eq!($a.uri(), $b.uri(), "status code mismatch");
assert_eq!($a.headers(), $b.headers(), "header mismatch");
assert_eq!($a.method(), $b.method(), "header mismatch");
assert_eq!($a.body().bytes(), $b.body().bytes(), "data mismatch");
assert_eq!(
$a.extensions().len(),
$b.extensions().len(),
"extensions size mismatch"
);
}};
}
#[track_caller]
fn check_roundtrip(req: impl Fn() -> http0::Request<SdkBody>) {
let mut container = super::Request::try_from(req()).unwrap();
container.add_extension(5_u32);
let mut h1 = container
.try_into_http1x()
.expect("failed converting to http1x");
assert_eq!(h1.extensions().get::<u32>(), Some(&5));
h1.extensions_mut().remove::<u32>();
let mut container = super::Request::try_from(h1).expect("failed converting from http1x");
container.add_extension(5_u32);
let mut h0 = container
.try_into_http02x()
.expect("failed converting back to http0x");
assert_eq!(h0.extensions().get::<u32>(), Some(&5));
h0.extensions_mut().remove::<u32>();
req_eq!(h0, req());
}
}