Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions tower-http/src/auth/async_require_authorization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
//! Authorize requests using the [`Authorization`] header asynchronously.
//!
//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
//!
//! # Example
//!
//! ```
//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
//! use hyper::{Request, Response, Body, Error};
//! use http::{StatusCode, header::AUTHORIZATION};
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
//! use futures_util::future::BoxFuture;
//!
//! #[derive(Clone, Copy)]
//! struct MyAuth;
//!
//! impl AsyncAuthorizeRequest for MyAuth {
//! type Output = UserId;
//! type Future = BoxFuture<'static, Option<UserId>>;
//! type ResponseBody = Body;
//!
//! fn authorize<B>(&mut self, request: &Request<B>) -> Self::Future {
//! Box::pin(async {
//! // ...
//! # None
//! })
//! }
//!
//! fn on_authorized<B>(&mut self, request: &mut Request<B>, user_id: UserId) {
//! // Set `user_id` as a request extension so it can be accessed by other
//! // services down the stack.
//! request.extensions_mut().insert(user_id);
//! }
//!
//! fn unauthorized_response<B>(&mut self, request: &Request<B>) -> Response<Body> {
//! Response::builder()
//! .status(StatusCode::UNAUTHORIZED)
//! .body(Body::empty())
//! .unwrap()
//! }
//! }
//!
//! #[derive(Debug)]
//! struct UserId(String);
//!
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> {
//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
//! // request was authorized and `UserId` will be present.
//! let user_id = request
//! .extensions()
//! .get::<UserId>()
//! .expect("UserId will be there if request was authorized");
//!
//! println!("request from {:?}", user_id);
//!
//! Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let service = ServiceBuilder::new()
//! // Authorize requests using `MyAuth`
//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
//! .service_fn(handle);
//! # Ok(())
//! # }
//! ```

use futures_core::ready;
use http::{Request, Response};
use http_body::Body;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;

/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
/// [`Authorization`] header.
///
/// See the [module docs](crate::auth::async_require_authorization) for an example.
///
/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
#[derive(Debug, Clone)]
pub struct AsyncRequireAuthorizationLayer<T> {
auth: T,
}

impl<T> AsyncRequireAuthorizationLayer<T>
where
T: AsyncAuthorizeRequest,
{
/// Authorize requests using a custom scheme.
pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
Self { auth }
}
}

impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
where
T: Clone + AsyncAuthorizeRequest,
{
type Service = AsyncRequireAuthorization<S, T>;

fn layer(&self, inner: S) -> Self::Service {
AsyncRequireAuthorization::new(inner, self.auth.clone())
}
}

/// Middleware that authorizes all requests using the [`Authorization`] header.
///
/// See the [module docs](crate::auth::async_require_authorization) for an example.
///
/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
#[derive(Clone, Debug)]
pub struct AsyncRequireAuthorization<S, T> {
inner: S,
auth: T,
}

impl<S, T> AsyncRequireAuthorization<S, T> {
define_inner_service_accessors!();
}

impl<S, T> AsyncRequireAuthorization<S, T>
where
T: AsyncAuthorizeRequest,
{
/// Authorize requests using a custom scheme.
///
/// The `Authorization` header is required to have the value provided.
pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
Self { inner, auth }
}
}

impl<ReqBody, ResBody, S, T> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: Default,
T: AsyncAuthorizeRequest<ResponseBody = ResBody> + Clone,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = ResponseFuture<T, S, ReqBody>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let auth = self.auth.clone();
let inner = self.inner.clone();
let authorize = self.auth.authorize(&req);

ResponseFuture {
auth,
state: State::Authorize {
authorize,
req: Some(req),
},
service: inner,
}
}
}

#[pin_project(project = StateProj)]
enum State<A, ReqBody, SFut> {
Authorize {
#[pin]
authorize: A,
req: Option<Request<ReqBody>>,
},
Authorized {
#[pin]
fut: SFut,
},
}

/// Response future for [`AsyncRequireAuthorization`].
#[pin_project]
pub struct ResponseFuture<Auth, S, ReqBody>
where
Auth: AsyncAuthorizeRequest,
S: Service<Request<ReqBody>>,
{
auth: Auth,
#[pin]
state: State<Auth::Future, ReqBody, S::Future>,
service: S,
}

impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody>
where
Auth: AsyncAuthorizeRequest<ResponseBody = B>,
S: Service<Request<ReqBody>, Response = Response<B>>,
{
type Output = Result<Response<B>, S::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

loop {
match this.state.as_mut().project() {
StateProj::Authorize { authorize, req } => {
let auth = ready!(authorize.poll(cx));
let mut req = req.take().expect("future polled after completion");
match auth {
Some(output) => {
this.auth.on_authorized(&mut req, output);
let fut = this.service.call(req);
this.state.set(State::Authorized { fut })
}
None => {
let res = this.auth.unauthorized_response(&req);
return Poll::Ready(Ok(res));
}
};
}
StateProj::Authorized { fut } => {
return fut.poll(cx);
}
}
}
}
}

/// Trait for authorizing requests.
pub trait AsyncAuthorizeRequest {
/// The output type of doing the authorization.
///
/// Use `()` if authorization doesn't produce any meaningful output.
type Output;

/// The Future type returned by `authorize`
type Future: Future<Output = Option<Self::Output>>;

/// The body type used for responses to unauthorized requests.
type ResponseBody: Body;

/// Authorize the request.
///
/// If the future resolves to `Some(_)` then the request is allowed through, otherwise not.
fn authorize<B>(&mut self, request: &Request<B>) -> Self::Future;

/// Callback for when a request has been successfully authorized.
///
/// For example this allows you to save `Self::Output` in a [request extension][] to make it
/// available to services further down the stack. This could for example be the "claims" for a
/// valid [JWT].
///
/// Defaults to doing nothing.
///
/// See the [module docs](crate::auth::async_require_authorization) for an example.
///
/// [request extension]: https://docs.rs/http/latest/http/struct.Extensions.html
/// [JWT]: https://jwt.io
#[inline]
fn on_authorized<B>(&mut self, _request: &mut Request<B>, _output: Self::Output) {}

/// Create the response for an unauthorized request.
fn unauthorized_response<B>(&mut self, request: &Request<B>) -> Response<Self::ResponseBody>;
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use futures_util::future::BoxFuture;
use http::{header, StatusCode};
use hyper::Body;
use tower::{BoxError, ServiceBuilder, ServiceExt};

#[derive(Clone, Copy)]
struct MyAuth;

impl AsyncAuthorizeRequest for MyAuth {
type Output = UserId;
type Future = BoxFuture<'static, Option<UserId>>;
type ResponseBody = Body;

fn authorize<B>(&mut self, request: &Request<B>) -> Self::Future {
let authorized = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|it| it.to_str().ok())
.and_then(|it| it.strip_prefix("Bearer "))
.map(|it| it == "69420")
.unwrap_or(false);

Box::pin(async move {
if authorized {
Some(UserId(String::from("6969")))
} else {
None
}
})
}

fn on_authorized<B>(&mut self, request: &mut Request<B>, user_id: UserId) {
request.extensions_mut().insert(user_id);
}

fn unauthorized_response<B>(&mut self, _request: &Request<B>) -> Response<Body> {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap()
}
}

#[derive(Debug)]
struct UserId(String);

#[tokio::test]
async fn require_async_auth_works() {
let mut service = ServiceBuilder::new()
.layer(AsyncRequireAuthorizationLayer::new(MyAuth))
.service_fn(echo);

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer 69420")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn require_async_auth_401() {
let mut service = ServiceBuilder::new()
.layer(AsyncRequireAuthorizationLayer::new(MyAuth))
.service_fn(echo);

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer deez")
.body(Body::empty())
.unwrap();

let res = service.ready().await.unwrap().call(request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}
4 changes: 4 additions & 0 deletions tower-http/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
//! Authorization related middleware.

pub mod add_authorization;
pub mod async_require_authorization;
pub mod require_authorization;

#[doc(inline)]
pub use self::{
add_authorization::{AddAuthorization, AddAuthorizationLayer},
async_require_authorization::{
AsyncAuthorizeRequest, AsyncRequireAuthorization, AsyncRequireAuthorizationLayer,
},
require_authorization::{AuthorizeRequest, RequireAuthorization, RequireAuthorizationLayer},
};