Skip to content

Commit 3c6b848

Browse files
authored
Lazy load temporary AccessKey (#2092)
1 parent 30d6a43 commit 3c6b848

File tree

7 files changed

+155
-91
lines changed

7 files changed

+155
-91
lines changed

src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/AccessKeySynchronizer.cs

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66
using System.Collections.Generic;
77
using System.Linq;
88
using System.Runtime.CompilerServices;
9+
using System.Threading;
910
using System.Threading.Tasks;
1011
using Microsoft.Extensions.Logging;
12+
using Microsoft.Extensions.Logging.Abstractions;
1113

1214
namespace Microsoft.Azure.SignalR;
1315

1416
internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
1517
{
16-
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);
18+
private readonly ConcurrentDictionary<MicrosoftEntraAccessKey, bool> _keyMap = new(ReferenceEqualityComparer.Instance);
1719

18-
private readonly ILoggerFactory _factory;
20+
private readonly ILogger<AccessKeySynchronizer> _logger;
1921

2022
private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));
2123

22-
internal IEnumerable<MicrosoftEntraAccessKey> AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType<MicrosoftEntraAccessKey>();
24+
internal IEnumerable<MicrosoftEntraAccessKey> InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key);
2325

2426
public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true)
2527
{
@@ -32,65 +34,74 @@ internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start)
3234
{
3335
if (start)
3436
{
35-
_ = UpdateAccessKeyAsync();
37+
_ = UpdateAllAccessKeyAsync();
3638
}
37-
_factory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
39+
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<AccessKeySynchronizer>();
3840
}
3941

4042
public void AddServiceEndpoint(ServiceEndpoint endpoint)
4143
{
4244
if (endpoint.AccessKey is MicrosoftEntraAccessKey key)
4345
{
44-
_ = key.UpdateAccessKeyAsync();
46+
_keyMap.TryAdd(key, true);
4547
}
46-
_endpoints.TryAdd(endpoint, null);
4748
}
4849

4950
public void Dispose() => _timer.Stop();
5051

5152
public void UpdateServiceEndpoints(IEnumerable<ServiceEndpoint> endpoints)
5253
{
53-
_endpoints.Clear();
54+
_keyMap.Clear();
5455
foreach (var endpoint in endpoints)
5556
{
5657
AddServiceEndpoint(endpoint);
5758
}
5859
}
5960

60-
internal bool ContainsServiceEndpoint(ServiceEndpoint e) => _endpoints.ContainsKey(e);
61+
/// <summary>
62+
/// Test only
63+
/// </summary>
64+
/// <param name="e"></param>
65+
/// <returns></returns>
66+
internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey);
6167

62-
internal int ServiceEndpointsCount() => _endpoints.Count;
68+
/// <summary>
69+
/// Test only
70+
/// </summary>
71+
/// <returns></returns>
72+
internal int Count() => _keyMap.Count;
6373

64-
private async Task UpdateAccessKeyAsync()
74+
private async Task UpdateAllAccessKeyAsync()
6575
{
6676
using (_timer)
6777
{
6878
_timer.Start();
6979

7080
while (await _timer)
7181
{
72-
foreach (var key in AccessKeyForMicrosoftEntraList)
82+
foreach (var key in InitializedKeyList)
7383
{
74-
_ = key.UpdateAccessKeyAsync();
84+
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
85+
_ = key.UpdateAccessKeyAsync(source.Token);
7586
}
7687
}
7788
}
7889
}
7990

80-
private sealed class ReferenceEqualityComparer : IEqualityComparer<ServiceEndpoint>
91+
private sealed class ReferenceEqualityComparer : IEqualityComparer<MicrosoftEntraAccessKey>
8192
{
8293
internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer();
8394

8495
private ReferenceEqualityComparer()
8596
{
8697
}
8798

88-
public bool Equals(ServiceEndpoint x, ServiceEndpoint y)
99+
public bool Equals(MicrosoftEntraAccessKey x, MicrosoftEntraAccessKey y)
89100
{
90101
return ReferenceEquals(x, y);
91102
}
92103

93-
public int GetHashCode(ServiceEndpoint obj)
104+
public int GetHashCode(MicrosoftEntraAccessKey obj)
94105
{
95106
return RuntimeHelpers.GetHashCode(obj);
96107
}

src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ internal class MicrosoftEntraAccessKey : IAccessKey
2424
{
2525
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);
2626

27+
private const int UpdateTaskIdle = 0;
28+
29+
private const int UpdateTaskRunning = 1;
30+
2731
private const int GetAccessKeyMaxRetryTimes = 3;
2832

2933
private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;
@@ -36,10 +40,12 @@ internal class MicrosoftEntraAccessKey : IAccessKey
3640

3741
private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);
3842

