Skip to content

Commit aa1e305

Browse files
authored
Support for partial messages for stateful reconnect. (#1930)
support for partial messages for stateful reconnect.
1 parent af65763 commit aa1e305

File tree

3 files changed

+279
-7
lines changed

3 files changed

+279
-7
lines changed

src/Microsoft.Azure.SignalR/Microsoft.Azure.SignalR.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
<ItemGroup>
1919
<Compile Include="..\Common\MemoryBufferWriter.cs" Link="Internals\MemoryBufferWriter.cs" />
20+
<Compile Include="..\Microsoft.Azure.SignalR.Emulator\Common\ExactSizeMemoryPool.cs" Link="Internals\ExactSizeMemoryPool.cs" />
2021
</ItemGroup>
2122

2223
<!-- Copy the referenced Microsoft.Azure.SignalR.Common.dll into the package https://github.com/nuget/home/issues/3891#issuecomment-397481361 -->

src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
5+
using System.Buffers;
56
using System.Collections.Concurrent;
67
using System.Collections.Generic;
78
using System.Linq;
@@ -49,6 +50,10 @@ internal partial class ServiceConnection : ServiceConnectionBase
4950

5051
private readonly AckHandler _ackHandler;
5152

53+
// Performance: Do not use ConcurrentDictionary. There is no multi-threading scenario here, all operations are in the same logical thread.
54+
private readonly Dictionary<string, List<IMemoryOwner<byte>>> _bufferingMessages =
55+
new Dictionary<string, List<IMemoryOwner<byte>>>(StringComparer.Ordinal);
56+
5257
public Action<HttpContext> ConfigureContext { get; set; }
5358

5459
public ServiceConnection(IServiceProtocol serviceProtocol,
@@ -99,6 +104,8 @@ protected override Task CleanupClientConnections(string fromInstanceId = null)
99104
continue;
100105
}
101106

107+
// make sure there is no await operation before _bufferingMessages.
108+
_bufferingMessages.Remove(connection.Key);
102109
// We should not wait until all the clients' lifetime ends to restart another service connection
103110
_ = PerformDisconnectAsyncCore(connection.Key);
104111
}
@@ -157,6 +164,8 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
157164
protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeConnectionMessage)
158165
{
159166
var connectionId = closeConnectionMessage.ConnectionId;
167+
// make sure there is no await operation before _bufferingMessages.
168+
_bufferingMessages.Remove(connectionId);
160169
if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var context))
161170
{
162171
if (closeConnectionMessage.Headers.TryGetValue(Constants.AsrsMigrateTo, out var to))
@@ -186,9 +195,45 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect
186195
{
187196
try
188197
{
189-
var payload = connectionDataMessage.Payload;
190-
Log.WriteMessageToApplication(Logger, payload.Length, connectionDataMessage.ConnectionId);
191-
await connection.WriteMessageAsync(payload);
198+
if (connectionDataMessage.IsPartial)
199+
{
200+
var owner = ExactSizeMemoryPool.Shared.Rent((int)connectionDataMessage.Payload.Length);
201+
connectionDataMessage.Payload.CopyTo(owner.Memory.Span);
202+
// make sure there is no await operation before _bufferingMessages.
203+
if (!_bufferingMessages.TryGetValue(connectionDataMessage.ConnectionId, out var list))
204+
{
205+
list = new List<IMemoryOwner<byte>>();
206+
_bufferingMessages[connectionDataMessage.ConnectionId] = list;
207+
}
208+
list.Add(owner);
209+
}
210+
else
211+
{
212+
// make sure there is no await operation before _bufferingMessages.
213+
if (_bufferingMessages.TryGetValue(connectionDataMessage.ConnectionId, out var list))
214+
{
215+
_bufferingMessages.Remove(connectionDataMessage.ConnectionId);
216+
long length = 0;
217+
foreach (var owner in list)
218+
{
219+
using (owner)
220+
{
221+
await connection.WriteMessageAsync(new ReadOnlySequence<byte>(owner.Memory));
222+
length += owner.Memory.Length;
223+
}
224+
}
225+
var payload = connectionDataMessage.Payload;
226+
length += payload.Length;
227+
Log.WriteMessageToApplication(Logger, length, connectionDataMessage.ConnectionId);
228+
await connection.WriteMessageAsync(payload);
229+
}
230+
else
231+
{
232+
var payload = connectionDataMessage.Payload;
233+
Log.WriteMessageToApplication(Logger, payload.Length, connectionDataMessage.ConnectionId);
234+
await connection.WriteMessageAsync(payload);
235+
}
236+
}
192237
}
193238
catch (Exception ex)
194239
{
@@ -211,6 +256,7 @@ protected override Task DispatchMessageAsync(ServiceMessage message)
211256
ServiceMappingMessage serviceMappingMessage => OnServiceMappingAsync(serviceMappingMessage),
212257
ClientCompletionMessage clientCompletionMessage => OnClientCompletionAsync(clientCompletionMessage),
213258
ErrorCompletionMessage errorCompletionMessage => OnErrorCompletionAsync(errorCompletionMessage),
259+
ConnectionReconnectMessage connectionReconnectMessage => OnConnectionReconnectAsync(connectionReconnectMessage),
214260
_ => base.DispatchMessageAsync(message)
215261
};
216262
}
@@ -510,5 +556,12 @@ private Task OnErrorCompletionAsync(ErrorCompletionMessage errorCompletionMessag
510556
_clientInvocationManager.Caller.TryCompleteResult(errorCompletionMessage.ConnectionId, errorCompletionMessage);
511557
return Task.CompletedTask;
512558
}
559+
560+
private Task OnConnectionReconnectAsync(ConnectionReconnectMessage connectionReconnectMessage)
561+
{
562+
// make sure there is no await operation before _bufferingMessages.
563+
_bufferingMessages.Remove(connectionReconnectMessage.ConnectionId);
564+
return Task.CompletedTask;
565+
}
513566
}
514567
}

