Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.Azure.SignalR/Microsoft.Azure.SignalR.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

<ItemGroup>
<Compile Include="..\Common\MemoryBufferWriter.cs" Link="Internals\MemoryBufferWriter.cs" />
<Compile Include="..\Microsoft.Azure.SignalR.Emulator\Common\ExactSizeMemoryPool.cs" Link="Internals\ExactSizeMemoryPool.cs" />
</ItemGroup>

<!-- Copy the referenced Microsoft.Azure.SignalR.Common.dll into the package https://github.com/nuget/home/issues/3891#issuecomment-397481361 -->
Expand Down
59 changes: 56 additions & 3 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
Expand Down Expand Up @@ -49,6 +50,10 @@ internal partial class ServiceConnection : ServiceConnectionBase

private readonly AckHandler _ackHandler;

// Performance: Do not use ConcurrentDictionary. There is no multi-threading scenario here, all operations are in the same logical thread.
private readonly Dictionary<string, List<IMemoryOwner<byte>>> _bufferingMessages =
new Dictionary<string, List<IMemoryOwner<byte>>>(StringComparer.Ordinal);

public Action<HttpContext> ConfigureContext { get; set; }

public ServiceConnection(IServiceProtocol serviceProtocol,
Expand Down Expand Up @@ -99,6 +104,8 @@ protected override Task CleanupClientConnections(string fromInstanceId = null)
continue;
}

// make sure there is no await operation before _bufferingMessages.
_bufferingMessages.Remove(connection.Key);
// We should not wait until all the clients' lifetime ends to restart another service connection
_ = PerformDisconnectAsyncCore(connection.Key);
}
Expand Down Expand Up @@ -157,6 +164,8 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeConnectionMessage)
{
var connectionId = closeConnectionMessage.ConnectionId;
// make sure there is no await operation before _bufferingMessages.
_bufferingMessages.Remove(connectionId);
if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var context))
{
if (closeConnectionMessage.Headers.TryGetValue(Constants.AsrsMigrateTo, out var to))
Expand Down Expand Up @@ -186,9 +195,45 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect
{
try
{
var payload = connectionDataMessage.Payload;
Log.WriteMessageToApplication(Logger, payload.Length, connectionDataMessage.ConnectionId);
await connection.WriteMessageAsync(payload);
if (connectionDataMessage.IsPartial)
{
var owner = ExactSizeMemoryPool.Shared.Rent((int)connectionDataMessage.Payload.Length);
connectionDataMessage.Payload.CopyTo(owner.Memory.Span);
// make sure there is no await operation before _bufferingMessages.
if (!_bufferingMessages.TryGetValue(connectionDataMessage.ConnectionId, out var list))
{
list = new List<IMemoryOwner<byte>>();
_bufferingMessages[connectionDataMessage.ConnectionId] = list;
}
list.Add(owner);
}
else
{
// make sure there is no await operation before _bufferingMessages.
if (_bufferingMessages.TryGetValue(connectionDataMessage.ConnectionId, out var list))
{
_bufferingMessages.Remove(connectionDataMessage.ConnectionId);
long length = 0;
foreach (var owner in list)
{
using (owner)
{
await connection.WriteMessageAsync(new ReadOnlySequence<byte>(owner.Memory));
length += owner.Memory.Length;
}
}
var payload = connectionDataMessage.Payload;
length += payload.Length;
Log.WriteMessageToApplication(Logger, length, connectionDataMessage.ConnectionId);
await connection.WriteMessageAsync(payload);
}
else
{
var payload = connectionDataMessage.Payload;
Log.WriteMessageToApplication(Logger, payload.Length, connectionDataMessage.ConnectionId);
await connection.WriteMessageAsync(payload);
}
}
}
catch (Exception ex)
{
Expand All @@ -211,6 +256,7 @@ protected override Task DispatchMessageAsync(ServiceMessage message)
ServiceMappingMessage serviceMappingMessage => OnServiceMappingAsync(serviceMappingMessage),
ClientCompletionMessage clientCompletionMessage => OnClientCompletionAsync(clientCompletionMessage),
ErrorCompletionMessage errorCompletionMessage => OnErrorCompletionAsync(errorCompletionMessage),
ConnectionReconnectMessage connectionReconnectMessage => OnConnectionReconnectAsync(connectionReconnectMessage),
_ => base.DispatchMessageAsync(message)
};
}
Expand Down Expand Up @@ -510,5 +556,12 @@ private Task OnErrorCompletionAsync(ErrorCompletionMessage errorCompletionMessag
_clientInvocationManager.Caller.TryCompleteResult(errorCompletionMessage.ConnectionId, errorCompletionMessage);
return Task.CompletedTask;
}

private Task OnConnectionReconnectAsync(ConnectionReconnectMessage connectionReconnectMessage)
{
// make sure there is no await operation before _bufferingMessages.
_bufferingMessages.Remove(connectionReconnectMessage.ConnectionId);
return Task.CompletedTask;
}
}
}
226 changes: 222 additions & 4 deletions test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ await transportConnection.Application.Output.WriteAsync(
}

