Skip to content

Commit f6c1db6

Browse files
committed
token provider
1 parent 712cafa commit f6c1db6

File tree

13 files changed

+136
-50
lines changed

13 files changed

+136
-50
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,6 @@ __pycache__/
297297

298298
# docker
299299
.docker/
300+
301+
# C# SDK version config
302+
global.json

src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,6 @@ public Task<string> GenerateClientAccessTokenAsync(string hubName = null, IEnume
5959
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
6060
}
6161

62-
public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
63-
{
64-
if (_accessKey is AadAccessKey key)
65-
{
66-
return key.GenerateAadTokenAsync();
67-
}
68-
69-
if (string.IsNullOrEmpty(hubName))
70-
{
71-
throw new ArgumentNullException(nameof(hubName));
72-
}
73-
74-
var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
75-
var claims = userId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, userId) } : null;
76-
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
77-
}
78-
7962
public string GetClientEndpoint(string hubName = null, string originalPath = null, string queryString = null)
8063
{
8164
var queryBuilder = new StringBuilder();
@@ -110,6 +93,24 @@ public string GetServerEndpoint(string hubName)
11093
return $"{_serverEndpoint}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
11194
}
11295

96+
public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
97+
{
98+
if (_accessKey is AadAccessKey aadAccessKey)
99+
{
100+
return new AadTokenProvider(aadAccessKey);
101+
}
102+
else if (_accessKey is not null)
103+
{
104+
var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
105+
var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null;
106+
return new LocalTokenProvider(_accessKey, audience, claims, _algorithm, _accessTokenLifetime);
107+
}
108+
else
109+
{
110+
throw new ArgumentNullException(nameof(AccessKey));
111+
}
112+
}
113+
113114
private string GetPrefixedHubName(string applicationName, string hubName)
114115
{
115116
return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}";
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System;
5+
using System.Threading.Tasks;
6+
7+
namespace Microsoft.Azure.SignalR
8+
{
9+
internal class AadTokenProvider : IAccessTokenProvider
10+
{
11+
private readonly AadAccessKey _accessKey;
12+
13+
public AadTokenProvider(AadAccessKey accessKey)
14+
{
15+
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
16+
}
17+
18+
public Task<string> ProvideAsync() => _accessKey.GenerateAadTokenAsync();
19+
}
20+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Security.Claims;
7+
using System.Threading.Tasks;
8+
9+
namespace Microsoft.Azure.SignalR
10+
{
11+
internal class LocalTokenProvider : IAccessTokenProvider
12+
{
13+
private readonly AccessKey _accessKey;
14+
15+
private readonly AccessTokenAlgorithm _algorithm;
16+
17+
private readonly string _audience;
18+
19+
private readonly TimeSpan _tokenLifetime;
20+
21+
private readonly IEnumerable<Claim> _claims;
22+
23+
public LocalTokenProvider(
24+
AccessKey accessKey,
25+
string audience,
26+
IEnumerable<Claim> claims,
27+
AccessTokenAlgorithm? algorithm = null,
28+
TimeSpan? tokenLifetime = null)
29+
{
30+
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
31+
_algorithm = algorithm ?? AccessTokenAlgorithm.HS256;
32+
_audience = audience;
33+
_claims = claims;
34+
_tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
35+
}
36+
37+
public Task<string> ProvideAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm);
38+
}
39+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System.Threading.Tasks;
5+
6+
namespace Microsoft.Azure.SignalR
7+
{
8+
public interface IAccessTokenProvider
9+
{
10+
Task<string> ProvideAsync();
11+
}
12+
}

src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceEndpointProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ internal interface IServiceEndpointProvider
1515

1616
string GetClientEndpoint(string hubName, string originalPath, string queryString);
1717

18-
Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null);
18+
IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId);
1919

2020
string GetServerEndpoint(string hubName);
2121

