Skip to content

Commit 54fed80

Browse files
Fix race condition when cancelling pending HTTP connection attempts (#110764)
Co-authored-by: Miha Zupan <[email protected]>
1 parent 08b542e commit 54fed80

File tree

5 files changed

+126
-39
lines changed

5 files changed

+126
-39
lines changed

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionPool/HttpConnectionPool.Http1.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,7 @@ private async Task InjectNewHttp11ConnectionAsync(RequestQueue<HttpConnection>.Q
259259
HttpConnection? connection = null;
260260
Exception? connectionException = null;
261261

262-
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource();
263-
waiter.ConnectionCancellationTokenSource = cts;
262+
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter);
264263
try
265264
{
266265
connection = await CreateHttp11ConnectionAsync(queueItem.Request, true, cts.Token).ConfigureAwait(false);

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionPool/HttpConnectionPool.Http2.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ private async Task InjectNewHttp2ConnectionAsync(RequestQueue<Http2Connection?>.
181181
Exception? connectionException = null;
182182
HttpConnectionWaiter<Http2Connection?> waiter = queueItem.Waiter;
183183

184-
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource();
185-
waiter.ConnectionCancellationTokenSource = cts;
184+
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter);
186185
try
187186
{
188187
(Stream stream, TransportContext? transportContext, Activity? activity, IPEndPoint? remoteEndPoint) = await ConnectAsync(queueItem.Request, true, cts.Token).ConfigureAwait(false);

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionPool/HttpConnectionPool.Http3.cs

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,52 +75,60 @@ internal sealed partial class HttpConnectionPool
7575
// Loop in case we get a 421 and need to send the request to a different authority.
7676
while (true)
7777
{
78-
if (!TryGetHttp3Authority(request, out HttpAuthority? authority, out Exception? reasonException))
78+
HttpConnectionWaiter<Http3Connection?>? http3ConnectionWaiter = null;
79+
try
7980
{
80-
if (reasonException is null)
81+
if (!TryGetHttp3Authority(request, out HttpAuthority? authority, out Exception? reasonException))
8182
{
82-
return null;
83+
if (reasonException is null)
84+
{
85+
return null;
86+
}
87+
ThrowGetVersionException(request, 3, reasonException);
8388
}
84-
ThrowGetVersionException(request, 3, reasonException);
85-
}
8689

87-
long queueStartingTimestamp = HttpTelemetry.Log.IsEnabled() || Settings._metrics!.RequestsQueueDuration.Enabled ? Stopwatch.GetTimestamp() : 0;
88-
Activity? waitForConnectionActivity = ConnectionSetupDistributedTracing.StartWaitForConnectionActivity(authority);
90+
long queueStartingTimestamp = HttpTelemetry.Log.IsEnabled() || Settings._metrics!.RequestsQueueDuration.Enabled ? Stopwatch.GetTimestamp() : 0;
91+
Activity? waitForConnectionActivity = ConnectionSetupDistributedTracing.StartWaitForConnectionActivity(authority);
8992

90-
if (!TryGetPooledHttp3Connection(request, out Http3Connection? connection, out HttpConnectionWaiter<Http3Connection?>? http3ConnectionWaiter))
91-
{
92-
try
93+
if (!TryGetPooledHttp3Connection(request, out Http3Connection? connection, out http3ConnectionWaiter))
9394
{
94-
connection = await http3ConnectionWaiter.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
95+
try
96+
{
97+
connection = await http3ConnectionWaiter.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
98+
}
99+
catch (Exception ex)
100+
{
101+
ConnectionSetupDistributedTracing.ReportError(waitForConnectionActivity, ex);
102+
waitForConnectionActivity?.Stop();
103+
throw;
104+
}
95105
}
96-
catch (Exception ex)
106+
107+
// Request cannot be sent over H/3 connection, try downgrade or report failure.
108+
// Note that if there's an H/3 suitable origin authority but is unavailable or blocked via Alt-Svc, exception is thrown instead.
109+
if (connection is null)
97110
{
98-
ConnectionSetupDistributedTracing.ReportError(waitForConnectionActivity, ex);
99-
waitForConnectionActivity?.Stop();
100-
throw;
111+
return null;
101112
}
102-
}
103113

104-
// Request cannot be sent over H/3 connection, try downgrade or report failure.
105-
// Note that if there's an H/3 suitable origin authority but is unavailable or blocked via Alt-Svc, exception is thrown instead.
106-
if (connection is null)
107-
{
108-
return null;
109-
}
114+
HttpResponseMessage response = await connection.SendAsync(request, queueStartingTimestamp, waitForConnectionActivity, cancellationToken).ConfigureAwait(false);
110115

111-
HttpResponseMessage response = await connection.SendAsync(request, queueStartingTimestamp, waitForConnectionActivity, cancellationToken).ConfigureAwait(false);
116+
// If an Alt-Svc authority returns 421, it means it can't actually handle the request.
117+
// An authority is supposed to be able to handle ALL requests to the origin, so this is a server bug.
118+
// In this case, we blocklist the authority and retry the request at the origin.
119+
if (response.StatusCode == HttpStatusCode.MisdirectedRequest && connection.Authority != _originAuthority)
120+
{
121+
response.Dispose();
122+
BlocklistAuthority(connection.Authority);
123+
continue;
124+
}
112125

113-
// If an Alt-Svc authority returns 421, it means it can't actually handle the request.
114-
// An authority is supposed to be able to handle ALL requests to the origin, so this is a server bug.
115-
// In this case, we blocklist the authority and retry the request at the origin.
116-
if (response.StatusCode == HttpStatusCode.MisdirectedRequest && connection.Authority != _originAuthority)
126+
return response;
127+
}
128+
finally
117129
{
118-
response.Dispose();
119-
BlocklistAuthority(connection.Authority);
120-
continue;
130+
http3ConnectionWaiter?.CancelIfNecessary(this, cancellationToken.IsCancellationRequested);
121131
}
122-
123-
return response;
124132
}
125133
}
126134

@@ -253,8 +261,7 @@ private async Task InjectNewHttp3ConnectionAsync(RequestQueue<Http3Connection?>.
253261
HttpAuthority? authority = null;
254262
HttpConnectionWaiter<Http3Connection?> waiter = queueItem.Waiter;
255263

256-
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource();
257-
waiter.ConnectionCancellationTokenSource = cts;
264+
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter);
258265
Activity? connectionSetupActivity = null;
259266
try
260267
{

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionPool/HttpConnectionPool.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,27 @@ private async ValueTask<Stream> EstablishSocksTunnel(HttpRequestMessage request,
827827
return stream;
828828
}
829829

830-
private CancellationTokenSource GetConnectTimeoutCancellationTokenSource() => new CancellationTokenSource(Settings._connectTimeout);
830+
private CancellationTokenSource GetConnectTimeoutCancellationTokenSource<T>(HttpConnectionWaiter<T> waiter)
831+
where T : HttpConnectionBase?
832+
{
833+
var cts = new CancellationTokenSource(Settings._connectTimeout);
834+
835+
lock (waiter)
836+
{
837+
waiter.ConnectionCancellationTokenSource = cts;
838+
839+
// The initiating request for this connection attempt may complete concurrently at any time.
840+
// If it completed before we've set the CTS, CancelIfNecessary would no-op.
841+
// Check it again now that we're holding the lock and ensure we always set a timeout.
842+
if (waiter.Task.IsCompleted)
843+
{
844+
waiter.CancelIfNecessary(this, requestCancelled: waiter.Task.IsCanceled);
845+
waiter.ConnectionCancellationTokenSource = null;
846+
}
847+
}
848+
849+
return cts;
850+
}
831851

832852
private static Exception CreateConnectTimeoutException(OperationCanceledException oce)
833853
{

src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,68 @@ await RemoteExecutor.Invoke(static async (versionString, timoutStr) =>
393393
}, UseVersion.ToString(), timeout.ToString()).DisposeAsync();
394394
}
395395

396+
[OuterLoop("We wait for PendingConnectionTimeout which defaults to 5 seconds.")]
397+
[Fact]
398+
public async Task PendingConnectionTimeout_SignalsAllConnectionAttempts()
399+
{
400+
if (UseVersion == HttpVersion.Version30)
401+
{
402+
// HTTP3 does not support ConnectCallback
403+
return;
404+
}
405+
406+
int pendingConnectionAttempts = 0;
407+
bool connectionAttemptTimedOut = false;
408+
409+
using var handler = new SocketsHttpHandler
410+
{
411+
ConnectCallback = async (context, cancellation) =>
412+
{
413+
Interlocked.Increment(ref pendingConnectionAttempts);
414+
try
415+
{
416+
await Assert.ThrowsAsync<TaskCanceledException>(() => Task.Delay(-1, cancellation)).WaitAsync(TestHelper.PassingTestTimeout);
417+
cancellation.ThrowIfCancellationRequested();
418+
throw new UnreachableException();
419+
}
420+
catch (TimeoutException)
421+
{
422+
connectionAttemptTimedOut = true;
423+
throw;
424+
}
425+
finally
426+
{
427+
Interlocked.Decrement(ref pendingConnectionAttempts);
428+
}
429+
}
430+
};
431+
432+
using HttpClient client = CreateHttpClient(handler);
433+
client.Timeout = TimeSpan.FromSeconds(2);
434+
435+
// Many of these requests should trigger new connection attempts, and all of those should eventually be cleaned up.
436+
await Parallel.ForAsync(0, 100, async (_, _) =>
437+
{
438+
await Assert.ThrowsAnyAsync<TaskCanceledException>(() => client.GetAsync("https://dummy"));
439+
});
440+
441+
Stopwatch stopwatch = Stopwatch.StartNew();
442+
443+
while (Volatile.Read(ref pendingConnectionAttempts) > 0)
444+
{
445+
Assert.False(connectionAttemptTimedOut);
446+
447+
if (stopwatch.Elapsed > 2 * TestHelper.PassingTestTimeout)
448+
{
449+
Assert.Fail("Connection attempts took too long to get cleaned up");
450+
}
451+
452+
await Task.Delay(100);
453+
}
454+
455+
Assert.False(connectionAttemptTimedOut);
456+
}
457+
396458
private sealed class SetTcsContent : StreamContent
397459
{
398460
private readonly TaskCompletionSource<bool> _tcs;

0 commit comments

Comments
 (0)