Skip to content

Commit 621886c

Browse files
authored
Add authtypes (azure, azure.msi, azure.app) (#1616)
1 parent de6672f commit 621886c

File tree

2 files changed

+120
-38
lines changed

2 files changed

+120
-38
lines changed

src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal static class ConnectionStringParser
1616

1717
private const string ClientCertProperty = "clientCert";
1818

19-
private const string ClientEndpointProperty = "ClientEndpoint";
19+
private const string ClientEndpointProperty = "clientEndpoint";
2020

2121
private const string ClientIdProperty = "clientId";
2222

@@ -35,58 +35,57 @@ internal static class ConnectionStringParser
3535

3636
private const string TenantIdProperty = "tenantId";
3737

38+
private const string TypeAzure = "azure";
39+
40+
private const string TypeAzureAD = "aad";
41+
42+
private const string TypeAzureApp = "azure.app";
43+
44+
private const string TypeAzureMsi = "azure.msi";
45+
3846
private const string ValidVersionRegex = "^" + SupportedVersion + @"\.\d+(?:[\w-.]+)?$";
3947

4048
private const string VersionProperty = "version";
4149

42-
private static readonly string InvalidPortValue = $"Invalid value for {PortProperty} property.";
50+
private static readonly string InvalidClientEndpointProperty = $"Invalid value for {ClientEndpointProperty} property, it must be a valid URI.";
51+
52+
private static readonly string InvalidEndpointProperty = $"Invalid value for {EndpointProperty} property, it must be a valid URI.";
53+
54+
private static readonly string InvalidPortValue = $"Invalid value for {PortProperty} property, it must be an positive integer between (0, 65536)";
4355

4456
private static readonly char[] KeyValueSeparator = { '=' };
4557

4658
private static readonly string MissingAccessKeyProperty =
4759
$"{AccessKeyProperty} is required.";
4860

61+
private static readonly string MissingClientIdProperty =
62+
$"Connection string missing required properties {ClientIdProperty}.";
63+
4964
private static readonly string MissingClientSecretProperty =
5065
$"Connection string missing required properties {ClientSecretProperty} or {ClientCertProperty}.";
5166

5267
private static readonly string MissingEndpointProperty =
5368
$"Connection string missing required properties {EndpointProperty}.";
5469

70+
private static readonly string MissingTenantIdProperty =
71+
$"Connection string missing required properties {TenantIdProperty}.";
72+
5573
private static readonly char[] PropertySeparator = { ';' };
5674

5775
internal static ParsedConnectionString Parse(string connectionString)
5876
{
59-
var properties = connectionString.Split(PropertySeparator, StringSplitOptions.RemoveEmptyEntries);
60-
if (properties.Length < 2)
61-
{
62-
throw new ArgumentException(MissingEndpointProperty, nameof(connectionString));
63-
}
64-
65-
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
66-
foreach (var property in properties)
67-
{
68-
var kvp = property.Split(KeyValueSeparator, 2);
69-
if (kvp.Length != 2) continue;
70-
71-
var key = kvp[0].Trim();
72-
if (dict.ContainsKey(key))
73-
{
74-
throw new ArgumentException($"Duplicate properties found in connection string: {key}.");
75-
}
76-
77-
dict.Add(key, kvp[1].Trim());
78-
}
77+
var dict = ToDictionary(connectionString);
7978

8079
// parse and validate endpoint.
8180
if (!dict.TryGetValue(EndpointProperty, out var endpoint))
8281
{
83-
throw new ArgumentException(MissingEndpointProperty, nameof(connectionString));
82+
throw new ArgumentException(MissingEndpointProperty, nameof(endpoint));
8483
}
8584
endpoint = endpoint.TrimEnd('/');
8685

8786
if (!TryGetEndpointUri(endpoint, out var endpointUri))
8887
{
89-
throw new ArgumentException($"Endpoint property in connection string is not a valid URI: {dict[EndpointProperty]}.");
88+
throw new ArgumentException(InvalidEndpointProperty, nameof(endpoint));
9089
}
9190
var builder = new UriBuilder(endpointUri);
9291

@@ -96,21 +95,21 @@ internal static ParsedConnectionString Parse(string connectionString)
9695
{
9796
if (!Regex.IsMatch(v, ValidVersionRegex))
9897
{
99-
throw new ArgumentException(string.Format(InvalidVersionValueFormat, v), nameof(connectionString));
98+
throw new ArgumentException(string.Format(InvalidVersionValueFormat, v), nameof(version));
10099
}
101100
version = v;
102101
}
103102

104103
// parse and validate port.
105104
if (dict.TryGetValue(PortProperty, out var s))
106105
{
107-
if (int.TryParse(s, out var p) && p > 0 && p <= 0xFFFF)
106+
if (int.TryParse(s, out var port) && port > 0 && port <= 0xFFFF)
108107
{
109-
builder.Port = p;
108+
builder.Port = port;
110109
}
111110
else
112111
{
113-
throw new ArgumentException(InvalidPortValue, nameof(connectionString));
112+
throw new ArgumentException(InvalidPortValue, nameof(port));
114113
}
115114
}
116115

@@ -121,14 +120,18 @@ internal static ParsedConnectionString Parse(string connectionString)
121120
{
122121
if (!TryGetEndpointUri(clientEndpoint, out clientEndpointUri))
123122
{
124-
throw new ArgumentException($"{ClientEndpointProperty} property in connection string is not a valid URI: {clientEndpoint}.");
123+
throw new ArgumentException(InvalidClientEndpointProperty, nameof(clientEndpoint));
125124
}
126125
}
127126

127+
// try building accesskey.
128128
dict.TryGetValue(AuthTypeProperty, out var type);
129129
var accessKey = type?.ToLower() switch
130130
{
131-
"aad" => BuildAadAccessKey(builder.Uri, dict),
131+
TypeAzureAD => BuildAadAccessKey(builder.Uri, dict),
132+
TypeAzure => BuildAzureAccessKey(builder.Uri, dict),
133+
TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, dict),
134+
TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, dict),
132135
_ => BuildAccessKey(builder.Uri, dict),
133136
};
134137

