Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/ServiceConnectionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ internal abstract class ServiceConnectionBase : IServiceConnection
private bool _isStopped;
private long _lastReceiveTimestamp;
protected ConnectionContext _connection;
protected string ErrorMessage;

public Task WaitForConnectionStart => _serviceConnectionStartTcs.Task;

Expand Down Expand Up @@ -85,6 +86,13 @@ public async virtual Task WriteAsync(ServiceMessage serviceMessage)
// The lock is per serviceConnection
await _serviceConnectionLock.WaitAsync();

var errorMessage = ErrorMessage;
if (!string.IsNullOrEmpty(errorMessage))
{
_serviceConnectionLock.Release();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a local errorMessage=ErrorMessage for throw, ErrorMessage can be reset in between if and throw

throw new InvalidOperationException(errorMessage);
}

if (_connection == null)
{
_serviceConnectionLock.Release();
Expand Down Expand Up @@ -119,6 +127,21 @@ public async virtual Task WriteAsync(ServiceMessage serviceMessage)

protected abstract Task OnMessageAsync(ConnectionDataMessage connectionDataMessage);

protected Task OnServiceErrorAsync(ServiceErrorMessage serviceErrorMessage)
{
if (!string.IsNullOrEmpty(serviceErrorMessage.ErrorMessage))
{
// When receives service error message, we suppose server -> service connection doesn't work,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidfowl After discussed offline, we renamed the message name and changed to log rather than throw exceptions in ProcessIncomingAsync

// and set ErrorMessage to prevent sending message from server to service
// But messages in the pipe from service -> server should be processed as usual. Just log without
// throw exception here.
ErrorMessage = serviceErrorMessage.ErrorMessage;
Log.ReceivedServiceErrorMessage(_logger, _connectionId, serviceErrorMessage.ErrorMessage);
}

return Task.CompletedTask;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CompletedTask [](start = 24, length = 13)

shall we close the server connection with this message?

}

private async Task<bool> StartAsyncCore()
{
// Always try until connected
Expand All @@ -130,6 +153,7 @@ private async Task<bool> StartAsyncCore()
try
{
_connection = await CreateConnection();
ErrorMessage = null;

if (await HandshakeAsync())
{
Expand Down Expand Up @@ -343,6 +367,8 @@ private Task DispatchMessageAsync(ServiceMessage message)
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage _:
// ignore ping
break;
Expand Down Expand Up @@ -514,6 +540,9 @@ private static class Log
private static readonly Action<ILogger, Exception> _failedSendingPing =
LoggerMessage.Define(LogLevel.Warning, new EventId(26, "FailedSendingPing"), "Failed sending a ping message to service.");

private static readonly Action<ILogger, string, string, Exception> _receivedServiceErrorMessage =
LoggerMessage.Define<string, string>(LogLevel.Warning, new EventId(27, "ReceivedServiceErrorMessage"), "Connection {ServiceConnectionId} received error message from service: {Error}");

public static void FailedToWrite(ILogger logger, Exception exception)
{
_failedToWrite(logger, exception);
Expand Down Expand Up @@ -643,6 +672,11 @@ public static void FailedSendingPing(ILogger logger, Exception exception)
{
_failedSendingPing(logger, exception);
}

public static void ReceivedServiceErrorMessage(ILogger logger, string connectionId, string errorMessage)
{
_receivedServiceErrorMessage(logger, connectionId, errorMessage, null);
}
}
}
}
20 changes: 20 additions & 0 deletions src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,24 @@ public class PingMessage : ServiceMessage
/// </summary>
public static PingMessage Instance = new PingMessage();
}

/// <summary>
/// A service error message
/// </summary>
public class ServiceErrorMessage : ServiceMessage
{
/// <summary>
/// Gets or sets the error message
/// </summary>
public string ErrorMessage { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="ServiceErrorMessage"/> class.
/// </summary>
/// <param name="errorMessage">An error message.</param>
public ServiceErrorMessage(string errorMessage)
{
ErrorMessage = errorMessage;
}
}
}
19 changes: 19 additions & 0 deletions src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ private static ServiceMessage ParseMessage(byte[] input, int startOffset)
return CreateGroupBroadcastDataMessage(input, ref startOffset);
case ServiceProtocolConstants.MultiGroupBroadcastDataMessageType:
return CreateMultiGroupBroadcastDataMessage(input, ref startOffset);
case ServiceProtocolConstants.ServiceErrorMessageType:
return CreateServiceErrorMessage(input, ref startOffset);
default:
// Future protocol changes can add message types, old clients can ignore them
return null;
Expand Down Expand Up @@ -190,6 +192,9 @@ private static void WriteMessageCore(ServiceMessage message, Stream packer)
case MultiGroupBroadcastDataMessage multiGroupBroadcastDataMessage:
WriteMultiGroupBroadcastDataMessage(multiGroupBroadcastDataMessage, packer);
break;
case ServiceErrorMessage serviceErrorMessage:
WriteServiceErrorMessage(serviceErrorMessage, packer);
break;
default:
throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
}
Expand Down Expand Up @@ -357,6 +362,13 @@ private static void WriteMultiGroupBroadcastDataMessage(MultiGroupBroadcastDataM
WritePayloads(message.Payloads, packer);
}

