Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import h2.exceptions
import h2.settings

from .._exceptions import ConnectionNotAvailable, RemoteProtocolError
from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncSemaphore
from .._trace import Trace
Expand Down Expand Up @@ -56,6 +60,7 @@ def __init__(
self._events: typing.Dict[int, h2.events.Event] = {}
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._connection_error_event: typing.Optional[h2.events.Event] = None

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -114,11 +119,28 @@ async def handle_async_request(self, request: Request) -> Response:
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={"stream_id": stream_id, "http_version": b"HTTP/2"},
)
except Exception: # noqa: PIE786
except Exception as exc: # noqa: PIE786
kwargs = {"stream_id": stream_id}
async with Trace("http2.response_closed", request, kwargs):
await self._response_closed(stream_id=stream_id)
raise

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
# closed frame has been seen by the state machine.
#
# This happens when one stream is reading, and encounters
# a GOAWAY event. Other flows of control may then raise
# a protocol error at any point they interact with the 'h2_state'.
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_error_event:
raise RemoteProtocolError(self._connection_error_event)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover

raise exc

async def _send_connection_init(self, request: Request) -> None:
"""
Expand Down Expand Up @@ -235,10 +257,17 @@ async def _receive_stream_event(
) -> h2.events.Event:
while not self._events.get(stream_id):
await self._receive_events(request, stream_id)
return self._events[stream_id].pop(0)
event = self._events[stream_id].pop(0)
# The StreamReset event applies to a single stream.
if hasattr(event, "error_code"):
raise RemoteProtocolError(event)
return event

async def _receive_events(self, request: Request, stream_id: int = None) -> None:
async with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)

# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
# check *within* the atomic read lock. Though it also need to be optional,
Expand All @@ -250,7 +279,10 @@ async def _receive_events(self, request: Request, stream_id: int = None) -> None
for event in events:
event_stream_id = getattr(event, "stream_id", 0)

if hasattr(event, "error_code"):
# The ConnectionTerminatedEvent applies to the entire connection,
# and should be saved so it can be raised on all streams.
if hasattr(event, "error_code") and event_stream_id == 0:
self._connection_error_event = event
raise RemoteProtocolError(event)

if event_stream_id in self._events:
Expand Down
42 changes: 37 additions & 5 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import h2.exceptions
import h2.settings

from .._exceptions import ConnectionNotAvailable, RemoteProtocolError
from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock, Semaphore
from .._trace import Trace
Expand Down Expand Up @@ -56,6 +60,7 @@ def __init__(
self._events: typing.Dict[int, h2.events.Event] = {}
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._connection_error_event: typing.Optional[h2.events.Event] = None

def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -114,11 +119,28 @@ def handle_request(self, request: Request) -> Response:
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={"stream_id": stream_id, "http_version": b"HTTP/2"},
)
except Exception: # noqa: PIE786
except Exception as exc: # noqa: PIE786
kwargs = {"stream_id": stream_id}
with Trace("http2.response_closed", request, kwargs):
self._response_closed(stream_id=stream_id)
raise

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
# closed frame has been seen by the state machine.
#
# This happens when one stream is reading, and encounters
# a GOAWAY event. Other flows of control may then raise
# a protocol error at any point they interact with the 'h2_state'.
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_error_event:
raise RemoteProtocolError(self._connection_error_event)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover

raise exc

def _send_connection_init(self, request: Request) -> None:
"""
Expand Down Expand Up @@ -235,10 +257,17 @@ def _receive_stream_event(
) -> h2.events.Event:
while not self._events.get(stream_id):
self._receive_events(request, stream_id)
return self._events[stream_id].pop(0)
event = self._events[stream_id].pop(0)
# The StreamReset event applies to a single stream.
if hasattr(event, "error_code"):
raise RemoteProtocolError(event)
return event

def _receive_events(self, request: Request, stream_id: int = None) -> None:
with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)

# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
# check *within* the atomic read lock. Though it also need to be optional,
Expand All @@ -250,7 +279,10 @@ def _receive_events(self, request: Request, stream_id: int = None) -> None:
for event in events:
event_stream_id = getattr(event, "stream_id", 0)

