Skip to content

Commit cc68054

Browse files
committed
Add EventLoop.create_connection impl
1 parent 1bd0c96 commit cc68054

File tree

5 files changed

+259
-13
lines changed

5 files changed

+259
-13
lines changed

rloop/_rloop.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any
1+
from typing import Any, Callable, Tuple, TypeVar
22
from weakref import WeakSet
33

44
__version__: str
55

6+
T = TypeVar('T')
7+
68
class CBHandle:
79
def cancel(self): ...
810
def cancelled(self) -> bool: ...
@@ -45,6 +47,8 @@ class EventLoop:
4547
def _sig_clear(self): ...
4648
def _ssock_set(self, fd): ...
4749
def _ssock_del(self, fd): ...
50+
def _tcp_conn(self, sock, protocol_factory: Callable[[], T]) -> Tuple[Any, T]: ...
51+
def _tcp_server(self, socks, rsocks, protocol_factory, backlog) -> Server: ...
4852
def call_soon(self, callback, *args, context=None) -> CBHandle: ...
4953
def call_soon_threadsafe(self, callback, *args, context=None) -> CBHandle: ...
5054

rloop/loop.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from asyncio.coroutines import iscoroutine as _iscoroutine, iscoroutinefunction as _iscoroutinefunction
1111
from asyncio.events import _get_running_loop, _set_running_loop
1212
from asyncio.futures import Future as _Future, isfuture as _isfuture, wrap_future as _wrap_future
13+
from asyncio.staggered import staggered_race as _staggered_race
1314
from asyncio.tasks import Task as _Task, ensure_future as _ensure_future, gather as _gather
1415
from concurrent.futures import ThreadPoolExecutor
1516
from contextvars import copy_context as _copy_context
@@ -27,7 +28,7 @@
2728
_SubProcessTransport,
2829
_ThreadedChildWatcher,
2930
)
30-
from .utils import _can_use_pidfd, _HAS_IPv6, _ipaddr_info, _noop, _set_reuseport
31+
from .utils import _can_use_pidfd, _HAS_IPv6, _interleave_addrinfos, _ipaddr_info, _noop, _set_reuseport
3132

3233

