Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,53 +253,7 @@ private void OnCheckSocketConnection(object? state)
{
CompatibilityHelpers.Assert(socketAddress != null);

try
{
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);

// Poll socket to check if it can be read from. Unfortunatly this requires reading pending data.
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
//
// Available data needs to be read now because the only way to determine whether the connection is closed is to
// get the results of polling after available data is received.
bool hasReadData;
do
{
closeSocket = IsSocketInBadState(socket, socketAddress);
var available = socket.Available;
if (available > 0)
{
hasReadData = true;
var serverDataAvailable = CalculateInitialSocketDataLength(_initialSocketData) + available;
if (serverDataAvailable > MaximumInitialSocketDataSize)
{
// Data sent to the client before a connection is started shouldn't be large.
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
}

SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);

// Data is already available so this won't block.
var buffer = new byte[available];
var readCount = socket.Receive(buffer);

_initialSocketData ??= new List<ReadOnlyMemory<byte>>();
_initialSocketData.Add(buffer.AsMemory(0, readCount));
}
else
{
hasReadData = false;
}
}
while (hasReadData);
}
catch (Exception ex)
{
closeSocket = true;
checkException = ex;
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
}
closeSocket = CheckSocket(socket, socketAddress, ref _initialSocketData, out checkException);
}
}

Expand Down Expand Up @@ -337,6 +291,68 @@ private void OnCheckSocketConnection(object? state)
}
}

/// <summary>
/// Checks whether the socket is healthy. May read available data into the passed in buffer.
/// Returns true if the socket should be closed.
/// </summary>
private bool CheckSocket(Socket socket, BalancerAddress socketAddress, ref List<ReadOnlyMemory<byte>>? socketData, out Exception? checkException)
{
checkException = null;

try
{
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);

// Poll socket to check if it can be read from. Unfortunatly this requires reading pending data.
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
//
// Available data needs to be read now because the only way to determine whether the connection is closed is to
// get the results of polling after available data is received.
bool hasReadData;
do
{
if (IsSocketInBadState(socket, socketAddress))
{
return true;
}

var available = socket.Available;
if (available > 0)
{
hasReadData = true;
var serverDataAvailable = CalculateInitialSocketDataLength(socketData) + available;
if (serverDataAvailable > MaximumInitialSocketDataSize)
{
// Data sent to the client before a connection is started shouldn't be large.
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
}

SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);

// Data is already available so this won't block.
var buffer = new byte[available];
var readCount = socket.Receive(buffer);

socketData ??= new List<ReadOnlyMemory<byte>>();
socketData.Add(buffer.AsMemory(0, readCount));
}
else
{
hasReadData = false;
}
}
while (hasReadData);
return false;
}
catch (Exception ex)
{
checkException = ex;
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
return true;
}
}

public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, CancellationToken cancellationToken)
{
SocketConnectivitySubchannelTransportLog.CreatingStream(_logger, _subchannel.Id, address);
Expand Down Expand Up @@ -383,7 +399,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout);
closeSocket = true;
}
else if (IsSocketInBadState(socket, address))
else if (CheckSocket(socket, address, ref socketData, out _))
{
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
closeSocket = true;
Expand Down
25 changes: 18 additions & 7 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ public static EndpointContext<TRequest, TResponse> CreateGrpcEndpoint<TRequest,
HttpProtocols? protocols = null,
bool? isHttps = null,
X509Certificate2? certificate = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
Action<KestrelServerOptions>? configureServer = null)
where TRequest : class, IMessage, new()
where TResponse : class, IMessage, new()
{
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory);
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory, configureServer);
var method = server.DynamicGrpc.AddUnaryMethod(callHandler, methodName);
var url = server.GetUrl(isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2);

Expand Down Expand Up @@ -88,7 +89,13 @@ public void Dispose()
}
}

public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? protocols = null, bool? isHttps = null, X509Certificate2? certificate = null, ILoggerFactory? loggerFactory = null)
public static GrpcTestFixture<Startup> CreateServer(
int port,
HttpProtocols? protocols = null,
bool? isHttps = null,
X509Certificate2? certificate = null,
ILoggerFactory? loggerFactory = null,
Action<KestrelServerOptions>? configureServer = null)
{
var endpointName = isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2;

Expand All @@ -102,6 +109,8 @@ public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? pro
},
(options, urls) =>
{
configureServer?.Invoke(options);

urls[endpointName] = isHttps.GetValueOrDefault(false)
? $"https://127.0.0.1:{port}"
: $"http://127.0.0.1:{port}";
Expand Down Expand Up @@ -136,13 +145,14 @@ public static Task<GrpcChannel> CreateChannel(
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
TimeSpan? connectionIdleTimeout = null,
TimeSpan? socketPingInterval = null)
{
var resolver = new TestResolver();
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
resolver.UpdateAddresses(e);

return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout, socketPingInterval);
}

public static async Task<GrpcChannel> CreateChannel(
Expand All @@ -154,12 +164,13 @@ public static async Task<GrpcChannel> CreateChannel(
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
TimeSpan? connectionIdleTimeout = null,
TimeSpan? socketPingInterval = null)
{
var services = new ServiceCollection();
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(socketPingInterval ?? TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());

var serviceConfig = new ServiceConfig();
Expand Down
51 changes: 50 additions & 1 deletion test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod), loggerFactory: LoggerFactory);

var connectionIdleTimeout = TimeSpan.FromSeconds(1);
var channel = await BalancerHelpers.CreateChannel(
Expand All @@ -180,6 +180,55 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
}

[Test]
public async Task Active_UnaryCall_ServerCloseOnKeepAlive_SocketRecreatedOnRequest()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
return Task.FromResult(new HelloReply { Message = request.Name });
}

// In this test the client connects to the server, and the server then closes it after keep-alive is triggered.
// The client then starts a gRPC call to the server. The client should discard the closed socket and create a new one.

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(
50051,
UnaryMethod,
nameof(UnaryMethod),
loggerFactory: LoggerFactory,
configureServer: o => o.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(1));

// Don't timeout the socket or ping it from the client.
var channel = await BalancerHelpers.CreateChannel(
LoggerFactory,
new RoundRobinConfig(),
new[] { endpoint.Address },
connectionIdleTimeout: TimeSpan.FromMinutes(30),
socketPingInterval: TimeSpan.FromMinutes(30)).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();

// Fails when this test is run with debugging. Kestrel doesn't trigger keepalive timeout if debugging is enabled.
await TestHelpers.AssertIsTrueRetryAsync(() =>
{
return Logs.Any(l => l.LoggerName.StartsWith("Microsoft.AspNetCore.Server.Kestrel") && l.EventId.Name == "ConnectionStop");
}, "Wait for server to close connection.");

var client = TestClientFactory.Create(channel, endpoint.Method);
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();

// Assert
Assert.AreEqual("Test!", response.Message);
}

[Test]
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
{
Expand Down