22#![ cfg( all( feature = "full" , not( target_os = "wasi" ) , not( miri) ) ) ] // Wasi doesn't support bind
33 // No `socket` on miri.
44
5+ use std:: time:: Duration ;
56use tokio:: io:: { self , AsyncReadExt , AsyncWriteExt } ;
67use tokio:: net:: { TcpListener , TcpStream } ;
8+ use tokio:: sync:: oneshot:: channel;
79use tokio_test:: assert_ok;
810
911#[ tokio:: test]
1012async fn shutdown ( ) {
1113 let srv = assert_ok ! ( TcpListener :: bind( "127.0.0.1:0" ) . await ) ;
1214 let addr = assert_ok ! ( srv. local_addr( ) ) ;
1315
14- tokio:: spawn ( async move {
16+ let handle = tokio:: spawn ( async move {
1517 let mut stream = assert_ok ! ( TcpStream :: connect( & addr) . await ) ;
1618
1719 assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
@@ -26,4 +28,55 @@ async fn shutdown() {
2628
2729 let n = assert_ok ! ( io:: copy( & mut rd, & mut wr) . await ) ;
2830 assert_eq ! ( n, 0 ) ;
31+ assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
32+ handle. await . unwrap ( )
33+ }
34+
35+ #[ tokio:: test]
36+ async fn shutdown_after_tcp_reset ( ) {
37+ let srv = assert_ok ! ( TcpListener :: bind( "127.0.0.1:0" ) . await ) ;
38+ let addr = assert_ok ! ( srv. local_addr( ) ) ;
39+
40+ let ( connected_tx, connected_rx) = channel ( ) ;
41+ let ( dropped_tx, dropped_rx) = channel ( ) ;
42+
43+ let handle = tokio:: spawn ( async move {
44+ let mut stream = assert_ok ! ( TcpStream :: connect( & addr) . await ) ;
45+ connected_tx. send ( ( ) ) . unwrap ( ) ;
46+
47+ dropped_rx. await . unwrap ( ) ;
48+ assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
49+ } ) ;
50+
51+ let ( stream, _) = assert_ok ! ( srv. accept( ) . await ) ;
52+ // By setting linger to 0 we will trigger a TCP reset
53+ stream. set_linger ( Some ( Duration :: new ( 0 , 0 ) ) ) . unwrap ( ) ;
54+ connected_rx. await . unwrap ( ) ;
55+
56+ drop ( stream) ;
57+ dropped_tx. send ( ( ) ) . unwrap ( ) ;
58+
59+ handle. await . unwrap ( ) ;
60+ }
61+
62+ #[ tokio:: test]
63+ async fn shutdown_multiple_calls ( ) {
64+ let srv = assert_ok ! ( TcpListener :: bind( "127.0.0.1:0" ) . await ) ;
65+ let addr = assert_ok ! ( srv. local_addr( ) ) ;
66+
67+ let ( connected_tx, connected_rx) = channel ( ) ;
68+
69+ let handle = tokio:: spawn ( async move {
70+ let mut stream = assert_ok ! ( TcpStream :: connect( & addr) . await ) ;
71+ connected_tx. send ( ( ) ) . unwrap ( ) ;
72+ assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
73+ assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
74+ assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
75+ } ) ;
76+
77+ let ( mut stream, _) = assert_ok ! ( srv. accept( ) . await ) ;
78+ connected_rx. await . unwrap ( ) ;
79+
80+ assert_ok ! ( AsyncWriteExt :: shutdown( & mut stream) . await ) ;
81+ handle. await . unwrap ( ) ;
2982}
0 commit comments