Skip to content

Commit 5032c59

Browse files
authored
Add ServerCloseMessage to protocol (#236)
* Add server close protocol * Change close message to error message * Add tests * Fix case and message problem * Throw exception in WriteAsync and add UT * Keep writespace contract * Rename to ServiceErrorMessage * Rename protected property * Update error log and use local variable to prevent multiple thread issues
1 parent 25bcc7b commit 5032c59

File tree

7 files changed

+105
-0
lines changed

7 files changed

+105
-0
lines changed

src/Microsoft.Azure.SignalR.Common/ServiceConnectionBase.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ internal abstract class ServiceConnectionBase : IServiceConnection
4141
private bool _isStopped;
4242
private long _lastReceiveTimestamp;
4343
protected ConnectionContext _connection;
44+
protected string ErrorMessage;
4445

4546
public Task WaitForConnectionStart => _serviceConnectionStartTcs.Task;
4647

@@ -85,6 +86,13 @@ public async virtual Task WriteAsync(ServiceMessage serviceMessage)
8586
// The lock is per serviceConnection
8687
await _serviceConnectionLock.WaitAsync();
8788

89+
var errorMessage = ErrorMessage;
90+
if (!string.IsNullOrEmpty(errorMessage))
91+
{
92+
_serviceConnectionLock.Release();
93+
throw new InvalidOperationException(errorMessage);
94+
}
95+
8896
if (_connection == null)
8997
{
9098
_serviceConnectionLock.Release();
@@ -119,6 +127,21 @@ public async virtual Task WriteAsync(ServiceMessage serviceMessage)
119127

120128
protected abstract Task OnMessageAsync(ConnectionDataMessage connectionDataMessage);
121129

130+
protected Task OnServiceErrorAsync(ServiceErrorMessage serviceErrorMessage)
131+
{
132+
if (!string.IsNullOrEmpty(serviceErrorMessage.ErrorMessage))
133+
{
134+
// When receives service error message, we suppose server -> service connection doesn't work,
135+
// and set ErrorMessage to prevent sending message from server to service
136+
// But messages in the pipe from service -> server should be processed as usual. Just log without
137+
// throw exception here.
138+
ErrorMessage = serviceErrorMessage.ErrorMessage;
139+
Log.ReceivedServiceErrorMessage(_logger, _connectionId, serviceErrorMessage.ErrorMessage);
140+
}
141+
142+
return Task.CompletedTask;
143+
}
144+
122145
private async Task<bool> StartAsyncCore()
123146
{
124147
// Always try until connected
@@ -130,6 +153,7 @@ private async Task<bool> StartAsyncCore()
130153
try
131154
{
132155
_connection = await CreateConnection();
156+
ErrorMessage = null;
133157

134158
if (await HandshakeAsync())
135159
{
@@ -343,6 +367,8 @@ private Task DispatchMessageAsync(ServiceMessage message)
343367
return OnDisconnectedAsync(closeConnectionMessage);
344368
case ConnectionDataMessage connectionDataMessage:
345369
return OnMessageAsync(connectionDataMessage);
370+
case ServiceErrorMessage serviceErrorMessage:
371+
return OnServiceErrorAsync(serviceErrorMessage);
346372
case PingMessage _:
347373
// ignore ping
348374
break;
@@ -514,6 +540,9 @@ private static class Log
514540
private static readonly Action<ILogger, Exception> _failedSendingPing =
515541
LoggerMessage.Define(LogLevel.Warning, new EventId(26, "FailedSendingPing"), "Failed sending a ping message to service.");
516542

543+
private static readonly Action<ILogger, string, string, Exception> _receivedServiceErrorMessage =
544+
LoggerMessage.Define<string, string>(LogLevel.Warning, new EventId(27, "ReceivedServiceErrorMessage"), "Connection {ServiceConnectionId} received error message from service: {Error}");
545+
517546
public static void FailedToWrite(ILogger logger, Exception exception)
518547
{
519548
_failedToWrite(logger, exception);
@@ -643,6 +672,11 @@ public static void FailedSendingPing(ILogger logger, Exception exception)
643672
{
644673
_failedSendingPing(logger, exception);
645674
}
675+
676+
public static void ReceivedServiceErrorMessage(ILogger logger, string connectionId, string errorMessage)
677+
{
678+
_receivedServiceErrorMessage(logger, connectionId, errorMessage, null);
679+
}
646680
}
647681
}
648682
}

src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,24 @@ public class PingMessage : ServiceMessage
6767
/// </summary>
6868
public static PingMessage Instance = new PingMessage();
6969
}
70+
71+
/// <summary>
72+
/// A service error message
73+
/// </summary>
74+
public class ServiceErrorMessage : ServiceMessage
75+
{
76+
/// <summary>
77+
/// Gets or sets the error message
78+
/// </summary>
79+
public string ErrorMessage { get; set; }
80+
81+
/// <summary>
82+
/// Initializes a new instance of the <see cref="ServiceErrorMessage"/> class.
83+
/// </summary>
84+
/// <param name="errorMessage">An error message.</param>
85+
public ServiceErrorMessage(string errorMessage)
86+
{
87+
ErrorMessage = errorMessage;
88+
}
89+
}
7090
}

src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ private static ServiceMessage ParseMessage(byte[] input, int startOffset)
8989
return CreateGroupBroadcastDataMessage(input, ref startOffset);
9090
case ServiceProtocolConstants.MultiGroupBroadcastDataMessageType:
9191
return CreateMultiGroupBroadcastDataMessage(input, ref startOffset);
92+
case ServiceProtocolConstants.ServiceErrorMessageType:
93+
return CreateServiceErrorMessage(input, ref startOffset);
9294
default:
9395
// Future protocol changes can add message types, old clients can ignore them
9496
return null;
@@ -190,6 +192,9 @@ private static void WriteMessageCore(ServiceMessage message, Stream packer)
190192
case MultiGroupBroadcastDataMessage multiGroupBroadcastDataMessage:
191193
WriteMultiGroupBroadcastDataMessage(multiGroupBroadcastDataMessage, packer);
192194
break;
195+
case ServiceErrorMessage serviceErrorMessage:
196+
WriteServiceErrorMessage(serviceErrorMessage, packer);
197+
break;
193198
default:
194199
throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
195200
}
@@ -357,6 +362,13 @@ private static void WriteMultiGroupBroadcastDataMessage(MultiGroupBroadcastDataM
357362
WritePayloads(message.Payloads, packer);
358363
}
359364

