Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
headers_parser=self._headers_parser,
)
if not payload_parser.done:
self._payload_parser = payload_parser
Expand All @@ -412,6 +413,7 @@
compression=msg.compression,
auto_decompress=self._auto_decompress,
lax=self.lax,
headers_parser=self._headers_parser,
)
elif not empty_body and length is None and self.read_until_eof:
payload = StreamReader(
Expand All @@ -430,6 +432,7 @@
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
headers_parser=self._headers_parser,
)
if not payload_parser.done:
self._payload_parser = payload_parser
Expand Down Expand Up @@ -758,6 +761,8 @@
response_with_body: bool = True,
auto_decompress: bool = True,
lax: bool = False,
*,
headers_parser: HeadersParser,
) -> None:
self._length = 0
self._type = ParseState.PARSE_UNTIL_EOF
Expand All @@ -766,6 +771,9 @@
self._chunk_tail = b""
self._auto_decompress = auto_decompress
self._lax = lax
self._headers_parser = headers_parser
# HeadersParser expects status/request line first, so skips the first line.
self._trailer_lines: list[bytes] = [b""]
self.done = False

# payload decompression wrapper
Expand Down Expand Up @@ -854,7 +862,7 @@

chunk = chunk[pos + len(SEP) :]
if size == 0: # eof marker
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
self._chunk = ChunkState.PARSE_TRAILERS
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
else:
Expand Down Expand Up @@ -888,38 +896,31 @@
self._chunk_tail = chunk
return False, b""

# if stream does not contain trailer, after 0\r\n
# we should get another \r\n otherwise
# trailers needs to be skipped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
head = chunk[: len(SEP)]
if head == SEP:
# end of stream
self.payload.feed_eof()
return True, chunk[len(SEP) :]
# Both CR and LF, or only LF may not be received yet. It is
# expected that CRLF or LF will be shown at the very first
# byte next time, otherwise trailers should come. The last
# CRLF which marks the end of response might not be
# contained in the same TCP segment which delivered the
# size indicator.
if not head:
return False, b""
if head == SEP[:1]:
self._chunk_tail = head
return False, b""
self._chunk = ChunkState.PARSE_TRAILERS

# read and discard trailer up to the CRLF terminator
if self._chunk == ChunkState.PARSE_TRAILERS:
pos = chunk.find(SEP)
if pos >= 0:
chunk = chunk[pos + len(SEP) :]
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
if pos < 0: # No line found
self._chunk_tail = chunk
return False, b""

line = chunk[:pos]
chunk = chunk[pos + len(SEP) :]
if SEP == b"\n": # For lax response parsing
line = line.rstrip(b"\r")
self._trailer_lines.append(line)

# \r\n\r\n found, end of stream
if self._trailer_lines[-1] == b"":
# Headers and trailers are defined the same way,
# so we reuse the HeadersParser here.
try:
trailers, raw_trailers = self._headers_parser.parse_headers(
self._trailer_lines
)
finally:
self._trailer_lines.clear()
self.payload.feed_eof()
return True, chunk

