Skip to content

Commit 8dc02bb

Browse files
authored
reduce retry interval when auth failed (#1451)
1 parent 1c4a20b commit 8dc02bb

File tree

3 files changed

+86
-26
lines changed

3 files changed

+86
-26
lines changed

src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ internal class AadAccessKey : AccessKey
2424

2525
private const string DefaultScope = "https://signalr.azure.com/.default";
2626

27-
private static readonly TokenRequestContext _defaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });
28-
2927
private static readonly TimeSpan AuthorizeInterval = TimeSpan.FromMinutes(AuthorizeIntervalInMinute);
3028

29+
private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });
30+
31+
private static readonly TimeSpan AuthorizeIntervalWhenFailed = TimeSpan.FromMinutes(5);
32+
3133
private static readonly TimeSpan AuthorizeRetryInterval = TimeSpan.FromSeconds(AuthorizeRetryIntervalInSec);
3234

3335
private static readonly TimeSpan AuthorizeTimeout = TimeSpan.FromSeconds(10);
@@ -38,12 +40,23 @@ internal class AadAccessKey : AccessKey
3840

3941
private DateTime _lastUpdatedTime = DateTime.MinValue;
4042

41-
public bool Authorized => InitializedTask.IsCompleted && _isAuthorized;
43+
public bool Authorized
44+
{
45+
get => _isAuthorized;
46+
private set
47+
{
48+
_lastUpdatedTime = DateTime.UtcNow;
49+
_isAuthorized = value;
50+
_initializedTcs.TrySetResult(null);
51+
}
52+
}
4253

4354
public TokenCredential TokenCredential { get; }
4455

4556
internal string AuthorizeUrl { get; }
4657

58+
internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(AuthorizeIntervalInMinute * 2);
59+
4760
private Task<object> InitializedTask => _initializedTcs.Task;
4861

4962
public AadAccessKey(Uri uri, TokenCredential credential) : base(uri)
@@ -59,7 +72,7 @@ public AadAccessKey(Uri uri, TokenCredential credential) : base(uri)
5972

6073
public virtual async Task<string> GenerateAadTokenAsync(CancellationToken ctoken = default)
6174
{
62-
var token = await TokenCredential.GetTokenAsync(_defaultRequestContext, ctoken);
75+
var token = await TokenCredential.GetTokenAsync(DefaultRequestContext, ctoken);
6376
return token.Token;
6477
}
6578

@@ -93,47 +106,52 @@ public override async Task<string> GenerateAccessTokenAsync(
93106
internal void UpdateAccessKey(string kid, string accessKey)
94107
{
95108
Key = new Tuple<string, string>(kid, accessKey);
96-
_lastUpdatedTime = DateTime.UtcNow;
97-
_isAuthorized = true;
98-
_initializedTcs.TrySetResult(null);
109+
Authorized = true;
99110
}
100111

101-
internal bool HasExpired => DateTime.UtcNow - _lastUpdatedTime > TimeSpan.FromMinutes(AuthorizeIntervalInMinute * 2);
102-
103-
internal async Task UpdateAccessKeyAsync()
112+
internal async Task UpdateAccessKeyAsync(CancellationToken token = default)
104113
{
105-
if (DateTime.UtcNow - _lastUpdatedTime < AuthorizeInterval)
114+
var delta = DateTime.UtcNow - _lastUpdatedTime;
115+
if (Authorized && delta < AuthorizeInterval)
106116
{
107117
return;
108118
}
119+
else if (!Authorized && delta < AuthorizeIntervalWhenFailed)
120+
{
121+
return;
122+
}
123+
await AuthorizeWithRetryAsync(token);
124+
}
109125

126+
private async Task AuthorizeWithRetryAsync(CancellationToken token)
127+
{
110128
Exception latest = null;
111-
for (int i = 0; i < AuthorizeMaxRetryTimes; i++)
129+
for (var i = 0; i < AuthorizeMaxRetryTimes; i++)
112130
{
113131
var source = new CancellationTokenSource(AuthorizeTimeout);
132+
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, token);
114133
try
115134
{
116-
await AuthorizeAsync(source.Token);
135+
var aadToken = await GenerateAadTokenAsync(linkedSource.Token);
136+
await AuthorizeWithTokenAsync(aadToken, linkedSource.Token);
117137
return;
118138
}
139+
catch (OperationCanceledException e)
140+
{
141+
latest = e;
142+
break;
143+
}
119144
catch (Exception e)
120145
{
121146
latest = e;
122147
await Task.Delay(AuthorizeRetryInterval);
123148
}
124149
}
125150

126-
_isAuthorized = false;
127-
_initializedTcs.TrySetResult(null);
151+
Authorized = false;
128152
throw latest;
129153
}
130154

