Skip to content

Commit b07f6a2

Browse files
committed
token provider
1 parent 712cafa commit b07f6a2

File tree

17 files changed

+174
-69
lines changed

17 files changed

+174
-69
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,6 @@ __pycache__/
297297

298298
# docker
299299
.docker/
300+
301+
# C# SDK version config
302+
global.json

src/Microsoft.Azure.SignalR.AspNet/EndpointProvider/ServiceEndpointProvider.cs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using System.Security.Claims;
88
using System.Text;
99
using System.Threading.Tasks;
10+
using Microsoft.Azure.SignalR.Common.Endpoints;
11+
using Microsoft.Azure.SignalR.Common.Interfaces;
1012

1113
namespace Microsoft.Azure.SignalR.AspNet
1214
{
@@ -59,23 +61,6 @@ public Task<string> GenerateClientAccessTokenAsync(string hubName = null, IEnume
5961
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
6062
}
6163

62-
public Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null)
63-
{
64-
if (_accessKey is AadAccessKey key)
65-
{
66-
return key.GenerateAadTokenAsync();
67-
}
68-
69-
if (string.IsNullOrEmpty(hubName))
70-
{
71-
throw new ArgumentNullException(nameof(hubName));
72-
}
73-
74-
var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
75-
var claims = userId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, userId) } : null;
76-
return _accessKey.GenerateAccessTokenAsync(audience, claims, lifetime ?? _accessTokenLifetime, _algorithm);
77-
}
78-
7964
public string GetClientEndpoint(string hubName = null, string originalPath = null, string queryString = null)
8065
{
8166
var queryBuilder = new StringBuilder();
@@ -110,6 +95,24 @@ public string GetServerEndpoint(string hubName)
11095
return $"{_serverEndpoint}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
11196
}
11297

98+
public IAccessTokenProvider GetAccessTokenProvider(string hubName, string serverId)
99+
{
100+
if (_accessKey is AadAccessKey aadAccessKey)
101+
{
102+
return new AadTokenProvider(aadAccessKey);
103+
}
104+
else if (_accessKey is not null)
105+
{
106+
var audience = $"{_audienceBaseUrl}{ServerPath}/?hub={GetPrefixedHubName(_appName, hubName)}";
107+
var claims = serverId != null ? new[] { new Claim(ClaimTypes.NameIdentifier, serverId) } : null;
108+
return new LocalTokenProvider(_accessKey, audience, claims, _algorithm, _accessTokenLifetime);
109+
}
110+
else
111+
{
112+
throw new ArgumentNullException(nameof(AccessKey));
113+
}
114+
}
115+
113116
private string GetPrefixedHubName(string applicationName, string hubName)
114117
{
115118
return string.IsNullOrEmpty(applicationName) ? hubName.ToLower() : $"{applicationName.ToLower()}_{hubName.ToLower()}";
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System;
5+
using System.Threading.Tasks;
6+
using Microsoft.Azure.SignalR.Common.Interfaces;
7+
8+
namespace Microsoft.Azure.SignalR.Common.Endpoints
9+
{
10+
internal class AadTokenProvider : IAccessTokenProvider
11+
{
12+
private readonly AadAccessKey _accessKey;
13+
14+
public AadTokenProvider(AadAccessKey accessKey)
15+
{
16+
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
17+
}
18+
19+
public Task<string> GenerateTokenAsync() => _accessKey.GenerateAadTokenAsync();
20+
}
21+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Security.Claims;
7+
using System.Threading.Tasks;
8+
using Microsoft.Azure.SignalR.Common.Interfaces;
9+
10+
namespace Microsoft.Azure.SignalR.Common.Endpoints
11+
{
12+
internal class LocalTokenProvider : IAccessTokenProvider
13+
{
14+
private readonly AccessKey _accessKey;
15+
16+
private readonly AccessTokenAlgorithm _algorithm;
17+
18+
private readonly string _audience;
19+
20+
private readonly TimeSpan _tokenLifetime;
21+
22+
private readonly IEnumerable<Claim> _claims;
23+
24+
public LocalTokenProvider(
25+
AccessKey accessKey,
26+
string audience,
27+
IEnumerable<Claim> claims,
28+
AccessTokenAlgorithm? algorithm = null,
29+
TimeSpan? tokenLifetime = null)
30+
{
31+
_accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey));
32+
_algorithm = algorithm ?? AccessTokenAlgorithm.HS256;
33+
_audience = audience;
34+
_claims = claims;
35+
_tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
36+
}
37+
38+
public Task<string> GenerateTokenAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm);
39+
}
40+
}

src/Microsoft.Azure.SignalR.Common/Exceptions/AzureSignalRUnauthorizedException.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
5-
using System.Runtime.Serialization;
5+
using Microsoft.Azure.SignalR.Common.Endpoints;
6+
using Microsoft.Azure.SignalR.Common.Interfaces;
67

