@@ -220,6 +220,7 @@ class SSLStream(_Stream):
220
220
def __init__ (
221
221
self , transport_stream , sslcontext , * , max_bytes = 32 * 1024 , ** kwargs ):
222
222
self .transport_stream = transport_stream
223
+ self ._exc = None
223
224
self ._bufsize = max_bytes
224
225
self ._outgoing = _stdlib_ssl .MemoryBIO ()
225
226
self ._incoming = _stdlib_ssl .MemoryBIO ()
@@ -234,7 +235,8 @@ def __init__(
234
235
self ._inner_recv_lock = _sync .Lock ()
235
236
236
237
# These are used to make sure that our caller doesn't attempt to make
237
- # multiple concurrent calls to send_all/wait_send_all_might_not_block or to receive_some.
238
+ # multiple concurrent calls to send_all/wait_send_all_might_not_block
239
+ # or to receive_some.
238
240
self ._outer_send_lock = _UnLock (
239
241
_core .ResourceBusyError ,
240
242
"another task is currently sending data on this SSLStream" )
@@ -263,11 +265,16 @@ def __setattr__(self, name, value):
263
265
def __dir__ (self ):
264
266
return super ().__dir__ () + list (self ._forwarded )
265
267
268
+ def _check_status (self ):
269
+ if self ._exc is not None :
270
+ raise self ._exc
271
+
266
272
# This is probably the single trickiest function in trio. It has lots of
267
273
# comments, though, just make sure to think carefully if you ever have to
268
274
# touch it. The big comment at the top of this file will help explain
269
275
# too.
270
276
async def _retry (self , fn , * args , ignore_want_read = False ):
277
+ print ("doing" , fn )
271
278
await _core .yield_if_cancelled ()
272
279
yielded = False
273
280
try :
@@ -288,11 +295,11 @@ async def _retry(self, fn, *args, ignore_want_read=False):
288
295
# might come in and mess with it while we're suspended), and
289
296
# we don't want to yield *before* starting the operation that
290
297
# will help us make progress, because then someone else might
291
- # come in and
298
+ # come in and leapfrog us.
292
299
293
300
# Call the SSLObject method, and get its result.
294
301
#
295
- # NB: despite what the docs, say SSLWantWriteError can't
302
+ # NB: despite what the docs say, SSLWantWriteError can't
296
303
# happen – "Writes to memory BIOs will always succeed if
297
304
# memory is available: that is their size can grow
298
305
# indefinitely."
@@ -303,14 +310,16 @@ async def _retry(self, fn, *args, ignore_want_read=False):
303
310
ret = fn (* args )
304
311
except _stdlib_ssl .SSLWantReadError :
305
312
want_read = True
306
- except _stdlib_ssl .SSLError as exc :
313
+ except (SSLError , CertificateError ) as exc :
314
+ self ._exc = _streams .BrokenStreamError
307
315
raise _streams .BrokenStreamError from exc
308
316
else :
309
317
finished = True
310
318
if ignore_want_read :
311
319
want_read = False
312
320
finished = True
313
321
to_send = self ._outgoing .read ()
322
+ print (bool (to_send ), want_read )
314
323
315
324
# Outputs from the above code block are:
316
325
#
@@ -373,7 +382,13 @@ async def _retry(self, fn, *args, ignore_want_read=False):
373
382
# NOTE: This relies on the lock being strict FIFO fair!
374
383
async with self ._inner_send_lock :
375
384
yielded = True
376
- await self .transport_stream .send_all (to_send )
385
+ try :
386
+ await self .transport_stream .send_all (to_send )
387
+ except :
388
+ # Some unknown amount of our data got sent, and we
389
+ # don't know how much. This stream is doomed.
390
+ self ._exc = _streams .BrokenStreamError
391
+ raise
377
392
elif want_read :
378
393
# It's possible that someone else is already blocked in
379
394
# transport_stream.receive_some. If so then we want to
@@ -427,22 +442,24 @@ async def do_handshake(self):
427
442
immediately without doing anything (except executing a checkpoint).
428
443
429
444
"""
430
- if self .transport_stream is None :
445
+ try :
446
+ self ._check_status ()
447
+ except :
431
448
await _core .yield_briefly ()
432
- raise _streams . ClosedStreamError
449
+ raise
433
450
await self ._handshook .ensure (checkpoint = True )
434
451
435
452
# Most things work if we don't explicitly force do_handshake to be called
436
- # before calling receive_some or send_all, because openssl will automatically
437
- # perform the handshake on the first SSL_{read,write} call. BUT, allowing
438
- # openssl to do this will disable Python's hostname checking!!! See:
453
+ # before calling receive_some or send_all, because openssl will
454
+ # automatically perform the handshake on the first SSL_{read,write}
455
+ # call. BUT, allowing openssl to do this will disable Python's hostname
456
+ # checking!!! See:
439
457
# https://bugs.python.org/issue30141
440
458
# So we *definitely* have to make sure that do_handshake is called
441
459
# before doing anything else.
442
460
async def receive_some (self , max_bytes ):
443
461
async with self ._outer_recv_lock :
444
- if self .transport_stream is None :
445
- raise _streams .ClosedStreamError
462
+ self ._check_status ()
446
463
await self ._handshook .ensure (checkpoint = False )
447
464
max_bytes = _operator .index (max_bytes )
448
465
if max_bytes < 1 :
@@ -451,8 +468,7 @@ async def receive_some(self, max_bytes):
451
468
452
469
async def send_all (self , data ):
453
470
async with self ._outer_send_lock :
454
- if self .transport_stream is None :
455
- raise _streams .ClosedStreamError
471
+ self ._check_status ()
456
472
await self ._handshook .ensure (checkpoint = False )
457
473
# SSLObject interprets write(b"") as an EOF for some reason, which
458
474
# is not what we want.
@@ -471,65 +487,82 @@ async def send_all(self, data):
471
487
# maybe it's actually better to error out...?
472
488
async def unwrap (self ):
473
489
async with self ._outer_recv_lock , self ._outer_send_lock :
474
- if self .transport_stream is None :
475
- raise _streams .ClosedStreamError
490
+ self ._check_status ()
476
491
await self ._handshook .ensure (checkpoint = False )
477
492
await self ._retry (self ._ssl_object .unwrap )
478
493
transport_stream = self .transport_stream
479
494
self .transport_stream = None
495
+ self ._exc = _streams .ClosedStreamError
480
496
return (transport_stream , self ._incoming .read ())
481
497
482
498
def forceful_close (self ):
483
- if self .transport_stream is not None :
499
+ if self ._exc is not _streams . ClosedStreamError :
484
500
self .transport_stream .forceful_close ()
485
- self .transport_stream = None
501
+ self ._exc = _streams . ClosedStreamError
486
502
487
503
async def graceful_close (self ):
488
- transport_stream = self .transport_stream
489
- if transport_stream is None :
504
+ if self ._exc is _streams .ClosedStreamError :
505
+ await _core .yield_briefly ()
506
+ return
507
+ if self ._exc is _streams .BrokenStreamError :
508
+ self .forceful_close ()
490
509
await _core .yield_briefly ()
491
510
return
492
511
try :
493
- # If we haven't even started the handshake, then we can (and must)
494
- # skip the SSL-level shutdown.
495
- if self ._handshook .started :
496
- # But if the handshake is in progress, wait for it to finish.
497
- await self ._handshook .ensure (checkpoint = False )
498
- # Then we want to call SSL_shutdown *once*, to send a
499
- # close_notify but *not* wait for the response (because we're
500
- # closing the socket anyway, so there's no point in waiting).
501
- # Subtlety: SSLObject.unwrap will immediately call it a second
502
- # time, and the second time will raise SSLWantReadError
503
- # because there hasn't been time for the other side to respond
504
- # yet. (Unless they spontaneously sent a close_notify before
505
- # we called this, and it's either already been processed or
506
- # gets pulled out of the buffer by Python's second call.) So
507
- # the way to do what we want is to to ignore SSLWantReadError
508
- # on this call.
509
- try :
510
- await self ._retry (
511
- self ._ssl_object .unwrap , ignore_want_read = True )
512
- except _streams .BrokenStreamError :
513
- # It's okay if the stream is broken and we can't send our
514
- # goodbye message, because we're cutting off the
515
- # connection anyway...
516
- pass
512
+ await self ._handshook .ensure (checkpoint = False )
513
+ # Here, we call SSL_shutdown *once*, because we want to send a
514
+ # close_notify but *not* wait for the other side to send back a
515
+ # response. In principle it would be more polite to wait for the
516
+ # other side to reply with their own close_notify. However, if
517
+ # they aren't paying attention (e.g., if they're just sending
518
+ # data and not receiving) then we will never notice our
519
+ # close_notify and we'll be waiting forever. Eventually we'll time
520
+ # out (hopefully), but it's still kind of nasty. And we can't
521
+ # require the other side to always be receiving, because (a)
522
+ # backpressure is kind of important, and (b) I bet there are
523
+ # broken TLS implementations out there that don't receive all the
524
+ # time. (Like e.g. anyone using Python ssl in synchronous mode.)
525
+ #
526
+ # The send-then-immediately-close behavior is explicitly allowed
527
+ # by the TLS specs, so we're ok on that.
528
+ #
529
+ # Subtlety: SSLObject.unwrap will immediately call it a second
530
+ # time, and the second time will raise SSLWantReadError because
531
+ # there hasn't been time for the other side to respond
532
+ # yet. (Unless they spontaneously sent a close_notify before we
533
+ # called this, and it's either already been processed or gets
534
+ # pulled out of the buffer by Python's second call.) So the way to
535
+ # do what we want is to ignore SSLWantReadError on this call.
536
+ #
537
+ # Also, because the other side might have already sent
538
+ # close_notify and closed their connection then it's possible that
539
+ # our attempt to send close_notify will raise
540
+ # BrokenStreamError. This is totally legal, and in fact can happen
541
+ # with two well-behaved trio programs talking to each other, so we
542
+ # don't want to raise an error. So we suppress BrokenStreamError
543
+ # here. (This is safe, because literally the only thing this call
544
+ # to _retry will do is send the close_notify alert, so that's
545
+ # surely where the error comes from.)
546
+ try :
547
+ await self ._retry (
548
+ self ._ssl_object .unwrap , ignore_want_read = True )
549
+ except _streams .BrokenStreamError :
550
+ pass
517
551
# Close the underlying stream
518
- await transport_stream .graceful_close ()
552
+ await self . transport_stream .graceful_close ()
519
553
except :
520
- transport_stream .forceful_close ()
554
+ self . transport_stream .forceful_close ()
521
555
raise
522
556
finally :
523
- self .transport_stream = None
557
+ self ._exc = _streams . ClosedStreamError
524
558
525
559
async def wait_send_all_might_not_block (self ):
526
560
# This method's implementation is deceptively simple.
527
561
#
528
562
# First, we take the outer send lock, because of trio's standard
529
563
# semantics that wait_send_all_might_not_block and send_all conflict.
530
564
async with self ._outer_send_lock :
531
- if self .transport_stream is None :
532
- raise _streams .ClosedStreamError
565
+ self ._check_status ()
533
566
# Then we take the inner send lock. We know that no other tasks
534
567
# are calling self.send_all or self.wait_send_all_might_not_block,
535
568
# because we have the outer_send_lock. But! There might be another
0 commit comments