[Fact]
public async Task ClientConnectionOutgoingAbortCanEndLifeTime()
public async Task TestClientConnectionOutgoingAbortCanEndLifeTime()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
logChecker: logs =>
Expand Down Expand Up @@ -310,7 +310,7 @@ await transportConnection.Application.Output.WriteAsync(
}

[Fact]
public async Task ClientConnectionContextAbortCanSendOutCloseMessage()
public async Task TestClientConnectionContextAbortCanSendOutCloseMessage()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
logChecker: logs =>
Expand Down Expand Up @@ -377,7 +377,7 @@ await transportConnection.Application.Output.WriteAsync(
}

[Fact]
public async Task ClientConnectionWithDiagnosticClientTagTest()
public async Task TestClientConnectionWithDiagnosticClientTagTest()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug))
{
Expand Down Expand Up @@ -434,7 +434,7 @@ await transportConnection.Application.Output.WriteAsync(
}

[Fact]
public async Task ClientConnectionLastWillCanSendOut()
public async Task TestClientConnectionLastWillCanSendOut()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
logChecker: logs =>
Expand Down Expand Up @@ -489,6 +489,166 @@ await transportConnection.Application.Output.WriteAsync(
}
}

[Fact]
public async Task TestPartialMessagesShouldFlushCorrectly()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
logChecker: logs =>
{
Assert.Empty(logs);
return true;
}))
{
var ccm = new TestClientConnectionManager();
var ccf = new ClientConnectionFactory();
var protocol = new ServiceProtocol();
TestConnection transportConnection = null;
var connectionFactory = new TestConnectionFactory(conn =>
{
transportConnection = conn;
return Task.CompletedTask;
});
var services = new ServiceCollection();

var connectionHandler = new TextContentConnectionHandler();
services.AddSingleton(connectionHandler);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TextContentConnectionHandler>();
ConnectionDelegate handler = builder.Build();
var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf,
"serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500);

var connectionTask = connection.StartAsync();

// completed handshake
await connection.ConnectionInitializedTask.OrTimeout();
Assert.Equal(ServiceConnectionStatus.Connected, connection.Status);
var clientConnectionId = Guid.NewGuid().ToString();

var waitClientTask = ccm.WaitForClientConnectionAsync(clientConnectionId);
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new OpenConnectionMessage(clientConnectionId, new Claim[] { })));

var clientConnection = await waitClientTask.OrTimeout();

var enumerator = connectionHandler.EnumerateContent().GetAsyncEnumerator();

// for normal message, it should flush immediately.
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"protocol\":\"json\",\"version\":1}\u001e"))));
await enumerator.MoveNextAsync();
Assert.Equal("{\"protocol\":\"json\",\"version\":1}\u001e", enumerator.Current);

// for partial message, it should wait for next message to complete.
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"type\":1,")) { IsPartial = true }));
var delay = Task.Delay(100);
var moveNextTask = enumerator.MoveNextAsync().AsTask();
Assert.Same(delay, await Task.WhenAny(delay, moveNextTask));

// when next message comes, it should complete the previous partial message.
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("\"target\":\"method\"}\u001e"))));
await moveNextTask;
if (enumerator.Current == "{\"type\":1,")
{
Assert.Equal("{\"type\":1,", enumerator.Current);
await enumerator.MoveNextAsync();
Assert.Equal("\"target\":\"method\"}\u001e", enumerator.Current);
}
else
{
// maybe merged into one message.
Assert.Equal("{\"type\":1,\"target\":\"method\"}\u001e", enumerator.Current);
}

// complete reading to end the connection
transportConnection.Application.Output.Complete();

await clientConnection.LifetimeTask.OrTimeout();

