diff --git a/src/Accounts/Accounts.Test/AzureRMProfileTests.cs b/src/Accounts/Accounts.Test/AzureRMProfileTests.cs index 35f8262da4ff..210529a2223b 100644 --- a/src/Accounts/Accounts.Test/AzureRMProfileTests.cs +++ b/src/Accounts/Accounts.Test/AzureRMProfileTests.cs @@ -1081,7 +1081,7 @@ public void CanRenewTokenLogin() Assert.Equal(keyVaultToken2, account.GetProperty(AzureAccount.Property.KeyVaultAccessToken)); var factory = new ClientFactory(); var rmClient = factory.CreateArmClient(profile.DefaultContext, AzureEnvironment.Endpoint.ResourceManager); - var rmCred = rmClient.Credentials as TokenCredentials; + var rmCred = rmClient.Credentials as RenewingTokenCredential; Assert.NotNull(rmCred); var message = new HttpRequestMessage(HttpMethod.Get, rmClient.BaseUri.ToString()); rmCred.ProcessHttpRequestAsync(message, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); @@ -1089,7 +1089,7 @@ public void CanRenewTokenLogin() Assert.NotNull(message.Headers.Authorization.Parameter); Assert.Contains(accessToken2, message.Headers.Authorization.Parameter); var graphClient = factory.CreateArmClient(profile.DefaultContext, AzureEnvironment.Endpoint.Graph); - var graphCred = graphClient.Credentials as TokenCredentials; + var graphCred = graphClient.Credentials as RenewingTokenCredential; Assert.NotNull(graphCred); var graphMessage = new HttpRequestMessage(HttpMethod.Get, rmClient.BaseUri.ToString()); graphCred.ProcessHttpRequestAsync(graphMessage, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); diff --git a/src/Accounts/Authentication/Authentication/ExternalAccessToken.cs b/src/Accounts/Authentication/Authentication/ExternalAccessToken.cs new file mode 100644 index 000000000000..11638d805392 --- /dev/null +++ b/src/Accounts/Authentication/Authentication/ExternalAccessToken.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Azure.Commands.Common.Authentication.Authentication +{ + public class ExternalAccessToken : IAccessToken + { + public string AccessToken + { + get; set; + } + + public string LoginType + { + get; set; + } + + public string TenantId + { + get; set; + } + + public string UserId + { + get; set; + } + + private readonly Func _refresh; + + public ExternalAccessToken(string token, Func refresh = null) + { + this.AccessToken = token; + this._refresh = refresh; + } + + public void AuthorizeRequest(Action authTokenSetter) + { + AccessToken = (_refresh == null) ? AccessToken : _refresh(); + authTokenSetter("Bearer", AccessToken); + } + } +} diff --git a/src/Accounts/Authentication/Factories/AuthenticationFactory.cs b/src/Accounts/Authentication/Factories/AuthenticationFactory.cs index 3bc2ce7e7568..96472f66b63e 100644 --- a/src/Accounts/Authentication/Factories/AuthenticationFactory.cs +++ b/src/Accounts/Authentication/Factories/AuthenticationFactory.cs @@ -21,6 +21,8 @@ using System.Security; using Microsoft.Azure.Commands.Common.Authentication.Properties; using System.Threading.Tasks; +using Microsoft.Azure.Commands.Common.Authentication.Authentication; +using System.Management.Automation; namespace Microsoft.Azure.Commands.Common.Authentication.Factories { @@ -302,7 +304,7 @@ public ServiceClientCredentials GetServiceClientCredentials(IAzureContext contex case AzureAccount.AccountType.Certificate: throw new NotSupportedException(AzureAccount.AccountType.Certificate.ToString()); case AzureAccount.AccountType.AccessToken: - return new TokenCredentials(GetEndpointToken(context.Account, targetEndpoint)); + return new RenewingTokenCredential(new ExternalAccessToken (GetEndpointToken(context.Account, targetEndpoint), () => GetEndpointToken(context.Account, targetEndpoint))); }