Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal interface ISubchannelTransport : IDisposable
TransportStatus TransportStatus { get; }

ValueTask<Stream> GetStreamAsync(DnsEndPoint endPoint, CancellationToken cancellationToken);
ValueTask<ConnectResult> TryConnectAsync(ConnectContext context);
ValueTask<ConnectResult> TryConnectAsync(ConnectContext context, int attempt);

void Disconnect();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void Disconnect()
_subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected.");
}

public ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
public ValueTask<ConnectResult> TryConnectAsync(ConnectContext context, int attempt)
{
Debug.Assert(_subchannel._addresses.Count == 1);
Debug.Assert(CurrentEndPoint == null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private void DisconnectUnsynchronized()
_currentEndPoint = null;
}

public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context, int attempt)
{
Debug.Assert(CurrentEndPoint == null);

Expand Down
36 changes: 32 additions & 4 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ public void Dispose()
/// <param name="addresses"></param>
public void UpdateAddresses(IReadOnlyList<BalancerAddress> addresses)
{
SubchannelLog.AddressesUpdated(_logger, Id, addresses);

var requireReconnect = false;
lock (Lock)
{
Expand Down Expand Up @@ -278,6 +280,8 @@ private void CancelInProgressConnect()
_connectContext.CancelConnect();
_connectContext.Dispose();
}

_delayInterruptTcs?.TrySetResult(null);
}
}

Expand Down Expand Up @@ -313,7 +317,7 @@ private async Task ConnectTransportAsync()
}
}

switch (await _transport.TryConnectAsync(connectContext).ConfigureAwait(false))
switch (await _transport.TryConnectAsync(connectContext, attempt).ConfigureAwait(false))
{
case ConnectResult.Success:
return;
Expand Down Expand Up @@ -345,17 +349,21 @@ private async Task ConnectTransportAsync()
{
// Task.Delay won. Check CTS to see if it won because of cancellation.
delayCts.Token.ThrowIfCancellationRequested();
SubchannelLog.ConnectBackoffComplete(_logger, Id);
}
else
{
SubchannelLog.ConnectBackoffInterrupted(_logger, Id);

// Delay interrupt was triggered. Reset back-off.
backoffPolicy = _manager.BackoffPolicyFactory.Create();

// Cancel the Task.Delay that's no longer needed.
// https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/519ef7d231c01116f02bc04354816a735f2a36b6/AsyncGuidance.md#using-a-timeout
delayCts.Cancel();

// Check to connect context token to see if the delay was interrupted because of a connect cancellation.
connectContext.CancellationToken.ThrowIfCancellationRequested();

// Delay interrupt was triggered. Reset back-off.
backoffPolicy = _manager.BackoffPolicyFactory.Create();
}
}
}
Expand Down Expand Up @@ -532,6 +540,12 @@ internal static class SubchannelLog
private static readonly Action<ILogger, string, Exception?> _cancelingConnect =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(17, "CancelingConnect"), "Subchannel id '{SubchannelId}' canceling connect.");

private static readonly Action<ILogger, string, Exception?> _connectBackoffComplete =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(18, "ConnectBackoffComplete"), "Subchannel id '{SubchannelId}' connect backoff complete.");

private static readonly Action<ILogger, string, string, Exception?> _addressesUpdated =
LoggerMessage.Define<string, string>(LogLevel.Trace, new EventId(19, "AddressesUpdated"), "Subchannel id '{SubchannelId}' updated with addresses: {Addresses}");

public static void SubchannelCreated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -620,5 +634,19 @@ public static void CancelingConnect(ILogger logger, string subchannelId)
{
_cancelingConnect(logger, subchannelId, null);
}

public static void ConnectBackoffComplete(ILogger logger, string subchannelId)
{
_connectBackoffComplete(logger, subchannelId, null);
}

public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Trace))
{
var addressesText = string.Join(", ", addresses.Select(a => a.EndPoint.Host + ":" + a.EndPoint.Port));
_addressesUpdated(logger, subchannelId, addressesText, null);
}
}
}
#endif
6 changes: 3 additions & 3 deletions test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public async Task PickAsync_ErrorConnectingToSubchannel_ThrowsError()
new BalancerAddress("localhost", 80)
});

var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
return Task.FromException<TryConnectResult>(new Exception("Test error!"));
});
Expand Down Expand Up @@ -357,7 +357,7 @@ public async Task UpdateAddresses_ConnectIsInProgress_InProgressConnectIsCancele

var syncPoint = new SyncPoint(runContinuationsAsynchronously: true);