78
namespace Microsoft.Azure.SignalR.Common
89
{
@@ -11,16 +12,23 @@ public class AzureSignalRUnauthorizedException : AzureSignalRException
1112
{
1213
private const string ErrorMessage = "Authorization failed. If you were using AccessKey, please check connection string and see if the AccessKey is correct. If you were using Azure Active Directory, please note that the role assignments will take up to 30 minutes to take effect if it was added recently.";
1314

15+
private const string ErrorMessageLocal = "Authorization failed. Please check your connection string and see if the AccessKey is correct.";
16+
17+
private const string ErrorMessageAad = "Authorization failed. Please check your role assignments and see if you have the `Signal App Server` or `SignalR Service Owner` role. Please also note that the role assignments will take up to 30 minutes to take effect.";
18+
1419
public AzureSignalRUnauthorizedException(string requestUri, Exception innerException) : base(string.IsNullOrEmpty(requestUri) ? ErrorMessage : $"{ErrorMessage} Request Uri: {requestUri}", innerException)
1520
{
1621
}
1722

18-
internal AzureSignalRUnauthorizedException(Exception innerException) : base(ErrorMessage, innerException)
23+
private AzureSignalRUnauthorizedException(Exception ex, string message) : base(message, ex)
1924
{
2025
}
2126

22-
protected AzureSignalRUnauthorizedException(SerializationInfo info, StreamingContext context) : base(info, context)
27+
internal static AzureSignalRUnauthorizedException Wraps(IAccessTokenProvider provider, Exception innerException) => provider switch
2328
{
24-
}
29+
LocalTokenProvider => new AzureSignalRUnauthorizedException(innerException, ErrorMessageLocal),
30+
AadTokenProvider => new AzureSignalRUnauthorizedException(innerException, ErrorMessageAad),
31+
_ => throw new NotSupportedException(),
32+
};
2533
}
2634
}

src/Microsoft.Azure.SignalR.Common/Exceptions/ExceptionExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using System.Net;
66
using System.Net.Http;
77
using System.Net.WebSockets;
8-
8+
using Microsoft.Azure.SignalR.Common.Interfaces;
99
using Microsoft.Rest;
1010

1111
namespace Microsoft.Azure.SignalR.Common
@@ -38,14 +38,14 @@ internal static Exception WrapAsAzureSignalRException(this Exception e, Uri base
3838
}
3939
}
4040

41-
internal static Exception WrapAsAzureSignalRException(this Exception e)
41+
internal static Exception WrapAsAzureSignalRException(this Exception e, IAccessTokenProvider provider)
4242
{
4343
switch (e)
4444
{
4545
case WebSocketException webSocketException:
4646
if (e.Message.StartsWith("The server returned status code \"401\""))
4747
{
48-
return new AzureSignalRUnauthorizedException(webSocketException);
48+
return AzureSignalRUnauthorizedException.Wraps(provider, webSocketException);
4949
}
5050
return e;
5151
default: return e;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System.Threading.Tasks;
5+
6+
namespace Microsoft.Azure.SignalR.Common.Interfaces
7+
{
8+
public interface IAccessTokenProvider
9+
{
10+
Task<string> GenerateTokenAsync();
11+
}
12+
}

src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceEndpointProvider.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Net;
77
using System.Security.Claims;
88
using System.Threading.Tasks;
9+
using Microsoft.Azure.SignalR.Common.Interfaces;
910

1011
namespace Microsoft.Azure.SignalR
1112
{
@@ -15,7 +16,7 @@ internal interface IServiceEndpointProvider
1516

1617
string GetClientEndpoint(string hubName, string originalPath, string queryString);
1718

18-
Task<string> GenerateServerAccessTokenAsync(string hubName, string userId, TimeSpan? lifetime = null);
19+
IAccessTokenProvider GetAccessTokenProvider(string hubName, string serverId);
1920

2021
string GetServerEndpoint(string hubName);
2122

src/Microsoft.Azure.SignalR.Common/RestClients/SignalRServiceRestClient.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ public partial class SignalRServiceRestClient
1515
{
1616
private readonly string _userAgent;
1717

18-
public SignalRServiceRestClient(string userAgent, ServiceClientCredentials credentials, HttpClient httpClient, bool disposeHttpClient) : this(credentials, httpClient, disposeHttpClient)
18+
public SignalRServiceRestClient(string userAgent,
19+
ServiceClientCredentials credentials,
20+
HttpClient httpClient,
21+
bool disposeHttpClient) : this(credentials, httpClient, disposeHttpClient)
1922
{
2023
if (string.IsNullOrWhiteSpace(userAgent))
2124
{

src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactory.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace Microsoft.Azure.SignalR
1414
internal class ConnectionFactory : IConnectionFactory
1515
{
1616
private readonly ILoggerFactory _loggerFactory;
17+
1718
private readonly string _serverId;
1819

1920
public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory loggerFactory)
@@ -31,7 +32,9 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint hubServiceE
3132
{
3233
var provider = hubServiceEndpoint.Provider;
3334
var hubName = hubServiceEndpoint.Hub;
34-
Task<string> accessTokenProvider() => provider.GenerateServerAccessTokenAsync(hubName, _serverId);
35+
36+
var accessTokenProvider = provider.GetAccessTokenProvider(hubName, _serverId);
37+
3538
var url = GetServiceUrl(provider, hubName, connectionId, target);
3639

3740
headers ??= new Dictionary<string, string>();
@@ -91,6 +94,7 @@ private Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, str
9194
private sealed class GracefulLoggerFactory : ILoggerFactory
9295
{
9396
private readonly ILoggerFactory _inner;
97+
9498
public GracefulLoggerFactory(ILoggerFactory inner)
9599
{
96100
_inner = inner;
@@ -115,6 +119,7 @@ public void AddProvider(ILoggerProvider provider)
115119
private sealed class GracefulLogger : ILogger
116120
{
117121
private readonly ILogger _inner;
122+
118123
public GracefulLogger(ILogger inner)
119124
{
120125
_inner = inner;

0 commit comments

Comments
 (0)