Skip to content

Commit 2b30fc0

Browse files
committed
reduce retry interval when auth failed
1 parent f34b323 commit 2b30fc0

File tree

3 files changed

+83
-25
lines changed

3 files changed

+83
-25
lines changed

src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ namespace Microsoft.Azure.SignalR
1616
{
1717
internal class AadAccessKey : AccessKey
1818
{
19-
internal const int AuthorizeIntervalInMinute = 55;
2019
internal const int AuthorizeMaxRetryTimes = 3;
2120
internal const int AuthorizeRetryIntervalInSec = 3;
2221

2322
private const string DefaultScope = "https://signalr.azure.com/.default";
2423

25-
private static readonly TokenRequestContext _defaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });
26-
private static readonly TimeSpan AuthorizeInterval = TimeSpan.FromMinutes(AuthorizeIntervalInMinute);
24+
private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { DefaultScope });
25+
private static readonly TimeSpan AuthorizeInterval = TimeSpan.FromMinutes(55);
26+
private static readonly TimeSpan AuthorizeIntervalWhenFailed = TimeSpan.FromMinutes(5);
2727
private static readonly TimeSpan AuthorizeRetryInterval = TimeSpan.FromSeconds(AuthorizeRetryIntervalInSec);
2828
private static readonly TimeSpan AuthorizeTimeout = TimeSpan.FromSeconds(10);
2929

@@ -33,7 +33,16 @@ internal class AadAccessKey : AccessKey
3333

3434
private DateTime _lastUpdatedTime = DateTime.MinValue;
3535

36-
public bool Authorized => InitializedTask.IsCompleted && _isAuthorized;
36+
public bool Authorized
37+
{
38+
get => _isAuthorized;
39+
private set
40+
{
41+
_lastUpdatedTime = DateTime.UtcNow;
42+
_isAuthorized = value;
43+
_initializedTcs.TrySetResult(null);
44+
}
45+
}
3746

3847
public TokenCredential TokenCredential { get; }
3948

@@ -54,7 +63,7 @@ public AadAccessKey(Uri uri, TokenCredential credential): base(uri)
5463

5564
public virtual async Task<string> GenerateAadTokenAsync(CancellationToken ctoken = default)
5665
{
57-
var token = await TokenCredential.GetTokenAsync(_defaultRequestContext, ctoken);
66+
var token = await TokenCredential.GetTokenAsync(DefaultRequestContext, ctoken);
5867
return token.Token;
5968
}
6069

@@ -88,45 +97,52 @@ public override async Task<string> GenerateAccessTokenAsync(
8897
internal void UpdateAccessKey(string kid, string accessKey)
8998
{
9099
Key = new Tuple<string, string>(kid, accessKey);
91-
_lastUpdatedTime = DateTime.UtcNow;
92-
_isAuthorized = true;
93-
_initializedTcs.TrySetResult(null);
100+
Authorized = true;
94101
}
95102

96-
internal async Task UpdateAccessKeyAsync()
103+
internal async Task UpdateAccessKeyAsync(CancellationToken token = default)
97104
{
98-
if (DateTime.UtcNow - _lastUpdatedTime < AuthorizeInterval)
105+
var delta = DateTime.UtcNow - _lastUpdatedTime;
106+
if (Authorized && delta < AuthorizeInterval)
99107
{
100108
return;
101109
}
110+
else if (!Authorized && delta < AuthorizeIntervalWhenFailed)
111+
{
112+
return;
113+
}
114+
await AuthorizeWithRetryAsync(token);
115+
}
102116

117+
private async Task AuthorizeWithRetryAsync(CancellationToken token)
118+
{
103119
Exception latest = null;
104-
for (int i = 0; i < AuthorizeMaxRetryTimes; i++)
120+
for (var i = 0; i < AuthorizeMaxRetryTimes; i++)
105121
{
106122
var source = new CancellationTokenSource(AuthorizeTimeout);
123+
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, token);
107124
try
108125
{
109-
await AuthorizeAsync(source.Token);
126+
var aadToken = await GenerateAadTokenAsync(linkedSource.Token);
127+
await AuthorizeWithTokenAsync(aadToken, linkedSource.Token);
110128
return;
111129
}
130+
catch (OperationCanceledException e)
131+
{
132+
latest = e;
133+
break;
134+
}
112135
catch (Exception e)
113136
{
114137
latest = e;
115138
await Task.Delay(AuthorizeRetryInterval);
116139
}
117140
}
118141

119-
_isAuthorized = false;
120-
_initializedTcs.TrySetResult(null);
142+
Authorized = false;
121143
throw latest;
122144
}
123145

