Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/core/Azure.Core/src/TokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public abstract class TokenCredential
/// <param name="requestContext">The <see cref="TokenRequestContext"/> with authentication information.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to use.</param>
/// <returns>A valid <see cref="AccessToken"/>.</returns>
public abstract Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken);
public abstract ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken);

/// <summary>
/// Gets an <see cref="AccessToken"/> for the specified set of scopes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
using Azure.Core.Tests;
using Castle.DynamicProxy;
Expand Down Expand Up @@ -36,10 +37,26 @@ public void Intercept(IInvocation invocation)
return;
}

var result = (Task)invocation.ReturnValue;
try
{
result.GetAwaiter().GetResult();
object returnValue = invocation.ReturnValue;
if (returnValue is Task t)
{
t.GetAwaiter().GetResult();
}
else
{
// Await ValueTask
Type returnType = returnValue.GetType();
MethodInfo getAwaiterMethod = returnType.GetMethod("GetAwaiter", BindingFlags.Instance | BindingFlags.Public);
MethodInfo getResultMethod = getAwaiterMethod.ReturnType.GetMethod("GetResult", BindingFlags.Instance | BindingFlags.Public);

getResultMethod.Invoke(
getAwaiterMethod.Invoke(returnValue, Array.Empty<object>()),
Array.Empty<object>());

}

expectedEvents.Add(expectedEventPrefix + ".Stop");
}
catch (Exception ex)
Expand Down
4 changes: 2 additions & 2 deletions sdk/core/Azure.Core/tests/TestFramework/TestRecording.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ public void DisableIdReuse()

private class TestCredential : TokenCredential
{
public override Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
return Task.FromResult(GetToken(requestContext, cancellationToken));
return new ValueTask<AccessToken>(GetToken(requestContext, cancellationToken));
}

public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void Intercept(IInvocation invocation)
}
else
{
invocation.ReturnValue = _taskFromResultMethod.MakeGenericMethod(returnType).Invoke(null, new[] { result });
SetAsyncResult(invocation, returnType, result);
}
}
catch (TargetInvocationException exception)
Expand All @@ -98,9 +98,51 @@ public void Intercept(IInvocation invocation)
}
else
{
invocation.ReturnValue = _taskFromExceptionMethod.MakeGenericMethod(methodInfo.ReturnType).Invoke(null, new[] { exception.InnerException });
SetAsyncException(invocation, returnType, exception.InnerException);
}
}
}

private void SetAsyncResult(IInvocation invocation, Type returnType, object result)
{
Type methodReturnType = invocation.Method.ReturnType;
if (methodReturnType.IsGenericType)
{
if (methodReturnType.GetGenericTypeDefinition() == typeof(Task<>))
{
invocation.ReturnValue = _taskFromResultMethod.MakeGenericMethod(returnType).Invoke(null, new[] { result });
return;
}
if (methodReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
invocation.ReturnValue = Activator.CreateInstance(typeof(ValueTask<>).MakeGenericType(returnType), result);
return;
}
}

throw new NotSupportedException();
}

private void SetAsyncException(IInvocation invocation, Type returnType, Exception result)
{
Type methodReturnType = invocation.Method.ReturnType;
if (methodReturnType.IsGenericType)
{
if (methodReturnType.GetGenericTypeDefinition() == typeof(Task<>))
{
invocation.ReturnValue = _taskFromExceptionMethod.MakeGenericMethod(returnType).Invoke(null, new[] { result });
return;
}

if (methodReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
var task = _taskFromExceptionMethod.MakeGenericMethod(returnType).Invoke(null, new[] { result });
invocation.ReturnValue = Activator.CreateInstance(typeof(ValueTask<>).MakeGenericType(returnType), task);
return;
}
}

throw new NotSupportedException();
}

private static MethodInfo GetMethod(IInvocation invocation, string nonAsyncMethodName, Type[] types) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public EventHubSharedKeyCredential(string sharedAccessKeyName,
///
/// <returns>The token representing the shared access signature for this credential.</returns>
///
public override Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) => throw new InvalidOperationException(Resources.SharedKeyCredentialCannotGenerateTokens);
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) => throw new InvalidOperationException(Resources.SharedKeyCredentialCannotGenerateTokens);

/// <summary>
/// Coverts to shared access signature credential.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext,
///
/// <returns>The token representing the shared access signature for this credential.</returns>
///
public override Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext,
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext,
CancellationToken cancellationToken) => Credential.GetTokenAsync(requestContext, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext,
///
/// <returns>The token representing the shared access signature for this credential.</returns>
///
public override Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext,
CancellationToken cancellationToken) => Task.FromResult(new AccessToken(SharedAccessSignature.Value, SharedAccessSignature.SignatureExpiration));
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext,
CancellationToken cancellationToken) => new ValueTask<AccessToken>(new AccessToken(SharedAccessSignature.Value, SharedAccessSignature.SignatureExpiration));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public void GetPropertiesAsyncCreatesTheRequest()

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.Is<CancellationToken>(value => value == cancellationSource.Token)))
.Returns(Task.FromResult(new AccessToken(tokenValue, DateTimeOffset.MaxValue)))
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, DateTimeOffset.MaxValue)))
.Verifiable();