39-
private readonly TaskCompletionSource<object?> _initializedTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
43+
private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
4044

4145
private readonly IHttpClientFactory _httpClientFactory;
4246

47+
private volatile int _updateState = 0;
48+
4349
private volatile bool _isAuthorized = false;
4450

4551
private DateTime _updateAt = DateTime.MinValue;
@@ -48,6 +54,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey
4854

4955
private volatile byte[]? _keyBytes;
5056

57+
public bool Initialized => _initializedTcs.Task.IsCompleted;
58+
5159
public bool Available
5260
{
5361
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
@@ -116,6 +124,12 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
116124
AccessTokenAlgorithm algorithm,
117125
CancellationToken ctoken = default)
118126
{
127+
if (!_initializedTcs.Task.IsCompleted)
128+
{
129+
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
130+
_ = UpdateAccessKeyAsync(source.Token);
131+
}
132+
119133
await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");
120134

121135
return Available
@@ -142,26 +156,35 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
142156
return;
143157
}
144158

159+
if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
160+
{
161+
return;
162+
}
163+
145164
for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
146165
{
166+
if (ctoken.IsCancellationRequested)
167+
{
168+
break;
169+
}
170+
147171
var source = new CancellationTokenSource(GetAccessKeyTimeout);
148-
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken);
149172
try
150173
{
151-
await UpdateAccessKeyInternalAsync(linkedSource.Token);
174+
await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken);
175+
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
152176
return;
153177
}
154178
catch (OperationCanceledException e)
155179
{
156-
LastException = e;
157-
break;
180+
LastException = e; // retry immediately
158181
}
159182
catch (Exception e)
160183
{
161184
LastException = e;
162185
try
163186
{
164-
await Task.Delay(GetAccessKeyRetryInterval, ctoken);
187+
await Task.Delay(GetAccessKeyRetryInterval, ctoken); // retry after interval.
165188
}
166189
catch (OperationCanceledException)
167190
{
@@ -175,6 +198,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
175198
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
176199
Available = false;
177200
}
201+
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
178202
}
179203

180204
private static string GetExceptionMessage(Exception? exception)
@@ -232,11 +256,11 @@ private async Task UpdateAccessKeyInternalAsync(CancellationToken ctoken)
232256
await ThrowExceptionOnResponseFailureAsync(request, response);
233257
}
234258

235-
private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
259+
private async Task HandleHttpResponseAsync(HttpResponseMessage response)
236260
{
237261
if (response.StatusCode != HttpStatusCode.OK)
238262
{
239-
return false;
263+
return;
240264
}
241265

242266
var content = await response.Content.ReadAsStringAsync();
@@ -250,8 +274,6 @@ private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
250274
{
251275
throw new AzureSignalRException("Missing required <AccessKey> field.");
252276
}
253-
254277
UpdateAccessKey(obj.KeyId, obj.AccessKey);
255-
return true;
256278
}
257279
}

src/Microsoft.Azure.SignalR.Common/Constants.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ internal static class Constants
2121

2222
public const string AsrsDefaultScope = "https://signalr.azure.com/.default";
2323

24-
2524
public const int DefaultCloseTimeoutMilliseconds = 10000;
2625

2726
public static class Keys
@@ -45,6 +44,8 @@ public static class Periods
4544

4645
public const int MaxCustomHandshakeTimeout = 30;
4746

47+
public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2);
48+
4849
public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);
4950

