|
10 | 10 | from asyncio.coroutines import iscoroutine as _iscoroutine, iscoroutinefunction as _iscoroutinefunction
|
11 | 11 | from asyncio.events import _get_running_loop, _set_running_loop
|
12 | 12 | from asyncio.futures import Future as _Future, isfuture as _isfuture, wrap_future as _wrap_future
|
| 13 | +from asyncio.staggered import staggered_race as _staggered_race |
13 | 14 | from asyncio.tasks import Task as _Task, ensure_future as _ensure_future, gather as _gather
|
14 | 15 | from concurrent.futures import ThreadPoolExecutor
|
15 | 16 | from contextvars import copy_context as _copy_context
|
|
27 | 28 | _SubProcessTransport,
|
28 | 29 | _ThreadedChildWatcher,
|
29 | 30 | )
|
30 |
| -from .utils import _can_use_pidfd, _HAS_IPv6, _ipaddr_info, _noop, _set_reuseport |
| 31 | +from .utils import _can_use_pidfd, _HAS_IPv6, _interleave_addrinfos, _ipaddr_info, _noop, _set_reuseport |
31 | 32 |
|
32 | 33 |
|
33 | 34 | class RLoop(__BaseLoop, __asyncio.AbstractEventLoop):
|
@@ -312,8 +313,164 @@ async def create_connection(
|
312 | 313 | ssl_shutdown_timeout=None,
|
313 | 314 | happy_eyeballs_delay=None,
|
314 | 315 | interleave=None,
|
| 316 | + all_errors=False, |
315 | 317 | ):
|
316 |
| - raise NotImplementedError |
| 318 | + # TODO |
| 319 | + if ssl: |
| 320 | + raise NotImplementedError |
| 321 | + |
| 322 | + if server_hostname is not None and not ssl: |
| 323 | + raise ValueError('server_hostname is only meaningful with ssl') |
| 324 | + |
| 325 | + if server_hostname is None and ssl: |
| 326 | + if not host: |
| 327 | + raise ValueError('You must set server_hostname when using ssl without a host') |
| 328 | + server_hostname = host |
| 329 | + |
| 330 | + if ssl_handshake_timeout is not None and not ssl: |
| 331 | + raise ValueError('ssl_handshake_timeout is only meaningful with ssl') |
| 332 | + |
| 333 | + if ssl_shutdown_timeout is not None and not ssl: |
| 334 | + raise ValueError('ssl_shutdown_timeout is only meaningful with ssl') |
| 335 | + |
| 336 | + # TODO |
| 337 | + # if sock is not None: |
| 338 | + # _check_ssl_socket(sock) |
| 339 | + |
| 340 | + if happy_eyeballs_delay is not None and interleave is None: |
| 341 | + # If using happy eyeballs, default to interleave addresses by family |
| 342 | + interleave = 1 |
| 343 | + |
| 344 | + if host is not None or port is not None: |
| 345 | + if sock is not None: |
| 346 | + raise ValueError('host/port and sock can not be specified at the same time') |
| 347 | + |
| 348 | + infos = await self._ensure_resolved( |
| 349 | + (host, port), family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self |
| 350 | + ) |
| 351 | + if not infos: |
| 352 | + raise OSError('getaddrinfo() returned empty list') |
| 353 | + |
| 354 | + if local_addr is not None: |
| 355 | + laddr_infos = await self._ensure_resolved( |
| 356 | + local_addr, family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self |
| 357 | + ) |
| 358 | + if not laddr_infos: |
| 359 | + raise OSError('getaddrinfo() returned empty list') |
| 360 | + else: |
| 361 | + laddr_infos = None |
| 362 | + |
| 363 | + if interleave: |
| 364 | + infos = _interleave_addrinfos(infos, interleave) |
| 365 | + |
| 366 | + exceptions = [] |
| 367 | + if happy_eyeballs_delay is None: |
| 368 | + # not using happy eyeballs |
| 369 | + for addrinfo in infos: |
| 370 | + try: |
| 371 | + sock = await self._connect_sock(exceptions, addrinfo, laddr_infos) |
| 372 | + break |
| 373 | + except OSError: |
| 374 | + continue |
| 375 | + else: # using happy eyeballs |
| 376 | + sock = ( |
| 377 | + await _staggered_race( |
| 378 | + ( |
| 379 | + # can't use functools.partial as it keeps a reference |
| 380 | + # to exceptions |
| 381 | + lambda addrinfo=addrinfo: self._connect_sock(exceptions, addrinfo, laddr_infos) |
| 382 | + for addrinfo in infos |
| 383 | + ), |
| 384 | + happy_eyeballs_delay, |
| 385 | + loop=self, |
| 386 | + ) |
| 387 | + )[0] # can't use sock, _, _ as it keeks a reference to exceptions |
| 388 | + |
| 389 | + if sock is None: |
| 390 | + exceptions = [exc for sub in exceptions for exc in sub] |
| 391 | + try: |
| 392 | + if all_errors: |
| 393 | + raise ExceptionGroup('create_connection failed', exceptions) |
| 394 | + if len(exceptions) == 1: |
| 395 | + raise exceptions[0] |
| 396 | + else: |
| 397 | + # If they all have the same str(), raise one. |
| 398 | + model = str(exceptions[0]) |
| 399 | + if all(str(exc) == model for exc in exceptions): |
| 400 | + raise exceptions[0] |
| 401 | + # Raise a combined exception so the user can see all |
| 402 | + # the various error messages. |
| 403 | + raise OSError('Multiple exceptions: {}'.format(', '.join(str(exc) for exc in exceptions))) |
| 404 | + finally: |
| 405 | + exceptions = None |
| 406 | + |
| 407 | + else: |
| 408 | + if sock is None: |
| 409 | + raise ValueError('host and port was not specified and no sock specified') |
| 410 | + if sock.type != socket.SOCK_STREAM: |
| 411 | + # We allow AF_INET, AF_INET6, AF_UNIX as long as they |
| 412 | + # are SOCK_STREAM. |
| 413 | + # We support passing AF_UNIX sockets even though we have |
| 414 | + # a dedicated API for that: create_unix_connection. |
| 415 | + # Disallowing AF_UNIX in this method, breaks backwards |
| 416 | + # compatibility. |
| 417 | + raise ValueError(f'A Stream Socket was expected, got {sock!r}') |
| 418 | + |
| 419 | + sock.setblocking(False) |
| 420 | + rsock = (sock.fileno(), sock.family) |
| 421 | + sock.detach() |
| 422 | + |
| 423 | + # TODO: ssl |
| 424 | + transport, protocol = self._tcp_conn(rsock, protocol_factory) |
| 425 | + # transport, protocol = await self._create_connection_transport( |
| 426 | + # sock, |
| 427 | + # protocol_factory, |
| 428 | + # ssl, |
| 429 | + # server_hostname, |
| 430 | + # ssl_handshake_timeout=ssl_handshake_timeout, |
| 431 | + # ssl_shutdown_timeout=ssl_shutdown_timeout, |
| 432 | + # ) |
| 433 | + |
| 434 | + return transport, protocol |
| 435 | + |
| 436 | + async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): |
| 437 | + my_exceptions = [] |
| 438 | + exceptions.append(my_exceptions) |
| 439 | + family, type_, proto, _, address = addr_info |
| 440 | + sock = None |
| 441 | + try: |
| 442 | + sock = socket.socket(family=family, type=type_, proto=proto) |
| 443 | + sock.setblocking(False) |
| 444 | + if local_addr_infos is not None: |
| 445 | + for lfamily, _, _, _, laddr in local_addr_infos: |
| 446 | + # skip local addresses of different family |
| 447 | + if lfamily != family: |
| 448 | + continue |
| 449 | + try: |
| 450 | + sock.bind(laddr) |
| 451 | + break |
| 452 | + except OSError as exc: |
| 453 | + msg = f'error while attempting to bind on address {laddr!r}: {str(exc).lower()}' |
| 454 | + exc = OSError(exc.errno, msg) |
| 455 | + my_exceptions.append(exc) |
| 456 | + else: # all bind attempts failed |
| 457 | + if my_exceptions: |
| 458 | + raise my_exceptions.pop() |
| 459 | + else: |
| 460 | + raise OSError(f'no matching local address with {family=} found') |
| 461 | + await self.sock_connect(sock, address) |
| 462 | + return sock |
| 463 | + except OSError as exc: |
| 464 | + my_exceptions.append(exc) |
| 465 | + if sock is not None: |
| 466 | + sock.close() |
| 467 | + raise |
| 468 | + except: |
| 469 | + if sock is not None: |
| 470 | + sock.close() |
| 471 | + raise |
| 472 | + finally: |
| 473 | + exceptions = my_exceptions = None |
317 | 474 |
|
318 | 475 | async def create_server(
|
319 | 476 | self,
|
|
0 commit comments