3334
class RLoop(__BaseLoop, __asyncio.AbstractEventLoop):
@@ -312,8 +313,164 @@ async def create_connection(
312313
ssl_shutdown_timeout=None,
313314
happy_eyeballs_delay=None,
314315
interleave=None,
316+
all_errors=False,
315317
):
316-
raise NotImplementedError
318+
# TODO
319+
if ssl:
320+
raise NotImplementedError
321+
322+
if server_hostname is not None and not ssl:
323+
raise ValueError('server_hostname is only meaningful with ssl')
324+
325+
if server_hostname is None and ssl:
326+
if not host:
327+
raise ValueError('You must set server_hostname when using ssl without a host')
328+
server_hostname = host
329+
330+
if ssl_handshake_timeout is not None and not ssl:
331+
raise ValueError('ssl_handshake_timeout is only meaningful with ssl')
332+
333+
if ssl_shutdown_timeout is not None and not ssl:
334+
raise ValueError('ssl_shutdown_timeout is only meaningful with ssl')
335+
336+
# TODO
337+
# if sock is not None:
338+
# _check_ssl_socket(sock)
339+
340+
if happy_eyeballs_delay is not None and interleave is None:
341+
# If using happy eyeballs, default to interleave addresses by family
342+
interleave = 1
343+
344+
if host is not None or port is not None:
345+
if sock is not None:
346+
raise ValueError('host/port and sock can not be specified at the same time')
347+
348+
infos = await self._ensure_resolved(
349+
(host, port), family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self
350+
)
351+
if not infos:
352+
raise OSError('getaddrinfo() returned empty list')
353+
354+
if local_addr is not None:
355+
laddr_infos = await self._ensure_resolved(
356+
local_addr, family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self
357+
)
358+
if not laddr_infos:
359+
raise OSError('getaddrinfo() returned empty list')
360+
else:
361+
laddr_infos = None
362+
363+
if interleave:
364+
infos = _interleave_addrinfos(infos, interleave)
365+
366+
exceptions = []
367+
if happy_eyeballs_delay is None:
368+
# not using happy eyeballs
369+
for addrinfo in infos:
370+
try:
371+
sock = await self._connect_sock(exceptions, addrinfo, laddr_infos)
372+
break
373+
except OSError:
374+
continue
375+
else: # using happy eyeballs
376+
sock = (
377+
await _staggered_race(
378+
(
379+
# can't use functools.partial as it keeps a reference
380+
# to exceptions
381+
lambda addrinfo=addrinfo: self._connect_sock(exceptions, addrinfo, laddr_infos)
382+
for addrinfo in infos
383+
),
384+
happy_eyeballs_delay,
385+
loop=self,
386+
)
387+
)[0] # can't use sock, _, _ as it keeks a reference to exceptions
388+
389+
if sock is None:
390+
exceptions = [exc for sub in exceptions for exc in sub]
391+
try:
392+
if all_errors:
393+
raise ExceptionGroup('create_connection failed', exceptions)
394+
if len(exceptions) == 1:
395+
raise exceptions[0]
396+
else:
397+
# If they all have the same str(), raise one.
398+
model = str(exceptions[0])
399+
if all(str(exc) == model for exc in exceptions):
400+
raise exceptions[0]
401+
# Raise a combined exception so the user can see all
402+
# the various error messages.
403+
raise OSError('Multiple exceptions: {}'.format(', '.join(str(exc) for exc in exceptions)))
404+
finally:
405+
exceptions = None
406+
407+
else:
408+
if sock is None:
409+
raise ValueError('host and port was not specified and no sock specified')
410+
if sock.type != socket.SOCK_STREAM:
411+
# We allow AF_INET, AF_INET6, AF_UNIX as long as they
412+
# are SOCK_STREAM.
413+
# We support passing AF_UNIX sockets even though we have
414+
# a dedicated API for that: create_unix_connection.
415+
# Disallowing AF_UNIX in this method, breaks backwards
416+
# compatibility.
417+
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
418+
419+
sock.setblocking(False)
420+
rsock = (sock.fileno(), sock.family)
421+
sock.detach()
422+
423+
# TODO: ssl
424+
transport, protocol = self._tcp_conn(rsock, protocol_factory)
425+
# transport, protocol = await self._create_connection_transport(
426+
# sock,
427+
# protocol_factory,
428+
# ssl,
429+
# server_hostname,
430+
# ssl_handshake_timeout=ssl_handshake_timeout,
431+
# ssl_shutdown_timeout=ssl_shutdown_timeout,
432+
# )
433+
434+
return transport, protocol
435+
436+
async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None):
437+
my_exceptions = []
438+
exceptions.append(my_exceptions)
439+
family, type_, proto, _, address = addr_info
440+
sock = None
441+
try:
442+
sock = socket.socket(family=family, type=type_, proto=proto)
443+
sock.setblocking(False)
444+
if local_addr_infos is not None:
445+
for lfamily, _, _, _, laddr in local_addr_infos:
446+
# skip local addresses of different family
447+
if lfamily != family:
448+
continue
449+
try:
450+
sock.bind(laddr)
451+
break
452+
except OSError as exc:
453+
msg = f'error while attempting to bind on address {laddr!r}: {str(exc).lower()}'
454+
exc = OSError(exc.errno, msg)
455+
my_exceptions.append(exc)
456+
else: # all bind attempts failed
457+
if my_exceptions:
458+
raise my_exceptions.pop()
459+
else:
460+
raise OSError(f'no matching local address with {family=} found')
461+
await self.sock_connect(sock, address)
462+
return sock
463+
except OSError as exc:
464+
my_exceptions.append(exc)
465+
if sock is not None:
466+
sock.close()
467+
raise
468+
except:
469+
if sock is not None:
470+
sock.close()
471+
raise
472+
finally:
473+
exceptions = my_exceptions = None
317474

