Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions Microsoft.Azure.Cosmos/src/Authorization/CosmosScopeProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Authorization
{
using System;
using global::Azure.Core;

internal sealed class CosmosScopeProvider : IScopeProvider
{
private const string AadInvalidScopeErrorMessage = "AADSTS500011";
private const string AadDefaultScope = "https://cosmos.azure.com/.default";
private const string ScopeFormat = "https://{0}/.default";

private readonly string accountScope;
private readonly string overrideScope;
private string currentScope;

public CosmosScopeProvider(Uri accountEndpoint)
{
this.overrideScope = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);
this.accountScope = string.Format(ScopeFormat, accountEndpoint.Host);
this.currentScope = this.overrideScope ?? this.accountScope;
}

public TokenRequestContext GetTokenRequestContext()
{
return new TokenRequestContext(new[] { this.currentScope });
}

public bool TryFallback(Exception exception)
{
// If override scope is set, never fallback
if (!string.IsNullOrEmpty(this.overrideScope))
{
return false;
}

// If already using fallback scope, do not fallback again
if (this.currentScope == CosmosScopeProvider.AadDefaultScope)
{
return false;
}

#pragma warning disable CDX1003 // DontUseExceptionToString
if (exception.ToString().Contains(CosmosScopeProvider.AadInvalidScopeErrorMessage) == true)
{
this.currentScope = CosmosScopeProvider.AadDefaultScope;
return true;
}
#pragma warning restore CDX1003 // DontUseExceptionToString

return false;
}
}
}
14 changes: 14 additions & 0 deletions Microsoft.Azure.Cosmos/src/Authorization/IScopeProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Authorization
{
using System;
using global::Azure.Core;

internal interface IScopeProvider
{
TokenRequestContext GetTokenRequestContext();
bool TryFallback(Exception ex);
}
}
93 changes: 43 additions & 50 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace Microsoft.Azure.Cosmos
using System.Threading;
using System.Threading.Tasks;
using global::Azure;
using global::Azure.Core;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Authorization;
using Microsoft.Azure.Cosmos.Core.Trace;
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
using Microsoft.Azure.Cosmos.Tracing;
Expand All @@ -36,9 +37,7 @@ internal sealed class TokenCredentialCache : IDisposable
// If the background refresh fails with less than a minute then just allow the request to hit the exception.
public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1);

private const string ScopeFormat = "https://{0}/.default";

private readonly TokenRequestContext tokenRequestContext;
private readonly IScopeProvider scopeProvider;
private readonly TokenCredential tokenCredential;
private readonly CancellationTokenSource cancellationTokenSource;
private readonly CancellationToken cancellationToken;
Expand All @@ -51,7 +50,7 @@ internal sealed class TokenCredentialCache : IDisposable
private Task<AccessToken>? currentRefreshOperation = null;
private AccessToken? cachedAccessToken = null;
private bool isBackgroundTaskRunning = false;
private bool isDisposed = false;
private bool isDisposed = false;

internal TokenCredentialCache(
TokenCredential tokenCredential,
Expand All @@ -65,14 +64,7 @@ internal TokenCredentialCache(
throw new ArgumentNullException(nameof(accountEndpoint));
}

string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);

this.tokenRequestContext = new TokenRequestContext(new string[]
{
!string.IsNullOrEmpty(scopeOverride)
? scopeOverride
: string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host)
});
this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint);

if (backgroundTokenCredentialRefreshInterval.HasValue)
{
Expand Down Expand Up @@ -129,7 +121,7 @@ public void Dispose()
}

this.cancellationTokenSource.Cancel();
this.cancellationTokenSource.Dispose();
this.cancellationTokenSource.Dispose();
this.isDisposed = true;
}

Expand Down Expand Up @@ -171,11 +163,13 @@ private async Task<AccessToken> GetNewTokenAsync(

private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
ITrace trace)
{
{
Exception? lastException = null;
const int totalRetryCount = 2;
TokenRequestContext tokenRequestContext = default;

try
{
Exception? lastException = null;
const int totalRetryCount = 2;
for (int retry = 0; retry < totalRetryCount; retry++)
{
if (this.cancellationToken.IsCancellationRequested)
Expand All @@ -190,11 +184,13 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
{
try
{
{
tokenRequestContext = this.scopeProvider.GetTokenRequestContext();

this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
requestContext: tokenRequestContext,
cancellationToken: this.cancellationToken);

if (!this.cachedAccessToken.HasValue)
Expand All @@ -219,32 +215,15 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(

return this.cachedAccessToken.Value;
}
catch (RequestFailedException requestFailedException)
{
lastException = requestFailedException;
getTokenTrace.AddDatum(
$"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedException.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
}
catch (OperationCanceledException operationCancelled)
{
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
operationCancelled.Message);
DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
Expand All @@ -255,15 +234,29 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
innerException: lastException,
trace: getTokenTrace);
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (exception is RequestFailedException requestFailedException &&
(requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden))
{
this.cachedAccessToken = default;
throw;
}
bool didFallback = this.scopeProvider.TryFallback(exception);

if (didFallback)
{
DefaultTrace.TraceInformation($"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}");
}
}
}
}
Expand Down
Loading
Loading