Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ private async Task UpdateAllAccessKeyAsync()
{
foreach (var key in InitializedKeyList)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = key.UpdateAccessKeyAsync(source.Token);
_ = key.UpdateAccessKeyAsync();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);

private static readonly TimeSpan GetAccessKeyIntervalWhenUnauthorized = TimeSpan.FromMinutes(5);
private static readonly TimeSpan GetAccessKeyIntervalUnavailable = TimeSpan.FromMinutes(5);

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

Expand All @@ -56,6 +56,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable);

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
Expand Down Expand Up @@ -121,8 +123,7 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
{
if (!_initializedTcs.Task.IsCompleted)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = UpdateAccessKeyAsync(source.Token);
_ = UpdateAccessKeyAsync();
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");
Expand All @@ -139,14 +140,9 @@ internal void UpdateAccessKey(string kid, string keyStr)
Available = true;
}

internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
internal async Task UpdateAccessKeyAsync()
{
var delta = DateTime.UtcNow - _updateAt;
if (Available && delta < GetAccessKeyInterval)
{
return;
}
else if (!Available && delta < GetAccessKeyIntervalWhenUnauthorized)
if (!NeedRefresh)
{
return;
}
Expand All @@ -158,15 +154,10 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)

for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
if (ctoken.IsCancellationRequested)
{
break;
}

var source = new CancellationTokenSource(GetAccessKeyTimeout);
try
{
await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken);
await UpdateAccessKeyInternalAsync(source.Token);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
return;
}
Expand All @@ -179,7 +170,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval, ctoken); // retry after interval.
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
catch (OperationCanceledException)
{
Expand Down
2 changes: 0 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ public static class Periods

public const int MaxCustomHandshakeTimeout = 30;

public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2);

public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);

public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,26 @@
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Moq;
using Xunit;

namespace Microsoft.Azure.SignalR.Common.Tests.Auth;

#nullable enable

[Collection("Auth")]
public class MicrosoftEntraAccessKeyTests
{
private const string DefaultSigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";

private static readonly Uri DefaultEndpoint = new("http://localhost");

public enum TokenType
{
Local,

MicrosoftEntra,
}

[Theory]
[InlineData("https://a.bc", "https://a.bc/api/v1/auth/accessKey")]
[InlineData("https://a.bc:80", "https://a.bc:80/api/v1/auth/accessKey")]
Expand All @@ -36,12 +44,7 @@ public void TestExpectedGetAccessKeyUrl(string endpoint, string expectedGetAcces
[Fact]
public async Task TestUpdateAccessKey()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential());

var audience = "http://localhost/chat";
var claims = Array.Empty<Claim>();
Expand All @@ -66,28 +69,23 @@ public async Task TestUpdateAccessKey()
[InlineData(true, 121, false, true)] // > 120, should set key unauthorized and log the exception
public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int timeElapsed, bool skip, bool hasException)
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential())
{
GetAccessKeyRetryInterval = TimeSpan.Zero
};
var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance);
isAuthorizedField.SetValue(key, isAuthorized);
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
isAuthorizedField?.SetValue(key, isAuthorized);
Assert.Equal(isAuthorized, Assert.IsType<bool>(isAuthorizedField?.GetValue(key)));

var updateAt = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed);
var updateAtField = typeof(MicrosoftEntraAccessKey).GetField("_updateAt", BindingFlags.NonPublic | BindingFlags.Instance);
updateAtField.SetValue(key, updateAt);
updateAtField?.SetValue(key, updateAt);

var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance);
var initializedTcs = (TaskCompletionSource<object>)initializedTcsField.GetValue(key);
var initializedTcs = Assert.IsType<TaskCompletionSource<object>>(initializedTcsField?.GetValue(key));

await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30));
var actualUpdateAt = Assert.IsType<DateTime>(updateAtField.GetValue(key));
var actualUpdateAt = Assert.IsType<DateTime>(updateAtField?.GetValue(key));

Assert.Equal(skip && isAuthorized, Assert.IsType<bool>(isAuthorizedField.GetValue(key)));

Expand Down Expand Up @@ -115,12 +113,7 @@ public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int time
[Fact]
public async Task TestInitializeFailed()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential())
{
GetAccessKeyRetryInterval = TimeSpan.Zero
};
Expand All @@ -141,8 +134,7 @@ public async Task TestInitializeFailed()
[Fact]
public async Task TestNotInitialized()
{
var mockCredential = new Mock<TokenCredential>();
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential());

var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var exception = await Assert.ThrowsAsync<TaskCanceledException>(
Expand All @@ -155,12 +147,7 @@ public async Task TestNotInitialized()
[ClassData(typeof(NotAuthorizedTestData))]
public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSignalRException e, string expectedErrorMessage)
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(e);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential() { Exception = e })
{
GetAccessKeyRetryInterval = TimeSpan.Zero,
};
Expand All @@ -177,7 +164,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig
);
Assert.Same(exception.InnerException, e);
Assert.Same(exception.InnerException, key.LastException);
Assert.StartsWith($"TokenCredentialProxy is not available for signing client tokens", exception.Message);
Assert.StartsWith($"{nameof(TestTokenCredential)} is not available for signing client tokens", exception.Message);
Assert.Contains(expectedErrorMessage, exception.Message);

var (kid, accessKey) = ("foo", DefaultSigningKey);
Expand Down Expand Up @@ -232,12 +219,7 @@ public async Task TestLazyLoadAccessKey()
[Fact]
public async Task TestLazyLoadAccessKeyFailed()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new Exception());
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, new TestTokenCredential())
{
GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1),
};
Expand Down Expand Up @@ -355,7 +337,7 @@ private sealed class TestHttpClient(HttpResponseMessage message) : HttpClient
{
public override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
Assert.Equal("Bearer", request.Headers.Authorization.Scheme);
Assert.Equal("Bearer", request.Headers?.Authorization?.Scheme);
return Task.FromResult(message);
}
}
Expand All @@ -364,11 +346,14 @@ private sealed class TextHttpContent : HttpContent
{
private readonly string _content;

private TextHttpContent(string content) => _content = content;
private TextHttpContent(string content)
{
_content = content;
}

internal static HttpContent From(string content) => new TextHttpContent(content);

protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context)
{
return stream.WriteAsync(Encoding.UTF8.GetBytes(_content)).AsTask();
}
Expand All @@ -380,21 +365,22 @@ protected override bool TryComputeLength(out long length)
}
}

public enum TokenType
private sealed class TestTokenCredential(TokenType? tokenType = null) : TokenCredential
{
Local,
MicrosoftEntra,
}
public Exception? Exception { get; init; }

private sealed class TestTokenCredential(TokenType tokenType) : TokenCredential
{
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
if (Exception != null)
{
throw Exception;
}

var issuer = tokenType switch
{
TokenType.Local => Constants.AsrsTokenIssuer,
TokenType.MicrosoftEntra => "microsoft.com",
_ => throw new NotImplementedException(),
_ => throw new InvalidOperationException(),
};
var key = new AccessKey(DefaultSigningKey);
var token = AuthUtility.GenerateJwtToken(key.KeyBytes, issuer: issuer);
Expand Down
Loading