Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.Azure.SignalR.Protocols/.editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

[*.{cs,vb}]

# IDE0060: Remove unused parameter
# Disable for protocol project for now
dotnet_diagnostic.IDE0060.severity = silent
27 changes: 20 additions & 7 deletions src/Microsoft.Azure.SignalR.Protocols/MessagePackUtils.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Security.Claims;

using MessagePack;
Expand Down Expand Up @@ -125,7 +126,7 @@ internal static int ReadInt32(ref MessagePackReader reader, string field)

internal static string ReadStringNotNull(ref MessagePackReader reader, string field)
{
string? result = null;
string? result;
try
{
result = reader.ReadString();
Expand Down Expand Up @@ -156,21 +157,33 @@ internal static string ReadStringNotNull(ref MessagePackReader reader, string fi
}
}

internal static string[] ReadStringArray(ref MessagePackReader reader, string field)
internal static string[] ReadStringArrayExcludeNull(ref MessagePackReader reader, string field)
{
var arrayLength = ReadArrayLength(ref reader, field);
if (arrayLength > 0)
{
var array = new string[arrayLength];
var count = 0;
for (int i = 0; i < arrayLength; i++)
{
var fieldName = $"{field}[{i}]";
array[i] = ReadStringNotNull(ref reader, fieldName);
var val = ReadString(ref reader, fieldName);
if (val != null)
{
array[count] = val;
count++;
}
}

return array;
if (arrayLength == count)
{
return array;
}
else
{
return array.Take(count).ToArray();
}
}

return [];
}

Expand Down Expand Up @@ -222,7 +235,7 @@ internal static long ReadMapLength(ref MessagePackReader reader, string field)
}
}

