Skip to content

Commit 8b0a252

Browse files
committed
Lots more TLS tests and graceful shutdown fixes
1 parent ff80cf1 commit 8b0a252

File tree

4 files changed

+262
-149
lines changed

4 files changed

+262
-149
lines changed

trio/ssl.py

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ class SSLStream(_Stream):
220220
def __init__(
221221
self, transport_stream, sslcontext, *, max_bytes=32 * 1024, **kwargs):
222222
self.transport_stream = transport_stream
223+
self._exc = None
223224
self._bufsize = max_bytes
224225
self._outgoing = _stdlib_ssl.MemoryBIO()
225226
self._incoming = _stdlib_ssl.MemoryBIO()
@@ -234,7 +235,8 @@ def __init__(
234235
self._inner_recv_lock = _sync.Lock()
235236

236237
# These are used to make sure that our caller doesn't attempt to make
237-
# multiple concurrent calls to send_all/wait_send_all_might_not_block or to receive_some.
238+
# multiple concurrent calls to send_all/wait_send_all_might_not_block
239+
# or to receive_some.
238240
self._outer_send_lock = _UnLock(
239241
_core.ResourceBusyError,
240242
"another task is currently sending data on this SSLStream")
@@ -263,11 +265,16 @@ def __setattr__(self, name, value):
263265
def __dir__(self):
264266
return super().__dir__() + list(self._forwarded)
265267

268+
def _check_status(self):
269+
if self._exc is not None:
270+
raise self._exc
271+
266272
# This is probably the single trickiest function in trio. It has lots of
267273
# comments, though, just make sure to think carefully if you ever have to
268274
# touch it. The big comment at the top of this file will help explain
269275
# too.
270276
async def _retry(self, fn, *args, ignore_want_read=False):
277+
print("doing", fn)
271278
await _core.yield_if_cancelled()
272279
yielded = False
273280
try:
@@ -288,11 +295,11 @@ async def _retry(self, fn, *args, ignore_want_read=False):
288295
# might come in and mess with it while we're suspended), and
289296
# we don't want to yield *before* starting the operation that
290297
# will help us make progress, because then someone else might
291-
# come in and
298+
# come in and leapfrog us.
292299

293300
# Call the SSLObject method, and get its result.
294301
#
295-
# NB: despite what the docs, say SSLWantWriteError can't
302+
# NB: despite what the docs say, SSLWantWriteError can't
296303
# happen – "Writes to memory BIOs will always succeed if
297304
# memory is available: that is their size can grow
298305
# indefinitely."
@@ -303,14 +310,16 @@ async def _retry(self, fn, *args, ignore_want_read=False):
303310
ret = fn(*args)
304311
except _stdlib_ssl.SSLWantReadError:
305312
want_read = True
306-
except _stdlib_ssl.SSLError as exc:
313+
except (SSLError, CertificateError) as exc:
314+
self._exc = _streams.BrokenStreamError
307315
raise _streams.BrokenStreamError from exc
308316
else:
309317
finished = True
310318
if ignore_want_read:
311319
want_read = False
312320
finished = True
313321
to_send = self._outgoing.read()
322+
print(bool(to_send), want_read)
314323

315324
# Outputs from the above code block are:
316325
#
@@ -373,7 +382,13 @@ async def _retry(self, fn, *args, ignore_want_read=False):
373382
# NOTE: This relies on the lock being strict FIFO fair!
374383
async with self._inner_send_lock:
375384
yielded = True
376-
await self.transport_stream.send_all(to_send)
385+
try:
386+
await self.transport_stream.send_all(to_send)
387+
except:
388+
# Some unknown amount of our data got sent, and we
389+
# don't know how much. This stream is doomed.
390+
self._exc = _streams.BrokenStreamError
391+
raise
377392
elif want_read:
378393
# It's possible that someone else is already blocked in
379394
# transport_stream.receive_some. If so then we want to
@@ -427,22 +442,24 @@ async def do_handshake(self):
427442
immediately without doing anything (except executing a checkpoint).
428443
429444
"""
430-
if self.transport_stream is None:
445+
try:
446+
self._check_status()
447+
except:
431448
await _core.yield_briefly()
432-
raise _streams.ClosedStreamError
449+
raise
433450
await self._handshook.ensure(checkpoint=True)
434451

435452
# Most things work if we don't explicitly force do_handshake to be called
436-
# before calling receive_some or send_all, because openssl will automatically
437-
# perform the handshake on the first SSL_{read,write} call. BUT, allowing
438-
# openssl to do this will disable Python's hostname checking!!! See:
453+
# before calling receive_some or send_all, because openssl will
454+
# automatically perform the handshake on the first SSL_{read,write}
455+
# call. BUT, allowing openssl to do this will disable Python's hostname
456+
# checking!!! See:
439457
# https://bugs.python.org/issue30141
440458
# So we *definitely* have to make sure that do_handshake is called
441459
# before doing anything else.
442460
async def receive_some(self, max_bytes):
443461
async with self._outer_recv_lock:
444-
if self.transport_stream is None:
445-
raise _streams.ClosedStreamError
462+
self._check_status()
446463
await self._handshook.ensure(checkpoint=False)
447464
max_bytes = _operator.index(max_bytes)
448465
if max_bytes < 1:
@@ -451,8 +468,7 @@ async def receive_some(self, max_bytes):
451468

452469
async def send_all(self, data):
453470
async with self._outer_send_lock:
454-
if self.transport_stream is None:
455-
raise _streams.ClosedStreamError
471+
self._check_status()
456472
await self._handshook.ensure(checkpoint=False)
457473
# SSLObject interprets write(b"") as an EOF for some reason, which
458474
# is not what we want.
@@ -471,65 +487,82 @@ async def send_all(self, data):
471487
# maybe it's actually better to error out...?
472488
async def unwrap(self):
473489
async with self._outer_recv_lock, self._outer_send_lock:
474-
if self.transport_stream is None:
475-
raise _streams.ClosedStreamError
490+
self._check_status()
476491
await self._handshook.ensure(checkpoint=False)
477492
await self._retry(self._ssl_object.unwrap)
478493
transport_stream = self.transport_stream
479494
self.transport_stream = None
495+
self._exc = _streams.ClosedStreamError
480496
return (transport_stream, self._incoming.read())
481497

482498
def forceful_close(self):
483-
if self.transport_stream is not None:
499+
if self._exc is not _streams.ClosedStreamError:
484500
self.transport_stream.forceful_close()
485-
self.transport_stream = None
501+
self._exc = _streams.ClosedStreamError
486502

487503
async def graceful_close(self):
488-
transport_stream = self.transport_stream
489-
if transport_stream is None:
504+
if self._exc is _streams.ClosedStreamError:
505+
await _core.yield_briefly()
506+
return
507+
if self._exc is _streams.BrokenStreamError:
508+
self.forceful_close()
490509
await _core.yield_briefly()
491510
return
492511
try:
493-
# If we haven't even started the handshake, then we can (and must)
494-
# skip the SSL-level shutdown.
495-
if self._handshook.started:
496-
# But if the handshake is in progress, wait for it to finish.
497-
await self._handshook.ensure(checkpoint=False)
498-
# Then we want to call SSL_shutdown *once*, to send a
499-
# close_notify but *not* wait for the response (because we're
500-
# closing the socket anyway, so there's no point in waiting).
501-
# Subtlety: SSLObject.unwrap will immediately call it a second
502-
# time, and the second time will raise SSLWantReadError
503-
# because there hasn't been time for the other side to respond
504-
# yet. (Unless they spontaneously sent a close_notify before
505-
# we called this, and it's either already been processed or
506-
# gets pulled out of the buffer by Python's second call.) So
507-
# the way to do what we want is to to ignore SSLWantReadError
508-
# on this call.
509-
try:
510-
await self._retry(
511-
self._ssl_object.unwrap, ignore_want_read=True)
512-
except _streams.BrokenStreamError:
513-
# It's okay if the stream is broken and we can't send our
514-
# goodbye message, because we're cutting off the
515-
# connection anyway...
516-
pass
512+
await self._handshook.ensure(checkpoint=False)
513+
# Here, we call SSL_shutdown *once*, because we want to send a
514+
# close_notify but *not* wait for the other side to send back a
515+
# response. In principle it would be more polite to wait for the
516+
# other side to reply with their own close_notify. However, if
517+
# they aren't paying attention (e.g., if they're just sending
518+
# data and not receiving) then we will never notice our
519+
# close_notify and we'll be waiting forever. Eventually we'll time
520+
# out (hopefully), but it's still kind of nasty. And we can't
521+
# require the other side to always be receiving, because (a)
522+
# backpressure is kind of important, and (b) I bet there are
523+
# broken TLS implementations out there that don't receive all the
524+
# time. (Like e.g. anyone using Python ssl in synchronous mode.)
525+
#
526+
# The send-then-immediately-close behavior is explicitly allowed
527+
# by the TLS specs, so we're ok on that.
528+
#
529+
# Subtlety: SSLObject.unwrap will immediately call it a second
530+
# time, and the second time will raise SSLWantReadError because
531+
# there hasn't been time for the other side to respond
532+
# yet. (Unless they spontaneously sent a close_notify before we
533+
# called this, and it's either already been processed or gets
534+
# pulled out of the buffer by Python's second call.) So the way to
535+
# do what we want is to ignore SSLWantReadError on this call.
536+
#
537+
# Also, because the other side might have already sent
538+
# close_notify and closed their connection then it's possible that
539+
# our attempt to send close_notify will raise
540+
# BrokenStreamError. This is totally legal, and in fact can happen
541+
# with two well-behaved trio programs talking to each other, so we
542+
# don't want to raise an error. So we suppress BrokenStreamError
543+
# here. (This is safe, because literally the only thing this call
544+
# to _retry will do is send the close_notify alert, so that's
545+
# surely where the error comes from.)
546+
try:
547+
await self._retry(
548+
self._ssl_object.unwrap, ignore_want_read=True)
549+
except _streams.BrokenStreamError:
550+
pass
517551
# Close the underlying stream
518-
await transport_stream.graceful_close()
552+
await self.transport_stream.graceful_close()
519553
except:
520-
transport_stream.forceful_close()
554+
self.transport_stream.forceful_close()
521555
raise
522556
finally:
523-
self.transport_stream = None
557+
self._exc = _streams.ClosedStreamError
524558

525559
async def wait_send_all_might_not_block(self):
526560
# This method's implementation is deceptively simple.
527561
#
528562
# First, we take the outer send lock, because of trio's standard
529563
# semantics that wait_send_all_might_not_block and send_all conflict.
530564
async with self._outer_send_lock:
531-
if self.transport_stream is None:
532-
raise _streams.ClosedStreamError
565+
self._check_status()
533566
# Then we take the inner send lock. We know that no other tasks
534567
# are calling self.send_all or self.wait_send_all_might_not_block,
535568
# because we have the outer_send_lock. But! There might be another

trio/testing.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -506,13 +506,16 @@ async def simple_check_wait_send_all_might_not_block(scope):
506506
nursery.cancel_scope)
507507
nursery.spawn(do_receive_some, 1)
508508

509-
# closing the r side
510-
await do_graceful_close(r)
509+
# closing the r side leads to BrokenStreamError on the s side
510+
# (eventually)
511+
async def expect_broken_stream_on_send():
512+
with _assert_raises(_streams.BrokenStreamError):
513+
while True:
514+
await do_send_all(b"x" * 100)
511515

512-
# ...leads to BrokenStreamError on the s side (eventually)
513-
with _assert_raises(_streams.BrokenStreamError):
514-
while True:
515-
await do_send_all(b"x" * 100)
516+
async with _core.open_nursery() as nursery:
517+
nursery.spawn(expect_broken_stream_on_send)
518+
nursery.spawn(do_graceful_close, r)
516519

517520
# once detected, the stream stays broken
518521
with _assert_raises(_streams.BrokenStreamError):
@@ -555,6 +558,7 @@ async def receive_send_then_close():
555558
await wait_all_tasks_blocked()
556559
await checked_receive_1(b"y")
557560
await checked_receive_1(b"")
561+
await do_graceful_close(r)
558562

559563
async with _core.open_nursery() as nursery:
560564
nursery.spawn(send_then_close)
@@ -621,8 +625,9 @@ async def expect_cancelled(afn, *args):
621625
nursery.spawn(expect_cancelled, do_send_all, b"x")
622626
nursery.spawn(expect_cancelled, do_receive_some, 1)
623627

624-
await do_graceful_close(s)
625-
await do_graceful_close(r)
628+
async with _core.open_nursery() as nursery:
629+
nursery.spawn(do_graceful_close, s)
630+
nursery.spawn(do_graceful_close, r)
626631

627632
# check wait_send_all_might_not_block, if we can
628633
if clogged_stream_maker is not None:
@@ -739,9 +744,13 @@ async def receiver(s, data, seed):
739744
nursery.spawn(receiver, s1, test_data[::-1], 2)
740745
nursery.spawn(receiver, s2, test_data, 3)
741746

742-
await s1.graceful_close()
743-
assert await s2.receive_some(10) == b""
744-
await s2.graceful_close()
747+
async def expect_receive_some_empty():
748+
assert await s2.receive_some(10) == b""
749+
await s2.graceful_close()
750+
751+
async with _core.open_nursery() as nursery:
752+
nursery.spawn(expect_receive_some_empty)
753+
nursery.spawn(s1.graceful_close)
745754

746755

747756
async def check_half_closeable_stream(stream_maker, clogged_stream_maker):
@@ -968,19 +977,23 @@ class MemoryReceiveStream(_abc.ReceiveStream):
968977
Args:
969978
receive_some_hook: An async function, or None. Called from
970979
:meth:`receive_some`. Can do whatever you like.
980+
close_hook: A synchronous function, or None. Called from
981+
:meth:`forceful_close`. Can do whatever you like.
971982
972983
.. attribute:: receive_some_hook
984+
close_hook
973985
974-
The :attr:`receive_some_hook` is also exposed as an attribute on the
975-
object, and you can change it at any time.
986+
Both hooks are also exposed as attributes on the object, and you can
987+
change them at any time.
976988
977989
"""
978-
def __init__(self, receive_some_hook=None):
990+
def __init__(self, receive_some_hook=None, close_hook=None):
979991
self._lock = _util.UnLock(
980992
_core.ResourceBusyError, "another task is using this stream")
981993
self._incoming = _UnboundedByteQueue()
982994
self._closed = False
983995
self.receive_some_hook = receive_some_hook
996+
self.close_hook = close_hook
984997

985998
async def receive_some(self, max_bytes):
986999
"""Calls the :attr:`receive_some_hook` (if any), and then retrieves
@@ -1012,7 +1025,8 @@ def forceful_close(self):
10121025
except _core.WouldBlock:
10131026
pass
10141027
self._incoming.close()
1015-
self._closed = True
1028+
if self.close_hook is not None:
1029+
self.close_hook()
10161030

10171031
def put_data(self, data):
10181032
"""Appends the given data to the internal buffer.

0 commit comments

Comments
 (0)