reqwest_middleware/
client.rs

1use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
2use reqwest::multipart::Form;
3use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
4use serde::Serialize;
5use std::convert::TryFrom;
6use std::fmt::{self, Display};
7use std::sync::Arc;
8use task_local_extensions::Extensions;
9
10use crate::error::Result;
11use crate::middleware::{Middleware, Next};
12use crate::RequestInitialiser;
13
14/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
15///
16/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
17pub struct ClientBuilder {
18    client: Client,
19    middleware_stack: Vec<Arc<dyn Middleware>>,
20    initialiser_stack: Vec<Arc<dyn RequestInitialiser>>,
21}
22
23impl ClientBuilder {
24    pub fn new(client: Client) -> Self {
25        ClientBuilder {
26            client,
27            middleware_stack: Vec::new(),
28            initialiser_stack: Vec::new(),
29        }
30    }
31
32    /// Convenience method to attach middleware.
33    ///
34    /// If you need to keep a reference to the middleware after attaching, use [`with_arc`].
35    ///
36    /// [`with_arc`]: Self::with_arc
37    pub fn with<M>(self, middleware: M) -> Self
38    where
39        M: Middleware,
40    {
41        self.with_arc(Arc::new(middleware))
42    }
43
44    /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`.
45    ///
46    /// [`with`]: Self::with
47    pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
48        self.middleware_stack.push(middleware);
49        self
50    }
51
52    /// Convenience method to attach a request initialiser.
53    ///
54    /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`].
55    ///
56    /// [`with_arc_init`]: Self::with_arc_init
57    pub fn with_init<I>(self, initialiser: I) -> Self
58    where
59        I: RequestInitialiser,
60    {
61        self.with_arc_init(Arc::new(initialiser))
62    }
63
64    /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
65    ///
66    /// [`with_init`]: Self::with_init
67    pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
68        self.initialiser_stack.push(initialiser);
69        self
70    }
71
72    /// Returns a `ClientWithMiddleware` using this builder configuration.
73    pub fn build(self) -> ClientWithMiddleware {
74        ClientWithMiddleware {
75            inner: self.client,
76            middleware_stack: self.middleware_stack.into_boxed_slice(),
77            initialiser_stack: self.initialiser_stack.into_boxed_slice(),
78        }
79    }
80}
81
82/// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every
83/// request.
84#[derive(Clone)]
85pub struct ClientWithMiddleware {
86    inner: reqwest::Client,
87    middleware_stack: Box<[Arc<dyn Middleware>]>,
88    initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
89}
90
91impl ClientWithMiddleware {
92    /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
93    pub fn new<T>(client: Client, middleware_stack: T) -> Self
94    where
95        T: Into<Box<[Arc<dyn Middleware>]>>,
96    {
97        ClientWithMiddleware {
98            inner: client,
99            middleware_stack: middleware_stack.into(),
100            // TODO(conradludgate) - allow downstream code to control this manually if desired
101            initialiser_stack: Box::new([]),
102        }
103    }
104
105    /// See [`Client::get`]
106    pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
107        self.request(Method::GET, url)
108    }
109
110    /// See [`Client::post`]
111    pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
112        self.request(Method::POST, url)
113    }
114
115    /// See [`Client::put`]
116    pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder {
117        self.request(Method::PUT, url)
118    }
119
120    /// See [`Client::patch`]
121    pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder {
122        self.request(Method::PATCH, url)
123    }
124
125    /// See [`Client::delete`]
126    pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder {
127        self.request(Method::DELETE, url)
128    }
129
130    /// See [`Client::head`]
131    pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
132        self.request(Method::HEAD, url)
133    }
134
135    /// See [`Client::request`]
136    pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
137        let req = RequestBuilder {
138            inner: self.inner.request(method, url),
139            client: self.clone(),
140            extensions: Extensions::new(),
141        };
142        self.initialiser_stack
143            .iter()
144            .fold(req, |req, i| i.init(req))
145    }
146
147    /// See [`Client::execute`]
148    pub async fn execute(&self, req: Request) -> Result<Response> {
149        let mut ext = Extensions::new();
150        self.execute_with_extensions(req, &mut ext).await
151    }
152
153    /// Executes a request with initial [`Extensions`].
154    pub async fn execute_with_extensions(
155        &self,
156        req: Request,
157        ext: &mut Extensions,
158    ) -> Result<Response> {
159        let next = Next::new(&self.inner, &self.middleware_stack);
160        next.run(req, ext).await
161    }
162}
163
164/// Create a `ClientWithMiddleware` without any middleware.
165impl From<Client> for ClientWithMiddleware {
166    fn from(client: Client) -> Self {
167        ClientWithMiddleware {
168            inner: client,
169            middleware_stack: Box::new([]),
170            initialiser_stack: Box::new([]),
171        }
172    }
173}
174
175impl fmt::Debug for ClientWithMiddleware {
176    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
177        // skipping middleware_stack field for now
178        f.debug_struct("ClientWithMiddleware")
179            .field("inner", &self.inner)
180            .finish_non_exhaustive()
181    }
182}
183
184/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
185#[must_use = "RequestBuilder does nothing until you 'send' it"]
186pub struct RequestBuilder {
187    inner: reqwest::RequestBuilder,
188    client: ClientWithMiddleware,
189    extensions: Extensions,
190}
191
192impl RequestBuilder {
193    pub fn header<K, V>(self, key: K, value: V) -> Self
194    where
195        HeaderName: TryFrom<K>,
196        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
197        HeaderValue: TryFrom<V>,
198        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
199    {
200        RequestBuilder {
201            inner: self.inner.header(key, value),
202            ..self
203        }
204    }
205
206    pub fn headers(self, headers: HeaderMap) -> Self {
207        RequestBuilder {
208            inner: self.inner.headers(headers),
209            ..self
210        }
211    }
212
213    #[cfg(not(target_arch = "wasm32"))]
214    pub fn version(self, version: reqwest::Version) -> Self {
215        RequestBuilder {
216            inner: self.inner.version(version),
217            ..self
218        }
219    }
220
221    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
222    where
223        U: Display,
224        P: Display,
225    {
226        RequestBuilder {
227            inner: self.inner.basic_auth(username, password),
228            ..self
229        }
230    }
231
232    pub fn bearer_auth<T>(self, token: T) -> Self
233    where
234        T: Display,
235    {
236        RequestBuilder {
237            inner: self.inner.bearer_auth(token),
238            ..self
239        }
240    }
241
242    pub fn body<T: Into<Body>>(self, body: T) -> Self {
243        RequestBuilder {
244            inner: self.inner.body(body),
245            ..self
246        }
247    }
248
249    #[cfg(not(target_arch = "wasm32"))]
250    pub fn timeout(self, timeout: std::time::Duration) -> Self {
251        RequestBuilder {
252            inner: self.inner.timeout(timeout),
253            ..self
254        }
255    }
256
257    pub fn multipart(self, multipart: Form) -> Self {
258        RequestBuilder {
259            inner: self.inner.multipart(multipart),
260            ..self
261        }
262    }
263
264    pub fn query<T: Serialize + ?Sized>(self, query: &T) -> Self {
265        RequestBuilder {
266            inner: self.inner.query(query),
267            ..self
268        }
269    }
270
271    pub fn form<T: Serialize + ?Sized>(self, form: &T) -> Self {
272        RequestBuilder {
273            inner: self.inner.form(form),
274            ..self
275        }
276    }
277
278    pub fn json<T: Serialize + ?Sized>(self, json: &T) -> Self {
279        RequestBuilder {
280            inner: self.inner.json(json),
281            ..self
282        }
283    }
284
285    pub fn build(self) -> reqwest::Result<Request> {
286        self.inner.build()
287    }
288
289    /// Inserts the extension into this request builder
290    pub fn with_extension<T: Send + Sync + 'static>(mut self, extension: T) -> Self {
291        self.extensions.insert(extension);
292        self
293    }
294
295    /// Returns a mutable reference to the internal set of extensions for this request
296    pub fn extensions(&mut self) -> &mut Extensions {
297        &mut self.extensions
298    }
299
300    pub async fn send(self) -> Result<Response> {
301        let Self {
302            inner,
303            client,
304            mut extensions,
305        } = self;
306        let req = inner.build()?;
307        client.execute_with_extensions(req, &mut extensions).await
308    }
309
310    /// Attempt to clone the RequestBuilder.
311    ///
312    /// `None` is returned if the RequestBuilder can not be cloned,
313    /// i.e. if the request body is a stream.
314    ///
315    /// # Extensions
316    /// Note that extensions are not preserved through cloning.
317    pub fn try_clone(&self) -> Option<Self> {
318        self.inner.try_clone().map(|inner| RequestBuilder {
319            inner,
320            client: self.client.clone(),
321            extensions: Extensions::new(),
322        })
323    }
324}
325
326impl fmt::Debug for RequestBuilder {
327    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
328        // skipping middleware_stack field for now
329        f.debug_struct("RequestBuilder")
330            .field("inner", &self.inner)
331            .finish_non_exhaustive()
332    }
333}