318475
async def create_server(
319476
self,

rloop/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import collections
2+
import itertools
13
import os
24
import socket
35

@@ -79,6 +81,23 @@ def _ipaddr_info(host, port, family, type, proto, flowinfo=0, scopeid=0):
7981
return None
8082

8183

84+
def _interleave_addrinfos(addrinfos, first_address_family_count=1):
85+
addrinfos_by_family = collections.OrderedDict()
86+
for addr in addrinfos:
87+
family = addr[0]
88+
if family not in addrinfos_by_family:
89+
addrinfos_by_family[family] = []
90+
addrinfos_by_family[family].append(addr)
91+
addrinfos_lists = list(addrinfos_by_family.values())
92+
93+
reordered = []
94+
if first_address_family_count > 1:
95+
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
96+
del addrinfos_lists[0][: first_address_family_count - 1]
97+
reordered.extend(a for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) if a is not None)
98+
return reordered
99+
100+
82101
def _set_reuseport(sock):
83102
if not hasattr(socket, 'SO_REUSEPORT'):
84103
raise ValueError('reuse_port not supported by socket module')

src/event_loop.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::{
1818
log::{log_exc_to_py_ctx, LogExc},
1919
py::{copy_context, weakset},
2020
server::Server,
21-
tcp::{TCPReadHandle, TCPServer, TCPServerRef, TCPStream, TCPWriteHandle},
21+
tcp::{PyTCPTransport, TCPReadHandle, TCPServer, TCPServerRef, TCPStream, TCPWriteHandle},
2222
time::Timer,
2323
};
2424

@@ -413,10 +413,12 @@ impl EventLoop {
413413
pub(crate) fn tcp_stream_close(&self, fd: usize) {
414414
// println!("tcp_stream_close {:?}", fd);
415415
if let Some((_, stream)) = self.tcp_streams.remove(&fd) {
416-
self.tcp_lstreams.alter(&stream.lfd, |_, mut v| {
417-
v.remove(&fd);
418-
v
419-
});
416+
if let Some(lfd) = &stream.lfd {
417+
self.tcp_lstreams.alter(lfd, |_, mut v| {
418+
v.remove(&fd);
419+
v
420+
});
421+
}
420422
}
421423
}
422424

@@ -988,6 +990,22 @@ impl EventLoop {
988990
})
989991
}
990992

