Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ internal static class Constants
public const string AsrsDefaultScope = "https://signalr.azure.com/.default";


public const int DefaultCloseTimeoutMilliseconds = 30000;
public const int DefaultCloseTimeoutMilliseconds = 10000;

public static class Keys
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static async Task OrCancelAsync(this Task task, CancellationToken token,
{
// make sure the task throws exception if any
await anyTask;
tcs.TrySetCanceled();
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ private static class Log
private static readonly Action<ILogger, string, Exception> _connectedStarting =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(6, "ConnectedStarting"), "Connection {TransportConnectionId} started.");

private static readonly Action<ILogger, string, Exception> _detectedLongRunningApplicationTask =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(7, "DetectedLongRunningApplicationTask"), "The connection {TransportConnectionId} has a long running application logic that prevents the connection from complete.");
private static readonly Action<ILogger, string, int, Exception> _detectedLongRunningApplicationTask =
LoggerMessage.Define<string, int>(LogLevel.Warning, new EventId(7, "DetectedLongRunningApplicationTask"), "The connection {TransportConnectionId} has a long running application logic that prevents the connection from complete after {TimeoutMilliseconds} milliseconds.");

private static readonly Action<ILogger, string, Exception> _outgoingTaskPaused =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(8, "OutgoingTaskPaused"), "Outgoing messages for connection {connectionId} have been paused.");
Expand Down Expand Up @@ -71,9 +71,9 @@ public static void ConnectedStarting(ILogger logger, string connectionId)
_connectedStarting(logger, connectionId, null);
}

public static void DetectedLongRunningApplicationTask(ILogger logger, string connectionId)
public static void DetectedLongRunningApplicationTask(ILogger logger, string connectionId, int timeoutMilliseconds)
{
_detectedLongRunningApplicationTask(logger, connectionId, null);
_detectedLongRunningApplicationTask(logger, connectionId, timeoutMilliseconds, null);
}

public static void OutgoingTaskPaused(ILogger logger, string connectionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ internal partial class ClientConnectionContext : ConnectionContext,

private readonly CancellationTokenSource _abortOutgoingCts = new CancellationTokenSource();

private readonly CancellationTokenSource _connectionClosedCts = new CancellationTokenSource();

private readonly object _heartbeatLock = new object();

private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1, 1);
Expand All @@ -97,6 +99,10 @@ internal partial class ClientConnectionContext : ConnectionContext,

private long _receivedBytes;

#if !NETSTANDARD
public override CancellationToken ConnectionClosed => _connectionClosedCts.Token;
#endif

public override string ConnectionId { get; set; }

public string InstanceId { get; }
Expand Down Expand Up @@ -136,8 +142,6 @@ public bool AbortOnClose

public ILogger<ServiceConnection> Logger { get; init; } = NullLogger<ServiceConnection>.Instance;

private Task DelayTask => Task.Delay(_closeTimeOutMilliseconds);

private CancellationToken OutgoingAborted => _abortOutgoingCts.Token;

