Skip to content
1 change: 1 addition & 0 deletions CHANGES/3736.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `BaseConnector.close()` a coroutine and wait until the client closes all connections. Drop deprecated "with Connector():" syntax.
9 changes: 8 additions & 1 deletion aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ServerDisconnectedError,
ServerTimeoutError,
)
from .helpers import BaseTimerContext
from .helpers import BaseTimerContext, set_exception, set_result
from .http import HttpResponseParser, RawResponseMessage
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader

Expand Down Expand Up @@ -38,6 +38,8 @@ def __init__(self,
self._read_timeout = None # type: Optional[float]
self._read_timeout_handle = None # type: Optional[asyncio.TimerHandle]

self.closed = self._loop.create_future() # type: asyncio.Future[None]

@property
def upgraded(self) -> bool:
return self._upgraded
Expand Down Expand Up @@ -70,6 +72,11 @@ def is_connected(self) -> bool:
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()

if exc is not None:
set_exception(self.closed, exc)
else:
set_result(self.closed, None)

if self._payload_parser is not None:
with suppress(Exception):
self._payload_parser.feed_eof()
Expand Down
51 changes: 33 additions & 18 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import logging
import random
import sys
import traceback
Expand Down Expand Up @@ -49,7 +50,6 @@
CeilTimeout,
get_running_loop,
is_ip_address,
noop2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete helpers.noop2 function as well.

sentinel,
)
from .http import RESPONSES
Expand Down Expand Up @@ -185,6 +185,10 @@ def closed(self) -> bool:

class _TransportPlaceholder:
""" placeholder for BaseConnector.connect function """
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
fut = loop.create_future()
fut.set_result(None)
self.closed = fut # type: asyncio.Future[Optional[Exception]] # noqa

def close(self) -> None:
pass
Expand Down Expand Up @@ -264,7 +268,7 @@ def __del__(self, _warnings: Any=warnings) -> None:

conns = [repr(c) for c in self._conns.values()]

self._close()
self._close_immediately()

if PY_36:
kwargs = {'source': self}
Expand All @@ -281,13 +285,11 @@ def __del__(self, _warnings: Any=warnings) -> None:
self._loop.call_exception_handler(context)

def __enter__(self) -> 'BaseConnector':
warnings.warn('"with Connector():" is deprecated, '
'use "async with Connector():" instead',
DeprecationWarning)
return self
raise TypeError('use "async with Connector():" instead')

def __exit__(self, *exc: Any) -> None:
self.close()
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover

async def __aenter__(self) -> 'BaseConnector':
return self
Expand Down Expand Up @@ -386,20 +388,29 @@ def _cleanup_closed(self) -> None:
self, '_cleanup_closed',
self._cleanup_closed_period, self._loop)

def close(self) -> Awaitable[None]:
async def close(self) -> None:
"""Close all opened transports."""
self._close()
return _DeprecationWaiter(noop2())
waiters = self._close_immediately()
if waiters:
results = await asyncio.gather(*waiters,
loop=self._loop,
return_exceptions=True)
for res in results:
if isinstance(res, Exception):
err_msg = "Error while closing connector: " + repr(res)
logging.error(err_msg)

def _close_immediately(self) -> List['asyncio.Future[None]']:
waiters = [] # type: List['asyncio.Future[None]']

def _close(self) -> None:
if self._closed:
return
return waiters

self._closed = True

try:
if self._loop.is_closed():
return
return waiters

# cancel cleanup task
if self._cleanup_handle:
Expand All @@ -412,14 +423,19 @@ def _close(self) -> None:
for data in self._conns.values():
for proto, t0 in data:
proto.close()
waiters.append(proto.closed)

for proto in self._acquired:
proto.close()
waiters.append(proto.closed)

# TODO (A.Yushovskiy, 24-May-2019) collect transp. closing futures
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incomplete task ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

work with transports here, especially self._cleanup_closed_transports, seem to be a workaround of the non-closed transports for SSL only:

if (key.is_ssl and
not self._cleanup_closed_disabled):
self._cleanup_closed_transports.append(
transport)

Also, transports in _cleanup_closed_transports are of type asyncio.Transport so we can not modify them to save closing future same way as we did in this PR with ResponseHandler, thus I don't see a way to get the awaitable result of transport.abort().

Good news is that most likely these calls are redundant since we close all protocols (i.e., save them to the list of futures: waiters.append(proto.closed))

for transport in self._cleanup_closed_transports:
if transport is not None:
transport.abort()

return waiters