// 1s for application task to timeout
await connectionTask.OrTimeout(1000);
Assert.Equal(ServiceConnectionStatus.Disconnected, connection.Status);
Assert.Empty(ccm.ClientConnections);
}
}

[Fact]
public async Task TestPartialMessagesShouldBeRemovedWhenReconnected()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: c => true,
logChecker: logs =>
{
Assert.Empty(logs);
return true;
}))
{
var ccm = new TestClientConnectionManager();
var ccf = new ClientConnectionFactory();
var protocol = new ServiceProtocol();
TestConnection transportConnection = null;
var connectionFactory = new TestConnectionFactory(conn =>
{
transportConnection = conn;
return Task.CompletedTask;
});
var services = new ServiceCollection();

var connectionHandler = new TextContentConnectionHandler();
services.AddSingleton(connectionHandler);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TextContentConnectionHandler>();
ConnectionDelegate handler = builder.Build();
var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf,
"serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500);

var connectionTask = connection.StartAsync();

// completed handshake
await connection.ConnectionInitializedTask.OrTimeout();
Assert.Equal(ServiceConnectionStatus.Connected, connection.Status);
var clientConnectionId = Guid.NewGuid().ToString();

var waitClientTask = ccm.WaitForClientConnectionAsync(clientConnectionId);
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new OpenConnectionMessage(clientConnectionId, new Claim[] { })));

var clientConnection = await waitClientTask.OrTimeout();

var enumerator = connectionHandler.EnumerateContent().GetAsyncEnumerator();

// for partial message, it should wait for next message to complete.
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"type\":1,")) { IsPartial = true }));
var delay = Task.Delay(100);
var moveNextTask = enumerator.MoveNextAsync().AsTask();
Assert.Same(delay, await Task.WhenAny(delay, moveNextTask));

// for reconnect message, it should remove all partial messages for the connection.
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new ConnectionReconnectMessage(clientConnectionId)));
delay = Task.Delay(100);
Assert.Same(delay, await Task.WhenAny(delay, moveNextTask));

// when next message comes, there is no partial message.
await transportConnection.Application.Output.WriteAsync(
protocol.GetMessageBytes(new ConnectionDataMessage(clientConnectionId, Encoding.UTF8.GetBytes("{\"type\":1,\"target\":\"method\"}\u001e"))));
await moveNextTask;
Assert.Equal("{\"type\":1,\"target\":\"method\"}\u001e", enumerator.Current);

// complete reading to end the connection
transportConnection.Application.Output.Complete();

await clientConnection.LifetimeTask.OrTimeout();

// 1s for application task to timeout
await connectionTask.OrTimeout(1000);
Assert.Equal(ServiceConnectionStatus.Disconnected, connection.Status);
Assert.Empty(ccm.ClientConnections);
}
}

private sealed class TestConnectionHandler : ConnectionHandler
{
private TaskCompletionSource<object> _startedTcs = new TaskCompletionSource<object>();
Expand Down Expand Up @@ -587,6 +747,64 @@ public override Task OnConnectedAsync(ConnectionContext connection)
}
}

private sealed class TextContentConnectionHandler : ConnectionHandler
{
private TaskCompletionSource<object> _startedTcs = new TaskCompletionSource<object>();
private LinkedList<TaskCompletionSource<string>> _content = new LinkedList<TaskCompletionSource<string>>();

public TextContentConnectionHandler()
{
_content.AddFirst(new TaskCompletionSource<string>());
}

public Task Started => _startedTcs.Task;

public override async Task OnConnectedAsync(ConnectionContext connection)
{
_startedTcs.TrySetResult(null);

while (true)
{
var result = await connection.Transport.Input.ReadAsync();

try
{
if (!result.Buffer.IsEmpty)
{
var last = _content.Last.Value;
_content.AddLast(new TaskCompletionSource<string>());
var text = Encoding.UTF8.GetString(result.Buffer);
last.TrySetResult(text);
}

if (result.IsCompleted)
{
_content.Last.Value.TrySetResult(null);
break;
}
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}

public async IAsyncEnumerable<string> EnumerateContent()
{
while (true)
{
var result = await _content.First.Value.Task;
_content.RemoveFirst();
if (result == null)
{
yield break;
}
yield return result;
}
}
}

internal sealed class TestClientConnectionManager : IClientConnectionManager
{
private readonly ClientConnectionManager _ccm = new ClientConnectionManager();
Expand Down