Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,27 @@ namespace Microsoft.IdentityModel.Protocols
/// </summary>
/// <typeparam name="T">The type of <see cref="IDocumentRetriever"/>.</typeparam>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1001:TypesThatOwnDisposableFieldsShouldBeDisposable")]
public class ConfigurationManager<T> : BaseConfigurationManager, IConfigurationManager<T> where T : class
public partial class ConfigurationManager<T> : BaseConfigurationManager, IConfigurationManager<T> where T : class
{
// To prevent tearing, this needs to be only updated through AtomicUpdateSyncAfter.
// Reads should be done through the property SyncAfter.
#pragma warning disable IDE0044 // Add readonly modifier
#pragma warning disable CS0649 // Unused, it gets used in tests.
internal Action _onBackgroundTaskFinish;
#pragma warning restore CS0649 // Unused
#pragma warning restore IDE0044 // Add readonly modifier

private DateTime _syncAfter = DateTime.MinValue;
private DateTime SyncAfter => _syncAfter;
private DateTime SyncAfter
{
get => _syncAfter;
set => AtomicUpdateDateTime(ref _syncAfter, ref value);
}

// See comment above, this should only be updated through AtomicUpdateLastRequestRefresh,
// read through LastRequestRefresh.
private DateTime _lastRequestRefresh = DateTime.MinValue;
private DateTime LastRequestRefresh => _lastRequestRefresh;
private DateTime LastRequestRefresh
{
get => _lastRequestRefresh;
set => AtomicUpdateDateTime(ref _lastRequestRefresh, ref value);
}

private bool _isFirstRefreshRequest = true;
private readonly SemaphoreSlim _configurationNullLock = new SemaphoreSlim(1);
Expand All @@ -53,8 +63,32 @@ public class ConfigurationManager<T> : BaseConfigurationManager, IConfigurationM
// requesting a refresh, so it should be done immediately so the next
// call to GetConfiguration will return new configuration if the minimum
// refresh interval has passed.
bool _refreshRequested;
private bool _refreshRequested;

// Wait handle used to signal a background task to update the configuration.
// Handle starts unset, and AutoResetEvent.Set() sets it, this indicates that
// the background refresh task should immediately run.
private readonly AutoResetEvent _updateMetadataEvent = new(false);

// Background task that updates the configuration. Signaled with _updateMetadataEvent.
// Task should be started with EnsureBackgroundRefreshTaskIsRunning.
private Task _updateMetadataTask;

private readonly CancellationTokenSource _backgroundRefreshTaskCancellationTokenSource;

/// <summary>
/// Requests that background tasks be shutdown.
/// This only applies if 'Switch.Microsoft.IdentityModel.UpdateConfigAsBlocking' is not set or set to false.
/// Note that this does not influence <see cref="GetConfigurationAsync(CancellationToken)"/>.
/// If the background task stops, the next time the task would be signaled, the task will be
/// restarted unless <see cref="ShutdownBackgroundTask"/> is called.
/// If using a background task, the <see cref="Configuration"/> cannot
/// be used after calling this method.
/// </summary>
public void ShutdownBackgroundTask()
{
_backgroundRefreshTaskCancellationTokenSource.Cancel();
}

/// <summary>
/// Instantiates a new <see cref="ConfigurationManager{T}"/> that manages automatic and controls refreshing on configuration data.
Expand Down Expand Up @@ -117,6 +151,10 @@ public ConfigurationManager(string metadataAddress, IConfigurationRetriever<T> c
MetadataAddress = metadataAddress;
_docRetriever = docRetriever;
_configRetriever = configRetriever;
_backgroundRefreshTaskCancellationTokenSource = new CancellationTokenSource();

if (!AppContextSwitches.UpdateConfigAsBlocking)
EnsureBackgroundRefreshTaskIsRunning();
}

/// <summary>
Expand Down Expand Up @@ -153,8 +191,14 @@ public ConfigurationManager(string metadataAddress, IConfigurationRetriever<T> c
/// <summary>
/// Obtains an updated version of Configuration.
/// </summary>
/// <returns>Configuration of type T.</returns>
/// <remarks>If the time since the last call is less than <see cref="BaseConfigurationManager.AutomaticRefreshInterval"/> then <see cref="IConfigurationRetriever{T}.GetConfigurationAsync"/> is not called and the current Configuration is returned.</remarks>
/// <returns>Configuration of type <typeparamref name="T"/>.</returns>
/// <remarks>
/// If the time since the last call is less than <see cref="BaseConfigurationManager.AutomaticRefreshInterval"/>
/// then <see cref="IConfigurationRetriever{T}.GetConfigurationAsync"/> is not called and the current Configuration is returned.
/// If the configuration is not able to be updated, but a previous configuration was previously retrieved, the previous configuration is returned.
/// </remarks>
/// <exception cref="InvalidOperationException" >Throw if the configuration is unable to be retrieved.</exception>
/// <exception cref="InvalidConfigurationException" >Throw if the configuration fails to be validated by the <see cref="IConfigurationValidator{T}"/>.</exception>
public async Task<T> GetConfigurationAsync()
{
return await GetConfigurationAsync(CancellationToken.None).ConfigureAwait(false);
Expand All @@ -164,13 +208,27 @@ public async Task<T> GetConfigurationAsync()
/// Obtains an updated version of Configuration.
/// </summary>
/// <param name="cancel">CancellationToken</param>
/// <returns>Configuration of type T.</returns>
/// <remarks>If the time since the last call is less than <see cref="BaseConfigurationManager.AutomaticRefreshInterval"/> then <see cref="IConfigurationRetriever{T}.GetConfigurationAsync"/> is not called and the current Configuration is returned.</remarks>
/// <returns>Configuration of type <typeparamref name="T"/>.</returns>
/// <remarks>
/// If the time since the last call is less than <see cref="BaseConfigurationManager.AutomaticRefreshInterval"/>
/// then <see cref="IConfigurationRetriever{T}.GetConfigurationAsync"/> is not called and the current Configuration is returned.
/// If the configuration is not able to be updated, but a previous configuration was previously retrieved, the previous configuration is returned.
/// </remarks>
/// <exception cref="InvalidOperationException" >Throw if the configuration is unable to be retrieved.</exception>
/// <exception cref="InvalidConfigurationException" >Throw if the configuration fails to be validated by the <see cref="IConfigurationValidator{T}"/>.</exception>
public virtual async Task<T> GetConfigurationAsync(CancellationToken cancel)
{
if (_currentConfiguration != null && SyncAfter > _timeProvider.GetUtcNow())
return _currentConfiguration;

if (AppContextSwitches.UpdateConfigAsBlocking)
return await GetConfigurationWithBlockingAsync(cancel).ConfigureAwait(false);
else
return await GetConfigurationWithBackgroundTaskUpdatesAsync(cancel).ConfigureAwait(false);
}

private async Task<T> GetConfigurationWithBackgroundTaskUpdatesAsync(CancellationToken cancel)
{
Exception fetchMetadataFailure = null;

// LOGIC
Expand Down Expand Up @@ -246,38 +304,10 @@ public virtual async Task<T> GetConfigurationAsync(CancellationToken cancel)
{
if (Interlocked.CompareExchange(ref _configurationRetrieverState, ConfigurationRetrieverRunning, ConfigurationRetrieverIdle) == ConfigurationRetrieverIdle)
{
if (_refreshRequested)
if (SyncAfter <= _timeProvider.GetUtcNow())
{
_refreshRequested = false;

try
{
// Log as manual because RequestRefresh was called
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.Manual);
}
#pragma warning disable CA1031 // Do not catch general exception types
catch
{ }
#pragma warning restore CA1031 // Do not catch general exception types

UpdateCurrentConfiguration();
}
else if (SyncAfter <= _timeProvider.GetUtcNow())
{
try
{
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.Automatic);
}
#pragma warning disable CA1031 // Do not catch general exception types
catch
{ }
#pragma warning restore CA1031 // Do not catch general exception types

_ = Task.Run(UpdateCurrentConfiguration, CancellationToken.None);
EnsureBackgroundRefreshTaskIsRunning();
_updateMetadataEvent.Set();
}
else
{
Expand All @@ -300,13 +330,61 @@ public virtual async Task<T> GetConfigurationAsync(CancellationToken cancel)
fetchMetadataFailure));
}

private void EnsureBackgroundRefreshTaskIsRunning()
{
if (_backgroundRefreshTaskCancellationTokenSource.IsCancellationRequested)
return;

if (_updateMetadataTask == null || _updateMetadataTask.Status != TaskStatus.Running)
_updateMetadataTask = Task.Run(UpdateCurrentConfigurationUsingSignals);
}

private void TelemetryForUpdate()
{
var updateMode = _refreshRequested ? TelemetryConstants.Protocols.Manual : TelemetryConstants.Protocols.Automatic;

if (_refreshRequested)
_refreshRequested = false;

try
{
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
updateMode);
}
#pragma warning disable CA1031 // Do not catch general exception types
catch
{ }
#pragma warning restore CA1031 // Do not catch general exception types
}

private void UpdateCurrentConfigurationUsingSignals()
{
try
{
while (!_backgroundRefreshTaskCancellationTokenSource.IsCancellationRequested)
{
if (_updateMetadataEvent.WaitOne(500))
{
UpdateCurrentConfiguration();
_onBackgroundTaskFinish?.Invoke();
}
}
}
catch (Exception ex)
{
TelemetryClient.LogBackgroundConfigurationRefreshFailure(MetadataAddress, ex);
}
}

/// <summary>
/// This should be called when the configuration needs to be updated either from RequestRefresh or AutomaticRefresh
/// The Caller should first check the state checking state using:
/// if (Interlocked.CompareExchange(ref _configurationRetrieverState, ConfigurationRetrieverRunning, ConfigurationRetrieverIdle) == ConfigurationRetrieverIdle).
/// </summary>
private void UpdateCurrentConfiguration()
{
TelemetryForUpdate();
#pragma warning disable CA1031 // Do not catch general exception types
long startTimestamp = _timeProvider.GetTimestamp();

Expand Down Expand Up @@ -368,25 +446,17 @@ private void UpdateConfiguration(T configuration)
_currentConfiguration = configuration;
var newSyncTime = DateTimeUtil.Add(_timeProvider.GetUtcNow().UtcDateTime, AutomaticRefreshInterval +
TimeSpan.FromSeconds(new Random().Next((int)AutomaticRefreshInterval.TotalSeconds / 20)));
AtomicUpdateSyncAfter(newSyncTime);
SyncAfter = newSyncTime;
}

private void AtomicUpdateSyncAfter(DateTime syncAfter)
private static void AtomicUpdateDateTime(ref DateTime field, ref DateTime value)
{
// DateTime's backing data is safe to treat as a long if the Kind is not local.
// _syncAfter will always be updated to a UTC time.
// time will always be updated to a UTC time.
// See the implementation of ToBinary on DateTime.
Interlocked.Exchange(
ref Unsafe.As<DateTime, long>(ref _syncAfter),
Unsafe.As<DateTime, long>(ref syncAfter));
}

private void AtomicUpdateLastRequestRefresh(DateTime lastRequestRefresh)
{
// See the comment in AtomicUpdateSyncAfter.
Interlocked.Exchange(
ref Unsafe.As<DateTime, long>(ref _lastRequestRefresh),
Unsafe.As<DateTime, long>(ref lastRequestRefresh));
ref Unsafe.As<DateTime, long>(ref field),
Unsafe.As<DateTime, long>(ref value));
}

