Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/11150.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed leak of ``aiodns.DNSResolver`` when :py:class:`~aiohttp.TCPConnector` is closed and no resolver was passed when creating the connector -- by :user:`Tasssadar`.

This was a regression introduced in version 3.12.0 (:pr:`10897`).
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ Vladimir Shulyak
Vladimir Zakharov
Vladyslav Bohaichuk
Vladyslav Bondar
Vojtěch Boček
W. Trevor King
Wei Lin
Weiwei Wang
Expand Down
18 changes: 15 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,14 @@ def __init__(
)

self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)

self._resolver: AbstractResolver
if resolver is None:
resolver = DefaultResolver(loop=self._loop)
self._resolver = resolver
self._resolver = DefaultResolver(loop=self._loop)
self._resolver_owner = True
else:
self._resolver = resolver
self._resolver_owner = False

self._use_dns_cache = use_dns_cache
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
Expand Down Expand Up @@ -956,6 +961,12 @@ def _close(self) -> List[Awaitable[object]]:

return waiters

async def close(self) -> None:
"""Close all opened transports."""
if self._resolver_owner:
await self._resolver.close()
await super().close()

@property
def family(self) -> int:
"""Socket family like AF_INET."""
Expand Down Expand Up @@ -1709,7 +1720,8 @@ def __init__(
loop=loop,
)
if not isinstance(
self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
self._loop,
asyncio.ProactorEventLoop, # type: ignore[attr-defined]
):
raise RuntimeError(
"Named Pipes only available in proactor loop under windows"
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ def release_resolver(
loop: The event loop the resolver was using.
"""
# Remove client from its loop's tracking
if loop not in self._loop_data:
current_loop_data = self._loop_data.get(loop)
if current_loop_data is None:
return
resolver, client_set = self._loop_data[loop]
resolver, client_set = current_loop_data
client_set.discard(client)
# If no more clients for this loop, cancel and remove its resolver
if not client_set:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,7 @@ async def test_tcp_connector_dns_cache_not_expired(loop, dns_response) -> None:
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
m_resolver().resolve.return_value = dns_response()
m_resolver().close = mock.AsyncMock()
await conn._resolve_host("localhost", 8080)
await conn._resolve_host("localhost", 8080)
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
Expand All @@ -1281,6 +1282,7 @@ async def test_tcp_connector_dns_cache_forever(loop, dns_response) -> None:
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
m_resolver().resolve.return_value = dns_response()
m_resolver().close = mock.AsyncMock()
await conn._resolve_host("localhost", 8080)
await conn._resolve_host("localhost", 8080)
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
Expand All @@ -1292,6 +1294,7 @@ async def test_tcp_connector_use_dns_cache_disabled(loop, dns_response) -> None:
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
m_resolver().close = mock.AsyncMock()
await conn._resolve_host("localhost", 8080)
await conn._resolve_host("localhost", 8080)
m_resolver().resolve.assert_has_calls(
Expand All @@ -1308,6 +1311,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None:
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
m_resolver().resolve.return_value = dns_response()
m_resolver().close = mock.AsyncMock()
loop.create_task(conn._resolve_host("localhost", 8080))
loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
Expand All @@ -1322,6 +1326,7 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
e = Exception()
m_resolver().resolve.side_effect = e
m_resolver().close = mock.AsyncMock()
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
Expand All @@ -1341,6 +1346,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
m_resolver().resolve.return_value = dns_response()
m_resolver().close = mock.AsyncMock()
loop.create_task(conn._resolve_host("localhost", 8080))
f = loop.create_task(conn._resolve_host("localhost", 8080))

Expand Down Expand Up @@ -1384,6 +1390,7 @@ def exception_handler(loop, context):
use_dns_cache=False,
)
m_resolver().resolve.return_value = dns_response_error()
m_resolver().close = mock.AsyncMock()
f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout(0)))

await asyncio.sleep(0)
Expand Down Expand Up @@ -1419,6 +1426,7 @@ async def test_tcp_connector_dns_tracing(loop, dns_response) -> None:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)

m_resolver().resolve.return_value = dns_response()
m_resolver().close = mock.AsyncMock()

await conn._resolve_host("localhost", 8080, traces=traces)
on_dns_resolvehost_start.assert_called_once_with(
Expand Down Expand Up @@ -1460,6 +1468,7 @@ async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response) -> N
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)

m_resolver().resolve.side_effect = [dns_response(), dns_response()]
m_resolver().close = mock.AsyncMock()

await conn._resolve_host("localhost", 8080, traces=traces)

Expand Down Expand Up @@ -1514,6 +1523,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
m_resolver().resolve.return_value = dns_response()
m_resolver().close = mock.AsyncMock()
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
await asyncio.sleep(0)
Expand All @@ -1528,6 +1538,14 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
await conn.close()


async def test_tcp_connector_close_resolver() -> None:
m_resolver = mock.AsyncMock()
with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver):
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
await conn.close()
m_resolver.close.assert_awaited_once()


async def test_dns_error(loop) -> None:
connector = aiohttp.TCPConnector(loop=loop)
connector._resolve_host = mock.AsyncMock(
Expand Down Expand Up @@ -3691,6 +3709,7 @@ async def resolve_response() -> List[ResolveResult]:

with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
m_resolver().resolve.return_value = resolve_response()
m_resolver().close = mock.AsyncMock()

connector = TCPConnector()
traces = [DummyTracer()]
Expand Down
Loading