Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Accounts/Accounts.Test/AzureRMProfileTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,15 +1081,15 @@ public void CanRenewTokenLogin()
Assert.Equal(keyVaultToken2, account.GetProperty(AzureAccount.Property.KeyVaultAccessToken));
var factory = new ClientFactory();
var rmClient = factory.CreateArmClient<MockServiceClient>(profile.DefaultContext, AzureEnvironment.Endpoint.ResourceManager);
var rmCred = rmClient.Credentials as TokenCredentials;
var rmCred = rmClient.Credentials as RenewingTokenCredential;
Assert.NotNull(rmCred);
var message = new HttpRequestMessage(HttpMethod.Get, rmClient.BaseUri.ToString());
rmCred.ProcessHttpRequestAsync(message, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
Assert.NotNull(message.Headers.Authorization);
Assert.NotNull(message.Headers.Authorization.Parameter);
Assert.Contains(accessToken2, message.Headers.Authorization.Parameter);
var graphClient = factory.CreateArmClient<MockServiceClient>(profile.DefaultContext, AzureEnvironment.Endpoint.Graph);
var graphCred = graphClient.Credentials as TokenCredentials;
var graphCred = graphClient.Credentials as RenewingTokenCredential;
Assert.NotNull(graphCred);
var graphMessage = new HttpRequestMessage(HttpMethod.Get, rmClient.BaseUri.ToString());
graphCred.ProcessHttpRequestAsync(graphMessage, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// ----------------------------------------------------------------------------------
//
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.Net.Http.Headers;
using System.Text;
using System.Threading.Tasks;
using ClientRuntime = Microsoft.Rest;

namespace Microsoft.Azure.Commands.Common.Authentication.Authentication
{
internal class RenewingAccessTokenProvider : ClientRuntime.ITokenProvider
Comment thread
VeryEarly marked this conversation as resolved.
Outdated
{
private const string _type = "Bearer";
private readonly Func<string> _accessToken;

/// <summary>
/// Create a token provider that returns the given
/// access token.
/// </summary>
/// <param name="accessToken">The access token to return.</param>
public RenewingAccessTokenProvider(Func<string> accessToken)
{
_accessToken = accessToken;
}

/// <summary>
/// Returns the static access token.
/// </summary>
/// <param name="cancellationToken">The cancellation token for this action.
/// <returns>The access token.</returns>
public Task<AuthenticationHeaderValue> GetAuthenticationHeaderAsync(System.Threading.CancellationToken cancellationToken)
{
return Task.FromResult(new AuthenticationHeaderValue(_type, _accessToken()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,28 @@
using System.Net.Http.Headers;
using System.Threading;
using System.Net.Http;
using System;

namespace Microsoft.Azure.Commands.Common.Authentication
{
public class RenewingTokenCredential : ServiceClientCredentials
{
private IAccessToken _token;
private readonly Func<IAccessToken> _refresh;


public RenewingTokenCredential(IAccessToken token)
public RenewingTokenCredential(IAccessToken token, Func<IAccessToken> refresh = null)
{
_token = token;
_refresh = refresh;
}

public override Task ProcessHttpRequestAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return Task.Run( () => _token.AuthorizeRequest((type, token) => request.Headers.Authorization = new AuthenticationHeaderValue(type, token)));
return Task.Run(() =>
{
_token = (_refresh == null) ? _token : _refresh();
_token.AuthorizeRequest((type, token) => request.Headers.Authorization = new AuthenticationHeaderValue(type, token));
});
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
using System.Security;
using Microsoft.Azure.Commands.Common.Authentication.Properties;
using System.Threading.Tasks;
using Microsoft.Azure.Commands.Common.Authentication.Authentication;
using System.Management.Automation;

namespace Microsoft.Azure.Commands.Common.Authentication.Factories
{
Expand Down Expand Up @@ -302,7 +304,7 @@ public ServiceClientCredentials GetServiceClientCredentials(IAzureContext contex
case AzureAccount.AccountType.Certificate:
throw new NotSupportedException(AzureAccount.AccountType.Certificate.ToString());
case AzureAccount.AccountType.AccessToken:
return new TokenCredentials(GetEndpointToken(context.Account, targetEndpoint));
return new RenewingTokenCredential(new RawAccessToken { AccessToken = GetEndpointToken(context.Account, targetEndpoint) }, () => new RawAccessToken { AccessToken = GetEndpointToken(context.Account, targetEndpoint) });
}


Expand Down