131-
private async Task AuthorizeAsync(CancellationToken ctoken = default)
132-
{
133-
var aadToken = await GenerateAadTokenAsync(ctoken);
134-
await AuthorizeWithTokenAsync(aadToken, ctoken);
135-
}
136-
137155
private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken token = default)
138156
{
139157
var api = new RestApiEndpoint(AuthorizeUrl, accessToken);

src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
1616
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);
1717

1818
private readonly ILoggerFactory _factory;
19-
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(5));
19+
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));
2020

2121
public AccessKeySynchronizer(
2222
ILoggerFactory loggerFactory

test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Reflection;
23
using System.Security.Claims;
34
using System.Threading;
45
using System.Threading.Tasks;
@@ -27,7 +28,7 @@ public void TestConstructor(string endpoint, string expectedAuthorizeUrl)
2728
[Fact]
2829
public async Task TestUpdateAccessKey()
2930
{
30-
var credential = new EnvironmentCredential();
31+
var credential = new DefaultAzureCredential();
3132
var endpoint = "http://localhost";
3233
var key = new AadAccessKey(new Uri(endpoint), credential);
3334

@@ -48,10 +49,50 @@ await Assert.ThrowsAsync<TaskCanceledException>(
4849
Assert.NotNull(token);
4950
}
5051

52+
[Theory]
53+
[InlineData(false, 1, true)]
54+
[InlineData(false, 4, true)]
55+
[InlineData(false, 6, false)]
56+
[InlineData(true, 6, true)]
57+
[InlineData(true, 54, true)]
58+
[InlineData(true, 56, false)]
59+
public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElapsed, bool shouldSkip)
60+
{
61+
var key = new AadAccessKey(new Uri("http://localhost"), new DefaultAzureCredential());
62+
63+
var isAuthorizedField = typeof(AadAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance);
64+
isAuthorizedField.SetValue(key, isAuthorized);
65+
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
66+
67+
var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed);
68+
var lastUpdatedTimeField = typeof(AadAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance);
69+
lastUpdatedTimeField.SetValue(key, lastUpdatedTime);
70+
71+
var initializedTcsField = typeof(AadAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance);
72+
var initializedTcs = (TaskCompletionSource<object>)initializedTcsField.GetValue(key);
73+
74+
var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
75+
76+
if (shouldSkip)
77+
{
78+
await key.UpdateAccessKeyAsync(source.Token);
79+
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
80+
Assert.Equal(lastUpdatedTime, (DateTime)lastUpdatedTimeField.GetValue(key));
81+
Assert.False(initializedTcs.Task.IsCompleted);
82+
}
83+
else
84+
{
85+
await Assert.ThrowsAsync<TaskCanceledException>(async () => await key.UpdateAccessKeyAsync(source.Token));
86+
Assert.False((bool)isAuthorizedField.GetValue(key));
87+
Assert.True(lastUpdatedTime < (DateTime)lastUpdatedTimeField.GetValue(key));
88+
Assert.True(initializedTcs.Task.IsCompleted);
89+
}
90+
}
91+
5192
[Fact]
5293
public async Task TestInitializeFailed()
5394
{
54-
var credential = new EnvironmentCredential();
95+
var credential = new DefaultAzureCredential();
5596
var key = new AadAccessKey(new Uri("http://localhost"), credential);
5697

5798
var audience = "http://localhost/chat";
@@ -63,8 +104,9 @@ public async Task TestInitializeFailed()
63104
async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm)
64105
);
65106

66-
await Assert.ThrowsAnyAsync<Exception>(
67-
async () => await key.UpdateAccessKeyAsync()
107+
var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
108+
await Assert.ThrowsAnyAsync<TaskCanceledException>(
109+
async () => await key.UpdateAccessKeyAsync(source.Token)
68110
);
69111

70112
await task;

0 commit comments

Comments
 (0)