finally:
self._conns.clear()
self._acquired.clear()
Expand Down Expand Up @@ -510,7 +526,8 @@ async def connect(self, req: 'ClientRequest',

proto = self._get(key)
if proto is None:
placeholder = cast(ResponseHandler, _TransportPlaceholder())
placeholder = cast(ResponseHandler,
_TransportPlaceholder(self._loop))
self._acquired.add(placeholder)
self._acquired_per_host[key].add(placeholder)

Expand Down Expand Up @@ -741,12 +758,10 @@ def __init__(self, *,
self._family = family
self._local_addr = local_addr

def close(self) -> Awaitable[None]:
"""Close all ongoing DNS calls."""
def _close_immediately(self) -> List['asyncio.Future[None]']:
for ev in self._throttle_dns_events.values():
ev.cancel()

return super().close()
return super()._close_immediately()

@property
def family(self) -> int:
Expand Down
7 changes: 1 addition & 6 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,7 @@ def all_tasks(
coroutines._DEBUG = False # type: ignore


@asyncio.coroutine
def noop(*args, **kwargs): # type: ignore
return # type: ignore


async def noop2(*args: Any, **kwargs: Any) -> None:
async def noop(*args: Any, **kwargs: Any) -> None:
return


Expand Down
2 changes: 1 addition & 1 deletion docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ BaseConnector

.. comethod:: close()

Close all opened connections.
Close all open connections (and await them to close).

.. comethod:: connect(request)

Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hashlib import md5, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import mock
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -73,6 +74,17 @@ def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes):
return sha256(tls_cert_der).digest()


@pytest.fixture
def create_mocked_conn(loop):
def _proto_factory(conn_closing_result=None, **kwargs):
proto = mock.Mock(**kwargs)
proto.closed = loop.create_future()
proto.closed.set_result(conn_closing_result)
return proto

yield _proto_factory


@pytest.fixture
def unix_sockname(tmp_path, tmp_path_factory):
"""Generate an fs path to the UNIX domain socket for testing.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,13 @@ async def handler(request):
assert resp.status == 200


async def test_aiohttp_request_ctx_manager_not_found() -> None:

with pytest.raises(aiohttp.ClientConnectionError):
async with aiohttp.request('GET', 'http://wrong-dns-name.com'):
assert False, "never executed" # pragma: no cover


async def test_aiohttp_request_coroutine(aiohttp_server) -> None:
async def handler(request):
return web.Response()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def test_terminate_without_writer(loop) -> None:
assert req._writer is None


async def test_custom_req_rep(loop) -> None:
async def test_custom_req_rep(loop, create_mocked_conn) -> None:
conn = None

class CustomResponse(ClientResponse):
Expand Down Expand Up @@ -1125,7 +1125,7 @@ async def send(self, conn):

async def create_connection(req, traces, timeout):
assert isinstance(req, CustomRequest)
return mock.Mock()
return create_mocked_conn()
connector = BaseConnector(loop=loop)
connector._create_connection = create_connection

Expand Down
16 changes: 8 additions & 8 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@


@pytest.fixture
def connector(loop):
def connector(loop, create_mocked_conn):
async def make_conn():
return BaseConnector(loop=loop)
conn = loop.run_until_complete(make_conn())
proto = mock.Mock()
proto = create_mocked_conn()
conn._conns['a'] = [(proto, 123)]
yield conn
conn.close()
Expand Down Expand Up @@ -327,10 +327,10 @@ async def test_request_closed_session(session) -> None:
await session.request('get', '/')


def test_close_flag_for_closed_connector(session) -> None:
async def test_close_flag_for_closed_connector(session) -> None:
conn = session.connector
assert not session.closed
conn.close()
await conn.close()
assert session.closed


Expand Down Expand Up @@ -395,7 +395,7 @@ async def test_borrow_connector_loop(connector, create_session, loop) -> None:
await session.close()


async def test_reraise_os_error(create_session) -> None:
async def test_reraise_os_error(create_session, create_mocked_conn) -> None:
err = OSError(1, "permission error")
req = mock.Mock()
req_factory = mock.Mock(return_value=req)
Expand All @@ -404,7 +404,7 @@ async def test_reraise_os_error(create_session) -> None:

async def create_connection(req, traces, timeout):
# return self.transport, self.protocol
return mock.Mock()
return create_mocked_conn()
session._connector._create_connection = create_connection
session._connector._release = mock.Mock()

Expand All @@ -415,7 +415,7 @@ async def create_connection(req, traces, timeout):
assert e.strerror == err.strerror


async def test_close_conn_on_error(create_session) -> None:
async def test_close_conn_on_error(create_session, create_mocked_conn) -> None:
class UnexpectedException(BaseException):
pass

Expand All @@ -435,7 +435,7 @@ async def connect(req, traces, timeout):

async def create_connection(req, traces, timeout):
# return self.transport, self.protocol
conn = mock.Mock()
conn = create_mocked_conn()
return conn

session._connector.connect = connect
Expand Down
Loading