Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,6 @@ public Task<string> GenerateClientAccessTokenAsync(string hubName = null, IEnume
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
{
if (_accessKey is AadAccessKey key)
{
return key.GenerateAadTokenAsync();
}

if (string.IsNullOrEmpty(hubName))
{
throw new ArgumentNullException(nameof(hubName));
}

var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
var claims = userId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, userId) } : null;
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

public string GetClientEndpoint(string hubName = null, string originalPath = null, string queryString = null)
{
var queryBuilder = new StringBuilder();
Expand Down Expand Up @@ -110,6 +93,24 @@ public string GetServerEndpoint(string hubName)
return $"{_serverEndpoint}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
}

public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
{
if (_accessKey is AadAccessKey aadAccessKey)
{
return new AadTokenProvider(aadAccessKey);
}
else if (_accessKey is not null)
{
var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null;
return new LocalTokenProvider(_accessKey, audience, claims, _algorithm, _accessTokenLifetime);
}
else
{
throw new ArgumentNullException(nameof(AccessKey));
}
}

private string GetPrefixedHubName(string applicationName, string hubName)
{
return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}";
Expand Down
20 changes: 20 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Auth/AadTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR
{
internal class AadTokenProvider : IAccessTokenProvider
{
private readonly AadAccessKey _accessKey;

public AadTokenProvider(AadAccessKey accessKey)
{
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
}

public Task<string> ProvideAsync() => _accessKey.GenerateAadTokenAsync();
}
}
39 changes: 39 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR
{
internal class LocalTokenProvider : IAccessTokenProvider
{
private readonly AccessKey _accessKey;

private readonly AccessTokenAlgorithm _algorithm;

private readonly string _audience;

private readonly TimeSpan _tokenLifetime;

private readonly IEnumerable<Claim> _claims;

public LocalTokenProvider(
AccessKey accessKey,
string audience,
IEnumerable<Claim> claims,
AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256,
TimeSpan? tokenLifetime = null)
{
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
_algorithm = algorithm;
_audience = audience;
_claims = claims;
_tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
}

public Task<string> ProvideAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR
{
internal interface IAccessTokenProvider
{
Task<string> ProvideAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal interface IServiceEndpointProvider

string GetClientEndpoint(string hubName, string originalPath, string queryString);

Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null);
IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId);

string GetServerEndpoint(string hubName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace Microsoft.Azure.SignalR
internal class ConnectionFactory : IConnectionFactory
{
private readonly ILoggerFactory _loggerFactory;

private readonly string _serverId;

public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory loggerFactory)
Expand All @@ -31,7 +32,9 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint hubServiceE
{
var provider = hubServiceEndpoint.Provider;
var hubName = hubServiceEndpoint.Hub;
Task<string> accessTokenProvider() => provider.GenerateServerAccessTokenAsync(hubName, _serverId);

var accessTokenProvider = provider.GetServerAccessTokenProvider(hubName, _serverId);

var url = GetServiceUrl(provider, hubName, connectionId, target);

headers ??= new Dictionary<string, string>();
Expand Down Expand Up @@ -91,6 +94,7 @@ private Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, str
private sealed class GracefulLoggerFactory : ILoggerFactory
{
private readonly ILoggerFactory _inner;

public GracefulLoggerFactory(ILoggerFactory inner)
{
_inner = inner;
Expand All @@ -115,6 +119,7 @@ public void AddProvider(ILoggerProvider provider)
private sealed class GracefulLogger : ILogger
{
private readonly ILogger _inner;

public GracefulLogger(ILogger inner)
{
_inner = inner;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ internal partial class WebSocketsTransport : IDuplexPipe

private readonly WebSocketMessageType _webSocketMessageType = WebSocketMessageType.Binary;
private readonly ClientWebSocket _webSocket;
private readonly Func<Task<string>> _accessTokenProvider;
private readonly IAccessTokenProvider _accessTokenProvider;
private IDuplexPipe _application;
private readonly ILogger _logger;
private readonly TimeSpan _closeTimeout;
Expand All @@ -37,7 +37,7 @@ internal partial class WebSocketsTransport : IDuplexPipe

public PipeWriter Output => _transport.Output;

public WebSocketsTransport(WebSocketConnectionOptions connectionOptions, ILoggerFactory loggerFactory, Func<Task<string>> accessTokenProvider)
public WebSocketsTransport(WebSocketConnectionOptions connectionOptions, ILoggerFactory loggerFactory, IAccessTokenProvider accessTokenProvider)
{
_logger = (loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory))).CreateLogger<WebSocketsTransport>();
_webSocket = new ClientWebSocket();
Expand Down Expand Up @@ -105,7 +105,7 @@ public async Task StartAsync(Uri url, CancellationToken cancellationToken = defa
// We don't need to capture to a local because we never change this delegate.
if (_accessTokenProvider != null)
{
var accessToken = await _accessTokenProvider();
var accessToken = await _accessTokenProvider.ProvideAsync();
if (!string.IsNullOrEmpty(accessToken))
{
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ internal class WebSocketConnectionContext : ConnectionContext

public WebSocketConnectionContext(WebSocketConnectionOptions httpConnectionOptions,
ILoggerFactory loggerFactory,
Func<Task<string>> accessTokenProvider)
IAccessTokenProvider accessTokenProvider)
{
Transport = _websocketTransport = new WebSocketsTransport(httpConnectionOptions, loggerFactory, accessTokenProvider);
ConnectionId = "sc_" + Guid.NewGuid();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,31 @@ public Task<string> GenerateClientAccessTokenAsync(string hubName, IEnumerable<C
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
{
if (_accessKey is AadAccessKey key)
{
return key.GenerateAadTokenAsync();
}

if (string.IsNullOrEmpty(hubName))
{
throw new ArgumentNullException(nameof(hubName));
}

var audience = _generator.GetServerAudience(hubName, _appName);
var claims = userId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, userId) } : null;
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
}

public string GetClientEndpoint(string hubName, string originalPath, string queryString)
{
return string.IsNullOrEmpty(hubName)
? throw new ArgumentNullException(nameof(hubName))
: _generator.GetClientEndpoint(hubName, _appName, originalPath, queryString);
}

public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
{
if (_accessKey is AadAccessKey aadAccessKey)
{
return new AadTokenProvider(aadAccessKey);
}
else if (_accessKey is not null)
{
var audience = _generator.GetServerAudience(hubName, _appName);
var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null;
return new LocalTokenProvider(_accessKey, audience, claims, _algorithm, _accessTokenLifetime);
}
else
{
throw new NotSupportedException("Access key cannot be null.");
}
}

public string GetServerEndpoint(string hubName)
{
return string.IsNullOrEmpty(hubName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public async Task TestGenerateServerAccessToken(string connectionString, string
{
var provider = new ServiceEndpointProvider(new ServiceEndpoint(connectionString), new ServiceOptions() { });

var serverToken = await provider.GenerateServerAccessTokenAsync("hub1", "user1");
var serverToken = await provider.GetServerAccessTokenProvider("hub1", "user1").ProvideAsync();

var handler = new JwtSecurityTokenHandler();
var principal = handler.ValidateToken(serverToken, new TokenValidationParameters
Expand All @@ -89,7 +89,7 @@ public async Task TestGenerateServerAccessTokenWIthPrefix(string connectionStrin
{
var provider = new ServiceEndpointProvider(new ServiceEndpoint(connectionString), new ServiceOptions() { ApplicationName = "prefix" });

var serverToken = await provider.GenerateServerAccessTokenAsync("hub1", "user1");
var serverToken = await provider.GetServerAccessTokenProvider("hub1", "user1").ProvideAsync();

var handler = new JwtSecurityTokenHandler();
var principal = handler.ValidateToken(serverToken, new TokenValidationParameters
Expand Down Expand Up @@ -131,7 +131,7 @@ public async Task TestGenerateServerAccessTokenWithSpecifedAlgorithm(AccessToken
{
var connectionString = "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789;Port=8080;Version=1.0";
var provider = new ServiceEndpointProvider(new ServiceEndpoint(connectionString), new ServiceOptions() { AccessTokenAlgorithm = algorithm });
var serverToken = await provider.GenerateServerAccessTokenAsync("hub1", "user1");
var serverToken = await provider.GetServerAccessTokenProvider("hub1", "user1").ProvideAsync();

var handler = new JwtSecurityTokenHandler();
var token = handler.ReadJwtToken(serverToken);
Expand Down Expand Up @@ -162,10 +162,10 @@ public async Task GenerateMultipleAccessTokenShouldBeUnique()
var tokens = new List<string>();

var count = 1000;
for (int i = 0; i < count; i++)
for (var i = 0; i < count; i++)
{
tokens.Add(await sep.GenerateClientAccessTokenAsync());
tokens.Add(await sep.GenerateServerAccessTokenAsync("test1", userId));
tokens.Add(await sep.GetServerAccessTokenProvider("test1", userId).ProvideAsync());
}

var distinct = tokens.Distinct();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public Task<string> GenerateClientAccessTokenAsync(string hubName, IEnumerable<C
throw new NotImplementedException();
}

public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
public IAccessTokenProvider GetServerAccessTokenProvider(string hubName, string serverId)
{
throw new NotImplementedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ internal async Task GenerateMultipleAccessTokenShouldBeUnique()
var sep = new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithPreviewVersion), _optionsWithoutAppName);
var userId = Guid.NewGuid().ToString();
var tokens = new List<string>();
for (int i = 0; i < count; i++)
for (var i = 0; i < count; i++)
{
tokens.Add(await sep.GenerateClientAccessTokenAsync(nameof(TestHub)));
tokens.Add(await sep.GenerateServerAccessTokenAsync(nameof(TestHub), userId));
tokens.Add(await sep.GetServerAccessTokenProvider(nameof(TestHub), userId).ProvideAsync());
}

var distinct = tokens.Distinct();
Expand All @@ -129,7 +129,7 @@ internal async Task GenerateMultipleAccessTokenShouldBeUnique()
internal async Task GenerateServerAccessToken(IServiceEndpointProvider provider)
{
const string userId = "UserA";
var tokenString = await provider.GenerateServerAccessTokenAsync(nameof(TestHub), userId);
var tokenString = await provider.GetServerAccessTokenProvider(nameof(TestHub), userId).ProvideAsync();
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/server/?hub={HubName}",
Expand All @@ -151,7 +151,7 @@ internal async Task GenerateServerAccessToken(IServiceEndpointProvider provider)
internal async Task GenerateServerAccessTokenWithPrefix(IServiceEndpointProvider provider)
{
const string userId = "UserA";
var tokenString = await provider.GenerateServerAccessTokenAsync(nameof(TestHub), userId);
var tokenString = await provider.GetServerAccessTokenProvider(nameof(TestHub), userId).ProvideAsync();
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/server/?hub={AppName}_{HubName}",
Expand Down Expand Up @@ -211,7 +211,7 @@ internal async Task GenerateClientAccessTokenWithPrefix(IServiceEndpointProvider
public async Task GenerateServerAccessTokenWithSpecifedAlgorithm(AccessTokenAlgorithm algorithm)
{
var provider = new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithV1Version), new ServiceOptions() { AccessTokenAlgorithm = algorithm });
var serverToken = await provider.GenerateServerAccessTokenAsync("hub1", "user1");
var serverToken = await provider.GetServerAccessTokenProvider("hub1", "user1").ProvideAsync();

var token = JwtTokenHelper.JwtHandler.ReadJwtToken(serverToken);

Expand Down