diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj index e6ffc2526b..d9a4222b36 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj @@ -32,6 +32,9 @@ + + $(MicrosoftIdentityClientBrokerVersion) + diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj index 1d9c1985bd..bca9ec4fee 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj @@ -40,6 +40,9 @@ runtime; build; native; contentfiles; analyzers; buildtransitive + + $(MicrosoftIdentityClientBrokerVersion) + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index e5630bc9c7..bd5c6b266f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -73,7 +73,7 @@ - $(DefineConstants);DEBUG;DBG;_DEBUG;_LOGGING;RESOURCE_ANNOTATION_WORK; + $(DefineConstants);DEBUG;DBG;_DEBUG;_LOGGING;RESOURCE_ANNOTATION_WORK;INTERACTIVE_AUTH; Full False @@ -82,6 +82,12 @@ Pdbonly True + + $(DefineConstants);NETFRAMEWORK; + + + $(DefineConstants);INTERACTIVE_AUTH; + @@ -1069,6 +1075,21 @@ All runtime; build; native; contentfiles; analyzers; buildtransitive + + $(AzureIdentityVersion) + + + $(MicrosoftIdentityClientBrokerVersion) + + + $(MicrosoftIdentityModelProtocolsOpenIdConnectVersion) + + + $(MicrosoftIdentityModelJsonWebTokensVersion) + + + $(SystemBuffersVersion) + diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs new file mode 100644 index 0000000000..54208ee3b5 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs @@ -0,0 +1,11 @@ +using System; +using System.Runtime.InteropServices; + +internal partial class Interop +{ + internal partial class Kernel32 + { + [DllImport("kernel32.dll")] + internal static extern IntPtr GetConsoleWindow(); + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs new file mode 100644 index 0000000000..868b18aeaf --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs @@ -0,0 +1,30 @@ +using System; +using System.Runtime.InteropServices; + + +internal partial class Interop +{ + internal partial class User32 + { + internal enum GetAncestorFlags + { + GetParent = 1, + GetRoot = 2, + /// + /// Retrieves the owned root window by walking the chain of parent and owner windows returned by GetParent. + /// + GetRootOwner = 3 + } + + /// + /// Retrieves the handle to the ancestor of the specified window. + /// + /// A handle to the window whose ancestor is to be retrieved. + /// If this parameter is the desktop window, the function returns NULL. + /// The ancestor to be retrieved. + /// The return value is the handle to the ancestor window. + [DllImport("user32.dll", ExactSpelling = true)] + internal static extern IntPtr GetAncestor(IntPtr hwnd, GetAncestorFlags flags); + } +} + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs new file mode 100644 index 0000000000..53eb66e6e0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs @@ -0,0 +1,12 @@ +using System; +using System.Threading; + +namespace Microsoft.Data.SqlClient +{ + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + private Func _parentActivityOrWindowFunc = null; + + private Func ParentActivityOrWindow => _parentActivityOrWindowFunc; + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs new file mode 100644 index 0000000000..c16edfc6b4 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs @@ -0,0 +1,41 @@ +using System; + +namespace Microsoft.Data.SqlClient +{ + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + private Func _parentActivityOrWindowFunc = null; + + private Func ParentActivityOrWindow + { + get + { + return _parentActivityOrWindowFunc != null ? _parentActivityOrWindowFunc : GetConsoleOrTerminalWindow; + } + + set + { + _parentActivityOrWindowFunc = value; + } + } + +#if NETFRAMEWORK + /// + public void SetIWin32WindowFunc(Func iWin32WindowFunc) => SetParentActivityOrWindow(iWin32WindowFunc); +#endif + + /// + /// + /// + /// + public void SetParentActivityOrWindow(Func parentActivityOrWindowFunc) => this.ParentActivityOrWindow = parentActivityOrWindowFunc; + + private object GetConsoleOrTerminalWindow() + { + IntPtr consoleHandle = Interop.Kernel32.GetConsoleWindow(); + IntPtr handle = Interop.User32.GetAncestor(consoleHandle, Interop.User32.GetAncestorFlags.GetRootOwner); + + return handle; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 0393de9beb..e6ac63cf77 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -13,12 +13,15 @@ using Microsoft.Data.Common.ConnectionString; using Microsoft.Extensions.Caching.Memory; using Microsoft.Identity.Client; +#if INTERACTIVE_AUTH +using Microsoft.Identity.Client.Broker; +#endif using Microsoft.Identity.Client.Extensibility; namespace Microsoft.Data.SqlClient { /// - public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider { /// /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey. @@ -38,6 +41,31 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro private Func _deviceCodeFlowCallback; private ICustomWebUi _customWebUI = null; private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId; + private SynchronizationContext _synchronizationContext = null; + + private SynchronizationContext SynchronizationContext + { + get + { + if (_synchronizationContext != null) + { + return _synchronizationContext; + } + else if (SynchronizationContext.Current != null) + { + return SynchronizationContext.Current; + } + else + { + return new SynchronizationContext(); + } + } + + set + { + _synchronizationContext = value; + } + } /// public ActiveDirectoryAuthenticationProvider() @@ -81,6 +109,12 @@ public static void ClearUserTokenCache() /// public void SetAcquireAuthorizationCodeAsyncCallback(Func> acquireAuthorizationCodeAsyncCallback) => _customWebUI = new CustomWebUi(acquireAuthorizationCodeAsyncCallback); + /// + /// TODO + /// + /// + public void SetSynchronizationContext(SynchronizationContext synchronizationContext) => this.SynchronizationContext = synchronizationContext; + /// public override bool IsSupported(SqlAuthenticationMethod authentication) { @@ -109,13 +143,6 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) _logger.LogInfo(_type, "BeforeUnload", $"being unloaded from SqlAuthProviders for {authentication}."); } -#if NETFRAMEWORK - private Func _iWin32WindowFunc = null; - - /// - public void SetIWin32WindowFunc(Func iWin32WindowFunc) => this._iWin32WindowFunc = iWin32WindowFunc; -#endif - /// public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters) @@ -153,46 +180,6 @@ public override async Task AcquireTokenAsync(SqlAuthenti string audience = parameters.Authority.Substring(separatorIndex + 1); string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault) - { - // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) }; - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) - { - // Cache ManagedIdentityCredential based on scope, authority, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) - { - // Cache ClientSecretCredential based on scope, authority, audience, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryWorkloadIdentity) - { - // Cache WorkloadIdentityCredential based on authority and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(WorkloadIdentityCredential), authority, string.Empty, string.Empty, clientId); - // If either tenant id, client id, or the token file path are not specified when fetching the token, - // a CredentialUnavailableException will be thrown instead - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Workload Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - /* * Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows * (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend @@ -208,22 +195,159 @@ public override async Task AcquireTokenAsync(SqlAuthenti redirectUri = "http://localhost"; } #endif - PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId -#if NETFRAMEWORK - , _iWin32WindowFunc -#endif - ); - - AuthenticationResult result = null; - IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) + switch (parameters.AuthenticationMethod) { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + case SqlAuthenticationMethod.ActiveDirectoryDefault: + { + // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } - if (result == null) - { - if (!string.IsNullOrEmpty(parameters.UserId)) + case SqlAuthenticationMethod.ActiveDirectoryManagedIdentity: + case SqlAuthenticationMethod.ActiveDirectoryMSI: + { + // Cache ManagedIdentityCredential based on scope, authority, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal: + { + // Cache ClientSecretCredential based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + + case SqlAuthenticationMethod.ActiveDirectoryWorkloadIdentity: + { + // Cache WorkloadIdentityCredential based on authority and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(WorkloadIdentityCredential), authority, string.Empty, string.Empty, clientId); + // If either tenant id, client id, or the token file path are not specified when fetching the token, + // a CredentialUnavailableException will be thrown instead + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Workload Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + + //Public client auth methods + case SqlAuthenticationMethod.ActiveDirectoryPassword: + { + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); + + AuthenticationResult result = null; + + string pwCacheKey = GetAccountPwCacheKey(parameters); + object previousPw = s_accountPwCache.Get(pwCacheKey); + byte[] currPwHash = GetHash(parameters.Password); + + if (previousPw != null && + previousPw is byte[] previousPwBytes && + // Only get the cached token if the current password hash matches the previously used password hash + AreEqual(currPwHash, previousPwBytes)) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + } + + if (result == null) + { + //TODO: need to use broker here too? + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + + // We cache the password hash to ensure future connection requests include a validated password + // when we check for a cached MSAL account. Otherwise, a connection request with the same username + // against the same tenant could succeed with an invalid password when we re-use the cached token. + using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey)) + { + entry.Value = GetHash(parameters.Password); + entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours); + }; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } +#if INTERACTIVE_AUTH + case SqlAuthenticationMethod.ActiveDirectoryInteractive: + case SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow: + case SqlAuthenticationMethod.ActiveDirectoryIntegrated: + { + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); + + AuthenticationResult result = null; + + try + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + + SynchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + result = await state._taskCompletionSource.Task; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + if (result == null) + { + // If no existing 'account' is found, we request user to sign in interactively. + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + + SynchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + result = await state._taskCompletionSource.Task; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } +#endif + default: { // The AcquireTokenByIntegratedWindowsAuth method is marked as obsolete in MSAL.NET // but it is still a supported way to acquire tokens for Active Directory Integrated authentication. @@ -357,12 +481,27 @@ private static async Task TryAcquireTokenSilent(IPublicCli return result; } - private static async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback) +#if INTERACTIVE_AUTH + private class InteractiveAuthStateObject { + internal IPublicClientApplication app; + internal string[] scopes; + internal Guid connectionId; + internal string userId; + internal SqlAuthenticationMethod authenticationMethod; + internal CancellationTokenSource cts; + internal ICustomWebUi customWebUI; + internal Func deviceCodeFlowCallback; + internal TaskCompletionSource _taskCompletionSource; + } + + private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) + { + InteractiveAuthStateObject interactiveAuthStateObject = (InteractiveAuthStateObject)state; + try { - if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) + if (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) { CancellationTokenSource ctsInteractive = new(); #if NET @@ -377,14 +516,17 @@ private static async Task AcquireTokenInteractiveDeviceFlo */ ctsInteractive.CancelAfter(180000); #endif - if (customWebUI != null) + + if (interactiveAuthStateObject.customWebUI != null) { - return await app.AcquireTokenInteractive(scopes) - .WithCorrelationId(connectionId) - .WithCustomWebUi(customWebUI) - .WithLoginHint(userId) + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithCustomWebUi(interactiveAuthStateObject.customWebUI) + .WithLoginHint(interactiveAuthStateObject.userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); + interactiveAuthStateObject._taskCompletionSource.SetResult(result); + return; } else { @@ -405,31 +547,43 @@ private static async Task AcquireTokenInteractiveDeviceFlo * * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/MSAL.NET-uses-web-browser#at-a-glance */ - return await app.AcquireTokenInteractive(scopes) - .WithCorrelationId(connectionId) - .WithLoginHint(userId) + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithLoginHint(interactiveAuthStateObject.userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); + interactiveAuthStateObject._taskCompletionSource.SetResult(result); + return; } } else { - AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, - deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)) - .WithCorrelationId(connectionId) - .ExecuteAsync(cancellationToken: cts.Token) + AuthenticationResult result = await interactiveAuthStateObject.app.AcquireTokenWithDeviceCode(interactiveAuthStateObject.scopes, + deviceCodeResult => interactiveAuthStateObject.deviceCodeFlowCallback(deviceCodeResult)) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .ExecuteAsync(cancellationToken: interactiveAuthStateObject.cts.Token) .ConfigureAwait(false); - return result; + interactiveAuthStateObject._taskCompletionSource.SetResult(result); + return; } } catch (OperationCanceledException) { SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation timed out while acquiring access token."); - throw (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? + + var error = (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? SQL.ActiveDirectoryInteractiveTimeout() : SQL.ActiveDirectoryDeviceFlowTimeout(); + + interactiveAuthStateObject._taskCompletionSource.SetException(error); + } + catch (Exception e) + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation failed while acquiring access token."); + interactiveAuthStateObject._taskCompletionSource.SetException(e); } } +#endif private static Task DefaultDeviceFlowCallback(DeviceCodeResult result) { @@ -561,6 +715,10 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ ClientName = DbConnectionStringDefaults.ApplicationName, ClientVersion = Common.ADP.GetAssemblyVersion().ToString(), RedirectUri = publicClientAppKey._redirectUri, +#if INTERACTIVE_AUTH + ParentActivityOrWindow = ParentActivityOrWindow + BrokerOptions = new BrokerOptions(BrokerOptions.OperatingSystems.Windows) +#endif }) .WithAuthority(publicClientAppKey._authority); @@ -653,22 +811,16 @@ internal class PublicClientAppKey public readonly string _authority; public readonly string _redirectUri; public readonly string _applicationClientId; -#if NETFRAMEWORK - public readonly Func _iWin32WindowFunc; -#endif + public readonly Func _parentActivityOrWindowFunc; public PublicClientAppKey(string authority, string redirectUri, string applicationClientId -#if NETFRAMEWORK - , Func iWin32WindowFunc -#endif + , Func parentActivityOrWindowFunc ) { _authority = authority; _redirectUri = redirectUri; _applicationClientId = applicationClientId; -#if NETFRAMEWORK - _iWin32WindowFunc = iWin32WindowFunc; -#endif + _parentActivityOrWindowFunc = parentActivityOrWindowFunc; } public override bool Equals(object obj) @@ -678,18 +830,14 @@ public override bool Equals(object obj) return (string.CompareOrdinal(_authority, pcaKey._authority) == 0 && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0 && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0 -#if NETFRAMEWORK - && pcaKey._iWin32WindowFunc == _iWin32WindowFunc -#endif + && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc ); } return false; } public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId -#if NETFRAMEWORK - , _iWin32WindowFunc -#endif + , _parentActivityOrWindowFunc ).GetHashCode(); } @@ -738,6 +886,5 @@ public override bool Equals(object obj) public override int GetHashCode() => Tuple.Create(_tokenCredentialType, _authority, _scope, _audience, _clientId).GetHashCode(); } - } }