var transportFactory = new TestSubchannelTransportFactory(async (s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create(async (s, c) =>
{
c.Register(state => ((SyncPoint)state!).Continue(), syncPoint);

Expand Down Expand Up @@ -548,7 +548,7 @@ public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect(

var callbackAsyncLocalValues = new List<object>();

var transportFactory = new TestSubchannelTransportFactory((subchannel, cancellationToken) =>
var transportFactory = TestSubchannelTransportFactory.Create((subchannel, cancellationToken) =>
{
callbackAsyncLocalValues.Add(asyncLocal.Value);
if (callbackAsyncLocalValues.Count >= 2)
Expand Down
20 changes: 10 additions & 10 deletions test/Grpc.Net.Client.Tests/Balancer/PickFirstBalancerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public async Task ChangeAddresses_HasReadySubchannel_OldSubchannelShutdown()
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));

var subChannelConnections = new List<Subchannel>();
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
lock (subChannelConnections)
{
Expand Down Expand Up @@ -176,7 +176,7 @@ public async Task ResolverError_HasFailedSubchannel_SubchannelShutdown()
new BalancerAddress("localhost", 80)
});

var transportFactory = new TestSubchannelTransportFactory((s, c) => Task.FromResult(new TryConnectResult(ConnectivityState.TransientFailure)));
var transportFactory = TestSubchannelTransportFactory.Create((s, c) => Task.FromResult(new TryConnectResult(ConnectivityState.TransientFailure)));
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<ISubchannelTransportFactory>(transportFactory);
var serviceProvider = services.BuildServiceProvider();
Expand Down Expand Up @@ -234,7 +234,7 @@ public async Task RequestConnection_InitialConnectionFails_ExponentialBackoff()
var connectivityState = ConnectivityState.TransientFailure;

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(async (s, c) =>
services.AddSingleton<ISubchannelTransportFactory>(TestSubchannelTransportFactory.Create(async (s, c) =>
{
await syncPoint.WaitToContinue();
return new TryConnectResult(connectivityState);
Expand Down Expand Up @@ -290,7 +290,7 @@ public async Task RequestConnection_InitialConnectionEnds_EntersIdleState()
});

var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;
return Task.FromResult(new TryConnectResult(ConnectivityState.Ready));
Expand Down Expand Up @@ -340,7 +340,7 @@ public async Task RequestConnection_IdleConnectionConnectAsync_StateToReady()
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;
return Task.FromResult(new TryConnectResult(ConnectivityState.Ready));
Expand Down Expand Up @@ -385,7 +385,7 @@ public async Task RequestConnection_IdleConnectionPick_StateToReady()
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;
return Task.FromResult(new TryConnectResult(ConnectivityState.Ready));
Expand Down Expand Up @@ -448,7 +448,7 @@ public async Task UnaryCall_TransportConnecting_OnePickStarted()

var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;

Expand Down Expand Up @@ -510,7 +510,7 @@ public async Task UnaryCall_TransportConnecting_ErrorAfterTransientFailure()

var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;

Expand Down Expand Up @@ -576,7 +576,7 @@ public async Task DeadlineExceeded_MultipleCalls_CallsWaitForDeadline()
var resolver = new TestResolver();
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
return Task.FromResult(new TryConnectResult(ConnectivityState.Connecting));
});
Expand Down Expand Up @@ -640,7 +640,7 @@ public async Task ConnectTimeout_MultipleCalls_AttemptReconnect()
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var tryConnectTcs = new TaskCompletionSource<TryConnectResult>(TaskCreationOptions.RunContinuationsAsynchronously);
var transportFactory = new TestSubchannelTransportFactory((s, ct) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, ct) =>
{
return tryConnectTcs.Task;
});
Expand Down
132 changes: 131 additions & 1 deletion test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#if SUPPORT_LOAD_BALANCING
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Threading;
Expand Down Expand Up @@ -298,7 +299,7 @@ public async Task Resolver_ServiceConfigInResult()
var currentConnectivityState = ConnectivityState.Ready;

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(async (s, c) =>
services.AddSingleton<ISubchannelTransportFactory>(TestSubchannelTransportFactory.Create(async (s, i, c) =>
{
await syncPoint.WaitToContinue();
return new TryConnectResult(currentConnectivityState);
Expand Down Expand Up @@ -551,5 +552,134 @@ public async Task ResolveServiceConfig_ErrorOnSecondResolve_PickSuccess()
var balancer = (ChildHandlerLoadBalancer)channel.ConnectionManager._balancer!;
return (T?)balancer._current?.LoadBalancer;
}

internal class TestBackoffPolicyFactory : IBackoffPolicyFactory
{
private readonly TimeSpan _backoff;

public TestBackoffPolicyFactory() : this(TimeSpan.FromSeconds(20))
{
}

public TestBackoffPolicyFactory(TimeSpan backoff)
{
_backoff = backoff;
}

public IBackoffPolicy Create()
{
return new TestBackoffPolicy(_backoff);
}

private class TestBackoffPolicy : IBackoffPolicy
{
private readonly TimeSpan _backoff;

public TestBackoffPolicy(TimeSpan backoff)
{
_backoff = backoff;
}

public TimeSpan NextBackoff()
{
return _backoff;
}
}
}

[Test]
public async Task Resolver_UpdateResultsAfterPreviousConnect_InterruptConnect()
{
// Arrange
var services = new ServiceCollection();

// add logger
services.AddNUnitLogger();
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger<ResolverTests>();

// add resolver and balancer
var resolver = new TestResolver(loggerFactory);
var result = ResolverResult.ForResult(new List<BalancerAddress> { new BalancerAddress("localhost", 80) }, serviceConfig: null, serviceConfigStatus: null);
resolver.UpdateResult(result);

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IBackoffPolicyFactory>(new TestBackoffPolicyFactory(TimeSpan.FromSeconds(0.2)));

var tryConnectData = new List<(IReadOnlyList<BalancerAddress> BalancerAddresses, int Attempt, bool IsCancellationRequested)>();

var tryConnectCount = 0;
services.AddSingleton<ISubchannelTransportFactory>(
TestSubchannelTransportFactory.Create((subchannel, attempt, cancellationToken) =>
{
var addresses = subchannel.GetAddresses();
var isCancellationRequested = cancellationToken.IsCancellationRequested;
ConnectivityState state;

var i = Interlocked.Increment(ref tryConnectCount);
if (i == 1)
{
state = ConnectivityState.Ready;
}
else
{
state = attempt >= 2 ? ConnectivityState.Ready : ConnectivityState.TransientFailure;
}

logger.LogInformation("TryConnect attempt {Attempt} to addresses {Addresses}. State: {ConnectivityState}, IsCancellationRequested: {IsCancellationRequested}", attempt, string.Join(", ", addresses), state, isCancellationRequested);

lock (tryConnectData)
{
tryConnectData.Add((addresses, attempt, isCancellationRequested));
}

return Task.FromResult(new TryConnectResult(state));
}));

var channelOptions = new GrpcChannelOptions
{
Credentials = ChannelCredentials.Insecure,
ServiceProvider = services.BuildServiceProvider(),
};

// Act
var channel = GrpcChannel.ForAddress("test:///test_addr", channelOptions);

logger.LogInformation("Client connecting.");
await channel.ConnectionManager.ConnectAsync(waitForReady: true, CancellationToken.None);

logger.LogInformation("Client updating resolver.");
result = ResolverResult.ForResult(new List<BalancerAddress> { new BalancerAddress("localhost", 81) }, serviceConfig: null, serviceConfigStatus: null);
resolver.UpdateResult(result);

logger.LogInformation("Client picking.");
await ExceptionAssert.ThrowsAsync<RpcException>(async () => await channel.ConnectionManager.PickAsync(
new PickContext(),
waitForReady: false,
CancellationToken.None));

logger.LogInformation("Client updating Resolver.");
result = ResolverResult.ForResult(new List<BalancerAddress> { new BalancerAddress("localhost", 82) }, serviceConfig: null, serviceConfigStatus: null);
resolver.UpdateResult(result);

logger.LogInformation("Client picking and waiting for ready.");
await channel.ConnectionManager.PickAsync(
new PickContext(),
waitForReady: true,
CancellationToken.None);

// Assert
logger.LogInformation("TryConnectData count: {Count}", tryConnectData.Count);
foreach (var data in tryConnectData)
{
logger.LogInformation("Attempt: {Attempt}, BalancerAddresses: {BalancerAddresses}, IsCancellationRequested: {IsCancellationRequested}", data.Attempt, string.Join(", ", data.BalancerAddresses), data.IsCancellationRequested);
}

var duplicate = tryConnectData.GroupBy(d => new { Address = d.BalancerAddresses.Single(), d.Attempt }).FirstOrDefault(g => g.Count() >= 2);
if (duplicate != null)
{
Assert.Fail($"Duplicate attempts to address. Count: {duplicate.Count()}, Address: {duplicate.Key.Address}");
}
}
}
#endif
Loading