1use crate::StatusCode;
2use std::borrow::Cow;
3use std::fmt::{Debug, Display};
4mod http_error;
5mod macros;
6use crate::headers::{self, Headers};
7pub use http_error::HttpError;
8
9use self::http_error::get_error_code_from_header;
10
11pub type Result<T> = std::result::Result<T, Error>;
13
14#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum ErrorKind {
19    HttpResponse {
21        status: StatusCode,
22        error_code: Option<String>,
23    },
24    Io,
26    DataConversion,
28    Credential,
30    MockFramework,
32    Other,
34}
35
36impl ErrorKind {
37    pub fn into_error(self) -> Error {
38        Error {
39            context: Context::Simple(self),
40        }
41    }
42
43    pub fn http_response(status: StatusCode, error_code: Option<String>) -> Self {
44        Self::HttpResponse { status, error_code }
45    }
46
47    pub fn http_response_from_parts(status: StatusCode, headers: &Headers, body: &[u8]) -> Self {
50        if let Some(header_err_code) = get_error_code_from_header(headers) {
51            Self::HttpResponse {
52                status,
53                error_code: Some(header_err_code),
54            }
55        } else {
56            let (error_code, _) = http_error::get_error_code_message_from_body(
57                body,
58                headers.get_optional_str(&headers::CONTENT_TYPE),
59            );
60            Self::HttpResponse { status, error_code }
61        }
62    }
63}
64
65impl Display for ErrorKind {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            ErrorKind::HttpResponse { status, error_code } => {
69                write!(
70                    f,
71                    "HttpResponse({},{})",
72                    status,
73                    error_code.as_deref().unwrap_or("unknown")
74                )
75            }
76            ErrorKind::Io => write!(f, "Io"),
77            ErrorKind::DataConversion => write!(f, "DataConversion"),
78            ErrorKind::Credential => write!(f, "Credential"),
79            ErrorKind::MockFramework => write!(f, "MockFramework"),
80            ErrorKind::Other => write!(f, "Other"),
81        }
82    }
83}
84
85#[derive(Debug)]
87pub struct Error {
88    context: Context,
89}
90
91impl Error {
92    pub fn new<E>(kind: ErrorKind, error: E) -> Self
94    where
95        E: Into<Box<dyn std::error::Error + Send + Sync>>,
96    {
97        Self {
98            context: Context::Custom(Custom {
99                kind,
100                error: error.into(),
101            }),
102        }
103    }
104
105    #[must_use]
108    pub fn full<E, C>(kind: ErrorKind, error: E, message: C) -> Self
109    where
110        E: Into<Box<dyn std::error::Error + Send + Sync>>,
111        C: Into<Cow<'static, str>>,
112    {
113        Self {
114            context: Context::Full(
115                Custom {
116                    kind,
117                    error: error.into(),
118                },
119                message.into(),
120            ),
121        }
122    }
123
124    #[must_use]
126    pub fn message<C>(kind: ErrorKind, message: C) -> Self
127    where
128        C: Into<Cow<'static, str>>,
129    {
130        Self {
131            context: Context::Message {
132                kind,
133                message: message.into(),
134            },
135        }
136    }
137
138    #[must_use]
140    pub fn with_message<F, C>(kind: ErrorKind, message: F) -> Self
141    where
142        Self: Sized,
143        F: FnOnce() -> C,
144        C: Into<Cow<'static, str>>,
145    {
146        Self {
147            context: Context::Message {
148                kind,
149                message: message().into(),
150            },
151        }
152    }
153
154    #[must_use]
156    pub fn context<C>(self, message: C) -> Self
157    where
158        C: Into<Cow<'static, str>>,
159    {
160        Self::full(self.kind().clone(), self, message)
161    }
162
163    #[must_use]
165    pub fn with_context<F, C>(self, f: F) -> Self
166    where
167        F: FnOnce() -> C,
168        C: Into<Cow<'static, str>>,
169    {
170        self.context(f())
171    }
172
173    pub fn kind(&self) -> &ErrorKind {
175        match &self.context {
176            Context::Simple(kind)
177            | Context::Message { kind, .. }
178            | Context::Custom(Custom { kind, .. })
179            | Context::Full(Custom { kind, .. }, _) => kind,
180        }
181    }
182
183    pub fn into_inner(self) -> std::result::Result<Box<dyn std::error::Error + Send + Sync>, Self> {
185        match self.context {
186            Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
187                Ok(error)
188            }
189            _ => Err(self),
190        }
191    }
192
193    pub fn into_downcast<T: std::error::Error + 'static>(self) -> std::result::Result<T, Self> {
197        if self.downcast_ref::<T>().is_none() {
198            return Err(self);
199        }
200        Ok(*self
202            .into_inner()?
203            .downcast()
204            .expect("failed to unwrap downcast"))
205    }
206
207    pub fn get_ref(&self) -> Option<&(dyn std::error::Error + Send + Sync + 'static)> {
209        match &self.context {
210            Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
211                Some(error.as_ref())
212            }
213            _ => None,
214        }
215    }
216
217    pub fn as_http_error(&self) -> Option<&HttpError> {
222        let mut error = self.get_ref()? as &(dyn std::error::Error);
223        loop {
224            match error.downcast_ref::<HttpError>() {
225                Some(e) => return Some(e),
226                None => error = error.source()?,
227            }
228        }
229    }
230
231    pub fn downcast_ref<T: std::error::Error + 'static>(&self) -> Option<&T> {
233        self.get_ref()?.downcast_ref()
234    }
235
236    pub fn get_mut(&mut self) -> Option<&mut (dyn std::error::Error + Send + Sync + 'static)> {
238        match &mut self.context {
239            Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
240                Some(error.as_mut())
241            }
242            _ => None,
243        }
244    }
245
246    pub fn downcast_mut<T: std::error::Error + 'static>(&mut self) -> Option<&mut T> {
248        self.get_mut()?.downcast_mut()
249    }
250}
251
252impl std::error::Error for Error {
253    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
254        match &self.context {
255            Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
256                Some(&**error)
257            }
258            _ => None,
259        }
260    }
261}
262
263impl From<ErrorKind> for Error {
264    fn from(kind: ErrorKind) -> Self {
265        Self {
266            context: Context::Simple(kind),
267        }
268    }
269}
270
271impl From<std::io::Error> for Error {
272    fn from(error: std::io::Error) -> Self {
273        Self::new(ErrorKind::Io, error)
274    }
275}
276
277impl From<base64::DecodeError> for Error {
278    fn from(error: base64::DecodeError) -> Self {
279        Self::new(ErrorKind::DataConversion, error)
280    }
281}
282
283impl From<serde_json::Error> for Error {
284    fn from(error: serde_json::Error) -> Self {
285        Self::new(ErrorKind::DataConversion, error)
286    }
287}
288
289impl From<std::string::FromUtf8Error> for Error {
290    fn from(error: std::string::FromUtf8Error) -> Self {
291        Self::new(ErrorKind::DataConversion, error)
292    }
293}
294
295impl From<std::str::Utf8Error> for Error {
296    fn from(error: std::str::Utf8Error) -> Self {
297        Self::new(ErrorKind::DataConversion, error)
298    }
299}
300
301impl From<url::ParseError> for Error {
302    fn from(error: url::ParseError) -> Self {
303        Self::new(ErrorKind::DataConversion, error)
304    }
305}
306
307impl Display for Error {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        match &self.context {
310            Context::Simple(kind) => write!(f, "{kind}"),
311            Context::Message { message, .. } => write!(f, "{message}"),
312            Context::Custom(Custom { error, .. }) => write!(f, "{error}"),
313            Context::Full(_, message) => {
314                write!(f, "{message}")
315            }
316        }
317    }
318}
319
320pub trait ResultExt<T>: private::Sealed {
324    fn map_kind(self, kind: ErrorKind) -> Result<T>
326    where
327        Self: Sized;
328
329    fn context<C>(self, kind: ErrorKind, message: C) -> Result<T>
331    where
332        Self: Sized,
333        C: Into<Cow<'static, str>>;
334
335    fn with_context<F, C>(self, kind: ErrorKind, f: F) -> Result<T>
337    where
338        Self: Sized,
339        F: FnOnce() -> C,
340        C: Into<Cow<'static, str>>;
341}
342
343mod private {
344    pub trait Sealed {}
345
346    impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error + Send + Sync + 'static {}
347}
348
349impl<T, E> ResultExt<T> for std::result::Result<T, E>
350where
351    E: std::error::Error + Send + Sync + 'static,
352{
353    fn map_kind(self, kind: ErrorKind) -> Result<T>
354    where
355        Self: Sized,
356    {
357        self.map_err(|e| Error::new(kind, e))
358    }
359
360    fn context<C>(self, kind: ErrorKind, message: C) -> Result<T>
361    where
362        Self: Sized,
363        C: Into<Cow<'static, str>>,
364    {
365        self.map_err(|e| Error {
366            context: Context::Full(
367                Custom {
368                    error: Box::new(e),
369                    kind,
370                },
371                message.into(),
372            ),
373        })
374    }
375
376    fn with_context<F, C>(self, kind: ErrorKind, f: F) -> Result<T>
377    where
378        Self: Sized,
379        F: FnOnce() -> C,
380        C: Into<Cow<'static, str>>,
381    {
382        self.context(kind, f())
383    }
384}
385
386#[derive(Debug)]
387enum Context {
388    Simple(ErrorKind),
389    Message {
390        kind: ErrorKind,
391        message: Cow<'static, str>,
392    },
393    Custom(Custom),
394    Full(Custom, Cow<'static, str>),
395}
396
397#[derive(Debug)]
398struct Custom {
399    kind: ErrorKind,
400    error: Box<dyn std::error::Error + Send + Sync>,
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use std::io;
407
408    #[allow(
409        dead_code,
410        unconditional_recursion,
411        clippy::extra_unused_type_parameters
412    )]
413    fn ensure_send<T: Send>() {
414        ensure_send::<Error>();
415    }
416
417    #[derive(thiserror::Error, Debug)]
418    enum IntermediateError {
419        #[error("second error")]
420        Io(#[from] std::io::Error),
421    }
422
423    fn create_error() -> Error {
424        let inner = io::Error::new(io::ErrorKind::BrokenPipe, "third error");
426        let inner: IntermediateError = inner.into();
427        let inner = io::Error::new(io::ErrorKind::ConnectionAborted, inner);
428
429        Error::new(ErrorKind::Io, inner)
431    }
432
433    #[test]
434    fn errors_display_properly() {
435        let error = create_error();
436
437        let mut error: &dyn std::error::Error = &error;
439        let display = format!("{error}");
440        let mut errors = vec![];
441        while let Some(cause) = error.source() {
442            errors.push(format!("{cause}"));
443            error = cause;
444        }
445
446        assert_eq!(display, "second error");
447        assert_eq!(errors.join(","), "second error,third error");
448
449        let inner = io::Error::new(io::ErrorKind::BrokenPipe, "third error");
450        let error: Result<()> = std::result::Result::<(), std::io::Error>::Err(inner)
451            .context(ErrorKind::Io, "oh no broken pipe!");
452        assert_eq!(format!("{}", error.unwrap_err()), "oh no broken pipe!");
453    }
454
455    #[test]
456    fn downcasting_works() {
457        let error = &create_error() as &dyn std::error::Error;
458        assert!(error.is::<Error>());
459        let downcasted = error
460            .source()
461            .unwrap()
462            .downcast_ref::<std::io::Error>()
463            .unwrap();
464        assert_eq!(format!("{downcasted}"), "second error");
465    }
466
467    #[test]
468    fn turn_into_inner_error() {
469        let error = create_error();
470        let inner = error.into_inner().unwrap();
471        let inner = inner.downcast_ref::<std::io::Error>().unwrap();
472        assert_eq!(format!("{inner}"), "second error");
473
474        let error = create_error();
475        let inner = error.get_ref().unwrap();
476        let inner = inner.downcast_ref::<std::io::Error>().unwrap();
477        assert_eq!(format!("{inner}"), "second error");
478
479        let mut error = create_error();
480        let inner = error.get_mut().unwrap();
481        let inner = inner.downcast_ref::<std::io::Error>().unwrap();
482        assert_eq!(format!("{inner}"), "second error");
483    }
484
485    #[test]
486    fn matching_against_http_error() {
487        let kind =
488            ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &Headers::new(), b"{}");
489
490        assert!(matches!(
491            kind,
492            ErrorKind::HttpResponse {
493                status: StatusCode::ImATeapot,
494                error_code: None
495            }
496        ));
497
498        let kind = ErrorKind::http_response_from_parts(
499            StatusCode::ImATeapot,
500            &Headers::new(),
501            br#"{"error": {"code":"teepot"}}"#,
502        );
503
504        assert!(matches!(
505            kind,
506            ErrorKind::HttpResponse {
507                status: StatusCode::ImATeapot,
508                error_code
509            }
510            if error_code.as_deref() == Some("teepot")
511        ));
512
513        let mut headers = Headers::new();
514        headers.insert(headers::ERROR_CODE, "teapot");
515        let kind = ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &headers, br#"{}"#);
516
517        assert!(matches!(
518            kind,
519            ErrorKind::HttpResponse {
520                status: StatusCode::ImATeapot,
521                error_code
522            }
523            if error_code.as_deref() == Some("teapot")
524        ));
525    }
526
527    #[test]
528    fn set_result_kind() {
529        let result = std::result::Result::<(), _>::Err(create_error());
530        let result = result.map_kind(ErrorKind::Io);
531        assert_eq!(&ErrorKind::Io, result.unwrap_err().kind());
532    }
533}