Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased

- **added:** Implement `OptionalFromRequest` for `Json` ([#3142])
- **added:** Implement `OptionalFromRequest` for `Extension` ([#3157])

[#3142]: https://github.com/tokio-rs/axum/pull/3142
[#3157]: https://github.com/tokio-rs/axum/pull/3157

# 0.8.0

Expand Down
106 changes: 93 additions & 13 deletions axum/src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::{extract::rejection::*, response::IntoResponseParts};
use axum_core::extract::OptionalFromRequestParts;
use axum_core::{
extract::FromRequestParts,
response::{IntoResponse, Response, ResponseParts},
};
use http::{request::Parts, Request};
use http::{request::Parts, Extensions, Request};
use std::{
convert::Infallible,
task::{Context, Poll},
Expand Down Expand Up @@ -43,7 +44,8 @@ use tower_service::Service;
/// ```
///
/// If the extension is missing it will reject the request with a `500 Internal
/// Server Error` response.
/// Server Error` response. Alternatively, you can use `Option<Extension<T>>` to
/// make the extension extractor optional.
///
/// # As response
///
Expand All @@ -69,6 +71,15 @@ use tower_service::Service;
#[must_use]
pub struct Extension<T>(pub T);

impl<T> Extension<T>
where
T: Clone + Send + Sync + 'static,
{
fn from_extensions(extensions: &Extensions) -> Option<Self> {
extensions.get::<T>().cloned().map(Extension)
}
}

impl<T, S> FromRequestParts<S> for Extension<T>
where
T: Clone + Send + Sync + 'static,
Expand All @@ -77,17 +88,27 @@ where
type Rejection = ExtensionRejection;

async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let value = req
.extensions
.get::<T>()
.ok_or_else(|| {
MissingExtension::from_err(format!(
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
std::any::type_name::<T>()
))
}).cloned()?;

Ok(Extension(value))
Ok(Self::from_extensions(&req.extensions).ok_or_else(|| {
MissingExtension::from_err(format!(
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
std::any::type_name::<T>()
))
})?)
}
}

impl<T, S> OptionalFromRequestParts<S> for Extension<T>
where
T: Clone + Send + Sync + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(
req: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(Self::from_extensions(&req.extensions))
}
}

Expand Down Expand Up @@ -161,3 +182,62 @@ where
self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::routing::get;
use crate::test_helpers::TestClient;
use crate::Router;
use http::StatusCode;

#[derive(Clone)]
struct Foo(String);

#[derive(Clone)]
struct Bar(String);

#[crate::test]
async fn extension_extractor() {
async fn requires_foo(Extension(foo): Extension<Foo>) -> String {
foo.0
}

async fn optional_foo(extension: Option<Extension<Foo>>) -> String {
extension.map(|foo| foo.0 .0).unwrap_or("none".to_owned())
}

async fn requires_bar(Extension(bar): Extension<Bar>) -> String {
bar.0
}

async fn optional_bar(extension: Option<Extension<Bar>>) -> String {
extension.map(|bar| bar.0 .0).unwrap_or("none".to_owned())
}

let app = Router::new()
.route("/requires_foo", get(requires_foo))
.route("/optional_foo", get(optional_foo))
.route("/requires_bar", get(requires_bar))
.route("/optional_bar", get(optional_bar))
.layer(Extension(Foo("foo".to_owned())));

let client = TestClient::new(app);

let response = client.get("/requires_foo").await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text().await, "foo");

let response = client.get("/optional_foo").await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text().await, "foo");

let response = client.get("/requires_bar").await;
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.text().await, "Missing request extension: Extension of type `axum::extension::tests::Bar` was not found. Perhaps you forgot to add it? See `axum::Extension`.");

let response = client.get("/optional_bar").await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text().await, "none");
}
}