Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ internal class AadAccessKey : AccessKey

private const string DefaultScope = "https://signalr.azure.com/.default";

private static readonly TokenRequestContext _defaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });

private static readonly TimeSpan AuthorizeInterval = TimeSpan.FromMinutes(AuthorizeIntervalInMinute);

private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });

private static readonly TimeSpan AuthorizeIntervalWhenFailed = TimeSpan.FromMinutes(5);

private static readonly TimeSpan AuthorizeRetryInterval = TimeSpan.FromSeconds(AuthorizeRetryIntervalInSec);

private static readonly TimeSpan AuthorizeTimeout = TimeSpan.FromSeconds(10);
Expand All @@ -38,12 +40,23 @@ internal class AadAccessKey : AccessKey

private DateTime _lastUpdatedTime = DateTime.MinValue;

public bool Authorized => InitializedTask.IsCompleted && _isAuthorized;
public bool Authorized
{
get => _isAuthorized;
private set
{
_lastUpdatedTime = DateTime.UtcNow;
_isAuthorized = value;
_initializedTcs.TrySetResult(null);
}
}

public TokenCredential TokenCredential { get; }

internal string AuthorizeUrl { get; }

internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(AuthorizeIntervalInMinute * 2);

private Task<object> InitializedTask => _initializedTcs.Task;

public AadAccessKey(Uri uri, TokenCredential credential) : base(uri)
Expand All @@ -59,7 +72,7 @@ public AadAccessKey(Uri uri, TokenCredential credential) : base(uri)

public virtual async Task<string> GenerateAadTokenAsync(CancellationToken ctoken = default)
{
var token = await TokenCredential.GetTokenAsync(_defaultRequestContext, ctoken);
var token = await TokenCredential.GetTokenAsync(DefaultRequestContext, ctoken);
return token.Token;
}

Expand Down Expand Up @@ -93,47 +106,52 @@ public override async Task<string> GenerateAccessTokenAsync(
internal void UpdateAccessKey(string kid, string accessKey)
{
Key = new Tuple<string, string>(kid, accessKey);
_lastUpdatedTime = DateTime.UtcNow;
_isAuthorized = true;
_initializedTcs.TrySetResult(null);
Authorized = true;
}

internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(AuthorizeIntervalInMinute * 2);

internal async Task UpdateAccessKeyAsync()
internal async Task UpdateAccessKeyAsync(CancellationToken token = default)
{
if (DateTime.UtcNow - _lastUpdatedTime < AuthorizeInterval)
var delta = DateTime.UtcNow - _lastUpdatedTime;
if (Authorized && delta < AuthorizeInterval)
{
return;
}
else if (!Authorized && delta < AuthorizeIntervalWhenFailed)
{
return;
}
await AuthorizeWithRetryAsync(token);
}

private async Task AuthorizeWithRetryAsync(CancellationToken token)
{
Exception latest = null;
for (int i = 0; i < AuthorizeMaxRetryTimes; i++)
for (var i = 0; i < AuthorizeMaxRetryTimes; i++)
{
var source = new CancellationTokenSource(AuthorizeTimeout);
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, token);
try
{
await AuthorizeAsync(source.Token);
var aadToken = await GenerateAadTokenAsync(linkedSource.Token);
await AuthorizeWithTokenAsync(aadToken, linkedSource.Token);
return;
}
catch (OperationCanceledException e)
{
latest = e;
break;
}
catch (Exception e)
{
latest = e;
await Task.Delay(AuthorizeRetryInterval);
}
}

_isAuthorized = false;
_initializedTcs.TrySetResult(null);
Authorized = false;
throw latest;
}

private async Task AuthorizeAsync(CancellationToken ctoken = default)
{
var aadToken = await GenerateAadTokenAsync(ctoken);
await AuthorizeWithTokenAsync(aadToken, ctoken);
}

private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken token = default)
{
var api = new RestApiEndpoint(AuthorizeUrl, accessToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);

private readonly ILoggerFactory _factory;
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(5));
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));

public AccessKeySynchronizer(
ILoggerFactory loggerFactory
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Reflection;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -27,7 +28,7 @@ public void TestConstructor(string endpoint, string expectedAuthorizeUrl)
[Fact]
public async Task TestUpdateAccessKey()
{
var credential = new EnvironmentCredential();
var credential = new DefaultAzureCredential();
var endpoint = "http://localhost";
var key = new AadAccessKey(new Uri(endpoint), credential);

Expand All @@ -48,10 +49,50 @@ await Assert.ThrowsAsync<TaskCanceledException>(
Assert.NotNull(token);
}

[Theory]
[InlineData(false, 1, true)]
[InlineData(false, 4, true)]
[InlineData(false, 6, false)]
[InlineData(true, 6, true)]
[InlineData(true, 54, true)]
[InlineData(true, 56, false)]
public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElapsed, bool shouldSkip)
{
var key = new AadAccessKey(new Uri("http://localhost"), new DefaultAzureCredential());

var isAuthorizedField = typeof(AadAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance);
isAuthorizedField.SetValue(key, isAuthorized);
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));

var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed);
var lastUpdatedTimeField = typeof(AadAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance);
lastUpdatedTimeField.SetValue(key, lastUpdatedTime);

var initializedTcsField = typeof(AadAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance);
var initializedTcs = (TaskCompletionSource<object>)initializedTcsField.GetValue(key);

var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));

if (shouldSkip)
{
await key.UpdateAccessKeyAsync(source.Token);
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
Assert.Equal(lastUpdatedTime, (DateTime)lastUpdatedTimeField.GetValue(key));
Assert.False(initializedTcs.Task.IsCompleted);
}
else
{
await Assert.ThrowsAsync<TaskCanceledException>(async () => await key.UpdateAccessKeyAsync(source.Token));
Assert.False((bool)isAuthorizedField.GetValue(key));
Assert.True(lastUpdatedTime < (DateTime)lastUpdatedTimeField.GetValue(key));
Assert.True(initializedTcs.Task.IsCompleted);
}
}

[Fact]
public async Task TestInitializeFailed()
{
var credential = new EnvironmentCredential();
var credential = new DefaultAzureCredential();
var key = new AadAccessKey(new Uri("http://localhost"), credential);

var audience = "http://localhost/chat";
Expand All @@ -63,8 +104,9 @@ public async Task TestInitializeFailed()
async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm)
);

await Assert.ThrowsAnyAsync<Exception>(
async () => await key.UpdateAccessKeyAsync()
var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
await Assert.ThrowsAnyAsync<TaskCanceledException>(
async () => await key.UpdateAccessKeyAsync(source.Token)
);

await task;
Expand Down