Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,49 @@
using System;
using Azure.Core;

namespace Microsoft.Azure.SignalR.Common
namespace Microsoft.Azure.SignalR.Common;

/// <summary>
/// The exception throws when AccessKey is not authorized.
/// </summary>
public class AzureSignalRAccessTokenNotAuthorizedException : AzureSignalRException
{
private const string Template = "{0} is not available for signing client tokens, {1}";

/// <summary>
/// The exception throws when AccessKey is not authorized.
/// Obsolete, <see cref="AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, Exception)"/>.
/// </summary>
public class AzureSignalRAccessTokenNotAuthorizedException : AzureSignalRException
/// <param name="message"></param>
[Obsolete]
public AzureSignalRAccessTokenNotAuthorizedException(string message) : base(message)
{
private const string Postfix = " appears to lack the permission to generate access tokens, see innerException for more details.";
}

/// <summary>
/// Obsolete, <see cref="AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, Exception)"/>.
/// </summary>
/// <param name="message"></param>
[Obsolete]
public AzureSignalRAccessTokenNotAuthorizedException(string message) : base(message)
{
}
/// <summary>
/// Obsolete, <see cref="AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, Exception)"/>.
/// </summary>
/// <param name="credentialName"></param>
/// <param name="inner"></param>
[Obsolete]
public AzureSignalRAccessTokenNotAuthorizedException(string credentialName, Exception inner) :
base(string.Format(Template, credentialName, GetInnerReason(inner)), inner)
{
}

/// <summary>
/// Obsolete, <see cref="AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, Exception)"/>.
/// </summary>
/// <param name="credentialName"></param>
/// <param name="innerException"></param>
[Obsolete]
public AzureSignalRAccessTokenNotAuthorizedException(string credentialName, Exception innerException) :
base(credentialName + Postfix, innerException)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="AzureSignalRAccessTokenNotAuthorizedException"/> class.
/// </summary>
internal AzureSignalRAccessTokenNotAuthorizedException(TokenCredential credential, Exception inner) :
base(string.Format(Template, credential.GetType().Name, GetInnerReason(inner)), inner)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="AzureSignalRAccessTokenNotAuthorizedException"/> class.
/// </summary>
internal AzureSignalRAccessTokenNotAuthorizedException(TokenCredential credential, Exception innerException) :
base(credential.GetType().Name + Postfix, innerException)
private static string GetInnerReason(Exception exception)
{
return exception switch
{
}
AzureSignalRUnauthorizedException => AzureSignalRUnauthorizedException.ErrorMessageMicrosoftEntra,
_ => exception.Message,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

using System;
using System.Net;
using System.Runtime.Serialization;

namespace Microsoft.Azure.SignalR.Common;

#nullable enable

[Serializable]
public class AzureSignalRRuntimeException : AzureSignalRException
{
Expand All @@ -15,20 +18,20 @@ public class AzureSignalRRuntimeException : AzureSignalRException

internal HttpStatusCode StatusCode { get; private set; } = HttpStatusCode.InternalServerError;

public AzureSignalRRuntimeException(string requestUri, Exception inner) : base(GetExceptionMessage(HttpStatusCode.InternalServerError, requestUri, string.Empty), inner)
public AzureSignalRRuntimeException(string? requestUri, Exception inner) : base(GetExceptionMessage(HttpStatusCode.InternalServerError, requestUri, string.Empty), inner)
{
}

internal AzureSignalRRuntimeException(string requestUri, Exception inner, HttpStatusCode statusCode, string content) : base(GetExceptionMessage(statusCode, requestUri, content), inner)
internal AzureSignalRRuntimeException(string? requestUri, Exception inner, HttpStatusCode statusCode, string? content) : base(GetExceptionMessage(statusCode, requestUri, content), inner)
{
StatusCode = statusCode;
}

private static string GetExceptionMessage(HttpStatusCode statusCode, string requestUri, string content)
private static string GetExceptionMessage(HttpStatusCode statusCode, string? requestUri, string? content)
{
var reason = statusCode switch
{
HttpStatusCode.Forbidden => GetForbiddenReason(content),
HttpStatusCode.Forbidden => GetForbiddenReason(content ?? string.Empty),
_ => ErrorMessage,
};
return string.IsNullOrEmpty(requestUri) ? reason : $"{reason} Request Uri: {requestUri}";
Expand All @@ -38,4 +41,11 @@ private static string GetForbiddenReason(string content)
{
return content.Contains("nginx") ? NetworkErrorMessage : content;
}

#if NET8_0_OR_GREATER
[Obsolete]
#endif
protected AzureSignalRRuntimeException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,32 @@
using System;
using System.Runtime.Serialization;

namespace Microsoft.Azure.SignalR.Common
namespace Microsoft.Azure.SignalR.Common;

[Serializable]
public class AzureSignalRUnauthorizedException : AzureSignalRException
{
[Serializable]
public class AzureSignalRUnauthorizedException : AzureSignalRException
{
internal const string ErrorMessage = "Authorization failed. If you were using AccessKey, please check connection string and see if the AccessKey is correct. If you were using Azure Active Directory, please note that the role assignments will take up to 30 minutes to take effect if it was added recently.";
/// <summary>
/// TODO: Try parse token issuer to identify the possible cause of 401
/// </summary>
internal const string ErrorMessage = $"401 Unauthorized. If you were using AccessKey, {ErrorMessageLocalAuth}. If you were using Microsoft Entra authentication, {ErrorMessageMicrosoftEntra}";

internal const string ErrorMessageLocalAuth = "please check the connection string and see if the AccessKey is correct.";

public AzureSignalRUnauthorizedException(string requestUri, Exception innerException) : base(string.IsNullOrEmpty(requestUri) ? ErrorMessage : $"{ErrorMessage} Request Uri: {requestUri}", innerException)
{
}
internal const string ErrorMessageMicrosoftEntra = "please check if you set the AzureAuthorityHosts properly in your TokenCredentialOptions, see \"https://learn.microsoft.com/en-us/dotnet/api/azure.identity.azureauthorityhosts\".";

internal AzureSignalRUnauthorizedException(Exception innerException) : base(ErrorMessage, innerException)
{
}
public AzureSignalRUnauthorizedException(string requestUri, Exception innerException) : base(string.IsNullOrEmpty(requestUri) ? ErrorMessage : $"{ErrorMessage} Request Uri: {requestUri}", innerException)
{
}

internal AzureSignalRUnauthorizedException(Exception innerException) : base(ErrorMessage, innerException)
{
}

#if NET8_0_OR_GREATER
[Obsolete]
[Obsolete]
#endif
protected AzureSignalRUnauthorizedException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
protected AzureSignalRUnauthorizedException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
Expand Down Expand Up @@ -129,14 +131,15 @@ public async Task TestInitializeFailed()
Assert.IsType<InvalidOperationException>(exception.InnerException);
}

[Fact]
public async Task TestUpdateAccessKeyAfterInitializeFailed()
[Theory]
[ClassData(typeof(NotAuthorizedTestData))]
public async Task TestUpdateAccessKeyFailedThrowsNotAuthorizedException(AzureSignalRException e, string expectedErrorMessage)
{
var mockCredential = new Mock<TokenCredential>();
mockCredential.Setup(credential => credential.GetTokenAsync(
It.IsAny<TokenRequestContext>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception"));
.ThrowsAsync(e);
var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object);

var audience = "http://localhost/chat";
Expand All @@ -149,10 +152,11 @@ public async Task TestUpdateAccessKeyAfterInitializeFailed()
var exception = await Assert.ThrowsAsync<AzureSignalRAccessTokenNotAuthorizedException>(
async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm)
);
Assert.IsType<InvalidOperationException>(exception.InnerException);
Assert.Same(exception.InnerException, e);
Assert.Same(exception.InnerException, key.LastException);
Assert.StartsWith($"TokenCredentialProxy is not available for signing client tokens", exception.Message);
Assert.Contains(expectedErrorMessage, exception.Message);

Assert.NotNull(key.LastException);
var (kid, accessKey) = ("foo", DefaultSigningKey);
key.UpdateAccessKey(kid, accessKey);
Assert.Null(key.LastException);
Expand Down Expand Up @@ -233,6 +237,21 @@ public async Task ThrowForbiddenExceptionTest(string expected, string responseCo
Assert.EndsWith($"Request Uri: {endpoint}", ex.Message);
}

public class NotAuthorizedTestData : IEnumerable<object[]>
{
private const string DefaultUri = "https://microsoft.com";

public IEnumerator<object[]> GetEnumerator()
{
yield return [new AzureSignalRUnauthorizedException(new Exception()), AzureSignalRUnauthorizedException.ErrorMessageMicrosoftEntra];
yield return [new AzureSignalRRuntimeException(DefaultUri, new Exception(), HttpStatusCode.Forbidden, "nginx"), AzureSignalRRuntimeException.NetworkErrorMessage];
yield return [new AzureSignalRRuntimeException(DefaultUri, new Exception(), HttpStatusCode.Forbidden, "http-content"), "http-content"];
yield return [new AzureSignalRRuntimeException(DefaultUri, new Exception("inner-exception-message"), HttpStatusCode.NotFound, "http"), AzureSignalRRuntimeException.ErrorMessage];
}

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

private sealed class TestHttpClientFactory(HttpResponseMessage message) : IHttpClientFactory
{
public HttpClient CreateClient(string name)
Expand Down
Loading