Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions sdk/identity/Azure.Identity/src/AuthenticationFailedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,27 @@ public AuthenticationFailedException(string message, Exception innerException)
{
}

internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> innerExceptions)
internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> exceptions)
{
return new AuthenticationFailedException(message, new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", innerExceptions.ToArray()));
// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(message);

bool allCredentialUnavailableException = true;
foreach (var exception in exceptions)
{
allCredentialUnavailableException &= exception is CredentialUnavailableException;
errorMsg.Append(Environment.NewLine).Append("- ").Append(exception.Message);
}

var innerException = exceptions.Count == 1
? exceptions[0]
: new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", exceptions);

// If all credentials have thrown CredentialUnavailableException, throw CredentialUnavailableException,
// otherwise throw AuthenticationFailedException
return allCredentialUnavailableException
? new CredentialUnavailableException(errorMsg.ToString(), innerException)
: new AuthenticationFailedException(errorMsg.ToString(), innerException);
}
}
}
80 changes: 40 additions & 40 deletions sdk/identity/Azure.Identity/src/ChainedTokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
using Azure.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Pipeline;

namespace Azure.Identity
{
Expand All @@ -22,6 +22,14 @@ public class ChainedTokenCredential : TokenCredential

private readonly TokenCredential[] _sources;

/// <summary>
/// Constructor for instrumenting in tests
/// </summary>
internal ChainedTokenCredential()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be protected internal instead of just internal for mocking purposes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have another public constructor, won't that be enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The public constructor requires you to pass one or more credential instances or it will throw. It's not a big deal since they could pass arbitrary instances and ignore them in their implementation. But it would just be more convenient if this were protected. Also, it adheres to the guideline, https://azure.github.io/azure-sdk/dotnet_introduction.html#dotnet-mocking-constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor doesn't initialize ChainedTokenCredential properly, making it unusable. We either have to make it fully functional with _sources == null, which IMO is meaningless, or leave it internal for testing purposes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this would be for mocking purposes I wouldn't expect developers to use it unless they were overriding GetToken and GetTokenAsync as well. It would need to be documented as such as well.

{
_sources = Array.Empty<TokenCredential>();
}

/// <summary>
/// Creates an instance with the specified <see cref="TokenCredential"/> sources.
/// </summary>
Expand Down Expand Up @@ -53,29 +61,7 @@ public ChainedTokenCredential(params TokenCredential[] sources)
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises a <see cref="CredentialUnavailableException"/> will be skipped.</returns>
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
List<Exception> exceptions = new List<Exception>();

for (int i = 0; i < _sources.Length; i++)
{
try
{
return _sources[i].GetToken(requestContext, cancellationToken);
}
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);

throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
=> GetTokenImplAsync(false, requestContext, cancellationToken).EnsureCompleted();

/// <summary>
/// Sequentially calls <see cref="TokenCredential.GetToken"/> on all the specified sources, returning the first successfully obtained <see cref="AccessToken"/>. This method is called by Azure SDK clients. It isn't intended for use in application code.
Expand All @@ -84,28 +70,42 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first <see cref="AccessToken"/> returned by the specified sources. Any credential which raises a <see cref="CredentialUnavailableException"/> will be skipped.</returns>
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
List<Exception> exceptions = new List<Exception>();
=> await GetTokenImplAsync(true, requestContext, cancellationToken).ConfigureAwait(false);

for (int i = 0; i < _sources.Length; i++)
private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestContext requestContext, CancellationToken cancellationToken)
{
var groupScopeHandler = new ScopeGroupHandler(default);
try
{
try
List<Exception> exceptions = new List<Exception>();
foreach (TokenCredential source in _sources)
{
return await _sources[i].GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false);
try
{
AccessToken token = async
? await source.GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false)
: source.GetToken(requestContext, cancellationToken);
groupScopeHandler.Dispose(default, default);
return token;
}
catch (AuthenticationFailedException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
}
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);

throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
}
throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
catch (Exception exception)
{
groupScopeHandler.Fail(default, default, exception);
throw;
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
}
}
4 changes: 2 additions & 2 deletions sdk/identity/Azure.Identity/src/CredentialDiagnosticScope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ namespace Azure.Identity
private readonly TokenRequestContext _context;
private readonly IScopeHandler _scopeHandler;

