1use crate::StatusCode;
2use std::borrow::Cow;
3use std::fmt::{Debug, Display};
4mod http_error;
5mod macros;
6use crate::headers::{self, Headers};
7pub use http_error::HttpError;
8
9use self::http_error::get_error_code_from_header;
10
11pub type Result<T> = std::result::Result<T, Error>;
13
14#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum ErrorKind {
19 HttpResponse {
21 status: StatusCode,
22 error_code: Option<String>,
23 },
24 Io,
26 DataConversion,
28 Credential,
30 MockFramework,
32 Other,
34}
35
36impl ErrorKind {
37 pub fn into_error(self) -> Error {
38 Error {
39 context: Context::Simple(self),
40 }
41 }
42
43 pub fn http_response(status: StatusCode, error_code: Option<String>) -> Self {
44 Self::HttpResponse { status, error_code }
45 }
46
47 pub fn http_response_from_parts(status: StatusCode, headers: &Headers, body: &[u8]) -> Self {
50 if let Some(header_err_code) = get_error_code_from_header(headers) {
51 Self::HttpResponse {
52 status,
53 error_code: Some(header_err_code),
54 }
55 } else {
56 let (error_code, _) = http_error::get_error_code_message_from_body(
57 body,
58 headers.get_optional_str(&headers::CONTENT_TYPE),
59 );
60 Self::HttpResponse { status, error_code }
61 }
62 }
63}
64
65impl Display for ErrorKind {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 ErrorKind::HttpResponse { status, error_code } => {
69 write!(
70 f,
71 "HttpResponse({},{})",
72 status,
73 error_code.as_deref().unwrap_or("unknown")
74 )
75 }
76 ErrorKind::Io => write!(f, "Io"),
77 ErrorKind::DataConversion => write!(f, "DataConversion"),
78 ErrorKind::Credential => write!(f, "Credential"),
79 ErrorKind::MockFramework => write!(f, "MockFramework"),
80 ErrorKind::Other => write!(f, "Other"),
81 }
82 }
83}
84
85#[derive(Debug)]
87pub struct Error {
88 context: Context,
89}
90
91impl Error {
92 pub fn new<E>(kind: ErrorKind, error: E) -> Self
94 where
95 E: Into<Box<dyn std::error::Error + Send + Sync>>,
96 {
97 Self {
98 context: Context::Custom(Custom {
99 kind,
100 error: error.into(),
101 }),
102 }
103 }
104
105 #[must_use]
108 pub fn full<E, C>(kind: ErrorKind, error: E, message: C) -> Self
109 where
110 E: Into<Box<dyn std::error::Error + Send + Sync>>,
111 C: Into<Cow<'static, str>>,
112 {
113 Self {
114 context: Context::Full(
115 Custom {
116 kind,
117 error: error.into(),
118 },
119 message.into(),
120 ),
121 }
122 }
123
124 #[must_use]
126 pub fn message<C>(kind: ErrorKind, message: C) -> Self
127 where
128 C: Into<Cow<'static, str>>,
129 {
130 Self {
131 context: Context::Message {
132 kind,
133 message: message.into(),
134 },
135 }
136 }
137
138 #[must_use]
140 pub fn with_message<F, C>(kind: ErrorKind, message: F) -> Self
141 where
142 Self: Sized,
143 F: FnOnce() -> C,
144 C: Into<Cow<'static, str>>,
145 {
146 Self {
147 context: Context::Message {
148 kind,
149 message: message().into(),
150 },
151 }
152 }
153
154 #[must_use]
156 pub fn context<C>(self, message: C) -> Self
157 where
158 C: Into<Cow<'static, str>>,
159 {
160 Self::full(self.kind().clone(), self, message)
161 }
162
163 #[must_use]
165 pub fn with_context<F, C>(self, f: F) -> Self
166 where
167 F: FnOnce() -> C,
168 C: Into<Cow<'static, str>>,
169 {
170 self.context(f())
171 }
172
173 pub fn kind(&self) -> &ErrorKind {
175 match &self.context {
176 Context::Simple(kind)
177 | Context::Message { kind, .. }
178 | Context::Custom(Custom { kind, .. })
179 | Context::Full(Custom { kind, .. }, _) => kind,
180 }
181 }
182
183 pub fn into_inner(self) -> std::result::Result<Box<dyn std::error::Error + Send + Sync>, Self> {
185 match self.context {
186 Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
187 Ok(error)
188 }
189 _ => Err(self),
190 }
191 }
192
193 pub fn into_downcast<T: std::error::Error + 'static>(self) -> std::result::Result<T, Self> {
197 if self.downcast_ref::<T>().is_none() {
198 return Err(self);
199 }
200 Ok(*self
202 .into_inner()?
203 .downcast()
204 .expect("failed to unwrap downcast"))
205 }
206
207 pub fn get_ref(&self) -> Option<&(dyn std::error::Error + Send + Sync + 'static)> {
209 match &self.context {
210 Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
211 Some(error.as_ref())
212 }
213 _ => None,
214 }
215 }
216
217 pub fn as_http_error(&self) -> Option<&HttpError> {
222 let mut error = self.get_ref()? as &(dyn std::error::Error);
223 loop {
224 match error.downcast_ref::<HttpError>() {
225 Some(e) => return Some(e),
226 None => error = error.source()?,
227 }
228 }
229 }
230
231 pub fn downcast_ref<T: std::error::Error + 'static>(&self) -> Option<&T> {
233 self.get_ref()?.downcast_ref()
234 }
235
236 pub fn get_mut(&mut self) -> Option<&mut (dyn std::error::Error + Send + Sync + 'static)> {
238 match &mut self.context {
239 Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
240 Some(error.as_mut())
241 }
242 _ => None,
243 }
244 }
245
246 pub fn downcast_mut<T: std::error::Error + 'static>(&mut self) -> Option<&mut T> {
248 self.get_mut()?.downcast_mut()
249 }
250}
251
252impl std::error::Error for Error {
253 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
254 match &self.context {
255 Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
256 Some(&**error)
257 }
258 _ => None,
259 }
260 }
261}
262
263impl From<ErrorKind> for Error {
264 fn from(kind: ErrorKind) -> Self {
265 Self {
266 context: Context::Simple(kind),
267 }
268 }
269}
270
271impl From<std::io::Error> for Error {
272 fn from(error: std::io::Error) -> Self {
273 Self::new(ErrorKind::Io, error)
274 }
275}
276
277impl From<base64::DecodeError> for Error {
278 fn from(error: base64::DecodeError) -> Self {
279 Self::new(ErrorKind::DataConversion, error)
280 }
281}
282
283impl From<serde_json::Error> for Error {
284 fn from(error: serde_json::Error) -> Self {
285 Self::new(ErrorKind::DataConversion, error)
286 }
287}
288
289impl From<std::string::FromUtf8Error> for Error {
290 fn from(error: std::string::FromUtf8Error) -> Self {
291 Self::new(ErrorKind::DataConversion, error)
292 }
293}
294
295impl From<std::str::Utf8Error> for Error {
296 fn from(error: std::str::Utf8Error) -> Self {
297 Self::new(ErrorKind::DataConversion, error)
298 }
299}
300
301impl From<url::ParseError> for Error {
302 fn from(error: url::ParseError) -> Self {
303 Self::new(ErrorKind::DataConversion, error)
304 }
305}
306
307impl Display for Error {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match &self.context {
310 Context::Simple(kind) => write!(f, "{kind}"),
311 Context::Message { message, .. } => write!(f, "{message}"),
312 Context::Custom(Custom { error, .. }) => write!(f, "{error}"),
313 Context::Full(_, message) => {
314 write!(f, "{message}")
315 }
316 }
317 }
318}
319
320pub trait ResultExt<T>: private::Sealed {
324 fn map_kind(self, kind: ErrorKind) -> Result<T>
326 where
327 Self: Sized;
328
329 fn context<C>(self, kind: ErrorKind, message: C) -> Result<T>
331 where
332 Self: Sized,
333 C: Into<Cow<'static, str>>;
334
335 fn with_context<F, C>(self, kind: ErrorKind, f: F) -> Result<T>
337 where
338 Self: Sized,
339 F: FnOnce() -> C,
340 C: Into<Cow<'static, str>>;
341}
342
343mod private {
344 pub trait Sealed {}
345
346 impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error + Send + Sync + 'static {}
347}
348
349impl<T, E> ResultExt<T> for std::result::Result<T, E>
350where
351 E: std::error::Error + Send + Sync + 'static,
352{
353 fn map_kind(self, kind: ErrorKind) -> Result<T>
354 where
355 Self: Sized,
356 {
357 self.map_err(|e| Error::new(kind, e))
358 }
359
360 fn context<C>(self, kind: ErrorKind, message: C) -> Result<T>
361 where
362 Self: Sized,
363 C: Into<Cow<'static, str>>,
364 {
365 self.map_err(|e| Error {
366 context: Context::Full(
367 Custom {
368 error: Box::new(e),
369 kind,
370 },
371 message.into(),
372 ),
373 })
374 }
375
376 fn with_context<F, C>(self, kind: ErrorKind, f: F) -> Result<T>
377 where
378 Self: Sized,
379 F: FnOnce() -> C,
380 C: Into<Cow<'static, str>>,
381 {
382 self.context(kind, f())
383 }
384}
385
386#[derive(Debug)]
387enum Context {
388 Simple(ErrorKind),
389 Message {
390 kind: ErrorKind,
391 message: Cow<'static, str>,
392 },
393 Custom(Custom),
394 Full(Custom, Cow<'static, str>),
395}
396
397#[derive(Debug)]
398struct Custom {
399 kind: ErrorKind,
400 error: Box<dyn std::error::Error + Send + Sync>,
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use std::io;
407
408 #[allow(
409 dead_code,
410 unconditional_recursion,
411 clippy::extra_unused_type_parameters
412 )]
413 fn ensure_send<T: Send>() {
414 ensure_send::<Error>();
415 }
416
417 #[derive(thiserror::Error, Debug)]
418 enum IntermediateError {
419 #[error("second error")]
420 Io(#[from] std::io::Error),
421 }
422
423 fn create_error() -> Error {
424 let inner = io::Error::new(io::ErrorKind::BrokenPipe, "third error");
426 let inner: IntermediateError = inner.into();
427 let inner = io::Error::new(io::ErrorKind::ConnectionAborted, inner);
428
429 Error::new(ErrorKind::Io, inner)
431 }
432
433 #[test]
434 fn errors_display_properly() {
435 let error = create_error();
436
437 let mut error: &dyn std::error::Error = &error;
439 let display = format!("{error}");
440 let mut errors = vec![];
441 while let Some(cause) = error.source() {
442 errors.push(format!("{cause}"));
443 error = cause;
444 }
445
446 assert_eq!(display, "second error");
447 assert_eq!(errors.join(","), "second error,third error");
448
449 let inner = io::Error::new(io::ErrorKind::BrokenPipe, "third error");
450 let error: Result<()> = std::result::Result::<(), std::io::Error>::Err(inner)
451 .context(ErrorKind::Io, "oh no broken pipe!");
452 assert_eq!(format!("{}", error.unwrap_err()), "oh no broken pipe!");
453 }
454
455 #[test]
456 fn downcasting_works() {
457 let error = &create_error() as &dyn std::error::Error;
458 assert!(error.is::<Error>());
459 let downcasted = error
460 .source()
461 .unwrap()
462 .downcast_ref::<std::io::Error>()
463 .unwrap();
464 assert_eq!(format!("{downcasted}"), "second error");
465 }
466
467 #[test]
468 fn turn_into_inner_error() {
469 let error = create_error();
470 let inner = error.into_inner().unwrap();
471 let inner = inner.downcast_ref::<std::io::Error>().unwrap();
472 assert_eq!(format!("{inner}"), "second error");
473
474 let error = create_error();
475 let inner = error.get_ref().unwrap();
476 let inner = inner.downcast_ref::<std::io::Error>().unwrap();
477 assert_eq!(format!("{inner}"), "second error");
478
479 let mut error = create_error();
480 let inner = error.get_mut().unwrap();
481 let inner = inner.downcast_ref::<std::io::Error>().unwrap();
482 assert_eq!(format!("{inner}"), "second error");
483 }
484
485 #[test]
486 fn matching_against_http_error() {
487 let kind =
488 ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &Headers::new(), b"{}");
489
490 assert!(matches!(
491 kind,
492 ErrorKind::HttpResponse {
493 status: StatusCode::ImATeapot,
494 error_code: None
495 }
496 ));
497
498 let kind = ErrorKind::http_response_from_parts(
499 StatusCode::ImATeapot,
500 &Headers::new(),
501 br#"{"error": {"code":"teepot"}}"#,
502 );
503
504 assert!(matches!(
505 kind,
506 ErrorKind::HttpResponse {
507 status: StatusCode::ImATeapot,
508 error_code
509 }
510 if error_code.as_deref() == Some("teepot")
511 ));
512
513 let mut headers = Headers::new();
514 headers.insert(headers::ERROR_CODE, "teapot");
515 let kind = ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &headers, br#"{}"#);
516
517 assert!(matches!(
518 kind,
519 ErrorKind::HttpResponse {
520 status: StatusCode::ImATeapot,
521 error_code
522 }
523 if error_code.as_deref() == Some("teapot")
524 ));
525 }
526
527 #[test]
528 fn set_result_kind() {
529 let result = std::result::Result::<(), _>::Err(create_error());
530 let result = result.map_kind(ErrorKind::Io);
531 assert_eq!(&ErrorKind::Io, result.unwrap_err().kind());
532 }
533}