@@ -194,5 +197,70 @@ private static AccessKey BuildAccessKey(Uri uri, Dictionary<string, string> dict
194197
}
195198
throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty);
196199
}
200+
201+
private static AccessKey BuildAzureAccessKey(Uri uri, Dictionary<string, string> dict)
202+
{
203+
return new AadAccessKey(uri, new DefaultAzureCredential());
204+
}
205+
206+
private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary<string, string> dict)
207+
{
208+
if (!dict.TryGetValue(ClientIdProperty, out var clientId))
209+
{
210+
throw new ArgumentException(MissingClientIdProperty, ClientIdProperty);
211+
}
212+
213+
if (!dict.TryGetValue(TenantIdProperty, out var tenantId))
214+
{
215+
throw new ArgumentException(MissingTenantIdProperty, TenantIdProperty);
216+
}
217+
218+
if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
219+
{
220+
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret));
221+
}
222+
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
223+
{
224+
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath));
225+
}
226+
throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty);
227+
}
228+
229+
private static AccessKey BuildAzureMsiAccessKey(Uri uri, Dictionary<string, string> dict)
230+
{
231+
if (dict.TryGetValue(ClientIdProperty, out var clientId))
232+
{
233+
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId));
234+
}
235+
return new AadAccessKey(uri, new ManagedIdentityCredential());
236+
}
237+
238+
private static Dictionary<string, string> ToDictionary(string connectionString)
239+
{
240+
var properties = connectionString.Split(PropertySeparator, StringSplitOptions.RemoveEmptyEntries);
241+
if (properties.Length < 2)
242+
{
243+
throw new ArgumentException(MissingEndpointProperty, nameof(connectionString));
244+
}
245+
246+
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
247+
foreach (var property in properties)
248+
{
249+
var kvp = property.Split(KeyValueSeparator, 2);
250+
if (kvp.Length != 2)
251+
{
252+
continue;
253+
}
254+
255+
var key = kvp[0].Trim();
256+
if (dict.ContainsKey(key))
257+
{
258+
throw new ArgumentException($"Duplicate properties found in connection string: {key}.");
259+
}
260+
261+
dict.Add(key, kvp[1].Trim());
262+
}
263+
return dict;
264+
}
197265
}
198266
}