test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs

Lines changed: 222 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ await transportConnection.Application.Output.WriteAsync(
253253
}
254254

255255
[Fact]
256-
public async Task ClientConnectionOutgoingAbortCanEndLifeTime()
256+
public async Task TestClientConnectionOutgoingAbortCanEndLifeTime()
257257
{
258258
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
259259
logChecker: logs =>
@@ -310,7 +310,7 @@ await transportConnection.Application.Output.WriteAsync(
310310
}
311311

312312
[Fact]
313-
public async Task ClientConnectionContextAbortCanSendOutCloseMessage()
313+
public async Task TestClientConnectionContextAbortCanSendOutCloseMessage()
314314
{
315315
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
316316
logChecker: logs =>
@@ -377,7 +377,7 @@ await transportConnection.Application.Output.WriteAsync(
377377
}
378378

379379
[Fact]
380-
public async Task ClientConnectionWithDiagnosticClientTagTest()
380+
public async Task TestClientConnectionWithDiagnosticClientTagTest()
381381
{
382382
using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug))
383383
{
@@ -434,7 +434,7 @@ await transportConnection.Application.Output.WriteAsync(
434434
}
435435

436436
[Fact]
437-
public async Task ClientConnectionLastWillCanSendOut()
437+
public async Task TestClientConnectionLastWillCanSendOut()
438438
{
439439
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
440440
logChecker: logs =>
@@ -489,6 +489,166 @@ await transportConnection.Application.Output.WriteAsync(
489489
}
490490
}
491491

492+
[Fact]
493+
public async Task TestPartialMessagesShouldFlushCorrectly()
494+
{
495+
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
496+
logChecker: logs =>
497+
{
498+
Assert.Empty(logs);
499+
return true;
500+
}))
501+
{
502+
var ccm = new TestClientConnectionManager();
503+
var ccf = new ClientConnectionFactory();
504+
var protocol = new ServiceProtocol();
505+
TestConnection transportConnection = null;
506+
var connectionFactory = new TestConnectionFactory(conn =>
507+
{
508+
transportConnection = conn;
509+
return Task.CompletedTask;
510+
});
511+
var services = new ServiceCollection();
512+
513+
var connectionHandler = new TextContentConnectionHandler();
514+
services.AddSingleton(connectionHandler);
515+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
516+
builder.UseConnectionHandler<TextContentConnectionHandler>();
517+
ConnectionDelegate handler = builder.Build();
518+
var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf,
519+
"serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500);
520+
521+
var connectionTask = connection.StartAsync();
522+
523+
// completed handshake
524+
await connection.ConnectionInitializedTask.OrTimeout();
525+
Assert.Equal(ServiceConnectionStatus.Connected, connection.Status);
526+
var clientConnectionId = Guid.NewGuid().ToString();
527+
528+
var waitClientTask = ccm.WaitForClientConnectionAsync(clientConnectionId);
529+
await transportConnection.Application.Output.WriteAsync(
530+
protocol.GetMessageBytes(new OpenConnectionMessage(clientConnectionId, new Claim[] { })));
531+
532+
var clientConnection = await waitClientTask.OrTimeout();
533+
534+
var enumerator = connectionHandler.EnumerateContent().GetAsyncEnumerator();
535+
536+
// for normal message, it should flush immediately.
537+
await transportConnection.Application.Output.WriteAsync(
538+
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"protocol\":\"json\",\"version\":1}\u001e"))));
539+
await enumerator.MoveNextAsync();
540+
Assert.Equal("{\"protocol\":\"json\",\"version\":1}\u001e", enumerator.Current);
541+
542+
// for partial message, it should wait for next message to complete.
543+
await transportConnection.Application.Output.WriteAsync(
544+
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"type\":1,")) { IsPartial = true }));
545+
var delay = Task.Delay(100);
546+
var moveNextTask = enumerator.MoveNextAsync().AsTask();
547+
Assert.Same(delay, await Task.WhenAny(delay, moveNextTask));
548+
549+
// when next message comes, it should complete the previous partial message.
550+
await transportConnection.Application.Output.WriteAsync(
551+
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("\"target\":\"method\"}\u001e"))));
552+
await moveNextTask;
553+
if (enumerator.Current == "{\"type\":1,")
554+
{
555+
Assert.Equal("{\"type\":1,", enumerator.Current);
556+
await enumerator.MoveNextAsync();
557+
Assert.Equal("\"target\":\"method\"}\u001e", enumerator.Current);
558+
}
559+
else
560+
{
561+
// maybe merged into one message.
562+
Assert.Equal("{\"type\":1,\"target\":\"method\"}\u001e", enumerator.Current);
563+
}
564+
565+
// complete reading to end the connection
566+
transportConnection.Application.Output.Complete();
567+
568+
await clientConnection.LifetimeTask.OrTimeout();
569+
570+
// 1s for application task to timeout
571+
await connectionTask.OrTimeout(1000);
572+
Assert.Equal(ServiceConnectionStatus.Disconnected, connection.Status);
573+
Assert.Empty(ccm.ClientConnections);
574+
}
575+
}
576+
577+
[Fact]
578+
public async Task TestPartialMessagesShouldBeRemovedWhenReconnected()
579+
{
580+
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
581+
logChecker: logs =>
582+
{
583+
Assert.Empty(logs);
584+
return true;
585+
}))
586+
{
587+
var ccm = new TestClientConnectionManager();
588+
var ccf = new ClientConnectionFactory();
589+
var protocol = new ServiceProtocol();
590+
TestConnection transportConnection = null;
591+
var connectionFactory = new TestConnectionFactory(conn =>
592+
{
593+
transportConnection = conn;
594+
return Task.CompletedTask;
595+
});
596+
var services = new ServiceCollection();
597+
598+
var connectionHandler = new TextContentConnectionHandler();
599+
services.AddSingleton(connectionHandler);
600+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
601+
builder.UseConnectionHandler<TextContentConnectionHandler>();
602+
ConnectionDelegate handler = builder.Build();
603+
var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf,
604+
"serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500);
605+
606+
var connectionTask = connection.StartAsync();
607+
608+
// completed handshake
609+
await connection.ConnectionInitializedTask.OrTimeout();
610+
Assert.Equal(ServiceConnectionStatus.Connected, connection.Status);
611+
var clientConnectionId = Guid.NewGuid().ToString();
612+
613+
var waitClientTask = ccm.WaitForClientConnectionAsync(clientConnectionId);
614+
await transportConnection.Application.Output.WriteAsync(
615+
protocol.GetMessageBytes(new OpenConnectionMessage(clientConnectionId, new Claim[] { })));
616+
617+
var clientConnection = await waitClientTask.OrTimeout();
618+
619+
var enumerator = connectionHandler.EnumerateContent().GetAsyncEnumerator();
620+
621+
// for partial message, it should wait for next message to complete.
622+
await transportConnection.Application.Output.WriteAsync(
623+
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"type\":1,")) { IsPartial = true }));
624+
var delay = Task.Delay(100);
625+
var moveNextTask = enumerator.MoveNextAsync().AsTask();
626+
Assert.Same(delay, await Task.WhenAny(delay, moveNextTask));
627+
628+
// for reconnect message, it should remove all partial messages for the connection.
629+
await transportConnection.Application.Output.WriteAsync(
630+
protocol.GetMessageBytes(new ConnectionReconnectMessage(clientConnectionId)));
631+
delay = Task.Delay(100);
632+
Assert.Same(delay, await Task.WhenAny(delay, moveNextTask));
633+
634+
// when next message comes, there is no partial message.
635+
await transportConnection.Application.Output.WriteAsync(
636+
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"type\":1,\"target\":\"method\"}\u001e"))));
637+
await moveNextTask;
638+
Assert.Equal("{\"type\":1,\"target\":\"method\"}\u001e", enumerator.Current);
639+
640+
// complete reading to end the connection
641+
transportConnection.Application.Output.Complete();
642+
643+
await clientConnection.LifetimeTask.OrTimeout();
644+
645+
// 1s for application task to timeout
646+
await connectionTask.OrTimeout(1000);
647+
Assert.Equal(ServiceConnectionStatus.Disconnected, connection.Status);
648+
Assert.Empty(ccm.ClientConnections);
649+
}
650+
}
651+
492652
private sealed class TestConnectionHandler : ConnectionHandler
493653
{
494654
private TaskCompletionSource<object> _startedTcs = new TaskCompletionSource<object>();
@@ -587,6 +747,64 @@ public override Task OnConnectedAsync(ConnectionContext connection)
587747
}
588748
}
589749