365+
private static void WriteServiceErrorMessage(ServiceErrorMessage message, Stream packer)
366+
{
367+
MessagePackBinary.WriteArrayHeader(packer,2);
368+
MessagePackBinary.WriteInt32(packer, ServiceProtocolConstants.ServiceErrorMessageType);
369+
MessagePackBinary.WriteString(packer, message.ErrorMessage);
370+
}
371+
360372
private static void WriteStringArray(IReadOnlyList<string> array, Stream packer)
361373
{
362374
if (array?.Count > 0)
@@ -527,6 +539,13 @@ private static MultiGroupBroadcastDataMessage CreateMultiGroupBroadcastDataMessa
527539
return new MultiGroupBroadcastDataMessage(groupList, payloads);
528540
}
529541

542+
private static ServiceErrorMessage CreateServiceErrorMessage(byte[] input, ref int offset)
543+
{
544+
var errorMessage = ReadString(input, ref offset, "errorMessage");
545+
546+
return new ServiceErrorMessage(errorMessage);
547+
}
548+
530549
private static Claim[] ReadClaims(byte[] input, ref int offset)
531550
{
532551
var claimCount = ReadMapLength(input, ref offset, "claims");

src/Microsoft.Azure.SignalR.Protocols/ServiceProtocolConstants.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ public static class ServiceProtocolConstants
1919
public const int LeaveGroupMessageType = 12;
2020
public const int GroupBroadcastDataMessageType = 13;
2121
public const int MultiGroupBroadcastDataMessageType = 14;
22+
public const int ServiceErrorMessageType = 15;
2223
}
2324
}

test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceMessageEqualityComparer.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ public bool Equals(ServiceMessage x, ServiceMessage y)
5353
case MultiGroupBroadcastDataMessage multiGroupBroadcastDataMessage:
5454
return MultiGroupBroadcastDataMessagesEqual(multiGroupBroadcastDataMessage,
5555
(MultiGroupBroadcastDataMessage)y);
56+
case ServiceErrorMessage serviceErrorMessage:
57+
return ServiceErrorMessageEqual(serviceErrorMessage, (ServiceErrorMessage)y);
5658
default:
5759
throw new InvalidOperationException($"Unknown message type: {x.GetType().FullName}");
5860
}
@@ -135,6 +137,11 @@ private bool MultiGroupBroadcastDataMessagesEqual(MultiGroupBroadcastDataMessage
135137
return SequenceEqual(x.GroupList, y.GroupList) && PayloadsEqual(x.Payloads, y.Payloads);
136138
}
137139

140+
private bool ServiceErrorMessageEqual(ServiceErrorMessage x, ServiceErrorMessage y)
141+
{
142+
return StringEqual(x.ErrorMessage, y.ErrorMessage);
143+
}
144+
138145
private static bool StringEqual(string x, string y)
139146
{
140147
return string.Equals(x, y, StringComparison.Ordinal);

test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ public static IEnumerable<object[]> TestDataNames
159159
["messagepack"] = new byte[] {7, 8, 1, 2, 3, 4, 5, 6}
160160
}),
161161
binary: "kw6Spmdyb3VwNKZncm91cDWCpGpzb27ECAECAwQFBgcIq21lc3NhZ2VwYWNrxAgHCAECAwQFBg=="),
162+
new ProtocolTestData(
163+
name: "ServiceError",
164+
message: new ServiceErrorMessage("Maximum message count limit reached: 100000"),
165+
binary: "kg/ZK01heGltdW0gbWVzc2FnZSBjb3VudCBsaW1pdCByZWFjaGVkOiAxMDAwMDA="),
162166
}.ToDictionary(t => t.Name);
163167

164168
[Theory]

test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ public async Task ClosingConnectionSendsCloseMessage()
140140
proxy.Stop();
141141
}
142142

143+
[Fact]
144+
public async Task ThrowingExceptionAfterServiceErrorMessage()
145+
{
146+
var proxy = new ServiceConnectionProxy();
147+
148+
var serverTask = proxy.WaitForServerConnectionAsync(1);
149+
_ = proxy.StartAsync();
150+
await serverTask.OrTimeout();
151+
152+
string errorMessage = "Maximum message count limit reached: 100000";
153+
154+
await proxy.WriteMessageAsync(new ServiceErrorMessage(errorMessage));
155+
await Task.Delay(200);
156+
157+
var serviceConnection = proxy.ServiceConnection;
158+
var excption = await Assert.ThrowsAsync<InvalidOperationException>(() =>
159+
serviceConnection.WriteAsync(new ConnectionDataMessage("1", null)));
160+
Assert.Equal(errorMessage, excption.Message);
161+
}
162+
143163
[Fact]
144164
public async Task WritingMessagesFromConnectionGetsSentAsConnectionData()
145165
{

0 commit comments

Comments
 (0)