use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::multipart::Form;
use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::sync::Arc;
use task_local_extensions::Extensions;
use crate::error::Result;
use crate::middleware::{Middleware, Next};
use crate::RequestInitialiser;
pub struct ClientBuilder {
client: Client,
middleware_stack: Vec<Arc<dyn Middleware>>,
initialiser_stack: Vec<Arc<dyn RequestInitialiser>>,
}
impl ClientBuilder {
pub fn new(client: Client) -> Self {
ClientBuilder {
client,
middleware_stack: Vec::new(),
initialiser_stack: Vec::new(),
}
}
pub fn with<M>(self, middleware: M) -> Self
where
M: Middleware,
{
self.with_arc(Arc::new(middleware))
}
pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middleware_stack.push(middleware);
self
}
pub fn with_init<I>(self, initialiser: I) -> Self
where
I: RequestInitialiser,
{
self.with_arc_init(Arc::new(initialiser))
}
pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
self.initialiser_stack.push(initialiser);
self
}
pub fn build(self) -> ClientWithMiddleware {
ClientWithMiddleware {
inner: self.client,
middleware_stack: self.middleware_stack.into_boxed_slice(),
initialiser_stack: self.initialiser_stack.into_boxed_slice(),
}
}
}
#[derive(Clone)]
pub struct ClientWithMiddleware {
inner: reqwest::Client,
middleware_stack: Box<[Arc<dyn Middleware>]>,
initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
}
impl ClientWithMiddleware {
pub fn new<T>(client: Client, middleware_stack: T) -> Self
where
T: Into<Box<[Arc<dyn Middleware>]>>,
{
ClientWithMiddleware {
inner: client,
middleware_stack: middleware_stack.into(),
initialiser_stack: Box::new([]),
}
}
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::GET, url)
}
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::POST, url)
}
pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::PUT, url)
}
pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::PATCH, url)
}
pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::DELETE, url)
}
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::HEAD, url)
}
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let req = RequestBuilder {
inner: self.inner.request(method, url),
client: self.clone(),
extensions: Extensions::new(),
};
self.initialiser_stack
.iter()
.fold(req, |req, i| i.init(req))
}
pub async fn execute(&self, req: Request) -> Result<Response> {
let mut ext = Extensions::new();
self.execute_with_extensions(req, &mut ext).await
}
pub async fn execute_with_extensions(
&self,
req: Request,
ext: &mut Extensions,
) -> Result<Response> {
let next = Next::new(&self.inner, &self.middleware_stack);
next.run(req, ext).await
}
}
impl From<Client> for ClientWithMiddleware {
fn from(client: Client) -> Self {
ClientWithMiddleware {
inner: client,
middleware_stack: Box::new([]),
initialiser_stack: Box::new([]),
}
}
}
impl fmt::Debug for ClientWithMiddleware {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ClientWithMiddleware")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
#[must_use = "RequestBuilder does nothing until you 'send' it"]
pub struct RequestBuilder {
inner: reqwest::RequestBuilder,
client: ClientWithMiddleware,
extensions: Extensions,
}
impl RequestBuilder {
pub fn header<K, V>(self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
RequestBuilder {
inner: self.inner.header(key, value),
..self
}
}
pub fn headers(self, headers: HeaderMap) -> Self {
RequestBuilder {
inner: self.inner.headers(headers),
..self
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn version(self, version: reqwest::Version) -> Self {
RequestBuilder {
inner: self.inner.version(version),
..self
}
}
pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
where
U: Display,
P: Display,
{
RequestBuilder {
inner: self.inner.basic_auth(username, password),
..self
}
}
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: Display,
{
RequestBuilder {
inner: self.inner.bearer_auth(token),
..self
}
}
pub fn body<T: Into<Body>>(self, body: T) -> Self {
RequestBuilder {
inner: self.inner.body(body),
..self
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn timeout(self, timeout: std::time::Duration) -> Self {
RequestBuilder {
inner: self.inner.timeout(timeout),
..self
}
}
pub fn multipart(self, multipart: Form) -> Self {
RequestBuilder {
inner: self.inner.multipart(multipart),
..self
}
}
pub fn query<T: Serialize + ?Sized>(self, query: &T) -> Self {
RequestBuilder {
inner: self.inner.query(query),
..self
}
}
pub fn form<T: Serialize + ?Sized>(self, form: &T) -> Self {
RequestBuilder {
inner: self.inner.form(form),
..self
}
}
pub fn json<T: Serialize + ?Sized>(self, json: &T) -> Self {
RequestBuilder {
inner: self.inner.json(json),
..self
}
}
pub fn build(self) -> reqwest::Result<Request> {
self.inner.build()
}
pub fn with_extension<T: Send + Sync + 'static>(mut self, extension: T) -> Self {
self.extensions.insert(extension);
self
}
pub fn extensions(&mut self) -> &mut Extensions {
&mut self.extensions
}
pub async fn send(self) -> Result<Response> {
let Self {
inner,
client,
mut extensions,
} = self;
let req = inner.build()?;
client.execute_with_extensions(req, &mut extensions).await
}
pub fn try_clone(&self) -> Option<Self> {
self.inner.try_clone().map(|inner| RequestBuilder {
inner,
client: self.client.clone(),
extensions: Extensions::new(),
})
}
}
impl fmt::Debug for RequestBuilder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("RequestBuilder")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}