if hasattr(event, "error_code"):
# The ConnectionTerminatedEvent applies to the entire connection,
# and should be saved so it can be raised on all streams.
if hasattr(event, "error_code") and event_stream_id == 0:
self._connection_error_event = event
raise RemoteProtocolError(event)

if event_stream_id in self._events:
Expand Down
69 changes: 66 additions & 3 deletions tests/_async/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ async def test_http2_connection_with_remote_protocol_error():


@pytest.mark.anyio
async def test_http2_connection_with_stream_cancelled():
async def test_http2_connection_with_rst_stream():
"""
If a remote protocol error occurs, then no response will be returned,
and the connection will not be reusable.
If a stream reset occurs, then no response will be returned,
but the connection will remain reusable for other requests.
"""
origin = Origin(b"https", b"example.com", 443)
stream = AsyncMockStream(
Expand All @@ -117,13 +117,76 @@ async def test_http2_connection_with_stream_cancelled():
),
flags=["END_HEADERS"],
).serialize(),
# Stream is closed midway through the first response...
hyperframe.frame.RstStreamFrame(stream_id=1, error_code=8).serialize(),
# ...Which doesn't prevent the second response.
hyperframe.frame.HeadersFrame(
stream_id=3,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=3, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
b"",
]
)
async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn:
with pytest.raises(RemoteProtocolError):
await conn.request("GET", "https://example.com/")
response = await conn.request("GET", "https://example.com/")
assert response.status == 200


@pytest.mark.anyio
async def test_http2_connection_with_goaway():
"""
If a stream reset occurs, then no response will be returned,
but the connection will remain reusable for other requests.
"""
origin = Origin(b"https", b"example.com", 443)
stream = AsyncMockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
# Connection is closed midway through the first response...
hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(),
# ...We'll never get to this second response.
hyperframe.frame.HeadersFrame(
stream_id=3,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=3, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
b"",
]
)
async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn:
with pytest.raises(RemoteProtocolError):
await conn.request("GET", "https://example.com/")
with pytest.raises(RemoteProtocolError):
await conn.request("GET", "https://example.com/")


@pytest.mark.anyio
Expand Down
69 changes: 66 additions & 3 deletions tests/_sync/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def test_http2_connection_with_remote_protocol_error():



def test_http2_connection_with_stream_cancelled():
def test_http2_connection_with_rst_stream():
"""
If a remote protocol error occurs, then no response will be returned,
and the connection will not be reusable.
If a stream reset occurs, then no response will be returned,
but the connection will remain reusable for other requests.
"""
origin = Origin(b"https", b"example.com", 443)
stream = MockStream(
Expand All @@ -117,13 +117,76 @@ def test_http2_connection_with_stream_cancelled():
),
flags=["END_HEADERS"],
).serialize(),
# Stream is closed midway through the first response...
hyperframe.frame.RstStreamFrame(stream_id=1, error_code=8).serialize(),
# ...Which doesn't prevent the second response.
hyperframe.frame.HeadersFrame(
stream_id=3,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=3, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
b"",
]
)
with HTTP2Connection(origin=origin, stream=stream) as conn:
with pytest.raises(RemoteProtocolError):
conn.request("GET", "https://example.com/")
response = conn.request("GET", "https://example.com/")
assert response.status == 200



def test_http2_connection_with_goaway():
"""
If a stream reset occurs, then no response will be returned,
but the connection will remain reusable for other requests.
"""
origin = Origin(b"https", b"example.com", 443)
stream = MockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
# Connection is closed midway through the first response...
hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(),
# ...We'll never get to this second response.
hyperframe.frame.HeadersFrame(
stream_id=3,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(
stream_id=3, data=b"Hello, world!", flags=["END_STREAM"]
).serialize(),
b"",
]
)
with HTTP2Connection(origin=origin, stream=stream) as conn:
with pytest.raises(RemoteProtocolError):
conn.request("GET", "https://example.com/")
with pytest.raises(RemoteProtocolError):
conn.request("GET", "https://example.com/")



Expand Down