internal static long ReadArrayLength(ref MessagePackReader reader, string field)
internal static int ReadArrayLength(ref MessagePackReader reader, string field)
{
try
{
Expand Down
18 changes: 9 additions & 9 deletions src/Microsoft.Azure.SignalR.Protocols/ServiceProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ private static CloseConnectionsWithAckMessage CreateCloseConnectionsWithAckMessa
{
var reason = ReadString(ref reader, "reason");
var ackId = ReadInt32(ref reader, "ackId");
var excluded = ReadStringArray(ref reader, "excluded");
var excluded = ReadStringArrayExcludeNull(ref reader, "excluded");

var result = new CloseConnectionsWithAckMessage(ackId)
{
Expand All @@ -992,7 +992,7 @@ private static CloseUserConnectionsWithAckMessage CreateCloseUserConnectionsWith
var userId = ReadStringNotNull(ref reader, "userId");
var reason = ReadString(ref reader, "reason");
var ackId = ReadInt32(ref reader, "ackId");
var excluded = ReadStringArray(ref reader, "excluded");
var excluded = ReadStringArrayExcludeNull(ref reader, "excluded");

var result = new CloseUserConnectionsWithAckMessage(userId, ackId)
{
Expand All @@ -1011,7 +1011,7 @@ private static CloseGroupConnectionsWithAckMessage CreateCloseGroupConnectionsWi
var group = ReadStringNotNull(ref reader, "group");
var reason = ReadString(ref reader, "reason");
var ackId = ReadInt32(ref reader, "ackId");
var excluded = ReadStringArray(ref reader, "excluded");
var excluded = ReadStringArrayExcludeNull(ref reader, "excluded");

var result = new CloseGroupConnectionsWithAckMessage(group, ackId)
{
Expand Down Expand Up @@ -1049,7 +1049,7 @@ private static ConnectionReconnectMessage CreateConnectionReconnectMessage(ref M

private static MultiConnectionDataMessage CreateMultiConnectionDataMessage(ref MessagePackReader reader, int arrayLength)
{
var connectionList = ReadStringArray(ref reader, "connectionList");
var connectionList = ReadStringArrayExcludeNull(ref reader, "connectionList");
var payloads = ReadPayloads(ref reader);

var result = new MultiConnectionDataMessage(connectionList, payloads);
Expand All @@ -1075,7 +1075,7 @@ private static ServiceMessage CreateUserDataMessage(ref MessagePackReader reader

private static MultiUserDataMessage CreateMultiUserDataMessage(ref MessagePackReader reader, int arrayLength)
{
var userList = ReadStringArray(ref reader, "userList");
var userList = ReadStringArrayExcludeNull(ref reader, "userList");
var payloads = ReadPayloads(ref reader);

var result = new MultiUserDataMessage(userList, payloads);
Expand All @@ -1088,7 +1088,7 @@ private static MultiUserDataMessage CreateMultiUserDataMessage(ref MessagePackRe

private static BroadcastDataMessage CreateBroadcastDataMessage(ref MessagePackReader reader, int arrayLength)
{
var excludedList = ReadStringArray(ref reader, "excludedList");
var excludedList = ReadStringArrayExcludeNull(ref reader, "excludedList");
var payloads = ReadPayloads(ref reader);

var result = new BroadcastDataMessage(excludedList, payloads);
Expand Down Expand Up @@ -1176,7 +1176,7 @@ private static UserLeaveGroupWithAckMessage CreateUserLeaveGroupWithAckMessage(r
private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength)
{
var groupName = ReadStringNotNull(ref reader, "groupName");
var excludedList = ReadStringArray(ref reader, "excludedList");
var excludedList = ReadStringArrayExcludeNull(ref reader, "excludedList");
var payloads = ReadPayloads(ref reader);

var result = new GroupBroadcastDataMessage(groupName, excludedList, payloads);
Expand All @@ -1187,7 +1187,7 @@ private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref Mes

if (arrayLength >= 7)
{
result.ExcludedUserList = ReadStringArray(ref reader, "excludedUserList");
result.ExcludedUserList = ReadStringArrayExcludeNull(ref reader, "excludedUserList");
result.CallerUserId = ReadString(ref reader, "callerUserId");
}

Expand All @@ -1196,7 +1196,7 @@ private static GroupBroadcastDataMessage CreateGroupBroadcastDataMessage(ref Mes

private static MultiGroupBroadcastDataMessage CreateMultiGroupBroadcastDataMessage(ref MessagePackReader reader, int arrayLength)
{
var groupList = ReadStringArray(ref reader, "groupList");
var groupList = ReadStringArrayExcludeNull(ref reader, "groupList");
var payloads = ReadPayloads(ref reader);

var result = new MultiGroupBroadcastDataMessage(groupList, payloads);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
Expand Down Expand Up @@ -171,20 +170,20 @@ private static bool CloseConnectionWithAckMessagesEqual(CloseConnectionWithAckMe

private static bool CloseConnectionsWithAckMessagesEqual(CloseConnectionsWithAckMessage x, CloseConnectionsWithAckMessage y)
{
return SequenceEqual(x.ExcludedList, y.ExcludedList) && StringEqual(x.Reason, y.Reason)
return SequenceEqualSkipNull(x.ExcludedList, y.ExcludedList) && StringEqual(x.Reason, y.Reason)
&& x.TracingId == y.TracingId && x.AckId == y.AckId;
}
#pragma warning restore CS0618 // Type or member is obsolete

private static bool CloseUserConnectionsWithAckMessagesEqual(CloseUserConnectionsWithAckMessage x, CloseUserConnectionsWithAckMessage y)
{
return SequenceEqual(x.ExcludedList, y.ExcludedList) && StringEqual(x.Reason, y.Reason) && StringEqual(x.UserId, y.UserId)
return SequenceEqualSkipNull(x.ExcludedList, y.ExcludedList) && StringEqual(x.Reason, y.Reason) && StringEqual(x.UserId, y.UserId)
&& x.TracingId == y.TracingId && x.AckId == y.AckId;
}

private static bool CloseGroupConnectionsWithAckMessagesEqual(CloseGroupConnectionsWithAckMessage x, CloseGroupConnectionsWithAckMessage y)
{
return SequenceEqual(x.ExcludedList, y.ExcludedList) && StringEqual(x.Reason, y.Reason) && StringEqual(x.GroupName, y.GroupName)
return SequenceEqualSkipNull(x.ExcludedList, y.ExcludedList) && StringEqual(x.Reason, y.Reason) && StringEqual(x.GroupName, y.GroupName)
&& x.TracingId == y.TracingId && x.AckId == y.AckId;
}

Expand All @@ -204,7 +203,7 @@ private static bool ConnectionReconnectMessagesEqual(ConnectionReconnectMessage

private static bool MultiConnectionDataMessagesEqual(MultiConnectionDataMessage x, MultiConnectionDataMessage y)
{
return SequenceEqual(x.ConnectionList, y.ConnectionList) &&
return SequenceEqualSkipNull(x.ConnectionList, y.ConnectionList) &&
PayloadsEqual(x.Payloads, y.Payloads) &&
x.TracingId == y.TracingId;
}
Expand All @@ -218,14 +217,14 @@ private static bool UserDataMessagesEqual(UserDataMessage x, UserDataMessage y)

private static bool MultiUserDataMessagesEqual(MultiUserDataMessage x, MultiUserDataMessage y)
{
return SequenceEqual(x.UserList, y.UserList) &&
return SequenceEqualSkipNull(x.UserList, y.UserList) &&
PayloadsEqual(x.Payloads, y.Payloads) &&
x.TracingId == y.TracingId;
}

private static bool BroadcastDataMessagesEqual(BroadcastDataMessage x, BroadcastDataMessage y)
{
return SequenceEqual(x.ExcludedList, y.ExcludedList) &&
return SequenceEqualSkipNull(x.ExcludedList, y.ExcludedList) &&
PayloadsEqual(x.Payloads, y.Payloads) &&
x.TracingId == y.TracingId;
}
Expand Down Expand Up @@ -280,16 +279,16 @@ private static bool GroupBroadcastDataMessagesEqual(GroupBroadcastDataMessage x,
{
return StringEqual(x.GroupName, y.GroupName) &&
StringEqual(x.CallerUserId, y.CallerUserId) &&
SequenceEqual(x.ExcludedList, y.ExcludedList) &&
SequenceEqual(x.ExcludedUserList, y.ExcludedUserList) &&
SequenceEqualSkipNull(x.ExcludedList, y.ExcludedList) &&
SequenceEqualSkipNull(x.ExcludedUserList, y.ExcludedUserList) &&
PayloadsEqual(x.Payloads, y.Payloads) &&
x.TracingId == y.TracingId;
}

private static bool MultiGroupBroadcastDataMessagesEqual(MultiGroupBroadcastDataMessage x,
MultiGroupBroadcastDataMessage y)
{
return SequenceEqual(x.GroupList, y.GroupList) &&
return SequenceEqualSkipNull(x.GroupList, y.GroupList) &&
PayloadsEqual(x.Payloads, y.Payloads) &&
x.TracingId == y.TracingId;
}
Expand Down Expand Up @@ -487,33 +486,34 @@ private static bool HeadersEqual(IDictionary<string, StringValues> x, IDictionar
return true;
}

private static bool SequenceEqual(object left, object right)
private static bool SequenceEqualSkipNull<T>(IEnumerable<T> leftEnumerable, IEnumerable<T> rightEnumerable)
{
if (left == null && right == null)
if (leftEnumerable == null && rightEnumerable == null)
{
return true;
}

var leftEnumerable = left as IEnumerable;
var rightEnumerable = right as IEnumerable;
if (leftEnumerable == null || rightEnumerable == null)
{
return false;
}

var leftEnumerator = leftEnumerable.GetEnumerator();
var rightEnumerator = rightEnumerable.GetEnumerator();
var leftMoved = leftEnumerator.MoveNext();
var rightMoved = rightEnumerator.MoveNext();
for (; leftMoved && rightMoved; leftMoved = leftEnumerator.MoveNext(), rightMoved = rightEnumerator.MoveNext())
return leftEnumerable.Where(x => x != null).SequenceEqual(rightEnumerable.Where(x => x != null));
}

private static bool SequenceEqual<T>(IEnumerable<T> leftEnumerable, IEnumerable<T> rightEnumerable)
{
if (leftEnumerable == null && rightEnumerable == null)
{
if (!Equals(leftEnumerator.Current, rightEnumerator.Current))
{
return false;
}
return true;
}

if (leftEnumerable == null || rightEnumerable == null)
{
return false;
}

return !leftMoved && !rightMoved;
return leftEnumerable.SequenceEqual(rightEnumerable);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ public static IEnumerable<object[]> TestParseOldData
name: "Ping+",
message: new PingMessage { Messages = new string[] { "a", "b" } },
binary: "kwOhYaFi"),
new ProtocolTestData(
name: "Ping+_Nullable",
message: new PingMessage { Messages = new string[] { "a", null } },
binary: "kwOhYcA="),
new ProtocolTestData(
name: "OpenConnection",
message: new OpenConnectionMessage("conn1", null),
Expand Down Expand Up @@ -408,6 +412,14 @@ public static IEnumerable<object[]> TestParseOldData
["messagepack"] = new byte[] {3, 4, 5, 6, 7, 1, 2}
}),
binary: "lAeSpWNvbm42pWNvbm43gqRqc29uxAcCAwQFBgcBq21lc3NhZ2VwYWNrxAcDBAUGBwECgA=="),
new ProtocolTestData(
name: "MultiConnectionData_Nullable",
message: new MultiConnectionDataMessage(new [] {"conn6", null}, new Dictionary<string, ReadOnlyMemory<byte>>
{
["json"] = new byte[] {2, 3, 4, 5, 6, 7, 1},
["messagepack"] = new byte[] {3, 4, 5, 6, 7, 1, 2}
}),
binary: "lAeSpWNvbm42wIKkanNvbsQHAgMEBQYHAattZXNzYWdlcGFja8QHAwQFBgcBAoA="),
new ProtocolTestData(
name: "UserData",
message: new UserDataMessage("user1",
Expand All @@ -426,6 +438,15 @@ public static IEnumerable<object[]> TestParseOldData
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
}),
binary: "lAmSpXVzZXIxpXVzZXIygqRqc29uxAcGBwECAwQFq21lc3NhZ2VwYWNrxAcHAQIDBAUGgA=="),
new ProtocolTestData(
name: "MultiUserData_Nullable",
message: new MultiUserDataMessage(new [] {"user1", null},
new Dictionary<string, ReadOnlyMemory<byte>>
{
["json"] = new byte[] {6, 7, 1, 2, 3, 4, 5},
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
}),
binary: "lAmSpXVzZXIxwIKkanNvbsQHBgcBAgMEBattZXNzYWdlcGFja8QHBwECAwQFBoA="),
new ProtocolTestData(
name: "Broadcast",
message: new BroadcastDataMessage(new Dictionary<string, ReadOnlyMemory<byte>>
Expand All @@ -443,6 +464,15 @@ public static IEnumerable<object[]> TestParseOldData
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
}),
binary: "lAqTpWNvbm43pWNvbm44pWNvbm45gqRqc29uxAcGBwECAwQFq21lc3NhZ2VwYWNrxAcHAQIDBAUGgA=="),
new ProtocolTestData(
name: "BroadcastExcept_Nullable",
message: new BroadcastDataMessage(new[] {"conn7", null},
new Dictionary<string, ReadOnlyMemory<byte>>
{
["json"] = new byte[] {6, 7, 1, 2, 3, 4, 5},
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
}),
binary: "lAqSpWNvbm43wIKkanNvbsQHBgcBAgMEBattZXNzYWdlcGFja8QHBwECAwQFBoA="),
new ProtocolTestData(
name: "JoinGroup",
message: new JoinGroupMessage("conn10", "group1"),
Expand Down Expand Up @@ -481,6 +511,15 @@ public static IEnumerable<object[]> TestParseOldData
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
}),
binary: "lw2mZ3JvdXAzkqZjb25uMTKmY29ubjEzgqRqc29uxAcGBwECAwQFq21lc3NhZ2VwYWNrxAcHAQIDBAUGgJDA"),
new ProtocolTestData(
name: "GroupBroadcastExcept_Nullable",
message: new GroupBroadcastDataMessage("group3", new [] {"conn12", null},
new Dictionary<string, ReadOnlyMemory<byte>>
{
["json"] = new byte[] {6, 7, 1, 2, 3, 4, 5},
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
}),
binary: "lw2mZ3JvdXAzkqZjb25uMTLAgqRqc29uxAcGBwECAwQFq21lc3NhZ2VwYWNrxAcHAQIDBAUGgJDA"),
new ProtocolTestData(
name: "GroupBroadcastExceptUser",
message: new GroupBroadcastDataMessage("group3", new [] {"conn12", "conn13"},
Expand All @@ -494,6 +533,19 @@ public static IEnumerable<object[]> TestParseOldData
CallerUserId = "user3"
},
binary: "lw2mZ3JvdXAzkqZjb25uMTKmY29ubjEzgqRqc29uxAcGBwECAwQFq21lc3NhZ2VwYWNrxAcHAQIDBAUGgJKldXNlcjGldXNlcjKldXNlcjM="),
new ProtocolTestData(
name: "GroupBroadcastExceptUser_Nullable",
message: new GroupBroadcastDataMessage("group3", new [] {"conn12", null},
new Dictionary<string, ReadOnlyMemory<byte>>
{
["json"] = new byte[] {6, 7, 1, 2, 3, 4, 5},
["messagepack"] = new byte[] {7, 1, 2, 3, 4, 5, 6}
})
{
ExcludedUserList = new [] {"user1", null},
CallerUserId = "user3"
},
binary: "lw2mZ3JvdXAzkqZjb25uMTLAgqRqc29uxAcGBwECAwQFq21lc3NhZ2VwYWNrxAcHAQIDBAUGgJKldXNlcjHApXVzZXIz"),
new ProtocolTestData(
name: "GroupBroadcastWithTracingId",
message: new GroupBroadcastDataMessage("group3",
Expand All @@ -512,6 +564,15 @@ public static IEnumerable<object[]> TestParseOldData
["messagepack"] = new byte[] {7, 8, 1, 2, 3, 4, 5, 6}
}),
binary: "lA6Spmdyb3VwNKZncm91cDWCpGpzb27ECAECAwQFBgcIq21lc3NhZ2VwYWNrxAgHCAECAwQFBoA="),
new ProtocolTestData(
name: "MultiGroupBroadcast_Nullable",
message: new MultiGroupBroadcastDataMessage(new [] {"group4", null},
new Dictionary<string, ReadOnlyMemory<byte>>
{
["json"] = new byte[] {1, 2, 3, 4, 5, 6, 7, 8},
["messagepack"] = new byte[] {7, 8, 1, 2, 3, 4, 5, 6}
}),
binary: "lA6Spmdyb3VwNMCCpGpzb27ECAECAwQFBgcIq21lc3NhZ2VwYWNrxAgHCAECAwQFBoA="),
new ProtocolTestData(
name: "ServiceError",
message: new ServiceErrorMessage("Maximum message count limit reached: 100000"),
Expand Down