public ClientConnectionContext(OpenConnectionMessage serviceMessage,
Expand Down Expand Up @@ -330,6 +334,7 @@ internal async Task ProcessOutgoingMessagesAsync(SignalRProtocol.IHubProtocol pr
var next = buffer;
while (!buffer.IsEmpty && protocol.TryParseMessage(ref next, FakeInvocationBinder.Instance, out var message))
{
// we still want messages to successfully going out when application completes
if (!await _pauseHandler.WaitAsync(StaticRandom.Next(500, 1500), OutgoingAborted))
{
Log.OutgoingTaskPaused(Logger, ConnectionId);
Expand Down Expand Up @@ -373,6 +378,10 @@ internal async Task ProcessOutgoingMessagesAsync(SignalRProtocol.IHubProtocol pr
Application.Input.AdvanceTo(buffer.Start, buffer.End);
}
}
catch (OperationCanceledException)
{
// cancelled
}
catch (ForwardMessageException)
{
// do nothing.
Expand Down Expand Up @@ -401,11 +410,6 @@ internal async Task ProcessApplicationAsync(ConnectionDelegate connectionDelegat
// application task can end when exception, or Context.Abort() from hub
await connectionDelegate(this);
}
catch (ObjectDisposedException)
{
// When the application shuts down and disposes IServiceProvider, HubConnectionHandler.RunHubAsync is still running and runs into _dispatcher.OnDisconnectedAsync
// no need to throw the error out
}
catch (Exception ex)
{
// Capture the exception to communicate it to the transport (this isn't strictly required)
Expand Down Expand Up @@ -434,14 +438,20 @@ internal async Task PerformDisconnectAsync()
// Wait for the connection's lifetime task to end
// Wait on the application task to complete
// We wait gracefully here to be consistent with self-host SignalR
await Task.WhenAny(LifetimeTask, DelayTask);

if (!LifetimeTask.IsCompleted)
using var cts = new CancellationTokenSource(_closeTimeOutMilliseconds);
try
{
Log.DetectedLongRunningApplicationTask(Logger, ConnectionId);
await LifetimeTask.OrCancelAsync(cts.Token);
cts.Cancel();
}
catch (OperationCanceledException)
{

await LifetimeTask;
Log.DetectedLongRunningApplicationTask(Logger, ConnectionId, _closeTimeOutMilliseconds);
#if !NETSTANDARD
_connectionClosedCts.Cancel();
#endif
}
}

internal async Task ProcessConnectionDataMessageAsync(ConnectionDataMessage connectionDataMessage)
Expand Down
32 changes: 22 additions & 10 deletions src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,14 @@ public void Start(ConnectionDelegate connectionDelegate, Action<HttpContext> con
public async Task ShutdownAsync()
{
var options = _options.GracefulShutdown;
if (options.Mode == GracefulShutdownMode.Off)
{
return;
}

try
{
var source = new CancellationTokenSource(_options.GracefulShutdown.Timeout);
var source = new CancellationTokenSource(options.Timeout);
await GracefulShutdownAsync(options, source.Token).OrCancelAsync(source.Token);
}
catch (OperationCanceledException)
{
Log.GracefulShutdownTimeoutExceeded(_logger, _hubName, Convert.ToInt32(_options.GracefulShutdown.Timeout.TotalMilliseconds));
Log.GracefulShutdownTimeoutExceeded(_logger, _hubName, Convert.ToInt32(options.Timeout.TotalMilliseconds));
}

Log.StoppingServer(_logger, _hubName);
Expand All @@ -158,8 +153,17 @@ private async Task GracefulShutdownAsync(GracefulShutdownOptions options, Cancel
Log.SettingServerOffline(_logger, _hubName);
await _serviceConnectionManager.OfflineAsync(options.Mode, cancellationToken);

Log.TriggeringShutdownHooks(_logger, _hubName);
await options.OnShutdown(Context);
if (options.Mode == GracefulShutdownMode.Off)
{
// By default we directly close all the client connections
Log.CloseClientConnections(_logger, _hubName);
await _serviceConnectionManager.CloseClientConnections(cancellationToken);
}
else
{
Log.TriggeringShutdownHooks(_logger, _hubName);
await options.OnShutdown(Context);
}

Log.WaitingClientConnectionsToClose(_logger, _hubName);
await _clientConnectionManager.WhenAllCompleted();
Expand Down Expand Up @@ -217,11 +221,14 @@ private static class Log
LoggerMessage.Define<string>(LogLevel.Information, new EventId(4, "TriggeringShutdownHooks"), "[{hubName}] Triggering shutdown hooks...");

private static readonly Action<ILogger, string, Exception> _waitingClientConnectionsToClose =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(5, "WaitingClientConnectionsToClose"), "[{hubName}] Waiting client connections to close...");
LoggerMessage.Define<string>(LogLevel.Information, new EventId(5, "WaitingClientConnectionsToClose"), "[{hubName}] Waiting for client connections to close...");

private static readonly Action<ILogger, string, Exception> _stoppingServer =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(6, "StoppingServer"), "[{hubName}] Stopping the hub server...");

private static readonly Action<ILogger, string, Exception> _closeClientConnections =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(7, "CloseClientConnections"), "[{hubName}] Closing client connections...");

public static void StartingConnection(ILogger logger, string name, int connectionNumber)
{
_startingConnection(logger, name, connectionNumber, null);
Expand All @@ -247,6 +254,11 @@ public static void WaitingClientConnectionsToClose(ILogger logger, string hubNam
_waitingClientConnectionsToClose(logger, hubName, null);
}

public static void CloseClientConnections(ILogger logger, string hubName)
{
_closeClientConnections(logger, hubName, null);
}

public static void StoppingServer(ILogger logger, string hubName)
{
_stoppingServer(logger, hubName, null);
Expand Down
25 changes: 22 additions & 3 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,29 @@ protected override Task DisposeConnection(ConnectionContext connection)
return _connectionFactory.DisposeAsync(connection);
}

public override Task CloseClientConnections(CancellationToken token)
public override async Task CloseClientConnections(CancellationToken token)
{
return Task.WhenAll(_clientConnectionManager.ClientConnections.Select(c => ((ClientConnectionContext)c).PerformDisconnectAsync()));
var tasks = new List<Task>();
foreach (var entity in _connectionIds)
{
if (_clientConnectionManager.TryGetClientConnection(entity.Key, out var c) && c is ClientConnectionContext connection)
{
// batch remove 100 connections once
if (tasks.Count % 100 == 0)
{
await Task.Delay(1);
await Task.WhenAll(tasks);
tasks.Clear();
}
// Add a little delay to avoid too much pressure closing all the clients at one time
tasks.Add(connection.PerformDisconnectAsync());
}
}

if (tasks.Count > 0)
{
await Task.WhenAll(tasks);
}
}

protected override Task CleanupClientConnections(string fromInstanceId = null)
Expand Down Expand Up @@ -284,7 +304,6 @@ private async Task ProcessClientConnectionAsync(ClientConnectionContext connecti
// app task completes connection.Transport.Output, which will completes connection.Application.Input and ends the transport
// Transports are written by us and are well behaved, wait for them to drain
connection.CancelOutgoing(true);

// transport never throws
await transport;
}
Expand Down
Loading
Loading