use crate::SdkError;
use aws_smithy_async::future::timeout::Timeout;
use aws_smithy_async::rt::sleep::{AsyncSleep, Sleep};
use aws_smithy_http::operation::Operation;
use aws_smithy_types::timeout::OperationTimeoutConfig;
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::Layer;
#[derive(Debug)]
struct RequestTimeoutError {
kind: &'static str,
duration: Duration,
}
impl RequestTimeoutError {
fn new(kind: &'static str, duration: Duration) -> Self {
Self { kind, duration }
}
}
impl std::fmt::Display for RequestTimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} occurred after {:?}", self.kind, self.duration)
}
}
impl std::error::Error for RequestTimeoutError {}
#[derive(Clone, Debug)]
pub struct TimeoutServiceParams {
duration: Duration,
kind: &'static str,
async_sleep: Arc<dyn AsyncSleep>,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct ClientTimeoutParams {
pub(crate) operation_timeout: Option<TimeoutServiceParams>,
pub(crate) operation_attempt_timeout: Option<TimeoutServiceParams>,
}
impl ClientTimeoutParams {
pub(crate) fn new(
timeout_config: &OperationTimeoutConfig,
async_sleep: Option<Arc<dyn AsyncSleep>>,
) -> Self {
if let Some(async_sleep) = async_sleep {
Self {
operation_timeout: timeout_config.operation_timeout().map(|duration| {
TimeoutServiceParams {
duration,
kind: "operation timeout (all attempts including retries)",
async_sleep: async_sleep.clone(),
}
}),
operation_attempt_timeout: timeout_config.operation_attempt_timeout().map(
|duration| TimeoutServiceParams {
duration,
kind: "operation attempt timeout (single attempt)",
async_sleep: async_sleep.clone(),
},
),
}
} else {
Default::default()
}
}
}
#[derive(Clone, Debug)]
pub struct TimeoutService<S> {
inner: S,
params: Option<TimeoutServiceParams>,
}
impl<S> TimeoutService<S> {
pub fn new(inner: S, params: Option<TimeoutServiceParams>) -> Self {
Self { inner, params }
}
pub fn no_timeout(inner: S) -> Self {
Self {
inner,
params: None,
}
}
}
#[non_exhaustive]
#[derive(Debug)]
pub struct TimeoutLayer(Option<TimeoutServiceParams>);
impl TimeoutLayer {
pub fn new(params: Option<TimeoutServiceParams>) -> Self {
TimeoutLayer(params)
}
}
impl<S> Layer<S> for TimeoutLayer {
type Service = TimeoutService<S>;
fn layer(&self, inner: S) -> Self::Service {
TimeoutService {
inner,
params: self.0.clone(),
}
}
}
pin_project! {
#[non_exhaustive]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(missing_docs)]
#[project = TimeoutServiceFutureProj]
pub enum TimeoutServiceFuture<F> {
Timeout {
#[pin]
future: Timeout<F, Sleep>,
kind: &'static str,
duration: Duration,
},
NoTimeout {
#[pin]
future: F
}
}
}
impl<F> TimeoutServiceFuture<F> {
pub fn new(future: F, params: &TimeoutServiceParams) -> Self {
Self::Timeout {
future: Timeout::new(future, params.async_sleep.sleep(params.duration)),
kind: params.kind,
duration: params.duration,
}
}
pub fn no_timeout(future: F) -> Self {
Self::NoTimeout { future }
}
}
impl<InnerFuture, T, E> Future for TimeoutServiceFuture<InnerFuture>
where
InnerFuture: Future<Output = Result<T, SdkError<E>>>,
{
type Output = Result<T, SdkError<E>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (future, kind, duration) = match self.project() {
TimeoutServiceFutureProj::NoTimeout { future } => return future.poll(cx),
TimeoutServiceFutureProj::Timeout {
future,
kind,
duration,
} => (future, kind, duration),
};
match future.poll(cx) {
Poll::Ready(Ok(response)) => Poll::Ready(response),
Poll::Ready(Err(_timeout)) => Poll::Ready(Err(SdkError::timeout_error(
RequestTimeoutError::new(kind, *duration),
))),
Poll::Pending => Poll::Pending,
}
}
}
impl<H, R, InnerService, E> tower::Service<Operation<H, R>> for TimeoutService<InnerService>
where
InnerService: tower::Service<Operation<H, R>, Error = SdkError<E>>,
{
type Response = InnerService::Response;
type Error = SdkError<E>;
type Future = TimeoutServiceFuture<InnerService::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Operation<H, R>) -> Self::Future {
let future = self.inner.call(req);
if let Some(params) = &self.params {
Self::Future::new(future, params)
} else {
Self::Future::no_timeout(future)
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::never::NeverService;
use crate::{SdkError, TimeoutLayer};
use aws_smithy_async::assert_elapsed;
use aws_smithy_async::rt::sleep::{AsyncSleep, TokioSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation::{Operation, Request};
use aws_smithy_types::timeout::TimeoutConfig;
use std::sync::Arc;
use std::time::Duration;
use tower::{Service, ServiceBuilder, ServiceExt};
#[tokio::test]
async fn test_timeout_service_ends_request_that_never_completes() {
let req = Request::new(http::Request::new(SdkBody::empty()));
let op = Operation::new(req, ());
let never_service: NeverService<_, (), _> = NeverService::new();
let timeout_config = OperationTimeoutConfig::from(
TimeoutConfig::builder()
.operation_timeout(Duration::from_secs_f32(0.25))
.build(),
);
let sleep_impl: Arc<dyn AsyncSleep> = Arc::new(TokioSleep::new());
let timeout_service_params = ClientTimeoutParams::new(&timeout_config, Some(sleep_impl));
let mut svc = ServiceBuilder::new()
.layer(TimeoutLayer::new(timeout_service_params.operation_timeout))
.service(never_service);
let now = tokio::time::Instant::now();
tokio::time::pause();
let err: SdkError<Box<dyn std::error::Error + 'static>> =
svc.ready().await.unwrap().call(op).await.unwrap_err();
assert_eq!(format!("{:?}", err), "TimeoutError(TimeoutError { source: RequestTimeoutError { kind: \"operation timeout (all attempts including retries)\", duration: 250ms } })");
assert_elapsed!(now, Duration::from_secs_f32(0.25));
}
}