Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageId>Microsoft.IdentityModel.LoggingExtensions</PackageId>
<PackageTags>.NET;Windows;Authentication;Identity;Extensions;Logging</PackageTags>
<TargetFrameworks>netstandard2.0</TargetFrameworks>
<TargetFrameworks>netstandard2.0;net9.0</TargetFrameworks>
<Nullable>enable</Nullable>
</PropertyGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.IdentityModel.TestUtils;
using Microsoft.IdentityModel.Tokens.Json.Tests;
using Microsoft.IdentityModel.Tokens;
using Microsoft.IdentityModel.Tokens.Json.Tests;
using Xunit;

namespace Microsoft.IdentityModel.JsonWebTokens.Tests
Expand Down Expand Up @@ -100,6 +102,44 @@ public void GetClaimAsType(JsonClaimSetTheoryData theoryData)
TestUtilities.AssertFailIfErrors(context);
}

// Tests a JsonClaimSet, to ensure the same List object is returned for concurrent calls to the Claims member.
[Fact]
public async Task ValidJsonClaimSet_ConcurrencyTest()
{
// Arrange
var numThreads = 10;
var barrier = new Barrier(numThreads);
var jsonClaims = new Dictionary<string, object>
{
{ "claim1", "value1" },
{ "claim2", "value2" }
};
var jsonClaimSet = new JsonClaimSet(jsonClaims);
List<Claim>[] allClaims = new List<Claim>[numThreads];
Task[] tasks = new Task[numThreads];

for (var i = 0; i < numThreads; i++)
{
var index = i;
tasks[i] = (Task.Run(() =>
{
barrier.SignalAndWait();
allClaims[index] = jsonClaimSet.Claims("claim1");
}));
}

// Act
await Task.WhenAll(tasks);

// Assert
Assert.All(allClaims, claims => Assert.NotNull(claims));
var firstClaims = allClaims[0];
for (var i = 1; i < numThreads; i++)
{
Assert.Same(firstClaims, allClaims[i]);
}
}

public static TheoryData<JsonClaimSetTheoryData> GetClaimAsTypeTheoryData()
{
var theoryData = new TheoryData<JsonClaimSetTheoryData>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.IdentityModel.TestUtils;
using Xunit;

Expand Down Expand Up @@ -280,6 +283,39 @@ public void UnwrapParameterCheck(KeyWrapTheoryData theoryData)
TestUtilities.AssertFailIfErrors(context);
}

// Tests that concurrent calls to WrapKey and UnwrapKey do not run into encrypt/decrypt lock contention issues or other race conditions.
[Fact]
public void WrapAndUnwrapKey_ConcurrencyTest()
{
// Arrange
var numThreads = 10;
var wrapBarrier = new Barrier(numThreads);
var unwrapBarrier = new Barrier(numThreads);
var barrierTimeoutInMs = 5000;
var key = new SymmetricSecurityKey(new byte[32]);
var provider = new SymmetricKeyWrapProvider(key, SecurityAlgorithms.Aes256KW);
var tasks = new List<Task>(numThreads);

// Act and Assert
for (int i = 0; i < numThreads; i++)
{
tasks.Add(Task.Run(() =>
{
// Wait for all threads to be ready before checking the WrapKey locks
var keyBytes = new byte[32];
wrapBarrier.SignalAndWait(barrierTimeoutInMs);
var wrappedKey = provider.WrapKey(keyBytes);
Assert.NotNull(wrappedKey);

// Wait for all threads to be ready before checking the UnwrapKey locks
unwrapBarrier.SignalAndWait(barrierTimeoutInMs);
var unwrappedKey = provider.UnwrapKey(wrappedKey);
Assert.NotNull(unwrappedKey);
}));
}
Task.WhenAll(tasks);
}

public static TheoryData<KeyWrapTheoryData> UnwrapTheoryData()
{
var theoryData = new TheoryData<KeyWrapTheoryData>();
Expand Down
35 changes: 35 additions & 0 deletions test/Microsoft.IdentityModel.Tokens.Tests/SecurityKeyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.IdentityModel.TestUtils;
using Xunit;

Expand Down Expand Up @@ -210,6 +212,39 @@ public void CanComputeJwkThumbprint()
Assert.False(new CustomSecurityKey().CanComputeJwkThumbprint(), "CustomSecurityKey shouldn't be able to compute JWK thumbprint if CanComputeJwkThumbprint() is not overriden.");
}

// Tests a SecurityKey object, to ensure the InternalId is set exactly once when faced with concurrent calls.
[Fact]
public async Task InternalId_ConcurrencyTest()
{
// Arrange
var numTasks = 10;
var barrier = new Barrier(numTasks);
var key = new CustomSecurityKey();
string[] internalIds = new string[numTasks];
Task[] tasks = new Task[numTasks];

for (int i = 0; i < numTasks; i++)
{
var index = i;
tasks[i] = Task.Run(() =>
{
barrier.SignalAndWait();
internalIds[index] = key.InternalId;
});
}

// Act
await Task.WhenAll(tasks);

// Assert
Assert.All(internalIds, id => Assert.NotNull(id));
var firstId = internalIds[0];
for (int i = 1; i < numTasks; i++)
{
Assert.Same(firstId, internalIds[i]);
}
}

public class SecurityKeyTheoryData : TheoryDataBase
{
public SecurityKey SecurityKey { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.TestUtils;
using Xunit;
Expand Down Expand Up @@ -44,5 +46,29 @@ public void GetSets()

TestUtilities.AssertFailIfErrors("TokenValidationResultTests.GetSets", context.Errors);
}

// Ensure setting the ClaimsIdentity object simultaneously doesn't cause lock contention or other concurrency issues.
[Fact]
public async Task ClaimsIdentity_ConcurrencyTest()
{
// Arrange
var numThreads = 10;
var barrier = new Barrier(numThreads);
var result = new TokenValidationResult();
var claimsIdentity = new CaseSensitiveClaimsIdentity(Default.PayloadClaims);
Task[] tasks = new Task[numThreads];

for (int i = 0; i < numThreads; i++)
{
tasks[i] = Task.Run(() =>
{
barrier.SignalAndWait();
result.ClaimsIdentity = claimsIdentity;
});
}

// Act and implicit Assert as any exception will cause the test to fail
await Task.WhenAll(tasks);
}
}
}
Loading