Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/Microsoft.Azure.SignalR.Common/Auth/AuthUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ internal static class AuthUtility

private static readonly SignalRJwtSecurityTokenHandler JwtTokenHandler = new SignalRJwtSecurityTokenHandler();

public static string GenerateJwtBearer(byte[] keyBytes,
string? kid = null,
string? issuer = null,
string? audience = null,
IEnumerable<Claim>? claims = null,
DateTime? expires = null,
DateTime? issuedAt = null,
DateTime? notBefore = null,
AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256)
public static string GenerateJwtToken(byte[] keyBytes,
string? kid = null,
string? issuer = null,
string? audience = null,
IEnumerable<Claim>? claims = null,
DateTime? expires = null,
DateTime? issuedAt = null,
DateTime? notBefore = null,
AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256)
{
var subject = claims == null ? null : new ClaimsIdentity(claims);

Expand All @@ -52,9 +52,10 @@ public static string GenerateAccessToken(byte[] keyBytes,
{
var expire = DateTime.UtcNow.Add(lifetime);

var jwtToken = GenerateJwtBearer(
var jwtToken = GenerateJwtToken(
keyBytes,
kid,
issuer: Constants.AsrsTokenIssuer,
audience: audience,
claims: claims,
expires: expire,
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ internal static class Constants

public const string AsrsIsDiagnosticClient = "Asrs-Is-Diagnostic-Client";

public const string AsrsTokenIssuer = "azure-signalr";

public const int DefaultCloseTimeoutMilliseconds = 30000;

public static class Keys
Expand Down
26 changes: 26 additions & 0 deletions test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading.Tasks;
using Xunit;

namespace Microsoft.Azure.SignalR.Common.Tests.Auth;

public class AccessKeyTests
{
private const string Audience = "https://localhost/aspnetclient?hub=testhub";

private const string SigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";

private const string Endpoint = "https://localhost:443";

[Fact]
public async Task TestGenerateAccessToken()
{
var accessKey = new AccessKey(new Uri(Endpoint), SigningKey);
var token = await accessKey.GenerateAccessTokenAsync(Audience, [], TimeSpan.FromHours(1), AccessTokenAlgorithm.HS256);
Assert.True(TokenUtilities.TryParseIssuer(token, out var iss));
Assert.Equal(Constants.AsrsTokenIssuer, iss);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ public void TestAccessTokenTooLongThrowsException()
Assert.Equal("AccessToken must not be longer than 4K.", exception.Message);
}

[Fact]
public void TestAccessTokenHasIssuer()
{
var accessKey = new AccessKey(new Uri("http://localhost:443"), SigningKey);
var token = AuthUtility.GenerateAccessToken(accessKey.KeyBytes,
accessKey.Kid,
Audience,
[],
DefaultLifetime,
AccessTokenAlgorithm.HS256);

Assert.True(TokenUtilities.TryParseIssuer(token, out var iss));
Assert.Equal(Constants.AsrsTokenIssuer, iss);
}

[Theory]
[InlineData(null)]
[InlineData("")]
Expand All @@ -49,7 +64,7 @@ public void TestAccessTokenTooLongThrowsException()
public void TestTryParseIssuer(string? issuer)
{
var accessKey = new AccessKey(new Uri("http://localhost:443"), SigningKey);
var token = AuthUtility.GenerateJwtBearer(accessKey.KeyBytes, issuer: issuer);
var token = AuthUtility.GenerateJwtToken(accessKey.KeyBytes, issuer: issuer);

if (string.IsNullOrEmpty(issuer))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public async Task Call_NegotiateAsync_After_WithEndpoints(ServiceTransportType s
Assert.Equal(ClientEndpointUtils.GetExpectedClientEndpoint(Hub, null, randomEndpoint.Endpoint), negotiationResponse.Url);
var tokenString = negotiationResponse.AccessToken;
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);
var expectedToken = JwtTokenHelper.GenerateJwtBearer(
var expectedToken = JwtTokenHelper.GenerateJwtToken(
ClientEndpointUtils.GetExpectedClientEndpoint(Hub, null, randomEndpoint.Endpoint),
ClaimsUtility.BuildJwtClaims(null, null, null),
token.ValidTo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public async Task GenerateClientEndpoint(string userId, Claim[] claims, string a
var tokenString = negotiationResponse.AccessToken;
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedToken = JwtTokenHelper.GenerateJwtBearer(ClientEndpointUtils.GetExpectedClientEndpoint(HubName, appName, endpoints[i].Endpoint), ClaimsUtility.BuildJwtClaims(null, userId, () => claims), token.ValidTo, token.ValidFrom, token.ValidFrom, endpoints[i].AccessKey);
var expectedToken = JwtTokenHelper.GenerateJwtToken(ClientEndpointUtils.GetExpectedClientEndpoint(HubName, appName, endpoints[i].Endpoint), ClaimsUtility.BuildJwtClaims(null, userId, () => claims), token.ValidTo, token.ValidFrom, token.ValidFrom, endpoints[i].AccessKey);

Assert.Equal(ClientEndpointUtils.GetExpectedClientEndpoint(HubName, appName, endpoints[i].Endpoint), negotiationResponse.Url);
Assert.Equal(expectedToken, tokenString);
Expand Down
34 changes: 15 additions & 19 deletions test/Microsoft.Azure.SignalR.Tests.Common/JwtTokenHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static string GenerateExpectedAccessToken(JwtSecurityToken token, string
claims.AddRange(customClaims.ToList());
}

var tokenString = GenerateJwtBearer(
var tokenString = GenerateJwtToken(
audience, claims,
token.ValidTo,
token.ValidFrom,
Expand All @@ -48,16 +48,14 @@ public static string GenerateExpectedAccessToken(JwtSecurityToken token, string
return GenerateExpectedAccessToken(token, audience, new AccessKey(new Uri(TestEndpoint), key), customClaims: customClaims);
}

public static string GenerateJwtBearer(
string audience,
IEnumerable<Claim> subject,
DateTime expires,
DateTime notBefore,
DateTime issueAt,
IAccessKey signingKey
)
public static string GenerateJwtToken(string audience,
IEnumerable<Claim> subject,
DateTime expires,
DateTime notBefore,
DateTime issueAt,
IAccessKey signingKey)
{
return AuthUtility.GenerateJwtBearer(
return AuthUtility.GenerateJwtToken(
signingKey.KeyBytes,
signingKey.Kid,
issuer: null,
Expand All @@ -69,15 +67,13 @@ IAccessKey signingKey
);
}

public static string GenerateJwtBearer(
string audience,
IEnumerable<Claim> subject,
DateTime expires,
DateTime notBefore,
DateTime issueAt,
string signingKey
)
public static string GenerateJwtToken(string audience,
IEnumerable<Claim> subject,
DateTime expires,
DateTime notBefore,
DateTime issueAt,
string signingKey)
{
return GenerateJwtBearer(audience, subject, expires, notBefore, issueAt, new AccessKey(new Uri(TestEndpoint), signingKey));
return GenerateJwtToken(audience, subject, expires, notBefore, issueAt, new AccessKey(new Uri(TestEndpoint), signingKey));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ internal async Task GenerateServerAccessToken(IServiceEndpointProvider provider)
var tokenString = await provider.GetServerAccessTokenProvider(nameof(TestHub), userId).ProvideAsync();
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/server/?hub={HubName}",
var expectedTokenString = JwtTokenHelper.GenerateJwtToken($"{Endpoint}/server/?hub={HubName}",
new[]
{
new Claim(ClaimTypes.NameIdentifier, userId)
Expand All @@ -154,7 +154,7 @@ internal async Task GenerateServerAccessTokenWithPrefix(IServiceEndpointProvider
var tokenString = await provider.GetServerAccessTokenProvider(nameof(TestHub), userId).ProvideAsync();
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/server/?hub={AppName}_{HubName}",
var expectedTokenString = JwtTokenHelper.GenerateJwtToken($"{Endpoint}/server/?hub={AppName}_{HubName}",
new[]
{
new Claim(ClaimTypes.NameIdentifier, userId)
Expand All @@ -176,7 +176,7 @@ internal async Task GenerateClientAccessToken(IServiceEndpointProvider provider)
var tokenString = await provider.GenerateClientAccessTokenAsync(HubName);
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/client/?hub={HubName}",
var expectedTokenString = JwtTokenHelper.GenerateJwtToken($"{Endpoint}/client/?hub={HubName}",
null,
token.ValidTo,
token.ValidFrom,
Expand All @@ -194,7 +194,7 @@ internal async Task GenerateClientAccessTokenWithPrefix(IServiceEndpointProvider
var tokenString = await provider.GenerateClientAccessTokenAsync(HubName);
var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/client/?hub={AppName}_{HubName}",
var expectedTokenString = JwtTokenHelper.GenerateJwtToken($"{Endpoint}/client/?hub={AppName}_{HubName}",
null,
token.ValidTo,
token.ValidFrom,
Expand Down
Loading