public CredentialDiagnosticScope(string name, TokenRequestContext context, IScopeHandler scopeHandler)
public CredentialDiagnosticScope(ClientDiagnostics diagnostics, string name, TokenRequestContext context, IScopeHandler scopeHandler)
{
_name = name;
_scope = scopeHandler.CreateScope(name);
_scope = scopeHandler.CreateScope(diagnostics, name);
_context = context;
_scopeHandler = scopeHandler;
}
Expand Down
133 changes: 7 additions & 126 deletions sdk/identity/Azure.Identity/src/CredentialPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Azure.Core;
using Azure.Core.Diagnostics;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
internal class CredentialPipeline
internal class CredentialPipeline
{
private static readonly Lazy<CredentialPipeline> s_singleton = new Lazy<CredentialPipeline>(() => new CredentialPipeline(new TokenCredentialOptions()));

private readonly IScopeHandler _defaultScopeHandler;
private IScopeHandler _groupScopeHandler;
private static readonly IScopeHandler _defaultScopeHandler = new ScopeHandler();

private CredentialPipeline(TokenCredentialOptions options)
{
Expand All @@ -26,8 +22,6 @@ private CredentialPipeline(TokenCredentialOptions options)
HttpPipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), Array.Empty<HttpPipelinePolicy>(), new CredentialResponseClassifier());

Diagnostics = new ClientDiagnostics(options);

_defaultScopeHandler = new ScopeHandler(Diagnostics);
}

public static CredentialPipeline GetInstance(TokenCredentialOptions options)
Expand All @@ -48,18 +42,18 @@ public IConfidentialClientApplication CreateMsalConfidentialClient(string tenant

public CredentialDiagnosticScope StartGetTokenScope(string fullyQualifiedMethod, TokenRequestContext context)
{
IScopeHandler scopeHandler = _groupScopeHandler ?? _defaultScopeHandler;
IScopeHandler scopeHandler = ScopeGroupHandler.Current ?? _defaultScopeHandler;

CredentialDiagnosticScope scope = new CredentialDiagnosticScope(fullyQualifiedMethod, context, scopeHandler);
CredentialDiagnosticScope scope = new CredentialDiagnosticScope(Diagnostics, fullyQualifiedMethod, context, scopeHandler);
scope.Start();
return scope;
}

public CredentialDiagnosticScope StartGetTokenScopeGroup(string fullyQualifiedMethod, TokenRequestContext context)
{
var scopeHandler = new ScopeGroupHandler(this, fullyQualifiedMethod);
var scopeHandler = new ScopeGroupHandler(fullyQualifiedMethod);

CredentialDiagnosticScope scope = new CredentialDiagnosticScope(fullyQualifiedMethod, context, scopeHandler);
CredentialDiagnosticScope scope = new CredentialDiagnosticScope(Diagnostics, fullyQualifiedMethod, context, scopeHandler);
scope.Start();
return scope;
}
Expand All @@ -74,123 +68,10 @@ public override bool IsRetriableResponse(HttpMessage message)

private class ScopeHandler : IScopeHandler
{
private readonly ClientDiagnostics _diagnostics;

public ScopeHandler(ClientDiagnostics diagnostics)
{
_diagnostics = diagnostics;
}

public DiagnosticScope CreateScope(string name) => _diagnostics.CreateScope(name);
public DiagnosticScope CreateScope(ClientDiagnostics diagnostics, string name) => diagnostics.CreateScope(name);
public void Start(string name, in DiagnosticScope scope) => scope.Start();
public void Dispose(string name, in DiagnosticScope scope) => scope.Dispose();
public void Fail(string name, in DiagnosticScope scope, Exception exception) => scope.Failed(exception);
}

private class ScopeGroupHandler : IScopeHandler
{
private readonly CredentialPipeline _pipeline;
private readonly string _groupName;
private Dictionary<string, (DateTime StartDateTime, Exception Exception)> _childScopes;

public ScopeGroupHandler(CredentialPipeline pipeline, string groupName)
{
_pipeline = pipeline;
_groupName = groupName;
}

public DiagnosticScope CreateScope(string name)
{
if (IsGroup(name))
{
_pipeline._groupScopeHandler = this;
return _pipeline.Diagnostics.CreateScope(name);
}

_childScopes ??= new Dictionary<string, (DateTime startDateTime, Exception exception)>();
_childScopes[name] = default;
return default;
}

public void Start(string name, in DiagnosticScope scope)
{
if (IsGroup(name))
{
scope.Start();
}
else
{
_childScopes[name] = (DateTime.UtcNow, default);
}
}

public void Dispose(string name, in DiagnosticScope scope)
{
if (!IsGroup(name))
{
return;
}

if (_childScopes != null)
{
var succeededScope = _childScopes.LastOrDefault(kvp => kvp.Value.Exception == default);
if (succeededScope.Key != default)
{
SucceedChildScope(succeededScope.Key, succeededScope.Value.StartDateTime);
}
}

scope.Dispose();
_pipeline._groupScopeHandler = default;
}

public void Fail(string name, in DiagnosticScope scope, Exception exception)
{
if (_childScopes == default)
{
scope.Failed(exception);
return;
}

if (IsGroup(name))
{
if (exception is OperationCanceledException)
{
var canceledScope = _childScopes.Last(kvp => kvp.Value.Exception == exception);
FailChildScope(canceledScope.Key, canceledScope.Value.StartDateTime, canceledScope.Value.Exception);
}
else
{
foreach (var childScope in _childScopes)
{
FailChildScope(childScope.Key, childScope.Value.StartDateTime, childScope.Value.Exception);
}
}

scope.Failed(exception);
}
else
{
_childScopes[name] = (_childScopes[name].StartDateTime, exception);
}
}

private void SucceedChildScope(string name, DateTime dateTime)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope(name);
scope.SetStartTime(dateTime);
scope.Start();
}

private void FailChildScope(string name, DateTime dateTime, Exception exception)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope(name);
scope.SetStartTime(dateTime);
scope.Start();
scope.Failed(exception);
}