src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactory.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace Microsoft.Azure.SignalR
1414
internal class ConnectionFactory : IConnectionFactory
1515
{
1616
private readonly ILoggerFactory _loggerFactory;
17+
1718
private readonly string _serverId;
1819

1920
public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory loggerFactory)
@@ -31,7 +32,9 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint hubServiceE
3132
{
3233
var provider = hubServiceEndpoint.Provider;
3334
var hubName = hubServiceEndpoint.Hub;
34-
Task<string> accessTokenProvider() => provider.GenerateServerAccessTokenAsync(hubName, _serverId);
35+
36+
var accessTokenProvider = provider.GetServerAccessTokenProvider(hubName, _serverId);
37+
3538
var url = GetServiceUrl(provider, hubName, connectionId, target);
3639

3740
headers ??= new Dictionary<string, string>();
@@ -91,6 +94,7 @@ private Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, str
9194
private sealed class GracefulLoggerFactory : ILoggerFactory
9295
{
9396
private readonly ILoggerFactory _inner;
97+
9498
public GracefulLoggerFactory(ILoggerFactory inner)
9599
{
96100
_inner = inner;
@@ -115,6 +119,7 @@ public void AddProvider(ILoggerProvider provider)
115119
private sealed class GracefulLogger : ILogger
116120
{
117121
private readonly ILogger _inner;
122+
118123
public GracefulLogger(ILogger inner)
119124
{
120125
_inner = inner;

src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ internal partial class WebSocketsTransport : IDuplexPipe
2323

2424
private readonly WebSocketMessageType _webSocketMessageType = WebSocketMessageType.Binary;
2525
private readonly ClientWebSocket _webSocket;
26-
private readonly Func<Task<string>> _accessTokenProvider;
26+
private readonly IAccessTokenProvider _accessTokenProvider;
2727
private IDuplexPipe _application;
2828
private readonly ILogger _logger;
2929
private readonly TimeSpan _closeTimeout;
@@ -37,7 +37,7 @@ internal partial class WebSocketsTransport : IDuplexPipe
3737

3838
public PipeWriter Output => _transport.Output;
3939

40-
public WebSocketsTransport(WebSocketConnectionOptions connectionOptions, ILoggerFactory loggerFactory, Func<Task<string>> accessTokenProvider)
40+
public WebSocketsTransport(WebSocketConnectionOptions connectionOptions, ILoggerFactory loggerFactory, IAccessTokenProvider accessTokenProvider)
4141
{
4242
_logger = (loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory))).CreateLogger<WebSocketsTransport>();
4343
_webSocket = new ClientWebSocket();
@@ -105,7 +105,7 @@ public async Task StartAsync(Uri url, CancellationToken cancellationToken = defa
105105
// We don't need to capture to a local because we never change this delegate.
106106
if (_accessTokenProvider != null)
107107
{
108-
var accessToken = await _accessTokenProvider();
108+
var accessToken = await _accessTokenProvider.ProvideAsync();
109109
if (!string.IsNullOrEmpty(accessToken))
110110
{
111111
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");

src/Microsoft.Azure.SignalR.Common/ServiceConnections/WebSocketConnectionContext.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ internal class WebSocketConnectionContext : ConnectionContext
3333

3434
public WebSocketConnectionContext(WebSocketConnectionOptions httpConnectionOptions,
3535
ILoggerFactory loggerFactory,
36-
Func<Task<string>> accessTokenProvider)
36+
IAccessTokenProvider accessTokenProvider)
3737
{
3838
Transport = _websocketTransport = new WebSocketsTransport(httpConnectionOptions, loggerFactory, accessTokenProvider);
3939
ConnectionId = "sc_" + Guid.NewGuid();

src/Microsoft.Azure.SignalR/EndpointProvider/ServiceEndpointProvider.cs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,31 @@ public Task<string> GenerateClientAccessTokenAsync(string hubName, IEnumerable<C
5252
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
5353
}
5454

55-
public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
56-
{
57-
if (_accessKey is AadAccessKey key)
58-
{
59-
return key.GenerateAadTokenAsync();
60-
}
61-
62-
if (string.IsNullOrEmpty(hubName))
63-
{
64-
throw new ArgumentNullException(nameof(hubName));
65-
}
66-
67-
var audience = _generator.GetServerAudience(hubName, _appName);
68-
var claims = userId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, userId) } : null;
69-
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
70-
}
71-
7255
public string GetClientEndpoint(string hubName, string originalPath, string queryString)
7356
{
7457
return string.IsNullOrEmpty(hubName)
7558
? throw new ArgumentNullException(nameof(hubName))
7659
: _generator.GetClientEndpoint(hubName, _appName, originalPath, queryString);
7760
}
7861

62+
public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
63+
{
64+
if (_accessKey is AadAccessKey aadAccessKey)
65+
{
66+
return new AadTokenProvider(aadAccessKey);
67+
}
68+
else if (_accessKey is not null)
69+
{
70+
var audience = _generator.GetServerAudience(hubName, _appName);
71+
var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null;
72+
return new LocalTokenProvider(_accessKey, audience, claims, _algorithm, _accessTokenLifetime);
73+
}
74+
else
75+
{
76+
throw new NotSupportedException("Unsupported auth type.");
77+
}
78+
}
79+
7980
public string GetServerEndpoint(string hubName)
8081
{
8182
return string.IsNullOrEmpty(hubName)

0 commit comments

Comments
 (0)