993+
fn _tcp_conn(
994+
pyself: Py<Self>,
995+
py: Python,
996+
sock: (i32, i32),
997+
protocol_factory: PyObject,
998+
) -> PyResult<(Py<PyTCPTransport>, PyObject)> {
999+
let rself = pyself.get();
1000+
let stream = TCPStream::from_py(py, &pyself, sock, protocol_factory);
1001+
let transport = stream.pytransport.clone_ref(py);
1002+
let fd = transport.get().fd;
1003+
let proto = PyTCPTransport::attach(&transport, py)?;
1004+
rself.tcp_streams.insert(fd, stream);
1005+
rself.tcp_stream_add(fd, Interest::READABLE);
1006+
Ok((transport, proto))
1007+
}
1008+
9911009
fn _tcp_server(
9921010
pyself: Py<Self>,
9931011
py: Python,
@@ -1020,6 +1038,7 @@ impl EventLoop {
10201038
fn _run(&self, py: Python) -> PyResult<()> {
10211039
let mut state = EventLoopRunState {
10221040
events: event::Events::with_capacity(128),
1041+
#[allow(clippy::large_stack_arrays)]
10231042
read_buf: [0; 262_144].into(),
10241043
tick_last: 0,
10251044
};

src/tcp.rs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl TCPServerRef {
129129
);
130130

131131
(
132-
TCPStream::new(
132+
TCPStream::from_listener(
133133
self.fd,
134134
stream,
135135
pytransport.into(),
@@ -143,7 +143,7 @@ impl TCPServerRef {
143143
}
144144

145145
pub(crate) struct TCPStream {
146-
pub lfd: usize,
146+
pub lfd: Option<usize>,
147147
pub io: TcpStream,
148148
pub pytransport: Arc<Py<PyTCPTransport>>,
149149
read_buffered: bool,
@@ -153,16 +153,16 @@ pub(crate) struct TCPStream {
153153
}
154154

155155
impl TCPStream {
156-
fn new(
157-
lfd: usize,
156+
fn from_listener(
157+
fd: usize,
158158
stream: TcpStream,
159159
pytransport: Arc<Py<PyTCPTransport>>,
160160
read_buffered: bool,
161161
pym_recv_data: Arc<PyObject>,
162162
pym_buf_get: PyObject,
163163
) -> Self {
164164
Self {
165-
lfd,
165+
lfd: Some(fd),
166166
io: stream,
167167
pytransport,
168168
read_buffered,
@@ -171,6 +171,45 @@ impl TCPStream {
171171
pym_buf_get,
172172
}
173173
}
174+
175+
pub(crate) fn from_py(py: Python, pyloop: &Py<EventLoop>, pysock: (i32, i32), proto_factory: PyObject) -> Self {
176+
let sock = unsafe { socket2::Socket::from_raw_fd(pysock.0) };
177+
_ = sock.set_nonblocking(true);
178+
let stdl: std::net::TcpStream = sock.into();
179+
let stream = TcpStream::from_std(stdl);
180+
// let stream = TcpStream::from_raw_fd(rsock);
181+
182+
let proto = proto_factory.bind(py).call0().unwrap();
183+
let mut buffered_proto = false;
184+
let pym_recv_data: PyObject;
185+
let pym_buf_get: PyObject;
186+
if proto.is_instance(asyncio_proto_buf(py).unwrap()).unwrap() {
187+
buffered_proto = true;
188+
pym_recv_data = proto.getattr(pyo3::intern!(py, "buffer_updated")).unwrap().unbind();
189+
pym_buf_get = proto.getattr(pyo3::intern!(py, "get_buffer")).unwrap().unbind();
190+
} else {
191+
pym_recv_data = proto.getattr(pyo3::intern!(py, "data_received")).unwrap().unbind();
192+
pym_buf_get = py.None();
193+
}
194+
let pyproto = proto.unbind();
195+
let pytransport = PyTCPTransport::new(
196+
py,
197+
stream.as_raw_fd() as usize,
198+
pysock.1,
199+
pyloop.clone_ref(py),
200+
pyproto.clone_ref(py),
201+
);
202+
203+
Self {
204+
lfd: None,
205+
io: stream,
206+
pytransport: pytransport.into(),
207+
read_buffered: buffered_proto,
208+
write_buffer: VecDeque::new(),
209+
pym_recv_data: pym_recv_data.into(),
210+
pym_buf_get,
211+
}
212+
}
174213
}
175214

176215
#[pyclass(frozen)]
@@ -217,6 +256,14 @@ impl PyTCPTransport {
217256
.unwrap()
218257
}
219258

259+
pub(crate) fn attach(pyself: &Py<Self>, py: Python) -> PyResult<PyObject> {
260+
let rself = pyself.get();
261+
rself
262+
.proto
263+
.call_method1(py, pyo3::intern!(py, "connection_made"), (pyself.clone_ref(py),))?;
264+
Ok(rself.proto.clone_ref(py))
265+
}
266+
220267
#[inline]
221268
fn write_buf_size_decr(pyself: &Py<Self>, py: Python, val: usize) {
222269
// println!("tcp write_buf_size_decr {:?}", val);

0 commit comments

Comments
 (0)