# Read all bytes until eof
elif self._type == ParseState.PARSE_UNTIL_EOF:
self.payload.feed_data(chunk)
Expand Down
76 changes: 58 additions & 18 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aiohttp.helpers import NO_EXTENSIONS
from aiohttp.http_parser import (
DeflateBuffer,
HeadersParser,
HttpParser,
HttpPayloadParser,
HttpRequestParser,
Expand Down Expand Up @@ -1354,6 +1355,25 @@ def test_parse_chunked_payload_chunk_extension(parser: HttpRequestParser) -> Non
assert payload.is_eof()


async def test_request_chunked_with_trailer(parser: HttpRequestParser) -> None:
text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\ntest: trailer\r\nsecond: test trailer\r\n\r\n"
messages, upgraded, tail = parser.feed_data(text)
assert not tail
msg, payload = messages[0]
assert await payload.read() == b"test"

# TODO: Add assertion of trailers when API added.


async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> None:
text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nbad\ntrailer\r\n\r\n"
messages, upgraded, tail = parser.feed_data(text)
assert not tail
msg, payload = messages[0]
with pytest.raises(http_exceptions.InvalidHeader, match=r"b'bad\\ntrailer'"):
await payload.read()


def test_parse_no_length_or_te_on_post(
loop: asyncio.AbstractEventLoop,
protocol: BaseProtocol,
Expand Down Expand Up @@ -1684,7 +1704,7 @@ def test_parse_bad_method_for_c_parser_raises(
class TestParsePayload:
async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out)
p = HttpPayloadParser(out, headers_parser=HeadersParser())
p.feed_data(b"data")
p.feed_eof()

Expand All @@ -1694,7 +1714,7 @@ async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None:
async def test_parse_length_payload_eof(self, protocol: BaseProtocol) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())

p = HttpPayloadParser(out, length=4)
p = HttpPayloadParser(out, length=4, headers_parser=HeadersParser())
p.feed_data(b"da")

with pytest.raises(http_exceptions.ContentLengthError):
Expand All @@ -1704,7 +1724,7 @@ async def test_parse_chunked_payload_size_error(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
with pytest.raises(http_exceptions.TransferEncodingError):
p.feed_data(b"blah\r\n")
assert isinstance(out.exception(), http_exceptions.TransferEncodingError)
Expand All @@ -1713,7 +1733,7 @@ async def test_parse_chunked_payload_split_end(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
p.feed_data(b"4\r\nasdf\r\n0\r\n")
p.feed_data(b"\r\n")

Expand All @@ -1724,7 +1744,7 @@ async def test_parse_chunked_payload_split_end2(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
p.feed_data(b"4\r\nasdf\r\n0\r\n\r")
p.feed_data(b"\n")

Expand All @@ -1735,7 +1755,7 @@ async def test_parse_chunked_payload_split_end_trailers(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
p.feed_data(b"4\r\nasdf\r\n0\r\n")
p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n")
p.feed_data(b"\r\n")
Expand All @@ -1747,7 +1767,7 @@ async def test_parse_chunked_payload_split_end_trailers2(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
p.feed_data(b"4\r\nasdf\r\n0\r\n")
p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r")
p.feed_data(b"\n")
Expand All @@ -1759,7 +1779,7 @@ async def test_parse_chunked_payload_split_end_trailers3(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
p.feed_data(b"4\r\nasdf\r\n0\r\nContent-MD5: ")
p.feed_data(b"912ec803b2ce49e4a541068d495ab570\r\n\r\n")

Expand All @@ -1770,7 +1790,7 @@ async def test_parse_chunked_payload_split_end_trailers4(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, chunked=True)
p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser())
p.feed_data(b"4\r\nasdf\r\n0\r\nC")
p.feed_data(b"ontent-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r\n")

Expand All @@ -1779,7 +1799,7 @@ async def test_parse_chunked_payload_split_end_trailers4(

async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=2)
p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser())
eof, tail = p.feed_data(b"1245")
assert eof

Expand All @@ -1792,7 +1812,9 @@ async def test_http_payload_parser_deflate(self, protocol: BaseProtocol) -> None

length = len(COMPRESSED)
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=length, compression="deflate")
p = HttpPayloadParser(
out, length=length, compression="deflate", headers_parser=HeadersParser()
)
p.feed_data(COMPRESSED)
assert b"data" == out._buffer[0]
assert out.is_eof()
Expand All @@ -1806,7 +1828,9 @@ async def test_http_payload_parser_deflate_no_hdrs(

length = len(COMPRESSED)
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=length, compression="deflate")
p = HttpPayloadParser(
out, length=length, compression="deflate", headers_parser=HeadersParser()
)
p.feed_data(COMPRESSED)
assert b"data" == out._buffer[0]
assert out.is_eof()
Expand All @@ -1819,7 +1843,9 @@ async def test_http_payload_parser_deflate_light(

length = len(COMPRESSED)
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=length, compression="deflate")
p = HttpPayloadParser(
out, length=length, compression="deflate", headers_parser=HeadersParser()
)
p.feed_data(COMPRESSED)

assert b"data" == out._buffer[0]
Expand All @@ -1829,7 +1855,9 @@ async def test_http_payload_parser_deflate_split(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, compression="deflate")
p = HttpPayloadParser(
out, compression="deflate", headers_parser=HeadersParser()
)
# Feeding one correct byte should be enough to choose exact
# deflate decompressor
p.feed_data(b"x")
Expand All @@ -1841,7 +1869,9 @@ async def test_http_payload_parser_deflate_split_err(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, compression="deflate")
p = HttpPayloadParser(
out, compression="deflate", headers_parser=HeadersParser()
)
# Feeding one wrong byte should be enough to choose exact
# deflate decompressor
p.feed_data(b"K")
Expand All @@ -1853,15 +1883,20 @@ async def test_http_payload_parser_length_zero(
self, protocol: BaseProtocol
) -> None:
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=0)
p = HttpPayloadParser(out, length=0, headers_parser=HeadersParser())
assert p.done
assert out.is_eof()

@pytest.mark.skipif(brotli is None, reason="brotli is not installed")
async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None:
compressed = brotli.compress(b"brotli data")
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=len(compressed), compression="br")
p = HttpPayloadParser(
out,
length=len(compressed),
compression="br",
headers_parser=HeadersParser(),
)
p.feed_data(compressed)
assert b"brotli data" == out._buffer[0]
assert out.is_eof()
Expand All @@ -1870,7 +1905,12 @@ async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None:
async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None:
compressed = zstandard.compress(b"zstd data")
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(out, length=len(compressed), compression="zstd")
p = HttpPayloadParser(
out,
length=len(compressed),
compression="zstd",
headers_parser=HeadersParser(),
)
p.feed_data(compressed)
assert b"zstd data" == out._buffer[0]
assert out.is_eof()
Expand Down
Loading