Skip to content

Commit 50beeaf

Browse files
authored
Add support for custom status code in TimeoutLayer (#599)
1 parent 35740de commit 50beeaf

File tree

3 files changed

+142
-17
lines changed

3 files changed

+142
-17
lines changed

examples/axum-key-value-store/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ fn app() -> Router {
7676
)
7777
.sensitive_response_headers(sensitive_headers)
7878
// Set a timeout
79-
.layer(TimeoutLayer::new(Duration::from_secs(10)))
79+
.layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(10)))
8080
// Compress responses
8181
.compression()
8282
// Set a `Content-Type` if there isn't one already.

tower-http/src/timeout/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
//! Middleware that applies a timeout to requests.
22
//!
3-
//! If the request does not complete within the specified timeout it will be aborted and a `408
4-
//! Request Timeout` response will be sent.
3+
//! If the request does not complete within the specified timeout, it will be aborted and a
4+
//! response with an empty body and a custom status code will be returned.
55
//!
66
//! # Differences from `tower::timeout`
77
//!
88
//! tower's [`Timeout`](tower::timeout::Timeout) middleware uses an error to signal timeout, i.e.
99
//! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely
1010
//! what you want as returning errors will terminate the connection without sending a response.
1111
//!
12-
//! This middleware won't change the error type and instead return a `408 Request Timeout`
13-
//! response. That means if your service's error type is [`Infallible`] it will still be
14-
//! [`Infallible`] after applying this middleware.
12+
//! This middleware won't change the error type and instead returns a response with an empty body
13+
//! and the specified status code. That means if your service's error type is [`Infallible`], it will
14+
//! still be [`Infallible`] after applying this middleware.
1515
//!
1616
//! # Example
1717
//!
1818
//! ```
19-
//! use http::{Request, Response};
19+
//! use http::{Request, Response, StatusCode};
2020
//! use http_body_util::Full;
2121
//! use bytes::Bytes;
2222
//! use std::{convert::Infallible, time::Duration};
@@ -31,8 +31,8 @@
3131
//! # #[tokio::main]
3232
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
3333
//! let svc = ServiceBuilder::new()
34-
//! // Timeout requests after 30 seconds
35-
//! .layer(TimeoutLayer::new(Duration::from_secs(30)))
34+
//! // Timeout requests after 30 seconds with the specified status code
35+
//! .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(30)))
3636
//! .service_fn(handle);
3737
//! # Ok(())
3838
//! # }

tower-http/src/timeout/service.rs

Lines changed: 133 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,81 @@ use tower_service::Service;
1717
#[derive(Debug, Clone, Copy)]
1818
pub struct TimeoutLayer {
1919
timeout: Duration,
20+
status_code: StatusCode,
2021
}
2122

2223
impl TimeoutLayer {
2324
/// Creates a new [`TimeoutLayer`].
25+
///
26+
/// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
27+
/// To customize the response status code, use the `with_status_code` method.
28+
#[deprecated(since = "0.6.7", note = "Use `TimeoutLayer::with_status_code` instead")]
2429
pub fn new(timeout: Duration) -> Self {
25-
TimeoutLayer { timeout }
30+
Self::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
31+
}
32+
33+
/// Creates a new [`TimeoutLayer`] with the specified status code for the timeout response.
34+
pub fn with_status_code(status_code: StatusCode, timeout: Duration) -> Self {
35+
Self {
36+
timeout,
37+
status_code,
38+
}
2639
}
2740
}
2841

2942
impl<S> Layer<S> for TimeoutLayer {
3043
type Service = Timeout<S>;
3144

3245
fn layer(&self, inner: S) -> Self::Service {
33-
Timeout::new(inner, self.timeout)
46+
Timeout::with_status_code(inner, self.status_code, self.timeout)
3447
}
3548
}
3649

3750
/// Middleware which apply a timeout to requests.
3851
///
39-
/// If the request does not complete within the specified timeout it will be aborted and a `408
40-
/// Request Timeout` response will be sent.
41-
///
4252
/// See the [module docs](super) for an example.
4353
#[derive(Debug, Clone, Copy)]
4454
pub struct Timeout<S> {
4555
inner: S,
4656
timeout: Duration,
57+
status_code: StatusCode,
4758
}
4859