private bool IsGroup(string name) => string.Equals(name, _groupName, StringComparison.Ordinal);
}
}
}
25 changes: 8 additions & 17 deletions sdk/identity/Azure.Identity/src/DefaultAzureCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace Azure.Identity
public class DefaultAzureCredential : TokenCredential
{
private const string DefaultExceptionMessage = "DefaultAzureCredential failed to retrieve a token from the included credentials.";
private const string UnhandledExceptionMessage = "DefaultAzureCredential authentication failed.";
private const string UnhandledExceptionMessage = "DefaultAzureCredential authentication failed due to an unhandled exception: ";
private static readonly TokenCredential[] s_defaultCredentialChain = GetDefaultAzureCredentialChain(new DefaultAzureCredentialFactory(null), new DefaultAzureCredentialOptions());

private readonly CredentialPipeline _pipeline;
Expand Down Expand Up @@ -143,7 +143,7 @@ private static async ValueTask<AccessToken> GetTokenFromCredentialAsync(TokenCre

private static async ValueTask<(AccessToken, TokenCredential)> GetTokenFromSourcesAsync(TokenCredential[] sources, TokenRequestContext requestContext, bool async, CancellationToken cancellationToken)
{
List<AuthenticationFailedException> exceptions = new List<AuthenticationFailedException>();
List<Exception> exceptions = new List<Exception>();

for (var i = 0; i < sources.Length && sources[i] != null; i++)
{
Expand All @@ -159,23 +159,14 @@ private static async ValueTask<AccessToken> GetTokenFromCredentialAsync(TokenCre
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(UnhandledExceptionMessage + e.Message, exceptions);
}
}

// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(DefaultExceptionMessage);

bool allCredentialUnavailableException = true;
foreach (AuthenticationFailedException ex in exceptions)
{
allCredentialUnavailableException &= ex is CredentialUnavailableException;
errorMsg.Append(Environment.NewLine).Append("- ").Append(ex.Message);
}

// If all credentials have thrown CredentialUnavailableException, throw CredentialUnavailableException,
// otherwise throw AuthenticationFailedException
throw allCredentialUnavailableException
? new CredentialUnavailableException(errorMsg.ToString())
: new AuthenticationFailedException(errorMsg.ToString());
throw AuthenticationFailedException.CreateAggregateException(DefaultExceptionMessage, exceptions);
}

private static TokenCredential[] GetDefaultAzureCredentialChain(DefaultAzureCredentialFactory factory, DefaultAzureCredentialOptions options)
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/IScopeHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Azure.Identity
{
internal interface IScopeHandler
{
DiagnosticScope CreateScope(string name);
DiagnosticScope CreateScope(ClientDiagnostics diagnostics, string name);
void Start(string name, in DiagnosticScope scope);
void Dispose(string name, in DiagnosticScope scope);
void Fail(string name, in DiagnosticScope scope, Exception exception);
Expand Down
Loading