5051
public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using Azure.Identity;
5+
using Microsoft.Azure.SignalR.Tests.Common;
6+
using Microsoft.Extensions.Logging.Abstractions;
7+
using Xunit;
8+
9+
namespace Microsoft.Azure.SignalR.Common.Tests.Auth;
10+
11+
public class AccessKeySynchronizerFacts
12+
{
13+
[Fact]
14+
public void AddAndRemoveServiceEndpointsTest()
15+
{
16+
var synchronizer = GetInstanceForTest();
17+
18+
var credential = new DefaultAzureCredential();
19+
var endpoint1 = new TestServiceEndpoint(credential);
20+
var endpoint2 = new TestServiceEndpoint(credential);
21+
22+
Assert.Equal(0, synchronizer.Count());
23+
synchronizer.UpdateServiceEndpoints([endpoint1]);
24+
Assert.Equal(1, synchronizer.Count());
25+
synchronizer.UpdateServiceEndpoints([endpoint1, endpoint2]);
26+
Assert.Empty(synchronizer.InitializedKeyList);
27+
28+
Assert.Equal(2, synchronizer.Count());
29+
Assert.True(synchronizer.ContainsKey(endpoint1));
30+
Assert.True(synchronizer.ContainsKey(endpoint2));
31+
32+
synchronizer.UpdateServiceEndpoints([endpoint2]);
33+
Assert.Equal(1, synchronizer.Count());
34+
synchronizer.UpdateServiceEndpoints([]);
35+
Assert.Equal(0, synchronizer.Count());
36+
Assert.Empty(synchronizer.InitializedKeyList);
37+
}
38+
39+
private static AccessKeySynchronizer GetInstanceForTest()
40+
{
41+
return new AccessKeySynchronizer(NullLoggerFactory.Instance, false);
42+
}
43+
}

test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraAccessKeyTests.cs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSig
165165
.ThrowsAsync(e);
166166
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
167167
{
168-
GetAccessKeyRetryInterval = TimeSpan.Zero
168+
GetAccessKeyRetryInterval = TimeSpan.Zero,
169169
};
170170

171171
var audience = "http://localhost/chat";
@@ -210,6 +210,52 @@ public async Task TestUpdateAccessKeySendRequest(string expectedKeyStr)
210210
Assert.Equal(expectedKeyStr, Encoding.UTF8.GetString(key.KeyBytes));
211211
}
212212

213+
[Fact]
214+
public async Task TestLazyLoadAccessKey()
215+
{
216+
var expectedKeyStr = DefaultSigningKey;
217+
var expectedKid = "foo";
218+
var text = "{" + string.Format("\"AccessKey\": \"{0}\", \"KeyId\": \"{1}\"", expectedKeyStr, expectedKid) + "}";
219+
var httpClientFactory = new TestHttpClientFactory(new HttpResponseMessage(HttpStatusCode.OK)
220+
{
221+
Content = TextHttpContent.From(text),
222+
});
223+
224+
var credential = new TestTokenCredential(TokenType.MicrosoftEntra);
225+
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, credential, httpClientFactory: httpClientFactory);
226+
227+
Assert.False(key.Initialized);
228+
229+
var token = await key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
230+
Assert.NotNull(token);
231+
232+
Assert.True(key.Initialized);
233+
}
234+
235+
[Fact]
236+
public async Task TestLazyLoadAccessKeyFailed()
237+
{
238+
var mockCredential = new Mock<TokenCredential>();
239+
mockCredential.Setup(credential => credential.GetTokenAsync(
240+
It.IsAny<TokenRequestContext>(),
241+
It.IsAny<CancellationToken>()))
242+
.ThrowsAsync(new Exception());
243+
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object)
244+
{
245+
GetAccessKeyRetryInterval = TimeSpan.FromSeconds(1),
246+
};
247+
248+
Assert.False(key.Initialized);
249+
250+
var task1 = key.GenerateAccessTokenAsync("https://localhost", [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
251+
var task2 = key.UpdateAccessKeyAsync();
252+
Assert.True(task2.IsCompleted); // another task is in progress.
253+
254+
await Assert.ThrowsAsync<AzureSignalRAccessTokenNotAuthorizedException>(async () => await task1);
255+
256+
Assert.True(key.Initialized);
257+
}
258+
213259
[Theory]
214260
[InlineData(TokenType.Local)]
215261
[InlineData(TokenType.MicrosoftEntra)]
@@ -226,7 +272,10 @@ public async Task ThrowUnauthorizedExceptionTest(TokenType tokenType)
226272
endpoint,
227273
new TestTokenCredential(tokenType),
228274
httpClientFactory: new TestHttpClientFactory(message)
229-
);
275+
)
276+
{
277+
GetAccessKeyRetryInterval = TimeSpan.Zero
278+
};
230279

231280
await key.UpdateAccessKeyAsync();
232281

0 commit comments

Comments
 (0)