4960
impl<S> Timeout<S> {
5061
/// Creates a new [`Timeout`].
62+
///
63+
/// By default, it will return a `408 Request Timeout` response if the request does not complete within the specified timeout.
64+
/// To customize the response status code, use the `with_status_code` method.
65+
#[deprecated(since = "0.6.7", note = "Use `Timeout::with_status_code` instead")]
5166
pub fn new(inner: S, timeout: Duration) -> Self {
52-
Self { inner, timeout }
67+
Self::with_status_code(inner, StatusCode::REQUEST_TIMEOUT, timeout)
68+
}
69+
70+
/// Creates a new [`Timeout`] with the specified status code for the timeout response.
71+
pub fn with_status_code(inner: S, status_code: StatusCode, timeout: Duration) -> Self {
72+
Self {
73+
inner,
74+
timeout,
75+
status_code,
76+
}
5377
}
5478

5579
define_inner_service_accessors!();
5680

5781
/// Returns a new [`Layer`] that wraps services with a `Timeout` middleware.
5882
///
5983
/// [`Layer`]: tower_layer::Layer
84+
#[deprecated(
85+
since = "0.6.7",
86+
note = "Use `Timeout::layer_with_status_code` instead"
87+
)]
6088
pub fn layer(timeout: Duration) -> TimeoutLayer {
61-
TimeoutLayer::new(timeout)
89+
TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout)
90+
}
91+
92+
/// Returns a new [`Layer`] that wraps services with a `Timeout` middleware with the specified status code.
93+
pub fn layer_with_status_code(status_code: StatusCode, timeout: Duration) -> TimeoutLayer {
94+
TimeoutLayer::with_status_code(status_code, timeout)
6295
}
6396
}
6497

@@ -81,6 +114,7 @@ where
81114
ResponseFuture {
82115
inner: self.inner.call(req),
83116
sleep,
117+
status_code: self.status_code,
84118
}
85119
}
86120
}
@@ -92,6 +126,7 @@ pin_project! {
92126
inner: F,
93127
#[pin]
94128
sleep: Sleep,
129+
status_code: StatusCode,
95130
}
96131
}
97132

@@ -107,7 +142,7 @@ where
107142

108143
if this.sleep.poll(cx).is_ready() {
109144
let mut res = Response::new(B::default());
110-
*res.status_mut() = StatusCode::REQUEST_TIMEOUT;
145+
*res.status_mut() = *this.status_code;
111146
return Poll::Ready(Ok(res));
112147
}
113148

@@ -269,3 +304,93 @@ where
269304
Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body))))
270305
}
271306
}
307+
308+
#[cfg(test)]
309+
mod tests {
310+
use super::*;
311+
use crate::test_helpers::Body;
312+
use http::{Request, Response, StatusCode};
313+
use std::time::Duration;
314+
use tower::{BoxError, ServiceBuilder, ServiceExt};
315+
316+
#[tokio::test]
317+
async fn request_completes_within_timeout() {
318+
let mut service = ServiceBuilder::new()
319+
.layer(TimeoutLayer::with_status_code(
320+
StatusCode::GATEWAY_TIMEOUT,
321+
Duration::from_secs(1),
322+
))
323+
.service_fn(fast_handler);
324+
325+
let request = Request::get("/").body(Body::empty()).unwrap();
326+
let res = service.ready().await.unwrap().call(request).await.unwrap();
327+
328+
assert_eq!(res.status(), StatusCode::OK);
329+
}
330+
331+
#[tokio::test]
332+
async fn timeout_middleware_with_custom_status_code() {
333+
let timeout_service = Timeout::with_status_code(
334+
tower::service_fn(slow_handler),
335+
StatusCode::REQUEST_TIMEOUT,
336+
Duration::from_millis(10),
337+
);
338+
339+
let mut service = ServiceBuilder::new().service(timeout_service);
340+
341+
let request = Request::get("/").body(Body::empty()).unwrap();
342+
let res = service.ready().await.unwrap().call(request).await.unwrap();
343+
344+
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
345+
}
346+
347+
#[tokio::test]
348+
async fn timeout_response_has_empty_body() {
349+
let mut service = ServiceBuilder::new()
350+
.layer(TimeoutLayer::with_status_code(
351+
StatusCode::GATEWAY_TIMEOUT,
352+
Duration::from_millis(10),
353+
))
354+
.service_fn(slow_handler);
355+
356+
let request = Request::get("/").body(Body::empty()).unwrap();
357+
let res = service.ready().await.unwrap().call(request).await.unwrap();
358+
359+
assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT);
360+
361+
// Verify the body is empty (default)
362+
use http_body_util::BodyExt;
363+
let body = res.into_body();
364+
let bytes = body.collect().await.unwrap().to_bytes();
365+
assert!(bytes.is_empty());
366+
}
367+
368+
#[tokio::test]
369+
async fn deprecated_new_method_compatibility() {
370+
#[allow(deprecated)]
371+
let layer = TimeoutLayer::new(Duration::from_millis(10));
372+
373+
let mut service = ServiceBuilder::new().layer(layer).service_fn(slow_handler);
374+
375+
let request = Request::get("/").body(Body::empty()).unwrap();
376+
let res = service.ready().await.unwrap().call(request).await.unwrap();
377+
378+
// Should use default 408 status code
379+
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
380+
}
381+
382+
async fn slow_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
383+
tokio::time::sleep(Duration::from_secs(10)).await;
384+
Ok(Response::builder()
385+
.status(StatusCode::OK)
386+
.body(Body::empty())
387+
.unwrap())
388+
}
389+
390+
async fn fast_handler(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
391+
Ok(Response::builder()
392+
.status(StatusCode::OK)
393+
.body(Body::empty())
394+
.unwrap())
395+
}
396+
}

0 commit comments

Comments
 (0)