mockConverter
Expand Down Expand Up @@ -273,7 +273,7 @@ public void GetPropertiesAsyncRespectsTheRetryPolicy(RetryOptions retryOptions)

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.Is<CancellationToken>(value => value == cancellationSource.Token)))
.Returns(Task.FromResult(new AccessToken(tokenValue, DateTimeOffset.MaxValue)));
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, DateTimeOffset.MaxValue)));

mockConverter
.Setup(converter => converter.CreateEventHubPropertiesRequest(It.Is<string>(value => value == eventHubName), It.Is<string>(value => value == tokenValue)))
Expand Down Expand Up @@ -354,7 +354,7 @@ public void GetPartitionPropertiesAsyncCreatesTheRequest()

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.Is<CancellationToken>(value => value == cancellationSource.Token)))
.Returns(Task.FromResult(new AccessToken(tokenValue, DateTimeOffset.MaxValue)))
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, DateTimeOffset.MaxValue)))
.Verifiable();

mockConverter
Expand Down Expand Up @@ -396,7 +396,7 @@ public void GetPartitionPropertiesAsyncRespectsTheRetryPolicy(RetryOptions retry

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.Is<CancellationToken>(value => value == cancellationSource.Token)))
.Returns(Task.FromResult(new AccessToken(tokenValue, DateTimeOffset.MaxValue)));
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, DateTimeOffset.MaxValue)));

mockConverter
.Setup(converter => converter.CreatePartitionPropertiesRequest(It.Is<string>(value => value == eventHubName), It.Is<string>(value => value == partitionId), It.Is<string>(value => value == tokenValue)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ public void ReceiveAsyncRespectsTheRetryPolicy(RetryOptions retryOptions)

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.Is<CancellationToken>(value => value == cancellationSource.Token)))
.Returns(Task.FromResult(new AccessToken(tokenValue, DateTimeOffset.MaxValue)));
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, DateTimeOffset.MaxValue)));

mockScope
.Setup(scope => scope.OpenConsumerLinkAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public async Task GetTokenAsyncPassesAlongTheClaims()

mockCredential
.Setup(credential => credential.GetTokenAsync(It.Is<TokenRequestContext>(value => value.Scopes == requiredClaims), It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(new AccessToken("blah", DateTimeOffset.Parse("2015-10-27T00:00:00Z"))))
.Returns(new ValueTask<AccessToken>(new AccessToken("blah", DateTimeOffset.Parse("2015-10-27T00:00:00Z"))))
.Verifiable();

await provider.GetTokenAsync(new Uri("http://www.here.com"), "nobody", requiredClaims);
Expand All @@ -70,7 +70,7 @@ public async Task GetTokenAsyncPopulatesUsingTheCredentialAccessToken()

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(new AccessToken(tokenValue, expires)));
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, expires)));

CbsToken cbsToken = await provider.GetTokenAsync(new Uri("http://www.here.com"), "nobody", new string[0]);

Expand Down Expand Up @@ -114,7 +114,7 @@ public async Task GetTokenAsyncSetsTheCorrectTypeForOtherTokens()

mockCredential
.Setup(credential => credential.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.Returns(Task.FromResult(new AccessToken(tokenValue, expires)));
.Returns(new ValueTask<AccessToken>(new AccessToken(tokenValue, expires)));

CbsToken cbsToken = await provider.GetTokenAsync(new Uri("http://www.here.com"), "nobody", new string[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _clientDiagnostics.CreateScope("Azure.Identity.AuthorizationCodeCredential.GetToken");

Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/ChainedTokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>The first non default <see cref="AccessToken"/> returned by the specified sources. If all credentials in the chain return default a default <see cref="AccessToken"/> is returned.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
AccessToken token = new AccessToken();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
return await _client.AuthenticateAsync(TenantId, ClientId, ClientCertificate, requestContext.Scopes, cancellationToken).ConfigureAwait(false);
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/ClientSecretCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public ClientSecretCredential(string tenantId, string clientId, string clientSec
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
return await _client.AuthenticateAsync(TenantId, ClientId, ClientSecret, requestContext.Scopes, cancellationToken).ConfigureAwait(false);
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/DefaultAzureCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
throw new AuthenticationFailedException(CredentialNotFoundMessage);
}

public override Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
throw new AuthenticationFailedException(CredentialNotFoundMessage);
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/DeviceCodeCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _clientDiagnostics.CreateScope("Azure.Identity.DeviceCodeCredential.GetToken");

Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/EnvironmentCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls, or a default <see cref="AccessToken"/>.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
return (_credential != null) ? await _credential.GetTokenAsync(requestContext, cancellationToken).ConfigureAwait(false) : default;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public override AccessToken GetToken(TokenRequestContext requestContext, Cancell
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
if (_account != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions
/// <param name="requestContext">The details of the authentication request.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
/// <returns>An <see cref="AccessToken"/> which can be used to authenticate service client calls, or a default <see cref="AccessToken"/> if no managed identity is available.</returns>
public override async Task<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken = default)
{
return await _client.AuthenticateAsync(requestContext.Scopes, _clientId, cancellationToken).ConfigureAwait(false);
}
Expand Down
Loading