Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.Azure.SignalR;

internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
{
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);
private readonly ConcurrentDictionary<MicrosoftEntraAccessKey, bool> _keyMap = new(ReferenceEqualityComparer.Instance);

private readonly ILoggerFactory _factory;
private readonly ILogger<AccessKeySynchronizer> _logger;

private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));

internal IEnumerable<MicrosoftEntraAccessKey> AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType<MicrosoftEntraAccessKey>();
internal IEnumerable<MicrosoftEntraAccessKey> InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key);

public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true)
{
Expand All @@ -32,65 +34,74 @@ internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start)
{
if (start)
{
_ = UpdateAccessKeyAsync();
_ = UpdateAllAccessKeyAsync();
}
_factory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<AccessKeySynchronizer>();
}

public void AddServiceEndpoint(ServiceEndpoint endpoint)
{
if (endpoint.AccessKey is MicrosoftEntraAccessKey key)
{
_ = key.UpdateAccessKeyAsync();
_keyMap.TryAdd(key, true);
}
_endpoints.TryAdd(endpoint, null);
}

public void Dispose() => _timer.Stop();

public void UpdateServiceEndpoints(IEnumerable<ServiceEndpoint> endpoints)
{
_endpoints.Clear();
_keyMap.Clear();
foreach (var endpoint in endpoints)
{
AddServiceEndpoint(endpoint);
}
}

internal bool ContainsServiceEndpoint(ServiceEndpoint e) => _endpoints.ContainsKey(e);
/// <summary>
/// Test only
/// </summary>
/// <param name="e"></param>
/// <returns></returns>
internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey);

internal int ServiceEndpointsCount() => _endpoints.Count;
/// <summary>
/// Test only
/// </summary>
/// <returns></returns>
internal int Count() => _keyMap.Count;

private async Task UpdateAccessKeyAsync()
private async Task UpdateAllAccessKeyAsync()
{
using (_timer)
{
_timer.Start();

while (await _timer)
{
foreach (var key in AccessKeyForMicrosoftEntraList)
foreach (var key in InitializedKeyList)
{
_ = key.UpdateAccessKeyAsync();
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = key.UpdateAccessKeyAsync(source.Token);
}
}
}
}

private sealed class ReferenceEqualityComparer : IEqualityComparer<ServiceEndpoint>
private sealed class ReferenceEqualityComparer : IEqualityComparer<MicrosoftEntraAccessKey>
{
internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer();

private ReferenceEqualityComparer()
{
}

public bool Equals(ServiceEndpoint x, ServiceEndpoint y)
public bool Equals(MicrosoftEntraAccessKey x, MicrosoftEntraAccessKey y)
{
return ReferenceEquals(x, y);
}

public int GetHashCode(ServiceEndpoint obj)
public int GetHashCode(MicrosoftEntraAccessKey obj)
{
return RuntimeHelpers.GetHashCode(obj);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;
Expand All @@ -36,10 +40,12 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

private readonly TaskCompletionSource<object?> _initializedTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand All @@ -48,6 +54,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
Expand Down Expand Up @@ -116,6 +124,12 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = UpdateAccessKeyAsync(source.Token);
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

return Available
Expand All @@ -142,26 +156,35 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
return;
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
{
return;
}

for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
if (ctoken.IsCancellationRequested)
{
break;
}

var source = new CancellationTokenSource(GetAccessKeyTimeout);
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken);
try
{
await UpdateAccessKeyInternalAsync(linkedSource.Token);
await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
return;
}
catch (OperationCanceledException e)
{
LastException = e;
break;
LastException = e; // retry immediately
}
catch (Exception e)
{
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval, ctoken);
await Task.Delay(GetAccessKeyRetryInterval, ctoken); // retry after interval.
}
catch (OperationCanceledException)
{
Expand All @@ -175,6 +198,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
}

private static string GetExceptionMessage(Exception? exception)
Expand Down Expand Up @@ -232,11 +256,11 @@ private async Task UpdateAccessKeyInternalAsync(CancellationToken ctoken)
await ThrowExceptionOnResponseFailureAsync(request, response);
}

