Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using Microsoft.AspNetCore.Connections;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Azure.SignalR.AspNet
{
Expand Down Expand Up @@ -71,7 +72,7 @@ protected override Task OnConnectedAsync(OpenConnectionMessage openConnectionMes
{
// Create empty transport with only channel for async processing messages
var connectionId = openConnectionMessage.ConnectionId;
var clientContext = new ClientContext(connectionId);
var clientContext = new ClientContext(connectionId, GetInstanceId(openConnectionMessage.Headers));

if (_clientConnectionManager.TryAdd(connectionId, this))
{
Expand Down Expand Up @@ -278,14 +279,48 @@ private static string GetString(ReadOnlySequence<byte> buffer)

return Encoding.UTF8.GetString(buffer.ToArray());
}


protected override Task DisconnectClientConnections(string instanceId)
{
if (_clientConnections.Count() == 0 || string.IsNullOrEmpty(instanceId))
{
return Task.CompletedTask;
}
try
{
// Connected service instance is down, close client gracefully
var connections = _clientConnections.Where(c => c.Value.InstanceId == instanceId);
foreach (var connection in connections)
{
var serviceMessage = new CloseConnectionMessage(connection.Value.ConnectionId, errorMessage: "Service transient error");
_ = PerformDisconnectCore(connection.Value.ConnectionId, true, true);
_ = WriteAsync(serviceMessage);
}
}
catch (Exception ex)
{
Log.FailedToCleanupConnections(Logger, ex);
}
return Task.CompletedTask;
}

private string GetInstanceId(IDictionary<string, StringValues> header)
{
if (header.ContainsKey(Constants.AsrsInstanceId))
{
return header[Constants.AsrsInstanceId];
}
return null;
}

private sealed class ClientContext
{
private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();

public ClientContext(string connectionId)
public ClientContext(string connectionId, string instanceId = null)
{
ConnectionId = connectionId;
InstanceId = instanceId;
var channel = Channel.CreateUnbounded<ServiceMessage>();
Input = channel.Reader;
Output = channel.Writer;
Expand All @@ -302,6 +337,8 @@ public void CancelPendingRead()

public string ConnectionId { get; }

public string InstanceId { get; }

public ChannelReader<ServiceMessage> Input { get; }

public ChannelWriter<ServiceMessage> Output { get; }
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ internal static class Constants
public const string ApplicationNameDefaultKey = "Azure:SignalR:ApplicationName";

public const string AsrsUserAgent = "Asrs-User-Agent";
public const string AsrsInstanceId = "Asrs-Instance-Id";

public const string AzureSignalREnabledKey = "Azure:SignalR:Enabled";

Expand Down Expand Up @@ -58,5 +59,11 @@ public static class Config
{
public static readonly string ConnectionStringKey = "Azure:SignalR:ConnectionString";
}

public static class ServicePingMessageKey
{
public const string RebalanceKey = "target";
public const string OfflineKey = "offline";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ public virtual async Task WriteAsync(ServiceMessage serviceMessage)
}
}

protected abstract Task DisconnectClientConnections(string instanceId);

protected abstract Task<ConnectionContext> CreateConnection(string target = null);

protected abstract Task DisposeConnection();
Expand Down Expand Up @@ -211,6 +213,10 @@ protected Task OnServiceErrorAsync(ServiceErrorMessage serviceErrorMessage)

protected Task OnPingMessageAsync(PingMessage pingMessage)
{
if (pingMessage.TryGetValue(Constants.ServicePingMessageKey.OfflineKey, out var instanceId) && !string.IsNullOrEmpty(instanceId))
{
return DisconnectClientConnections(instanceId);
}
return _serviceMessageHandler.HandlePingAsync(pingMessage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ namespace Microsoft.Azure.SignalR
{
internal class StrongServiceConnectionContainer : ServiceConnectionContainerBase
{
private const string PingTargetKey = "target";
private readonly List<IServiceConnection> _onDemandServiceConnections;

// The lock is only used to lock the on-demand part
Expand All @@ -26,19 +25,11 @@ public StrongServiceConnectionContainer(IServiceConnectionFactory serviceConnect

public override Task HandlePingAsync(PingMessage pingMessage)
{
int index = 0;
while (index < pingMessage.Messages.Length - 1)
if (pingMessage.TryGetValue(Constants.ServicePingMessageKey.RebalanceKey, out string target) && !string.IsNullOrEmpty(target))
{
if (pingMessage.Messages[index] == PingTargetKey &&
!string.IsNullOrEmpty(pingMessage.Messages[index + 1]))
{
var connection = CreateOnDemandServiceConnection();
return StartCoreAsync(connection, pingMessage.Messages[index + 1]);
}

index += 2;
var connection = CreateOnDemandServiceConnection();
return StartCoreAsync(connection, target);
}

return Task.CompletedTask;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public static class ServicePingMessageExtension
{
public static bool TryGetServiceStatusPingMessage(this ServicePingMessage message, out ServiceStatusPingMessage serviceStatus)
{
if (TryGetValue(message, ServiceStatusPingMessage.Key, out var value))
if (message.TryGetValue(ServiceStatusPingMessage.Key, out var value))
{
serviceStatus = new ServiceStatusPingMessage(value);
return true;
Expand All @@ -19,7 +19,7 @@ public static bool TryGetServiceStatusPingMessage(this ServicePingMessage messag
return false;
}

private static bool TryGetValue(ServicePingMessage pingMessage, string key, out string value)
public static bool TryGetValue(this ServicePingMessage pingMessage, string key, out string value)
{
if (pingMessage == null)
{
Expand Down
40 changes: 37 additions & 3 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Azure.SignalR
{
Expand Down Expand Up @@ -82,6 +83,30 @@ protected override async Task CleanupConnections()
}
}

protected override Task DisconnectClientConnections(string instanceId)
{
if (string.IsNullOrEmpty(instanceId) || _connectionIds.Count == 0)
{
return Task.CompletedTask;
}
try
{
// Connected service instance is down, close client gracefully
var connections = _connectionIds.Where(c => c.Value == instanceId);
foreach (var connection in connections)
{
var serviceMessage = new CloseConnectionMessage(connection.Key, errorMessage: "Service transient error.");
_ = PerformDisconnectAsyncCore(connection.Key, false);
_ = WriteAsync(serviceMessage);
}
}
catch (Exception ex)
{
Log.FailedToCleanupConnections(Logger, ex);
}
return Task.CompletedTask;
}

protected override ReadOnlyMemory<byte> GetPingMessage()
{
_pingMessages[1] = _clientConnectionManager.ClientConnections.Count.ToString();
Expand Down Expand Up @@ -144,16 +169,16 @@ private async Task ProcessOutgoingMessagesAsync(ServiceConnectionContext connect
}
}

private void AddClientConnection(ServiceConnectionContext connection)
private void AddClientConnection(ServiceConnectionContext connection, string instanceId)
{
_clientConnectionManager.AddClientConnection(connection);
_connectionIds.TryAdd(connection.ConnectionId, connection.ConnectionId);
_connectionIds.TryAdd(connection.ConnectionId, instanceId);
}

protected override Task OnConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
AddClientConnection(connection);
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);

// Execute the application code
Expand Down Expand Up @@ -286,5 +311,14 @@ protected override async Task OnMessageAsync(ConnectionDataMessage connectionDat
Log.ReceivedMessageForNonExistentConnection(Logger, connectionDataMessage.ConnectionId);
}
}

private string GetInstanceId(IDictionary<string, StringValues> header)
{
if (header.ContainsKey(Constants.AsrsInstanceId))
{
return header[Constants.AsrsInstanceId];
}
return string.Empty;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ public async Task TestContainerWithTwoEndpointWithOneOfflineSucceedsWithWarning(
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, logChecker: logs =>
{
var warns = logs.Where(s => s.Write.LogLevel == LogLevel.Warning).ToList();
Assert.Equal(1, warns.Count);
Assert.Single(warns);
Assert.Contains(warns, s => s.Write.Message.Contains("Message JoinGroupWithAckMessage is not sent to endpoint (Primary)http://url1 because all connections to this endpoint are offline."));
return true;
}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Primitives;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -512,6 +513,59 @@ public async Task ServiceConnectionWithTransportLayerClosedShouldCleanupEndlessI
}
}

[Fact]
public async Task ServiceConnectionWithOfflinePingWillTriggerDisconnectClients()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug))
{
var hubConfig = Utility.GetTestHubConfig(loggerFactory);
var atm = new AzureTransportManager(hubConfig.Resolver);
hubConfig.Resolver.Register(typeof(ITransportManager), () => atm);

var clientConnectionManager = new TestClientConnectionManager();
using (var proxy = new TestServiceConnectionProxy(clientConnectionManager, loggerFactory: loggerFactory))
{
// prepare 2 clients with different instancesId connected
var instanceId1 = Guid.NewGuid().ToString();
var connectionId1 = Guid.NewGuid().ToString("N");
var header1 = new Dictionary<string, StringValues>() { {Constants.AsrsInstanceId, instanceId1 } };
var instanceId2 = Guid.NewGuid().ToString();
var connectionId2 = Guid.NewGuid().ToString("N");
var header2 = new Dictionary<string, StringValues>() { { Constants.AsrsInstanceId, instanceId2 } };

// start the server connection
await proxy.StartServiceAsync().OrTimeout();

// Application layer sends OpenConnectionMessage
var openConnectionMessage = new OpenConnectionMessage(connectionId1, new Claim[0], header1, "?transport=webSockets");
var task = clientConnectionManager.WaitForClientConnectAsync(connectionId1);
await proxy.WriteMessageAsync(openConnectionMessage);
await task.OrTimeout();
openConnectionMessage = new OpenConnectionMessage(connectionId2, new Claim[0], header2, "?transport=webSockets");
task = clientConnectionManager.WaitForClientConnectAsync(connectionId2);
await proxy.WriteMessageAsync(openConnectionMessage);
await task.OrTimeout();

// 2 clients are connected
clientConnectionManager.CurrentTransports.TryGetValue(connectionId1, out var transport1);
Assert.NotNull(transport1);
clientConnectionManager.CurrentTransports.TryGetValue(connectionId2, out var transport2);
Assert.NotNull(transport2);

// server received offline ping on instance1 and trigger cleanup related client1
await proxy.WriteMessageAsync(new PingMessage()
{
Messages = new[] { "offline", instanceId1 }
});

// Validate client1 disconnect and client2 is still connected
await transport1.WaitOnDisconnected().OrTimeout();
Assert.Single(clientConnectionManager.CurrentTransports);
Assert.Equal(connectionId2, clientConnectionManager.CurrentTransports.FirstOrDefault().Key);
}
}
}

private sealed class HubResponseItem
{
public string H { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,10 @@ public override Task WriteAsync(ServiceMessage serviceMessage)

return Task.CompletedTask;
}

protected override Task DisconnectClientConnections(string instanceId)
{
return Task.CompletedTask;
}
}
}
47 changes: 47 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
Expand Down Expand Up @@ -451,6 +452,52 @@ await proxy.WriteMessageAsync(new PingMessage()
await serverTask3.OrTimeout();
}

/// <summary>
/// If receive offline ping message and will trigger close connection
/// connection to default.
/// </summary>
/// <returns></returns>
[Fact]
public async Task CloseClientsWhenReceiveInstanceOfflinePing()
{
// Prepare clients
var instanceId1 = Guid.NewGuid().ToString();
var connectionId1 = Guid.NewGuid().ToString("N");
var header1 = new HeaderDictionary() { { Constants.AsrsInstanceId, instanceId1 } };

var instanceId2 = Guid.NewGuid().ToString();
var connectionId2 = Guid.NewGuid().ToString("N");
var header2 = new HeaderDictionary() { { Constants.AsrsInstanceId, instanceId2 } };

var proxy = new ServiceConnectionProxy();

var connectionTask = proxy.WaitForServerConnectionAsync(1);
_ = proxy.StartAsync();

await connectionTask.OrTimeout();

// 2 client connects with different instanceIds
await proxy.WriteMessageAsync(new OpenConnectionMessage(connectionId1, null, header1, null));
connectionTask = proxy.WaitForConnectionAsync(connectionId1);
await connectionTask.OrTimeout();

await proxy.WriteMessageAsync(new OpenConnectionMessage(connectionId2, null, header2, null));
connectionTask = proxy.WaitForConnectionAsync(connectionId2);
await connectionTask.OrTimeout();

// Server received instance offline ping on instanceId1 and trigger cleanup related client1
await proxy.WriteMessageAsync(new PingMessage()
{
Messages = new[] { "offline", instanceId1 }
});
var messageTask = proxy.WaitForApplicationMessageAsync(typeof(CloseConnectionMessage));
await messageTask.OrTimeout();

// Check client2 is still connected
Assert.Single(proxy.ClientConnections);
Assert.Equal(connectionId2, proxy.ClientConnections.FirstOrDefault().Key);
}

private static void AssertTimeout(params Task[] task)
{
Assert.False(Task.WaitAll(task, DefaultTimeoutInMilliSeconds));
Expand Down