Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ def __init__(self, manager, app, router, **kwargs):
self._manager = manager
self._app = app
self._router = router
self._middlewares = app.middlewares

def connection_made(self, transport):
super().connection_made(transport)
Expand All @@ -1188,7 +1189,8 @@ def connection_lost(self, exc):
def handle_request(self, message, payload):
now = self._loop.time()

request = Request(self._app, message, payload,
app = self._app
request = Request(app, message, payload,
self.transport, self.writer, self.keep_alive_timeout)
try:
match_info = yield from self._router.resolve(request)
Expand All @@ -1198,7 +1200,10 @@ def handle_request(self, message, payload):
request._match_info = match_info
handler = match_info.handler

for factory in reversed(self._middlewares):
handler = yield from factory(app, handler)
resp = yield from handler(request)

if not isinstance(resp, StreamResponse):
raise RuntimeError(
("Handler should return response instance, got {!r}")
Expand Down Expand Up @@ -1273,8 +1278,8 @@ def __call__(self):
class Application(dict):

def __init__(self, *, logger=web_logger, loop=None,
router=None, handler_factory=RequestHandlerFactory, **kwargs):
# TODO: explicitly accept *debug* param
router=None, handler_factory=RequestHandlerFactory,
middlewares=(), **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
if router is None:
Expand All @@ -1288,6 +1293,9 @@ def __init__(self, *, logger=web_logger, loop=None,
self.logger = logger

self.update(**kwargs)
for factory in middlewares:
assert asyncio.iscoroutinefunction(factory), factory
self._middlewares = tuple(middlewares)

@property
def router(self):
Expand All @@ -1297,6 +1305,10 @@ def router(self):
def loop(self):
return self._loop

@property
def middlewares(self):
return self._middlewares

def make_handler(self, **kwargs):
return self._handler_factory(
self, self.router, loop=self.loop, **kwargs)
Expand Down
118 changes: 118 additions & 0 deletions tests/test_web_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import asyncio
import socket
import unittest
from aiohttp import web, request


class TestWebFunctional(unittest.TestCase):

def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()

def find_unused_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()
return port

@asyncio.coroutine
def create_server(self, method, path, handler, *middlewares):
app = web.Application(loop=self.loop, middlewares=middlewares)
app.router.add_route(method, path, handler)

port = self.find_unused_port()
srv = yield from self.loop.create_server(
app.make_handler(debug=True), '127.0.0.1', port)
url = "http://127.0.0.1:{}".format(port) + path
self.addCleanup(srv.close)
return app, srv, url

def test_middleware_modifies_response(self):

@asyncio.coroutine
def handler(request):
return web.Response(body=b'OK')

@asyncio.coroutine
def middleware_factory(app, handler):
def middleware(request):
resp = yield from handler(request)
self.assertEqual(200, resp.status)
resp.set_status(201)
resp.text = resp.text + '[MIDDLEWARE]'
return resp
return middleware

@asyncio.coroutine
def go():
_, _, url = yield from self.create_server('GET', '/', handler,
middleware_factory)
resp = yield from request('GET', url, loop=self.loop)
self.assertEqual(201, resp.status)
txt = yield from resp.text()
self.assertEqual('OK[MIDDLEWARE]', txt)

self.loop.run_until_complete(go())

def test_middleware_handles_exception(self):

@asyncio.coroutine
def handler(request):
raise RuntimeError('Error text')

@asyncio.coroutine
def middleware_factory(app, handler):
def middleware(request):
with self.assertRaises(RuntimeError) as ctx:
yield from handler(request)
return web.Response(status=501,
text=str(ctx.exception) + '[MIDDLEWARE]')

return middleware

@asyncio.coroutine
def go():
_, _, url = yield from self.create_server('GET', '/', handler,
middleware_factory)
resp = yield from request('GET', url, loop=self.loop)
self.assertEqual(501, resp.status)
txt = yield from resp.text()
self.assertEqual('Error text[MIDDLEWARE]', txt)

self.loop.run_until_complete(go())

def test_middleware_chain(self):

@asyncio.coroutine
def handler(request):
return web.Response(text='OK')

def make_factory(num):

@asyncio.coroutine
def factory(app, handler):

def middleware(request):
resp = yield from handler(request)
resp.text = resp.text + '[{}]'.format(num)
return resp

return middleware
return factory

@asyncio.coroutine
def go():
_, _, url = yield from self.create_server('GET', '/', handler,
make_factory(1),
make_factory(2))
resp = yield from request('GET', url, loop=self.loop)
self.assertEqual(200, resp.status)
txt = yield from resp.text()
self.assertEqual('OK[2][1]', txt)

self.loop.run_until_complete(go())