750+
private sealed class TextContentConnectionHandler : ConnectionHandler
751+
{
752+
private TaskCompletionSource<object> _startedTcs = new TaskCompletionSource<object>();
753+
private LinkedList<TaskCompletionSource<string>> _content = new LinkedList<TaskCompletionSource<string>>();
754+
755+
public TextContentConnectionHandler()
756+
{
757+
_content.AddFirst(new TaskCompletionSource<string>());
758+
}
759+
760+
public Task Started => _startedTcs.Task;
761+
762+
public override async Task OnConnectedAsync(ConnectionContext connection)
763+
{
764+
_startedTcs.TrySetResult(null);
765+
766+
while (true)
767+
{
768+
var result = await connection.Transport.Input.ReadAsync();
769+
770+
try
771+
{
772+
if (!result.Buffer.IsEmpty)
773+
{
774+
var last = _content.Last.Value;
775+
_content.AddLast(new TaskCompletionSource<string>());
776+
var text = Encoding.UTF8.GetString(result.Buffer);
777+
last.TrySetResult(text);
778+
}
779+
780+
if (result.IsCompleted)
781+
{
782+
_content.Last.Value.TrySetResult(null);
783+
break;
784+
}
785+
}
786+
finally
787+
{
788+
connection.Transport.Input.AdvanceTo(result.Buffer.End);
789+
}
790+
}
791+
}
792+
793+
public async IAsyncEnumerable<string> EnumerateContent()
794+
{
795+
while (true)
796+
{
797+
var result = await _content.First.Value.Task;
798+
_content.RemoveFirst();
799+
if (result == null)
800+
{
801+
yield break;
802+
}
803+
yield return result;
804+
}
805+
}
806+
}
807+
590808
internal sealed class TestClientConnectionManager : IClientConnectionManager
591809
{
592810
private readonly ClientConnectionManager _ccm = new ClientConnectionManager();

0 commit comments

Comments
 (0)