private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
private async Task HandleHttpResponseAsync(HttpResponseMessage response)
{
if (response.StatusCode != HttpStatusCode.OK)
{
return false;
return;
}

var content = await response.Content.ReadAsStringAsync();
Expand All @@ -250,8 +274,6 @@ private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
{
throw new AzureSignalRException("Missing required <AccessKey> field.");
}

UpdateAccessKey(obj.KeyId, obj.AccessKey);
return true;
}
}
3 changes: 2 additions & 1 deletion src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ internal static class Constants

public const string AsrsDefaultScope = "https://signalr.azure.com/.default";


public const int DefaultCloseTimeoutMilliseconds = 10000;

public static class Keys
Expand All @@ -45,6 +44,8 @@ public static class Periods

public const int MaxCustomHandshakeTimeout = 30;

public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2);

public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);

public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Azure.Identity;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging.Abstractions;
using Xunit;

namespace Microsoft.Azure.SignalR.Common.Tests.Auth;

public class AccessKeySynchronizerFacts
{
[Fact]
public void AddAndRemoveServiceEndpointsTest()
{
var synchronizer = GetInstanceForTest();

var credential = new DefaultAzureCredential();
var endpoint1 = new TestServiceEndpoint(credential);
var endpoint2 = new TestServiceEndpoint(credential);

Assert.Equal(0, synchronizer.Count());
synchronizer.UpdateServiceEndpoints([endpoint1]);
Assert.Equal(1, synchronizer.Count());
synchronizer.UpdateServiceEndpoints([endpoint1, endpoint2]);
Assert.Empty(synchronizer.InitializedKeyList);

Assert.Equal(2, synchronizer.Count());
Assert.True(synchronizer.ContainsKey(endpoint1));
Assert.True(synchronizer.ContainsKey(endpoint2));

synchronizer.UpdateServiceEndpoints([endpoint2]);
Assert.Equal(1, synchronizer.Count());
synchronizer.UpdateServiceEndpoints([]);
Assert.Equal(0, synchronizer.Count());
Assert.Empty(synchronizer.InitializedKeyList);
}

private static AccessKeySynchronizer GetInstanceForTest()
{
return new AccessKeySynchronizer(NullLoggerFactory.Instance, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig
.ThrowsAsync(e);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
{
GetAccessKeyRetryInterval = TimeSpan.Zero
GetAccessKeyRetryInterval = TimeSpan.Zero,
};

var audience = "http://localhost/chat";
Expand Down Expand Up @@ -210,6 +210,52 @@ public async Task TestUpdateAccessKeySendRequest(string expectedKeyStr)
Assert.Equal(expectedKeyStr, Encoding.UTF8.GetString(key.KeyBytes));
}

[Fact]
public async Task TestLazyLoadAccessKey()
{
var expectedKeyStr = DefaultSigningKey;
var expectedKid = "foo";
var text = "{" + string.Format("\"AccessKey\": \"{0}\", \"KeyId\": \"{1}\"", expectedKeyStr, expectedKid) + "}";
var httpClientFactory = new TestHttpClientFactory(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = TextHttpContent.From(text),
});

var credential = new TestTokenCredential(TokenType.MicrosoftEntra);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, credential, httpClientFactory: httpClientFactory);

Assert.False(key.Initialized);

var token = await key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
Assert.NotNull(token);

Assert.True(key.Initialized);
}

[Fact]
public async Task TestLazyLoadAccessKeyFailed()
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new Exception());
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
{
GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1),
};

Assert.False(key.Initialized);

var task1 = key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
var task2 = key.UpdateAccessKeyAsync();
Assert.True(task2.IsCompleted); // another task is in progress.

await Assert.ThrowsAsync<AzureSignalRAccessTokenNotAuthorizedException>(async () => await task1);

Assert.True(key.Initialized);
}

[Theory]
[InlineData(TokenType.Local)]
[InlineData(TokenType.MicrosoftEntra)]
Expand All @@ -226,7 +272,10 @@ public async Task ThrowUnauthorizedExceptionTest(TokenType tokenType)
endpoint,
new TestTokenCredential(tokenType),
httpClientFactory: new TestHttpClientFactory(message)
);
)
{
GetAccessKeyRetryInterval = TimeSpan.Zero
};

await key.UpdateAccessKeyAsync();

Expand Down
Loading
Loading