diff --git a/generator/.DevConfigs/3f64412c-8c12-48ef-a8d6-11747964c162.json b/generator/.DevConfigs/3f64412c-8c12-48ef-a8d6-11747964c162.json new file mode 100644 index 000000000000..d97ecc0ecc62 --- /dev/null +++ b/generator/.DevConfigs/3f64412c-8c12-48ef-a8d6-11747964c162.json @@ -0,0 +1,9 @@ +{ + "core": { + "updateMinimum": true, + "type": "patch", + "changeLogMessages": [ + "Fix SynchronizationLockException exception when using IMDS for credentials (https://github.com/aws/aws-sdk-net/issues/4199)" + ] + } +} \ No newline at end of file diff --git a/sdk/src/Core/Amazon.Runtime/Credentials/DefaultInstanceProfileAWSCredentials.cs b/sdk/src/Core/Amazon.Runtime/Credentials/DefaultInstanceProfileAWSCredentials.cs index 1490d56f224c..89a1b2fd2c20 100644 --- a/sdk/src/Core/Amazon.Runtime/Credentials/DefaultInstanceProfileAWSCredentials.cs +++ b/sdk/src/Core/Amazon.Runtime/Credentials/DefaultInstanceProfileAWSCredentials.cs @@ -1,8 +1,8 @@ -/* +/* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://aws.amazon.com/apache2.0 @@ -16,7 +16,6 @@ using Amazon.Runtime.Internal.Util; using Amazon.Util; using System; -using System.Globalization; using System.Threading; using System.Threading.Tasks; @@ -29,22 +28,29 @@ namespace Amazon.Runtime internal class DefaultInstanceProfileAWSCredentials : AWSCredentials, IDisposable { private static readonly object _instanceLock = new object(); - private readonly ReaderWriterLockSlim _credentialsLock = new ReaderWriterLockSlim(); // Lock to control getting credentials across multiple threads. - private readonly Timer _credentialsRetrieverTimer; - private RefreshingAWSCredentials.CredentialsRefreshState _lastRetrievedCredentials; - private Logger _logger; + // Semaphore used to ensure only one thread/task access IMDS at a time. + // In the GetCredentials call once the lock is obtained, a second check is made + // to see if the credentials have already been refreshed so only one access to IMDS + // is required across multiple threads/tasks when credentials are expired. + private readonly SemaphoreSlim _credentialsSemaphore = new SemaphoreSlim(initialCount: 1, maxCount: 1); - private static readonly TimeSpan _neverTimespan = TimeSpan.FromMilliseconds(-1); - private static readonly TimeSpan _refreshRate = TimeSpan.FromMinutes(2); // EC2 refreshes credentials 2 min before expiration - private const string FailedToGetCredentialsMessage = "Failed to retrieve credentials from EC2 Instance Metadata Service."; - private static readonly TimeSpan _credentialsLockTimeout = TimeSpan.FromMinutes(1); + private readonly Timer _credentialsRetrieverTimer; + private volatile RefreshingAWSCredentials.CredentialsRefreshState _lastRetrievedCredentials; /// /// Control flag: in the event IMDS returns an expired credential, a refresh must be immediately /// retried, if it continues to fail, then retry every 5-10 minutes. /// - private static volatile bool _imdsRefreshFailed = false; + private volatile bool _previousRefreshFailed = false; + + private readonly IIMDSAccessMethods _imdsAccessMethods; + private readonly TimeSpan _refreshRate = TimeSpan.FromMinutes(2); // EC2 refreshes credentials 2 min before expiration + + private static readonly TimeSpan _neverTimespan = TimeSpan.FromMilliseconds(-1); + private const string FailedToGetCredentialsMessage = "Failed to retrieve credentials from EC2 Instance Metadata Service."; + private const string FailedToGetLockMessage = "Failed to obtain lock to refresh credentials from EC2 Instance Metadata Service."; + private static readonly TimeSpan _credentialsLockTimeout = TimeSpan.FromMinutes(1); private const string _usingExpiredCredentialsFromIMDS = "Attempting credential expiration extension due to a credential service availability issue. " + @@ -56,7 +62,7 @@ public static DefaultInstanceProfileAWSCredentials Instance { get { - CheckIsIMDSEnabled(); + IMDSAccessMethods.DefaultInstance.CheckIsIMDSEnabled(); if (_instance == null) { @@ -75,97 +81,109 @@ public static DefaultInstanceProfileAWSCredentials Instance private DefaultInstanceProfileAWSCredentials() { + _imdsAccessMethods = new IMDSAccessMethods(); + // if IMDS is turned off, no need to spin up the timer task if (!EC2InstanceMetadata.IsIMDSEnabled) { return; } - _logger = Logger.GetLogger(typeof(DefaultInstanceProfileAWSCredentials)); + _credentialsRetrieverTimer = new Timer(RenewCredentials, null, TimeSpan.Zero, _neverTimespan); // This invokes synchronous calls in seperate thread. + FeatureIdSources.Add(UserAgentFeatureId.CREDENTIALS_IMDS); + } + + internal static DefaultInstanceProfileAWSCredentials CreateTestDefaultInstanceProfileAWSCredentials( + IIMDSAccessMethods imdsAccessMethods, + TimeSpan refreshRate) + { + return new DefaultInstanceProfileAWSCredentials( + imdsAccessMethods, + refreshRate); + } + + private DefaultInstanceProfileAWSCredentials( + IIMDSAccessMethods imdsAccessMethods, + TimeSpan refreshRate) + { + _imdsAccessMethods = imdsAccessMethods; + _refreshRate = refreshRate; + _credentialsRetrieverTimer = new Timer(RenewCredentials, null, TimeSpan.Zero, _neverTimespan); // This invokes synchronous calls in seperate thread. FeatureIdSources.Add(UserAgentFeatureId.CREDENTIALS_IMDS); } #region Overrides - + /// /// Returns a copy of the most recent instance profile credentials. /// public override ImmutableCredentials GetCredentials() { - CheckIsIMDSEnabled(); - ImmutableCredentials credentials = null; + // The logger instance is not cached as member variable because the static constructor + // can fire before the adaptor plug ins have been registered causing them to not + // send the logs to the correct destination. + var logger = Logger.GetLogger(typeof(DefaultInstanceProfileAWSCredentials)); + _imdsAccessMethods.CheckIsIMDSEnabled(); + ImmutableCredentials credentials; - // Try to acquire read lock. The thread would be blocked if another thread has write lock. - if (_credentialsLock.TryEnterReadLock(_credentialsLockTimeout)) + // Copy current assignment of _lastRetrievedCredentials into a local variable in case another thread updates it. + var localLastRetrievedCredentials = _lastRetrievedCredentials; + + if (null != localLastRetrievedCredentials && !localLastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + { + credentials = localLastRetrievedCredentials.Credentials; + } + else { - try + logger.DebugFormat("Waiting on lock to refresh ECS IMDS"); + if (_credentialsSemaphore.Wait(_credentialsLockTimeout)) { - if (null != _lastRetrievedCredentials) + try { - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero) && - !_imdsRefreshFailed) - { - // this is the first failure - immediately try to renew - _imdsRefreshFailed = true; - _lastRetrievedCredentials = FetchCredentials(); - } + logger.DebugFormat("Obtained lock to refresh ECS IMDS"); - // if credentials are expired, we'll still return them, but log a message about - // them being expired. - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + // Check to see if another thread has already refreshed the credentials + localLastRetrievedCredentials = _lastRetrievedCredentials; + if (null != localLastRetrievedCredentials && !localLastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) { - _logger.InfoFormat(_usingExpiredCredentialsFromIMDS); + logger.DebugFormat("Another thread has refreshed credentials and reusing those credentials"); + credentials = localLastRetrievedCredentials.Credentials; } else { - _imdsRefreshFailed = false; + logger.DebugFormat("Fetching credentials from ECS IMDS"); + _lastRetrievedCredentials = _imdsAccessMethods.FetchCredentials(); + // if credentials are expired, we'll still return them, but log a message about + // them being expired. + if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + { + // If the previous refresh did not fail then this is the first failure and write the single + // expired warning log message for this specific incident. + if (!_previousRefreshFailed) + { + logger.InfoFormat(_usingExpiredCredentialsFromIMDS); + } + + _previousRefreshFailed = true; + } + else + { + _previousRefreshFailed = false; + } + + credentials = _lastRetrievedCredentials.Credentials; } - - return _lastRetrievedCredentials?.Credentials; - } - } - finally - { - _credentialsLock.ExitReadLock(); - } - } - - // If there's no credentials cached, hit IMDS directly. Try to acquire write lock. - if (_credentialsLock.TryEnterWriteLock(_credentialsLockTimeout)) - { - try - { - // Check for last retrieved credentials again in case other thread might have already fetched it. - if (null == _lastRetrievedCredentials) - { - _lastRetrievedCredentials = FetchCredentials(); } - - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero) && - !_imdsRefreshFailed) + finally { - // this is the first failure - immediately try to renew - _imdsRefreshFailed = true; - _lastRetrievedCredentials = FetchCredentials(); + logger.DebugFormat("Releasing lock after refreshing ECS IMDS"); + SafeReleaseSemaphore(); } - - // if credentials are expired, we'll still return them, but log a message about - // them being expired. - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) - { - _logger.InfoFormat(_usingExpiredCredentialsFromIMDS); - } - else - { - _imdsRefreshFailed = false; - } - - credentials = _lastRetrievedCredentials.Credentials; } - finally + else { - _credentialsLock.ExitWriteLock(); + throw new AmazonServiceException(FailedToGetLockMessage); } } @@ -182,79 +200,70 @@ public override ImmutableCredentials GetCredentials() /// public override async Task GetCredentialsAsync() { - CheckIsIMDSEnabled(); - ImmutableCredentials credentials = null; + // The logger instance is not cached as member variable because the static constructor + // can fire before the adaptor plug ins have been registered causing them to not + // send the logs to the correct destination. + var logger = Logger.GetLogger(typeof(DefaultInstanceProfileAWSCredentials)); + _imdsAccessMethods.CheckIsIMDSEnabled(); + ImmutableCredentials credentials; + + // Copy current assignment of _lastRetrievedCredentials into a local variable in case another thread updates it. + var localLastRetrievedCredentials = _lastRetrievedCredentials; - // Try to acquire read lock. The thread would be blocked if another thread has write lock. - if (_credentialsLock.TryEnterReadLock(_credentialsLockTimeout)) + if (null != localLastRetrievedCredentials && !localLastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) { - try + credentials = localLastRetrievedCredentials.Credentials; + } + else + { + logger.DebugFormat("Waiting on lock to refresh ECS IMDS"); + if (await _credentialsSemaphore.WaitAsync(_credentialsLockTimeout).ConfigureAwait(false)) { - if (null != _lastRetrievedCredentials) + try { - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero) && - !_imdsRefreshFailed) - { - // this is the first failure - immediately try to renew - _imdsRefreshFailed = true; - _lastRetrievedCredentials = await FetchCredentialsAsync().ConfigureAwait(false); - } + logger.DebugFormat("Obtained lock to refresh ECS IMDS"); - // if credentials are expired, we'll still return them, but log a message about - // them being expired. - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + // Check to see if another thread has already refreshed the credentials + localLastRetrievedCredentials = _lastRetrievedCredentials; + if (null != localLastRetrievedCredentials && !localLastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) { - _logger.InfoFormat(_usingExpiredCredentialsFromIMDS); + logger.DebugFormat("Another thread has refreshed credentials and reusing those credentials"); + credentials = localLastRetrievedCredentials.Credentials; } else { - _imdsRefreshFailed = false; + logger.DebugFormat("Fetching credentials from ECS IMDS"); + _lastRetrievedCredentials = await _imdsAccessMethods.FetchCredentialsAsync().ConfigureAwait(false); + // if credentials are expired, we'll still return them, but log a message about + // them being expired. + if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + { + // If the previous refresh did not fail then this is the first failure and write the single + // expired warning log message for this specific incident. + if (!_previousRefreshFailed) + { + logger.InfoFormat(_usingExpiredCredentialsFromIMDS); + } + + _previousRefreshFailed = true; + } + else + { + _previousRefreshFailed = false; + } + + credentials = _lastRetrievedCredentials.Credentials; } - - return _lastRetrievedCredentials?.Credentials; - } - } - finally - { - _credentialsLock.ExitReadLock(); - } - } - - // If there's no credentials cached, hit IMDS directly. Try to acquire write lock. - if (_credentialsLock.TryEnterWriteLock(_credentialsLockTimeout)) - { - try - { - // Check for last retrieved credentials again in case other thread might have already fetched it. - if (null == _lastRetrievedCredentials) - { - _lastRetrievedCredentials = await FetchCredentialsAsync().ConfigureAwait(false); - } - - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero) && - !_imdsRefreshFailed) - { - // this is the first failure - immediately try to renew - _imdsRefreshFailed = true; - _lastRetrievedCredentials = await FetchCredentialsAsync().ConfigureAwait(false); } - - // if credentials are expired, we'll still return them, but log a message about - // them being expired. - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + finally { - _logger.InfoFormat(_usingExpiredCredentialsFromIMDS); + logger.DebugFormat("Releasing lock after refreshing ECS IMDS"); + SafeReleaseSemaphore(); } - else - { - _imdsRefreshFailed = false; - } - - credentials = _lastRetrievedCredentials.Credentials; } - finally + else { - _credentialsLock.ExitWriteLock(); + throw new AmazonServiceException(FailedToGetLockMessage); } } @@ -269,105 +278,174 @@ public override async Task GetCredentialsAsync() #endregion #region Private members - private void RenewCredentials(object unused) + private void RenewCredentials(object _) { + var logger = Logger.GetLogger(typeof(DefaultInstanceProfileAWSCredentials)); + + // This would only be true for unit tests that want to disable the timer-based refresh for + // more predictable testing. + if (_refreshRate <= TimeSpan.Zero) + return; + // By default, the refreshRate will continue to be // _refreshRate, but if FetchCredentials() returns an expired credential, // the refresh rate will be adjusted var refreshRate = _refreshRate; - + var lockedObtained = false; try { - // if FetchCredentials() call were to fail, _lastRetrievedCredentials - // would remain unchanged and would continue to be returned in GetCredentials() - _lastRetrievedCredentials = FetchCredentials(); - - // check for a first time failure - if (!_imdsRefreshFailed && - _lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + logger.DebugFormat("[Background Timer] Waiting on lock to refresh ECS IMDS"); + if (!_isDisposed && _credentialsSemaphore.Wait(_credentialsLockTimeout)) { - // this is the first failure - immediately try to renew - _imdsRefreshFailed = true; - _lastRetrievedCredentials = FetchCredentials(); - } + lockedObtained = true; + logger.DebugFormat("[Background Timer] Obtained lock to refresh ECS IMDS"); - // first failure refresh failed OR subsequent refresh failed. - if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) - { - // relax the refresh rate to at least 5 minutes - refreshRate = TimeSpan.FromMinutes(new Random().Next(5, 11)); + // if FetchCredentials() call were to fail, _lastRetrievedCredentials + // would remain unchanged and would continue to be returned in GetCredentials() + _lastRetrievedCredentials = _imdsAccessMethods.FetchCredentials(); + + // If the previous refresh do not fail but this refresh did fail retry immediately. + if (!_previousRefreshFailed && _lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + { + logger.DebugFormat("[Background Timer] First refresh failed and trying an immediate retry fetching credentials"); + _lastRetrievedCredentials = _imdsAccessMethods.FetchCredentials(); + } + + // if credentials are expired, we'll still return them, but log a message about them being expired. + if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + { + logger.DebugFormat("[Background Timer] Credential refresh failed"); + // If the previous refresh did not fail then this is the first failure and write the single + // expired warning log message for this specific incident. + if (!_previousRefreshFailed) + { + logger.InfoFormat(_usingExpiredCredentialsFromIMDS); + } + + _previousRefreshFailed = true; + } + else + { + logger.DebugFormat("[Background Timer] Credential refresh succeeded"); + _previousRefreshFailed = false; + } + + + // first failure refresh failed OR subsequent refresh failed. + if (_lastRetrievedCredentials.IsExpiredWithin(TimeSpan.Zero)) + { + // relax the refresh rate to at least 5 minutes + refreshRate = TimeSpan.FromMinutes(new Random().Next(5, 11)); + } } else { - _imdsRefreshFailed = false; + logger.InfoFormat("[Background Timer] {0}", FailedToGetLockMessage); } + } catch (OperationCanceledException e) { - _logger.Error(e, "RenewCredentials task canceled"); + logger.Error(e, "[Background Timer] RenewCredentials task canceled"); } catch (Exception e) { // we want to suppress any exceptions from this timer task. - _logger.Error(e, FailedToGetCredentialsMessage); + logger.Error(e, FailedToGetCredentialsMessage); } finally { // re-invoke this task once after time specified by refreshRate set at beginning of this method _credentialsRetrieverTimer.Change(refreshRate, _neverTimespan); + + if (lockedObtained) + SafeReleaseSemaphore(); } } - private static RefreshingAWSCredentials.CredentialsRefreshState FetchCredentials() + // Helper to release the semaphore safely if Dispose() races with an in-flight caller/timer. + private void SafeReleaseSemaphore() { - var securityCredentials = EC2InstanceMetadata.IAMSecurityCredentials; - if (securityCredentials == null) - throw new AmazonServiceException("Unable to get IAM security credentials from EC2 Instance Metadata Service."); - - IAMSecurityCredentialMetadata metadata = GetMetadataFromSecurityCredentials(securityCredentials); + try + { + _credentialsSemaphore.Release(); + } + catch (ObjectDisposedException) + { + // Semaphore already disposed — safe to ignore. + // This can happen during the Dispose of this method and race conditions disposing the timer. + } + catch (SemaphoreFullException) + { + // Defensive: ignore if Release would overflow the semaphore. + } + } - return new RefreshingAWSCredentials.CredentialsRefreshState( - new ImmutableCredentials(metadata.AccessKeyId, metadata.SecretAccessKey, metadata.Token), - metadata.Expiration); + internal interface IIMDSAccessMethods + { + RefreshingAWSCredentials.CredentialsRefreshState FetchCredentials(); + Task FetchCredentialsAsync(); + void CheckIsIMDSEnabled(); } - private static async Task FetchCredentialsAsync() + + private class IMDSAccessMethods : IIMDSAccessMethods { - var securityCredentials = await EC2InstanceMetadata.GetIAMSecurityCredentialsAsync().ConfigureAwait(false); - if (securityCredentials == null) - throw new AmazonServiceException("Unable to get IAM security credentials from EC2 Instance Metadata Service."); + public static IMDSAccessMethods DefaultInstance = new IMDSAccessMethods(); - IAMSecurityCredentialMetadata metadata = GetMetadataFromSecurityCredentials(securityCredentials); + public RefreshingAWSCredentials.CredentialsRefreshState FetchCredentials() + { + var securityCredentials = EC2InstanceMetadata.IAMSecurityCredentials; + if (securityCredentials == null) + throw new AmazonServiceException("Unable to get IAM security credentials from EC2 Instance Metadata Service."); - return new RefreshingAWSCredentials.CredentialsRefreshState( - new ImmutableCredentials(metadata.AccessKeyId, metadata.SecretAccessKey, metadata.Token), - metadata.Expiration); - } + IAMSecurityCredentialMetadata metadata = GetMetadataFromSecurityCredentials(securityCredentials); - private static IAMSecurityCredentialMetadata GetMetadataFromSecurityCredentials(System.Collections.Generic.IDictionary securityCredentials) - { - string firstRole = null; - foreach (var role in securityCredentials.Keys) + return new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials(metadata.AccessKeyId, metadata.SecretAccessKey, metadata.Token), + metadata.Expiration); + } + + public async Task FetchCredentialsAsync() { - firstRole = role; - break; + var securityCredentials = await EC2InstanceMetadata.GetIAMSecurityCredentialsAsync().ConfigureAwait(false); + if (securityCredentials == null) + throw new AmazonServiceException("Unable to get IAM security credentials from EC2 Instance Metadata Service."); + + IAMSecurityCredentialMetadata metadata = GetMetadataFromSecurityCredentials(securityCredentials); + + return new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials(metadata.AccessKeyId, metadata.SecretAccessKey, metadata.Token), + metadata.Expiration); } - if (string.IsNullOrEmpty(firstRole)) - throw new AmazonServiceException("Unable to get EC2 instance role from EC2 Instance Metadata Service."); + private IAMSecurityCredentialMetadata GetMetadataFromSecurityCredentials(System.Collections.Generic.IDictionary securityCredentials) + { + string firstRole = null; + foreach (var role in securityCredentials.Keys) + { + firstRole = role; + break; + } - var metadata = securityCredentials[firstRole]; - if (metadata == null) - throw new AmazonServiceException("Unable to get credentials for role \"" + firstRole + "\" from EC2 Instance Metadata Service."); + if (string.IsNullOrEmpty(firstRole)) + throw new AmazonServiceException("Unable to get EC2 instance role from EC2 Instance Metadata Service."); - return metadata; - } + var metadata = securityCredentials[firstRole]; + if (metadata == null) + throw new AmazonServiceException("Unable to get credentials for role \"" + firstRole + "\" from EC2 Instance Metadata Service."); - private static void CheckIsIMDSEnabled() - { - // keep this behavior consistent with GetObjectFromResponse case. - if (!EC2InstanceMetadata.IsIMDSEnabled) throw new AmazonServiceException("Unable to retrieve credentials."); + return metadata; + } + + public void CheckIsIMDSEnabled() + { + // keep this behavior consistent with GetObjectFromResponse case. + if (!EC2InstanceMetadata.IsIMDSEnabled) throw new AmazonServiceException("Unable to retrieve credentials."); + } } + + #endregion #region IDisposable Support @@ -380,15 +458,28 @@ protected virtual void Dispose(bool disposing) { if (disposing) { + // Mark disposed early to reduce race window with the timer. + _isDisposed = true; + + // Stop and dispose timer while holding the instance lock to avoid + // concurrent RenewCredentials runs racing with disposal. lock (_instanceLock) { - _credentialsRetrieverTimer.Dispose(); - _logger = null; + try + { + _credentialsRetrieverTimer?.Change(Timeout.Infinite, Timeout.Infinite); + } + catch (ObjectDisposedException) { } + + _credentialsRetrieverTimer?.Dispose(); _instance = null; } + + // Dispose the semaphore after timer is stopped/disposed. + _credentialsSemaphore.Dispose(); } - _isDisposed = true; + // keep the flag set } } diff --git a/sdk/test/NetStandard/UnitTests/Core/Credentials/DefaultInstanceProfileAWSCredentialsTests.cs b/sdk/test/NetStandard/UnitTests/Core/Credentials/DefaultInstanceProfileAWSCredentialsTests.cs new file mode 100644 index 000000000000..bea69db58c3b --- /dev/null +++ b/sdk/test/NetStandard/UnitTests/Core/Credentials/DefaultInstanceProfileAWSCredentialsTests.cs @@ -0,0 +1,324 @@ +using Amazon.Runtime; +using System; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace UnitTests.NetStandard.Core.Credentials +{ + public sealed class DefaultInstanceProfileAWSCredentialsTests + { + [Fact] + public async Task GetCredentialsAsync_FetchesFromIMDS_WhenNoCached() + { + var expectedAccessKey = "AK123"; + var expectedSecret = "SK123"; + var expectedToken = "T123"; + + var stub = new SimpleStubImds + { + CredentialsToReturn = new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials(expectedAccessKey, expectedSecret, expectedToken), + DateTime.UtcNow.AddMinutes(30)) + }; + + // lower refresh rate to speed test + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(5)); + + var creds = await provider.GetCredentialsAsync().ConfigureAwait(false); + + Assert.Equal(expectedAccessKey, creds.AccessKey); + Assert.Equal(expectedSecret, creds.SecretKey); + Assert.Equal(expectedToken, creds.Token); + } + + [Fact] + public void GetCredentials_UsesCachedCredentials_WhenValid() + { + var initiallyReturned = new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials("AK_A", "SK_A", "T_A"), + DateTime.UtcNow.AddMinutes(30)); + + var laterReturned = new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials("AK_B", "SK_B", "T_B"), + DateTime.UtcNow.AddMinutes(30)); + + var stub = new SimpleStubImds + { + CredentialsToReturn = initiallyReturned + }; + + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(5)); + + // First call: populates cache + var first = provider.GetCredentials(); + Assert.Equal("AK_A", first.AccessKey); + + // Change what IMDS would return now; provider should keep cached credentials + stub.CredentialsToReturn = laterReturned; + var second = provider.GetCredentials(); + Assert.Equal("AK_A", second.AccessKey); + } + + [Fact] + public async Task GetCredentialsAsync_ConcurrentRequests_OnlyOneFetch() + { + var stub = new CountingDelayStubImds(delayMs: 100); + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(-1)); + + var tasks = Enumerable.Range(0, 5) + .Select(_ => provider.GetCredentialsAsync()) + .ToArray(); + + await Task.WhenAll(tasks).ConfigureAwait(false); + + // All tasks should have received the same access key + Assert.All(tasks, t => Assert.Equal("CONCURRENT_AK", t.Result.AccessKey)); + + // The stub should have been invoked exactly once + Assert.Equal(1, stub.FetchAsyncCallCount); + } + + [Fact] + public async Task GetCredentialsAsync_ConcurrentRequests_TimerRefreshed() + { + var stub = new CountingDelayStubImds(delayMs: 100); + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(1)); + + await Task.Delay(2000).ConfigureAwait(false); + + var tasks = Enumerable.Range(0, 5) + .Select(_ => provider.GetCredentialsAsync()) + .ToArray(); + + await Task.WhenAll(tasks).ConfigureAwait(false); + + // All tasks should have received the same access key + Assert.All(tasks, t => Assert.Equal("CONCURRENT_AK", t.Result.AccessKey)); + + // No async fetches should have occurred, as the timer would have refreshed already + Assert.Equal(0, stub.FetchAsyncCallCount); + + // The sync fetch should have been called at least once by the timer. Given the short + // time set for the refresh rate it is not possible to guarantee how many times it ran. + Assert.True(stub.FetchCallCount > 0); + } + + [Fact] + public async Task GetCredentialsAsync_ReturnsExpiredCredentials() + { + var expectedAccessKey = "AK123"; + var expectedSecret = "SK123"; + var expectedToken = "T123"; + + var stub = new SimpleStubImds + { + CredentialsToReturn = new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials(expectedAccessKey, expectedSecret, expectedToken), + DateTime.UtcNow.AddMinutes(-30)) + }; + + // lower refresh rate to speed test + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(-1)); + + var creds = await provider.GetCredentialsAsync().ConfigureAwait(false); + + Assert.Equal(expectedAccessKey, creds.AccessKey); + Assert.Equal(expectedSecret, creds.SecretKey); + Assert.Equal(expectedToken, creds.Token); + } + + [Fact] + public void GetCredentials_ReturnsExpiredCredentials() + { + var expectedAccessKey = "AK123"; + var expectedSecret = "SK123"; + var expectedToken = "T123"; + + var stub = new SimpleStubImds + { + CredentialsToReturn = new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials(expectedAccessKey, expectedSecret, expectedToken), + DateTime.UtcNow.AddMinutes(-30)) + }; + + // lower refresh rate to speed test + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(-1)); + + var creds = provider.GetCredentials(); + + Assert.Equal(expectedAccessKey, creds.AccessKey); + Assert.Equal(expectedSecret, creds.SecretKey); + Assert.Equal(expectedToken, creds.Token); + } + + [Fact] + public async Task GetCredentials_RecoverExpired() + { + var endCheckForExpiration = DateTime.UtcNow.AddSeconds(3); + var stub = new SimulatedIMDSOutageStubImds(TimeSpan.FromSeconds(5)); + + // lower refresh rate to speed test + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(-1)); + + while(DateTime.UtcNow < endCheckForExpiration) + { + var creds = provider.GetCredentials(); + + Assert.StartsWith("OUTAGE", creds.AccessKey); + Assert.StartsWith("OUTAGE", creds.SecretKey); + Assert.StartsWith("OUTAGE", creds.Token); + } + Assert.True(stub.CombinedFetchCallCount > 0); + + stub.CombinedFetchCallCount = 0; + await Task.Delay(TimeSpan.FromSeconds(3)); + + var currentCredentials = provider.GetCredentials(); + Assert.True(stub.CombinedFetchCallCount > 0); + + Assert.StartsWith("CURRENT", currentCredentials.AccessKey); + Assert.StartsWith("CURRENT", currentCredentials.SecretKey); + Assert.StartsWith("CURRENT", currentCredentials.Token); + } + + [Fact] + public async Task GetCredentialsAsync_RecoverExpired() + { + var endCheckForExpiration = DateTime.UtcNow.AddSeconds(3); + var stub = new SimulatedIMDSOutageStubImds(TimeSpan.FromSeconds(5)); + + // lower refresh rate to speed test + var provider = DefaultInstanceProfileAWSCredentials.CreateTestDefaultInstanceProfileAWSCredentials(stub, TimeSpan.FromSeconds(-1)); + + while (DateTime.UtcNow < endCheckForExpiration) + { + var creds = await provider.GetCredentialsAsync(); + + Assert.StartsWith("OUTAGE", creds.AccessKey); + Assert.StartsWith("OUTAGE", creds.SecretKey); + Assert.StartsWith("OUTAGE", creds.Token); + } + Assert.True(stub.CombinedFetchCallCount > 0); + + stub.CombinedFetchCallCount = 0; + await Task.Delay(TimeSpan.FromSeconds(3)); + + var currentCredentials = await provider.GetCredentialsAsync(); + Assert.True(stub.CombinedFetchCallCount > 0); + + Assert.StartsWith("CURRENT", currentCredentials.AccessKey); + Assert.StartsWith("CURRENT", currentCredentials.SecretKey); + Assert.StartsWith("CURRENT", currentCredentials.Token); + } + + #region Test stubs + + // Simple stub that returns a preconfigured sync/async credentials refresh state. + private class SimpleStubImds : DefaultInstanceProfileAWSCredentials.IIMDSAccessMethods + { + public RefreshingAWSCredentials.CredentialsRefreshState CredentialsToReturn { get; set; } + + public RefreshingAWSCredentials.CredentialsRefreshState FetchCredentials() + { + return CredentialsToReturn; + } + + public Task FetchCredentialsAsync() + { + return Task.FromResult(CredentialsToReturn); + } + + public void CheckIsIMDSEnabled() + { + // No-op for tests; production check not needed. + } + } + + // Stub that counts async fetch calls and delays to simulate long IMDS response. + private class CountingDelayStubImds : DefaultInstanceProfileAWSCredentials.IIMDSAccessMethods + { + private readonly int _delayMs; + public int FetchAsyncCallCount; + public int FetchCallCount; + + public CountingDelayStubImds(int delayMs) + { + _delayMs = delayMs; + } + + public RefreshingAWSCredentials.CredentialsRefreshState FetchCredentials() + { + Interlocked.Increment(ref FetchCallCount); + // Provide a synchronous immediate response for sync tests if needed. + return new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials("CONCURRENT_AK", "CONCURRENT_SK", "CONCURRENT_T"), + DateTime.UtcNow.AddMinutes(30)); + } + + public async Task FetchCredentialsAsync() + { + Interlocked.Increment(ref FetchAsyncCallCount); + await Task.Delay(_delayMs).ConfigureAwait(false); + return new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials("CONCURRENT_AK", "CONCURRENT_SK", "CONCURRENT_T"), + DateTime.UtcNow.AddMinutes(30)); + } + + public void CheckIsIMDSEnabled() + { + // No-op for tests. + } + } + + private class SimulatedIMDSOutageStubImds : DefaultInstanceProfileAWSCredentials.IIMDSAccessMethods + { + TimeSpan _expirationLength; + DateTime? _imdsOutageEndTime; + RefreshingAWSCredentials.CredentialsRefreshState _expiredCredentials; + public int CombinedFetchCallCount; + + public SimulatedIMDSOutageStubImds(TimeSpan expirationLength) + { + _expirationLength = expirationLength; + } + + public RefreshingAWSCredentials.CredentialsRefreshState FetchCredentials() + { + Interlocked.Increment(ref CombinedFetchCallCount); + + if (_expiredCredentials == null) + { + _expiredCredentials = new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials("OUTAGE_AK", "OUTAGE_SK", "OUTAGE_T"), + DateTime.UtcNow.AddMinutes(-30)); + + _imdsOutageEndTime = DateTime.UtcNow.Add(_expirationLength); + } + + if (DateTime.UtcNow < _imdsOutageEndTime) + { + return _expiredCredentials; + } + + return new RefreshingAWSCredentials.CredentialsRefreshState( + new ImmutableCredentials("CURRENT_AK", "CURRENT_SK", "CURRENT_T"), + DateTime.UtcNow.AddMinutes(30)); + } + + public Task FetchCredentialsAsync() + { + return Task.FromResult(FetchCredentials()); + } + + public void CheckIsIMDSEnabled() + { + // No-op for tests; production check not needed. + } + } + + #endregion + } +}