test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public static IEnumerable<object[]> ServerEndpointTestData
3838

3939
[Theory]
4040
[InlineData("endpoint=https://aaa;AuthType=aad;clientId=123;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")]
41+
[InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=123;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")]
4142
public void InvalidAzureApplication(string connectionString)
4243
{
4344
var exception = Assert.Throws<ArgumentException>(() => ConnectionStringParser.Parse(connectionString));
@@ -50,7 +51,7 @@ public void InvalidAzureApplication(string connectionString)
5051
public void InvalidClientEndpoint(string connectionString)
5152
{
5253
var exception = Assert.Throws<ArgumentException>(() => ConnectionStringParser.Parse(connectionString));
53-
Assert.Contains("ClientEndpoint property in connection string is not a valid URI", exception.Message);
54+
Assert.Contains("Invalid value for clientEndpoint property, it must be a valid URI. (Parameter 'clientEndpoint')", exception.Message);
5455
}
5556

5657
[Theory]
@@ -70,7 +71,7 @@ public void InvalidConnectionStrings(string connectionString)
7071
public void InvalidEndpoint(string connectionString)
7172
{
7273
var exception = Assert.Throws<ArgumentException>(() => ConnectionStringParser.Parse(connectionString));
73-
Assert.Contains("Endpoint property in connection string is not a valid URI", exception.Message);
74+
Assert.Contains("Invalid value for endpoint property, it must be a valid URI. (Parameter 'endpoint')", exception.Message);
7475
}
7576

7677
[Theory]
@@ -80,7 +81,7 @@ public void InvalidEndpoint(string connectionString)
8081
public void InvalidPort(string connectionString)
8182
{
8283
var exception = Assert.Throws<ArgumentException>(() => ConnectionStringParser.Parse(connectionString));
83-
Assert.Contains("Invalid value for port property.", exception.Message);
84+
Assert.Contains("Invalid value for port property, it must be an positive integer between (0, 65536) (Parameter 'port')", exception.Message);
8485
}
8586

8687
[Theory]
@@ -95,6 +96,7 @@ public void InvalidVersion(string connectionString, string version)
9596

9697
[Theory]
9798
[InlineData("endpoint=https://aaa;AuthType=aad;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")]
99+
[InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")]
98100
public void TestAzureApplication(string connectionString)
99101
{
100102
var r = ConnectionStringParser.Parse(connectionString);
@@ -138,14 +140,26 @@ public void TestVersion(string connectionString, string expectedVersion)
138140
}
139141

140142
[Theory]
141-
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;")]
143+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientId=xxxx;")] // should ignore the clientId
144+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;tenantId=xxxx;")] // should ignore the tenantId
145+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientSecret=xxxx;")] // should ignore the clientSecret
146+
internal void TestDefaultAzureCredential(string expectedEndpoint, string connectionString)
147+
{
148+
var r = ConnectionStringParser.Parse(connectionString);
142149

143-
// simply ignore the clientSecret
144-
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientSecret=xxxx;")]
150+
Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/'));
151+
var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
152+
Assert.IsType<DefaultAzureCredential>(aadAccessKey.TokenCredential);
153+
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
154+
}
145155

146-
// simply ignore the tenantId
147-
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;tenantId=xxxx;")]
156+
[Theory]
157+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;")]
148158
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientId=123;")]
159+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;tenantId=xxxx;")] // should ignore the tenantId
160+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientSecret=xxxx;")] // should ignore the clientSecret
161+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;")]
162+
[InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;clientId=123;")]
149163
internal void TestManagedIdentity(string expectedEndpoint, string connectionString)
150164
{
151165
var r = ConnectionStringParser.Parse(connectionString);

0 commit comments

Comments
 (0)