Skip to content
51 changes: 51 additions & 0 deletions aiohttp/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def set_parser(self, parser):
import asyncio
import asyncio.streams
import inspect
import socket
from . import errors
from .streams import FlowControlDataQueue, EofStream

Expand All @@ -67,6 +68,13 @@ def set_parser(self, parser):

DEFAULT_LIMIT = 2 ** 16

if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None


class StreamParser:
"""StreamParser manages incoming bytes stream and protocol parsers.
Expand Down Expand Up @@ -224,6 +232,49 @@ def __init__(self, transport, protocol, reader, loop):
self._protocol = protocol
self._reader = reader
self._loop = loop
self._tcp_nodelay = False
self._tcp_cork = False
self._socket = transport.get_extra_info('socket')

@property
def tcp_nodelay(self):
return self._tcp_nodelay

def set_tcp_nodelay(self, value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why not setter fortcp_nodelay property?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about property's setter but found that syscall worth explicit function call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if value not in (1,0):
    raise ValueError('.....')

value = bool(value)
if self._tcp_nodelay == value:
return
self._tcp_nodelay = value
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return
if self._tcp_cork:
self._tcp_cork = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)

@property
def tcp_cork(self):
return self._tcp_cork

def set_tcp_cork(self, value):
value = bool(value)
if self._tcp_cork == value:
return
self._tcp_cork = value
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return
if self._tcp_nodelay:
self._socket.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY,
False)
self._tcp_nodelay = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)


class StreamProtocol(asyncio.streams.FlowControlMixin, asyncio.Protocol):
Expand Down
36 changes: 36 additions & 0 deletions aiohttp/web_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def __init__(self, *, status=200, reason=None, headers=None):
self._req = None
self._resp_impl = None
self._eof_sent = False
self._tcp_nodelay = True
self._tcp_cork = False

if headers is not None:
self._headers.extend(headers)
Expand Down Expand Up @@ -604,6 +606,36 @@ def last_modified(self, value):
elif isinstance(value, str):
self.headers[hdrs.LAST_MODIFIED] = value

@property
def tcp_nodelay(self):
return self._tcp_nodelay

def set_tcp_nodelay(self, value):
value = bool(value)
self._tcp_nodelay = value
if value:
self._tcp_cork = False
if self._resp_impl is None:
return
if value:
self._resp_impl.transport.set_tcp_cork(False)
self._resp_impl.transport.set_tcp_nodelay(value)

@property
def tcp_cork(self):
return self._tcp_cork

def set_tcp_cork(self, value):
value = bool(value)
self._tcp_cork = value
if value:
self._tcp_nodelay = False
if self._resp_impl is None:
return
if value:
self._resp_impl.transport.set_tcp_nodelay(False)
self._resp_impl.transport.set_tcp_cork(value)

def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
params = '; '.join("%s=%s" % i for i in self._content_dict.items())
if params:
Expand Down Expand Up @@ -669,6 +701,8 @@ def _start(self, request):
request.version,
not keep_alive,
self._reason)
resp_impl.transport.set_tcp_nodelay(self._tcp_nodelay)
resp_impl.transport.set_tcp_cork(self._tcp_cork)

self._copy_cookies()

Expand Down Expand Up @@ -736,6 +770,7 @@ def __init__(self, *, body=None, status=200,
reason=None, text=None, headers=None, content_type=None,
charset=None):
super().__init__(status=status, reason=reason, headers=headers)
self.set_tcp_cork(True)

if body is not None and text is not None:
raise ValueError("body and text are not allowed together.")
Expand Down Expand Up @@ -815,6 +850,7 @@ def write_eof(self):
body = self._body
if body is not None:
self.write(body)
self.set_tcp_nodelay(True)
yield from super().write_eof()


Expand Down
3 changes: 3 additions & 0 deletions aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,14 @@ def handle(self, request):
file_size = st.st_size

resp.content_length = file_size
resp.set_tcp_cork(True)
yield from resp.prepare(request)

with open(filepath, 'rb') as f:
yield from self._sendfile(request, resp, f, file_size)

resp.set_tcp_nodelay(True)

return resp

def __repr__(self):
Expand Down
217 changes: 217 additions & 0 deletions tests/test_stream_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import pytest
import socket
from aiohttp.parsers import StreamWriter, CORK
from unittest import mock


# nodelay

def test_nodelay_default(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
assert not writer.tcp_nodelay
assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)


def test_set_nodelay_no_change(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(False)
assert not writer.tcp_nodelay
assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)


def test_set_nodelay_enable(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(True)
assert writer.tcp_nodelay
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)


def test_set_nodelay_enable_and_disable(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(True)
writer.set_tcp_nodelay(False)
assert not writer.tcp_nodelay
assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)


def test_set_nodelay_enable_ipv6(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(True)
assert writer.tcp_nodelay
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="requires unix sockets")
def test_set_nodelay_enable_unix(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(True)
assert writer.tcp_nodelay


def test_set_nodelay_enable_no_socket(loop):
transport = mock.Mock()
transport.get_extra_info.return_value = None
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(True)
assert writer.tcp_nodelay
assert writer._socket is None


# cork

@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_cork_default(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
assert not writer.tcp_cork
assert not s.getsockopt(socket.IPPROTO_TCP, CORK)


@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_cork_no_change(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(False)
assert not writer.tcp_cork
assert not s.getsockopt(socket.IPPROTO_TCP, CORK)


@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_cork_enable(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(True)
assert writer.tcp_cork
assert s.getsockopt(socket.IPPROTO_TCP, CORK)


@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_cork_enable_and_disable(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(True)
writer.set_tcp_cork(False)
assert not writer.tcp_cork
assert not s.getsockopt(socket.IPPROTO_TCP, CORK)


@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_cork_enable_ipv6(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(True)
assert writer.tcp_cork
assert s.getsockopt(socket.IPPROTO_TCP, CORK)


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="requires unix sockets")
@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_cork_enable_unix(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(True)
assert writer.tcp_cork


@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_cork_enable_no_socket(loop):
transport = mock.Mock()
transport.get_extra_info.return_value = None
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(True)
assert writer.tcp_cork
assert writer._socket is None


# cork and nodelay interference

@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_enabling_cork_disables_nodelay(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_nodelay(True)
writer.set_tcp_cork(True)
assert not writer.tcp_nodelay
assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
assert writer.tcp_cork
assert s.getsockopt(socket.IPPROTO_TCP, CORK)


@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
def test_set_enabling_nodelay_disables_cork(loop):
transport = mock.Mock()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
transport.get_extra_info.return_value = s
proto = mock.Mock()
reader = mock.Mock()
writer = StreamWriter(transport, proto, reader, loop)
writer.set_tcp_cork(True)
writer.set_tcp_nodelay(True)
assert writer.tcp_nodelay
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
assert not writer.tcp_cork
assert not s.getsockopt(socket.IPPROTO_TCP, CORK)
Loading