124-
private async Task AuthorizeAsync(CancellationToken ctoken = default)
125-
{
126-
var aadToken = await GenerateAadTokenAsync(ctoken);
127-
await AuthorizeWithTokenAsync(aadToken, ctoken);
128-
}
129-
130146
private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken token = default)
131147
{
132148
var api = new RestApiEndpoint(AuthorizeUrl, accessToken);

src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
1616
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);
1717

1818
private readonly ILoggerFactory _factory;
19-
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(5));
19+
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));
2020

2121
public AccessKeySynchronizer(
2222
ILoggerFactory loggerFactory

test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Reflection;
23
using System.Security.Claims;
34
using System.Threading;
45
using System.Threading.Tasks;
@@ -27,7 +28,7 @@ public void TestConstructor(string endpoint, string expectedAuthorizeUrl)
2728
[Fact]
2829
public async Task TestUpdateAccessKey()
2930
{
30-
var credential = new EnvironmentCredential();
31+
var credential = new DefaultAzureCredential();
3132
var endpoint = "http://localhost";
3233
var key = new AadAccessKey(new Uri(endpoint), credential);
3334

@@ -48,10 +49,50 @@ await Assert.ThrowsAsync<TaskCanceledException>(
4849
Assert.NotNull(token);
4950
}
5051

52+
[Theory]
53+
[InlineData(false, 1, true)]
54+
[InlineData(false, 4, true)]
55+
[InlineData(false, 6, false)]
56+
[InlineData(true, 6, true)]
57+
[InlineData(true, 54, true)]
58+
[InlineData(true, 56, false)]
59+
public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElapsed, bool shouldSkip)
60+
{
61+
var key = new AadAccessKey(new Uri("http://localhost"), new DefaultAzureCredential());
62+
63+
var isAuthorizedField = typeof(AadAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance);
64+
isAuthorizedField.SetValue(key, isAuthorized);
65+
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
66+
67+
var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed);
68+
var lastUpdatedTimeField = typeof(AadAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance);
69+
lastUpdatedTimeField.SetValue(key, lastUpdatedTime);
70+
71+
var initializedTcsField = typeof(AadAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance);
72+
var initializedTcs = (TaskCompletionSource<object>)initializedTcsField.GetValue(key);
73+
74+
var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
75+
76+
if (shouldSkip)
77+
{
78+
await key.UpdateAccessKeyAsync(source.Token);
79+
Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key));
80+
Assert.Equal(lastUpdatedTime, (DateTime)lastUpdatedTimeField.GetValue(key));
81+
Assert.False(initializedTcs.Task.IsCompleted);
82+
}
83+
else
84+
{
85+
await Assert.ThrowsAsync<TaskCanceledException>(async () => await key.UpdateAccessKeyAsync(source.Token));
86+
Assert.False((bool)isAuthorizedField.GetValue(key));
87+
Assert.True(lastUpdatedTime < (DateTime)lastUpdatedTimeField.GetValue(key));
88+
Assert.True(initializedTcs.Task.IsCompleted);
89+
}
90+
}
91+
5192
[Fact]
5293
public async Task TestInitializeFailed()
5394
{
54-
var credential = new EnvironmentCredential();
95+
var credential = new DefaultAzureCredential();
5596
var key = new AadAccessKey(new Uri("http://localhost"), credential);
5697

5798
var audience = "http://localhost/chat";
@@ -63,8 +104,9 @@ public async Task TestInitializeFailed()
63104
async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm)
64105
);
65106

66-
await Assert.ThrowsAnyAsync<Exception>(
67-
async () => await key.UpdateAccessKeyAsync()
107+
var source = new CancellationTokenSource(TimeSpan.FromSeconds(1));
108+
await Assert.ThrowsAnyAsync<TaskCanceledException>(
109+
async () => await key.UpdateAccessKeyAsync(source.Token)
68110
);
69111

70112
await task;

0 commit comments

Comments
 (0)