diff --git a/sdk/identity/Azure.Identity/src/DeviceCodeCredential.cs b/sdk/identity/Azure.Identity/src/DeviceCodeCredential.cs index d45b346c9d86..c2aa3cda3f9c 100644 --- a/sdk/identity/Azure.Identity/src/DeviceCodeCredential.cs +++ b/sdk/identity/Azure.Identity/src/DeviceCodeCredential.cs @@ -198,7 +198,7 @@ private async ValueTask GetTokenImplAsync(bool async, TokenRequestC { try { - AuthenticationResult result = await Client.AcquireTokenSilentAsync(requestContext.Scopes, (AuthenticationAccount)Record, async, cancellationToken).ConfigureAwait(false); + AuthenticationResult result = await Client.AcquireTokenSilentAsync(requestContext.Scopes, Record, async, cancellationToken).ConfigureAwait(false); return scope.Succeeded(new AccessToken(result.AccessToken, result.ExpiresOn)); } diff --git a/sdk/identity/Azure.Identity/src/InteractiveBrowserCredential.cs b/sdk/identity/Azure.Identity/src/InteractiveBrowserCredential.cs index 498dec8e468a..c0bcec8d8595 100644 --- a/sdk/identity/Azure.Identity/src/InteractiveBrowserCredential.cs +++ b/sdk/identity/Azure.Identity/src/InteractiveBrowserCredential.cs @@ -181,7 +181,7 @@ private async ValueTask GetTokenImplAsync(bool async, TokenRequestC { try { - AuthenticationResult result = await Client.AcquireTokenSilentAsync(requestContext.Scopes, (AuthenticationAccount)Record, async, cancellationToken).ConfigureAwait(false); + AuthenticationResult result = await Client.AcquireTokenSilentAsync(requestContext.Scopes, Record, async, cancellationToken).ConfigureAwait(false); return scope.Succeeded(new AccessToken(result.AccessToken, result.ExpiresOn)); } diff --git a/sdk/identity/Azure.Identity/src/MsalPublicClient.cs b/sdk/identity/Azure.Identity/src/MsalPublicClient.cs index 8841f476ed92..ded0a5450573 100644 --- a/sdk/identity/Azure.Identity/src/MsalPublicClient.cs +++ b/sdk/identity/Azure.Identity/src/MsalPublicClient.cs @@ -52,6 +52,17 @@ public virtual async ValueTask AcquireTokenSilentAsync(str IPublicClientApplication client = await GetClientAsync(async, cancellationToken).ConfigureAwait(false); return await client.AcquireTokenSilent(scopes, account).ExecuteAsync(async, cancellationToken).ConfigureAwait(false); } + public virtual async ValueTask AcquireTokenSilentAsync(string[] scopes, AuthenticationRecord record, bool async, CancellationToken cancellationToken) + { + IPublicClientApplication client = await GetClientAsync(async, cancellationToken).ConfigureAwait(false); + + // if the user specified a TenantId when they created the client we want to authenticate to that tenant. + // otherwise we should authenticate with the tenant specified by the authentication record since that's the tenant the + // user authenticated to originally. + return await client.AcquireTokenSilent(scopes, (AuthenticationAccount)record) + .WithAuthority(Pipeline.AuthorityHost.AbsoluteUri, TenantId ?? record.TenantId) + .ExecuteAsync(async, cancellationToken).ConfigureAwait(false); + } public virtual async ValueTask AcquireTokenInteractiveAsync(string[] scopes, Prompt prompt, bool async, CancellationToken cancellationToken) { diff --git a/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialLiveTests.cs b/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialLiveTests.cs index 6d245ba817e6..73f44b785b7d 100644 --- a/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialLiveTests.cs +++ b/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialLiveTests.cs @@ -85,5 +85,23 @@ public async Task AuthenticateWithSharedTokenCacheAsync() Assert.NotNull(token.Token); } + + [Test] + [Ignore("This test is an integration test which can only be run with user interaction")] + // This test should be run with an MSA account to validate that the refresh for MSA accounts works properly + public async Task AuthenticateWithMSAWithSubsequentSilentRefresh() + { + var cred = new InteractiveBrowserCredential(); + + // this should pop browser + var authRecord = await cred.AuthenticateAsync(); + + Assert.NotNull(authRecord); + + // this should not pop browser + AccessToken token = await cred.GetTokenAsync(new TokenRequestContext(new string[] { "https://vault.azure.net/.default" })).ConfigureAwait(false); + + Assert.NotNull(token.Token); + } } } diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockMsalPublicClient.cs b/sdk/identity/Azure.Identity/tests/Mock/MockMsalPublicClient.cs index a3976b3516eb..0d70c907b4e1 100644 --- a/sdk/identity/Azure.Identity/tests/Mock/MockMsalPublicClient.cs +++ b/sdk/identity/Azure.Identity/tests/Mock/MockMsalPublicClient.cs @@ -73,6 +73,18 @@ public override ValueTask AcquireTokenSilentAsync(string[] throw new NotImplementedException(); } + public override ValueTask AcquireTokenSilentAsync(string[] scopes, AuthenticationRecord record, bool async, CancellationToken cancellationToken) + { + Func factory = SilentAuthFactory ?? AuthFactory; + + if (factory != null) + { + return new ValueTask(factory(scopes)); + } + + throw new NotImplementedException(); + } + public override ValueTask AcquireTokenWithDeviceCodeAsync(string[] scopes, Func deviceCodeCallback, bool async, CancellationToken cancellationToken) { Func factory = DeviceCodeAuthFactory ?? AuthFactory;