/// <summary>
Expand All @@ -408,15 +478,32 @@ public override async Task<BaseConfiguration> GetBaseConfigurationAsync(Cancella
/// <para>If <see cref="BaseConfigurationManager.RefreshInterval"/> == <see cref="TimeSpan.MaxValue"/> then this method does nothing.</para>
/// </summary>
public override void RequestRefresh()
{
if (AppContextSwitches.UpdateConfigAsBlocking)
{
RequestRefreshBlocking();
}
else
{
RequestRefreshBackgroundThread();
}
}

private void RequestRefreshBackgroundThread()
{
DateTime now = _timeProvider.GetUtcNow().UtcDateTime;

if (now >= DateTimeUtil.Add(LastRequestRefresh, RefreshInterval) || _isFirstRefreshRequest)
{
_isFirstRefreshRequest = false;
AtomicUpdateSyncAfter(now);
AtomicUpdateLastRequestRefresh(now);
_refreshRequested = true;

if (Interlocked.CompareExchange(ref _configurationRetrieverState, ConfigurationRetrieverRunning, ConfigurationRetrieverIdle) == ConfigurationRetrieverIdle)
{
_refreshRequested = true;
LastRequestRefresh = now;
EnsureBackgroundRefreshTaskIsRunning();
_updateMetadataEvent.Set();
}
}
}

Expand Down
Loading
Loading