private static void WriteServiceErrorMessage(ServiceErrorMessage message, Stream packer)
{
MessagePackBinary.WriteArrayHeader(packer,2);
MessagePackBinary.WriteInt32(packer, ServiceProtocolConstants.ServiceErrorMessageType);
MessagePackBinary.WriteString(packer, message.ErrorMessage);
}

private static void WriteStringArray(IReadOnlyList<string> array, Stream packer)
{
if (array?.Count > 0)
Expand Down Expand Up @@ -527,6 +539,13 @@ private static MultiGroupBroadcastDataMessage CreateMultiGroupBroadcastDataMessa
return new MultiGroupBroadcastDataMessage(groupList, payloads);
}

private static ServiceErrorMessage CreateServiceErrorMessage(byte[] input, ref int offset)
{
var errorMessage = ReadString(input, ref offset, "errorMessage");

return new ServiceErrorMessage(errorMessage);
}

private static Claim[] ReadClaims(byte[] input, ref int offset)
{
var claimCount = ReadMapLength(input, ref offset, "claims");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ public static class ServiceProtocolConstants
public const int LeaveGroupMessageType = 12;
public const int GroupBroadcastDataMessageType = 13;
public const int MultiGroupBroadcastDataMessageType = 14;
public const int ServiceErrorMessageType = 15;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public bool Equals(ServiceMessage x, ServiceMessage y)
case MultiGroupBroadcastDataMessage multiGroupBroadcastDataMessage:
return MultiGroupBroadcastDataMessagesEqual(multiGroupBroadcastDataMessage,
(MultiGroupBroadcastDataMessage)y);
case ServiceErrorMessage serviceErrorMessage:
return ServiceErrorMessageEqual(serviceErrorMessage, (ServiceErrorMessage)y);
default:
throw new InvalidOperationException($"Unknown message type: {x.GetType().FullName}");
}
Expand Down Expand Up @@ -135,6 +137,11 @@ private bool MultiGroupBroadcastDataMessagesEqual(MultiGroupBroadcastDataMessage
return SequenceEqual(x.GroupList, y.GroupList) && PayloadsEqual(x.Payloads, y.Payloads);
}

private bool ServiceErrorMessageEqual(ServiceErrorMessage x, ServiceErrorMessage y)
{
return StringEqual(x.ErrorMessage, y.ErrorMessage);
}

private static bool StringEqual(string x, string y)
{
return string.Equals(x, y, StringComparison.Ordinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ public static IEnumerable<object[]> TestDataNames
["messagepack"] = new byte[] {7, 8, 1, 2, 3, 4, 5, 6}
}),
binary: "kw6Spmdyb3VwNKZncm91cDWCpGpzb27ECAECAwQFBgcIq21lc3NhZ2VwYWNrxAgHCAECAwQFBg=="),
new ProtocolTestData(
name: "ServiceError",
message: new ServiceErrorMessage("Maximum message count limit reached: 100000"),
binary: "kg/ZK01heGltdW0gbWVzc2FnZSBjb3VudCBsaW1pdCByZWFjaGVkOiAxMDAwMDA="),
}.ToDictionary(t => t.Name);

[Theory]
Expand Down
20 changes: 20 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,26 @@ public async Task ClosingConnectionSendsCloseMessage()
proxy.Stop();
}

[Fact]
public async Task ThrowingExceptionAfterServiceErrorMessage()
{
var proxy = new ServiceConnectionProxy();

var serverTask = proxy.WaitForServerConnectionAsync(1);
_ = proxy.StartAsync();
await serverTask.OrTimeout();

string errorMessage = "Maximum message count limit reached: 100000";

await proxy.WriteMessageAsync(new ServiceErrorMessage(errorMessage));
await Task.Delay(200);

var serviceConnection = proxy.ServiceConnection;
var excption = await Assert.ThrowsAsync<InvalidOperationException>(() =>
serviceConnection.WriteAsync(new ConnectionDataMessage("1", null)));
Assert.Equal(errorMessage, excption.Message);
}

[Fact]
public async Task WritingMessagesFromConnectionGetsSentAsConnectionData()
{
Expand Down