diff --git a/sdk/identity/Azure.Identity/src/IdentityClient.cs b/sdk/identity/Azure.Identity/src/AadIdentityClient.cs similarity index 76% rename from sdk/identity/Azure.Identity/src/IdentityClient.cs rename to sdk/identity/Azure.Identity/src/AadIdentityClient.cs index 53bfee83ebec..48df6447ed43 100644 --- a/sdk/identity/Azure.Identity/src/IdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/AadIdentityClient.cs @@ -19,24 +19,23 @@ namespace Azure.Identity { - internal class IdentityClient + internal class AadIdentityClient { - private static Lazy s_sharedClient = new Lazy(() => new IdentityClient()); + private static Lazy s_sharedClient = new Lazy(() => new AadIdentityClient()); private readonly IdentityClientOptions _options; private readonly HttpPipeline _pipeline; - private readonly Uri ImdsEndptoint = new Uri("http://169.254.169.254/metadata/identity/oauth2/token"); - private const string MsiApiVersion = "2018-02-01"; + private const string ClientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; - public IdentityClient(IdentityClientOptions options = null) + public AadIdentityClient(IdentityClientOptions options = null) { _options = options ?? new IdentityClientOptions(); _pipeline = HttpPipelineBuilder.Build(_options, bufferResponse: true); } - public static IdentityClient SharedClient { get { return s_sharedClient.Value; } } + public static AadIdentityClient SharedClient { get { return s_sharedClient.Value; } } public virtual async Task AuthenticateAsync(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default) @@ -106,65 +105,6 @@ public virtual AccessToken Authenticate(string tenantId, string clientId, X509Ce throw response.CreateRequestFailedException(); } } - public virtual async Task AuthenticateManagedIdentityAsync(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) - { - using (Request request = CreateManagedIdentityAuthRequest(scopes, clientId)) - { - var response = await _pipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); - - if (response.Status == 200 || response.Status == 201) - { - var result = await DeserializeAsync(response.ContentStream, cancellationToken).ConfigureAwait(false); - - return new Response(response, result); - } - - throw response.CreateRequestFailedException(); - } - } - - public virtual AccessToken AuthenticateManagedIdentity(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) - { - using (Request request = CreateManagedIdentityAuthRequest(scopes, clientId)) - { - var response = _pipeline.SendRequest(request, cancellationToken); - - if (response.Status == 200 || response.Status == 201) - { - var result = Deserialize(response.ContentStream); - - return new Response(response, result); - } - - throw response.CreateRequestFailedException(); - } - } - - private Request CreateManagedIdentityAuthRequest(string[] scopes, string clientId = null) - { - // covert the scopes to a resource string - string resource = ScopeUtilities.ScopesToResource(scopes); - - Request request = _pipeline.CreateRequest(); - - request.Method = HttpPipelineMethod.Get; - - request.Headers.Add("Metadata", "true"); - - // TODO: support MSI for hosted services - request.UriBuilder.Uri = ImdsEndptoint; - - request.UriBuilder.AppendQuery("api-version", MsiApiVersion); - - request.UriBuilder.AppendQuery("resource", Uri.EscapeDataString(resource)); - - if (!string.IsNullOrEmpty(clientId)) - { - request.UriBuilder.AppendQuery("client_id", Uri.EscapeDataString(clientId)); - } - - return request; - } private Request CreateClientSecretAuthRequest(string tenantId, string clientId, string clientSecret, string[] scopes) { diff --git a/sdk/identity/Azure.Identity/src/Base64Url.cs b/sdk/identity/Azure.Identity/src/Base64Url.cs index c5f39055ed74..a66274902598 100644 --- a/sdk/identity/Azure.Identity/src/Base64Url.cs +++ b/sdk/identity/Azure.Identity/src/Base64Url.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See License.txt in the project root for -// license information. +// Licensed under the MIT License. using System; using System.Text; diff --git a/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs b/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs index 750944599f0f..bf72c38e82a9 100644 --- a/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs +++ b/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See License.txt in the project root for -// license information. +// Licensed under the MIT License. + using Azure.Core; using System; @@ -15,7 +15,7 @@ public class ClientCertificateCredential : TokenCredential private string _tenantId; private string _clientId; private X509Certificate2 _clientCertificate; - private IdentityClient _client; + private AadIdentityClient _client; public ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 clientCertificate) : this(tenantId, clientId, clientCertificate, null) @@ -30,7 +30,7 @@ public ClientCertificateCredential(string tenantId, string clientId, X509Certifi _clientCertificate = clientCertificate ?? throw new ArgumentNullException(nameof(clientCertificate)); - _client = (options != null) ? new IdentityClient(options) : IdentityClient.SharedClient; + _client = (options != null) ? new AadIdentityClient(options) : AadIdentityClient.SharedClient; } public override AccessToken GetToken(string[] scopes, CancellationToken cancellationToken = default) diff --git a/sdk/identity/Azure.Identity/src/ClientSecretCredential.cs b/sdk/identity/Azure.Identity/src/ClientSecretCredential.cs index a36dd07a0b8b..0f6011fc890a 100644 --- a/sdk/identity/Azure.Identity/src/ClientSecretCredential.cs +++ b/sdk/identity/Azure.Identity/src/ClientSecretCredential.cs @@ -13,7 +13,7 @@ public class ClientSecretCredential : TokenCredential private string _tenantId; private string _clientId; private string _clientSecret; - private IdentityClient _client; + private AadIdentityClient _client; public ClientSecretCredential(string tenantId, string clientId, string clientSecret) @@ -27,7 +27,7 @@ public ClientSecretCredential(string tenantId, string clientId, string clientSec _clientId = clientId; _clientSecret = clientSecret; - _client = (options != null) ? new IdentityClient(options) : IdentityClient.SharedClient; + _client = (options != null) ? new AadIdentityClient(options) : AadIdentityClient.SharedClient; } public override async Task GetTokenAsync(string[] scopes, CancellationToken cancellationToken = default) diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs new file mode 100644 index 000000000000..a8cb821b4fb2 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs @@ -0,0 +1,416 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Azure.Core; +using Azure.Core.Pipeline; +using System; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Identity +{ + internal class ManagedIdentityClient + { + private static Lazy s_sharedClient = new Lazy(() => new ManagedIdentityClient()); + + // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + private static readonly Uri ImdsEndpoint = new Uri("http://169.254.169.254/metadata/identity/oauth2/token"); + private const string ImdsApiVersion = "2018-02-01"; + private const int ImdsAvailableTimeoutMs = 500; + + // MSI Constants. Docs for MSI are available here https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity + private const string MsiEndpointEnvironemntVariable = "MSI_ENDPOINT"; + private const string MsiSecretEnvironemntVariable = "MSI_SECRET"; + private const string AppServiceMsiApiVersion = "2017-09-01"; + + private static SemaphoreSlim s_initLock = new SemaphoreSlim(1, 1); + private static MsiType s_msiType; + private static Uri s_endpoint; + + private readonly IdentityClientOptions _options; + private readonly HttpPipeline _pipeline; + + public ManagedIdentityClient(IdentityClientOptions options = null) + { + _options = options ?? new IdentityClientOptions(); + + _pipeline = HttpPipelineBuilder.Build(_options, bufferResponse: true); + } + + private enum MsiType + { + Unknown = 0, + Imds = 1, + AppService = 2, + CloudShell = 3, + Unavailable = 4 + } + + public static ManagedIdentityClient SharedClient { get { return s_sharedClient.Value; } } + + public virtual async Task AuthenticateAsync(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) + { + MsiType msiType = await GetMsiTypeAsync(cancellationToken).ConfigureAwait(false); + + // if msi is unavailable or we were unable to determine the type return a default access token + if (msiType == MsiType.Unavailable || msiType == MsiType.Unknown) + { + return default; + } + + using (Request request = CreateAuthRequest(msiType, scopes, clientId)) + { + var response = await _pipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + + if (response.Status == 200 || response.Status == 201) + { + var result = await DeserializeAsync(response.ContentStream, cancellationToken).ConfigureAwait(false); + + return result; + } + + throw response.CreateRequestFailedException(); + } + } + + public virtual AccessToken Authenticate(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) + { + MsiType msiType = GetMsiType(cancellationToken); + + // if msi is unavailable or we were unable to determine the type return a default access token + if (msiType == MsiType.Unavailable || msiType == MsiType.Unknown) + { + return default; + } + + using (Request request = CreateAuthRequest(msiType, scopes, clientId)) + { + var response = _pipeline.SendRequest(request, cancellationToken); + + if (response.Status == 200 || response.Status == 201) + { + var result = Deserialize(response.ContentStream); + + return result; + } + + throw response.CreateRequestFailedException(); + } + } + + private Request CreateAuthRequest(MsiType msiType, string[] scopes, string clientId) + { + switch (msiType) + { + case MsiType.Imds: + return CreateImdsAuthRequest(scopes, clientId); + case MsiType.AppService: + return CreateAppServiceAuthRequest(scopes, clientId); + case MsiType.CloudShell: + return CreateCloudShellAuthRequest(scopes, clientId); + default: + return default; + } + } + + private async ValueTask GetMsiTypeAsync(CancellationToken cancellationToken) + { // if we haven't already determined the msi type + if (s_msiType == MsiType.Unknown) + { + // aquire the init lock + await s_initLock.WaitAsync(cancellationToken).ConfigureAwait(false); + + try + { + // check again if the we already determined the msiType now that we hold the lock + if (s_msiType == MsiType.Unknown) + { + string endpointEnvVar = Environment.GetEnvironmentVariable(MsiEndpointEnvironemntVariable); + string secretEnvVar = Environment.GetEnvironmentVariable(MsiSecretEnvironemntVariable); + + // if the env var MSI_ENDPOINT is set + if (!string.IsNullOrEmpty(endpointEnvVar)) + { + s_endpoint = new Uri(endpointEnvVar); + + // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService + if (!string.IsNullOrEmpty(secretEnvVar)) + { + s_msiType = MsiType.AppService; + } + // if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell + else + { + s_msiType = MsiType.CloudShell; + } + } + // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the MsiType is Imds + else if (await ImdsAvailableAsync(cancellationToken).ConfigureAwait(false)) + { + s_endpoint = ImdsEndpoint; + s_msiType = MsiType.Imds; + } + // if MSI_ENDPOINT is NOT set and IMDS enpoint is not available ManagedIdentity is not available + else + { + s_msiType = MsiType.Unavailable; + } + } + } + // release the init lock + finally + { + s_initLock.Release(); + } + } + + return s_msiType; + } + + + private MsiType GetMsiType(CancellationToken cancellationToken) + { + // if we haven't already determined the msi type + if (s_msiType == MsiType.Unknown) + { + // aquire the init lock + s_initLock.Wait(cancellationToken); + + try + { + // check again if the we already determined the msiType now that we hold the lock + if (s_msiType == MsiType.Unknown) + { + string endpointEnvVar = Environment.GetEnvironmentVariable(MsiEndpointEnvironemntVariable); + string secretEnvVar = Environment.GetEnvironmentVariable(MsiSecretEnvironemntVariable); + + // if the env var MSI_ENDPOINT is set + if (!string.IsNullOrEmpty(endpointEnvVar)) + { + s_endpoint = new Uri(endpointEnvVar); + + // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the MsiType is AppService + if (!string.IsNullOrEmpty(secretEnvVar)) + { + s_msiType = MsiType.AppService; + } + // if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell + else + { + s_msiType = MsiType.CloudShell; + } + } + // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the MsiType is Imds + else if (ImdsAvailable(cancellationToken)) + { + s_endpoint = ImdsEndpoint; + s_msiType = MsiType.Imds; + } + // if MSI_ENDPOINT is NOT set and IMDS enpoint is not available ManagedIdentity is not available + else + { + s_msiType = MsiType.Unavailable; + } + } + } + // release the init lock + finally + { + s_initLock.Release(); + } + } + + return s_msiType; + } + + private bool ImdsAvailable(CancellationToken cancellationToken) + { + // send a request without the Metadata header. This will result in a failed request, + // but we're just interested in if we get a response before the timeout of 500ms + // if we don't get a response we assume the imds endpoint is not available + using (Request request = _pipeline.CreateRequest()) + { + request.Method = HttpPipelineMethod.Get; + + request.UriBuilder.Uri = ImdsEndpoint; + + request.UriBuilder.AppendQuery("api-version", ImdsApiVersion); + + var imdsTimeout = new CancellationTokenSource(ImdsAvailableTimeoutMs).Token; + + try + { + var response = _pipeline.SendRequest(request, CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, imdsTimeout).Token); + + return true; + } + // we only want to handle the case when the imdsTimeout resulted in the request being cancelled. + // this indicates that the request timed out and that imds is not available. If the operation + // was user cancelled we don't wan't to handle the exception so s_identityAvailable will + // remain unset, as we still haven't determined if the imds endpoint is available. + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + return false; + } + } + } + + private async Task ImdsAvailableAsync(CancellationToken cancellationToken) + { + // send a request without the Metadata header. This will result in a failed request, + // but we're just interested in if we get a response before the timeout of 500ms + // if we don't get a response we assume the imds endpoint is not available + using (Request request = _pipeline.CreateRequest()) + { + request.Method = HttpPipelineMethod.Get; + + request.UriBuilder.Uri = ImdsEndpoint; + + request.UriBuilder.AppendQuery("api-version", ImdsApiVersion); + + var imdsTimeout = new CancellationTokenSource(ImdsAvailableTimeoutMs).Token; + + try + { + var response = await _pipeline.SendRequestAsync(request, CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, imdsTimeout).Token).ConfigureAwait(false); + + return true; + } + // we only want to handle the case when the imdsTimeout resulted in the request being cancelled. + // this indicates that the request timed out and that imds is not available. If the operation + // was user cancelled we don't wan't to handle the exception so s_identityAvailable will + // remain unset, as we still haven't determined if the imds endpoint is available. + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + return false; + } + } + } + + private Request CreateImdsAuthRequest(string[] scopes, string clientId) + { + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + + request.Method = HttpPipelineMethod.Get; + + request.Headers.Add("Metadata", "true"); + + request.UriBuilder.Uri = s_endpoint; + + request.UriBuilder.AppendQuery("api-version", ImdsApiVersion); + + request.UriBuilder.AppendQuery("resource", Uri.EscapeDataString(resource)); + + if (!string.IsNullOrEmpty(clientId)) + { + request.UriBuilder.AppendQuery("client_id", Uri.EscapeDataString(clientId)); + } + + return request; + } + + private Request CreateAppServiceAuthRequest(string[] scopes, string clientId) + { + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + + request.Method = HttpPipelineMethod.Get; + + request.Headers.Add("secret", Environment.GetEnvironmentVariable(MsiSecretEnvironemntVariable)); + + request.UriBuilder.Uri = s_endpoint; + + request.UriBuilder.AppendQuery("api-version", AppServiceMsiApiVersion); + + request.UriBuilder.AppendQuery("resource", Uri.EscapeDataString(resource)); + + if (!string.IsNullOrEmpty(clientId)) + { + request.UriBuilder.AppendQuery("client_id", Uri.EscapeDataString(clientId)); + } + + return request; + } + + private Request CreateCloudShellAuthRequest(string[] scopes, string clientId) + { + // covert the scopes to a resource string + string resource = ScopeUtilities.ScopesToResource(scopes); + + Request request = _pipeline.CreateRequest(); + + request.Method = HttpPipelineMethod.Post; + + request.Headers.Add(HttpHeader.Common.FormUrlEncodedContentType); + + request.UriBuilder.Uri = s_endpoint; + + request.Headers.Add("Metadata", "true"); + + var bodyStr = $"resource={Uri.EscapeDataString(resource)}"; + + if (!string.IsNullOrEmpty(clientId)) + { + bodyStr += $"&client_id={Uri.EscapeDataString(clientId)}"; + } + + ReadOnlyMemory content = Encoding.UTF8.GetBytes(bodyStr).AsMemory(); + + request.Content = HttpPipelineRequestContent.Create(content); + + return request; + } + + private async Task DeserializeAsync(Stream content, CancellationToken cancellationToken) + { + using (JsonDocument json = await JsonDocument.ParseAsync(content, default, cancellationToken).ConfigureAwait(false)) + { + return Deserialize(json.RootElement); + } + } + + private AccessToken Deserialize(Stream content) + { + using (JsonDocument json = JsonDocument.Parse(content)) + { + return Deserialize(json.RootElement); + } + } + + private AccessToken Deserialize(JsonElement json) + { + string accessToken = null; + + DateTimeOffset expiresOn = DateTimeOffset.MaxValue; + + if (json.TryGetProperty("access_token", out JsonElement accessTokenProp)) + { + accessToken = accessTokenProp.GetString(); + } + + if (json.TryGetProperty("expires_on", out JsonElement expiresOnProp)) + { + // if s_msiType is AppService expires_on will be a string formatted datetimeoffset + if (s_msiType == MsiType.AppService) + { + expiresOn = DateTimeOffset.Parse(expiresOnProp.GetString()); + } + // otherwise expires_on will be a unix timestamp seconds from epoch + else + { + expiresOn = DateTimeOffset.FromUnixTimeMilliseconds(expiresOnProp.GetInt64()); + } + } + + return new AccessToken(accessToken, expiresOn); + } + } +} \ No newline at end of file diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs index 7cd7b2a55631..094592182e6f 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs @@ -14,23 +14,23 @@ namespace Azure.Identity public class ManagedIdentityCredential : TokenCredential { private string _clientId; - private IdentityClient _client; - + private ManagedIdentityClient _client; + public ManagedIdentityCredential(string clientId = null, IdentityClientOptions options = null) { _clientId = clientId; - _client = (options != null) ? new IdentityClient(options) : IdentityClient.SharedClient; + _client = (options != null) ? new ManagedIdentityClient(options) : ManagedIdentityClient.SharedClient; } public override async Task GetTokenAsync(string[] scopes, CancellationToken cancellationToken = default) { - return await this._client.AuthenticateManagedIdentityAsync(scopes, _clientId, cancellationToken).ConfigureAwait(false); + return await this._client.AuthenticateAsync(scopes, _clientId, cancellationToken).ConfigureAwait(false); } public override AccessToken GetToken(string[] scopes, CancellationToken cancellationToken = default) { - return this._client.AuthenticateManagedIdentity(scopes, _clientId, cancellationToken); + return this._client.Authenticate(scopes, _clientId, cancellationToken); } } } diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockIdentityClient.cs b/sdk/identity/Azure.Identity/tests/Mock/MockIdentityClient.cs index c5150a4fc9e8..05bb8b773029 100644 --- a/sdk/identity/Azure.Identity/tests/Mock/MockIdentityClient.cs +++ b/sdk/identity/Azure.Identity/tests/Mock/MockIdentityClient.cs @@ -10,8 +10,48 @@ namespace Azure.Identity.Tests.Mock { - internal class MockIdentityClient : IdentityClient + internal class MockManagedIdentityClient : ManagedIdentityClient { + public MockManagedIdentityClient() + : this(LiveTokenFactory) + { + } + + public MockManagedIdentityClient(AccessToken token) + : this(() => token) + { + } + + public MockManagedIdentityClient(Func tokenFactory) + : this((scopes, tenantId, clientId, clientSecret, cancellationToken) => tokenFactory()) + { + } + + public MockManagedIdentityClient(Func tokenFactory) + { + _tokenFactory = tokenFactory; + } + + public override AccessToken Authenticate(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) + { + return CreateToken(scopes, clientId: clientId, cancellationToken: cancellationToken); + } + + public async override Task AuthenticateAsync(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) + { + return await CreateTokenAsync(scopes, clientId: clientId, cancellationToken: cancellationToken); + } + + private async Task CreateTokenAsync(string[] scopes, string tenantId = default, string clientId = default, string clientSecret = default, CancellationToken cancellationToken = default) + { + if (cancellationToken != default) + { + await Task.Delay(1000, cancellationToken).ConfigureAwait(false); + } + + return _tokenFactory(scopes, tenantId, clientId, clientSecret, cancellationToken); + } + private static AccessToken ExpiredTokenFactory(string[] scopes, string tenantId, string clientId, string clientSecret, CancellationToken cancellationToken) { return CreateAccessToken(scopes, tenantId, clientId, clientSecret, DateTimeOffset.UtcNow - TimeSpan.FromMinutes(1)); @@ -34,29 +74,70 @@ private static MockToken CreateMockToken(string[] scopes, string tenantId, strin return new MockToken().WithField("scopes", string.Join("+", scopes)).WithField("tenantId", tenantId).WithField("clientId", clientId).WithField("clientSecret", clientSecret); } - public static MockIdentityClient ExpiredTokenClient { get; } = new MockIdentityClient(ExpiredTokenFactory); + public static MockManagedIdentityClient ExpiredTokenClient { get; } = new MockManagedIdentityClient(ExpiredTokenFactory); + + public static MockManagedIdentityClient LiveTokenClient { get; } = new MockManagedIdentityClient(LiveTokenFactory); + + private Func _tokenFactory; - public static MockIdentityClient LiveTokenClient { get; } = new MockIdentityClient(LiveTokenFactory); + private AccessToken CreateToken(string[] scopes, string tenantId = default, string clientId = default, string clientSecret = default, CancellationToken cancellationToken = default) + { + if (cancellationToken != default) + { + Task.Delay(1000, cancellationToken).GetAwaiter().GetResult(); + } + + return _tokenFactory(scopes, tenantId, clientId, clientSecret, cancellationToken); + } + } + + internal class MockAadClient : AadIdentityClient + { + private static AccessToken ExpiredTokenFactory(string[] scopes, string tenantId, string clientId, string clientSecret, CancellationToken cancellationToken) + { + return CreateAccessToken(scopes, tenantId, clientId, clientSecret, DateTimeOffset.UtcNow - TimeSpan.FromMinutes(1)); + } + + private static AccessToken LiveTokenFactory(string[] scopes, string tenantId, string clientId, string clientSecret, CancellationToken cancellationToken) + { + return CreateAccessToken(scopes, tenantId, clientId, clientSecret, DateTimeOffset.UtcNow + TimeSpan.FromHours(1)); + } + + private static AccessToken CreateAccessToken(string[] scopes, string tenantId, string clientId, string clientSecret, DateTimeOffset expires) + { + MockToken token = CreateMockToken(scopes, tenantId, clientId, clientSecret); + + return new AccessToken(token.ToString(), expires); + } + + private static MockToken CreateMockToken(string[] scopes, string tenantId, string clientId, string clientSecret) + { + return new MockToken().WithField("scopes", string.Join("+", scopes)).WithField("tenantId", tenantId).WithField("clientId", clientId).WithField("clientSecret", clientSecret); + } + + public static MockAadClient ExpiredTokenClient { get; } = new MockAadClient(ExpiredTokenFactory); + + public static MockAadClient LiveTokenClient { get; } = new MockAadClient(LiveTokenFactory); private Func _tokenFactory; - public MockIdentityClient() + public MockAadClient() : this(LiveTokenFactory) { } - public MockIdentityClient(AccessToken token) + public MockAadClient(AccessToken token) : this(() => token) { } - public MockIdentityClient(Func tokenFactory) + public MockAadClient(Func tokenFactory) : this((scopes, tenantId, clientId, clientSecret, cancellationToken) => tokenFactory()) { } - public MockIdentityClient(Func tokenFactory) + public MockAadClient(Func tokenFactory) { _tokenFactory = tokenFactory; } @@ -70,16 +151,6 @@ public async override Task AuthenticateAsync(string tenantId, strin { return await CreateTokenAsync(scopes, clientId: clientId, tenantId: tenantId, clientSecret: clientSecret, cancellationToken: cancellationToken); } - public override AccessToken AuthenticateManagedIdentity(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) - { - return CreateToken(scopes, clientId: clientId, cancellationToken: cancellationToken); - } - - public async override Task AuthenticateManagedIdentityAsync(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) - { - - return await CreateTokenAsync(scopes, clientId: clientId, cancellationToken: cancellationToken); - } private async Task CreateTokenAsync(string[] scopes, string tenantId = default, string clientId = default, string clientSecret = default, CancellationToken cancellationToken = default) { diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityCredentialTests.cs b/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityCredentialTests.cs index 9cba6dcd0858..071ee1e39419 100644 --- a/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityCredentialTests.cs @@ -19,7 +19,7 @@ public async Task CancellationTokenHonoredAsync() { var credential = new ManagedIdentityCredential(); - credential._client(new MockIdentityClient()); + credential._client(new MockManagedIdentityClient()); var cancellation = new CancellationTokenSource(); @@ -37,7 +37,7 @@ public async Task ScopesHonoredAsync() { var credential = new ManagedIdentityCredential(); - credential._client(new MockIdentityClient()); + credential._client(new MockManagedIdentityClient()); AccessToken defaultScopeToken = await credential.GetTokenAsync(MockScopes.Default); @@ -47,13 +47,15 @@ public async Task ScopesHonoredAsync() [Test] public async Task VerifyMSIRequest() { + var pingResponse = new MockResponse(400); + var response = new MockResponse(200); var expectedToken = "mock-msi-access-token"; response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_in\": 3600 }}"); - var mockTransport = new MockTransport(response); + var mockTransport = new MockTransport(pingResponse, response); var options = new IdentityClientOptions() { Transport = mockTransport }; @@ -63,7 +65,7 @@ public async Task VerifyMSIRequest() Assert.AreEqual(expectedToken, actualToken.Token); - MockRequest request = mockTransport.SingleRequest; + MockRequest request = mockTransport.Requests[1]; string query = request.UriBuilder.Query; diff --git a/sdk/identity/Azure.Identity/tests/TestAccessorExtensions.cs b/sdk/identity/Azure.Identity/tests/TestAccessorExtensions.cs index f456f967ed19..79be2c6840e2 100644 --- a/sdk/identity/Azure.Identity/tests/TestAccessorExtensions.cs +++ b/sdk/identity/Azure.Identity/tests/TestAccessorExtensions.cs @@ -29,16 +29,16 @@ public static string _client(this ClientSecretCredential credential) { return typeof(ClientSecretCredential).GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(credential) as string; } - public static void _client(this ClientSecretCredential credential, IdentityClient client) + public static void _client(this ClientSecretCredential credential, AadIdentityClient client) { typeof(ClientSecretCredential).GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(credential, client); } - public static string _client(this ManagedIdentityCredential credential) + public static ManagedIdentityClient _client(this ManagedIdentityCredential credential) { - return typeof(ManagedIdentityCredential).GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(credential) as string; + return typeof(ManagedIdentityCredential).GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(credential) as ManagedIdentityClient; } - public static void _client(this ManagedIdentityCredential credential, IdentityClient client) + public static void _client(this ManagedIdentityCredential credential, ManagedIdentityClient client) { typeof(ManagedIdentityCredential).GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(credential, client); }