diff --git a/Directory.Packages.props b/Directory.Packages.props index c339554838..de52bd55a7 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -10,6 +10,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj index e6ffc2526b..86218c9805 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj @@ -37,6 +37,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index bb14e233a3..c859dda2f1 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -1084,6 +1084,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj index 1d9c1985bd..8b568b387c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj @@ -40,6 +40,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index e5630bc9c7..2bf0db0e8f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -73,7 +73,7 @@ - $(DefineConstants);DEBUG;DBG;_DEBUG;_LOGGING;RESOURCE_ANNOTATION_WORK; + $(DefineConstants);DEBUG;DBG;_DEBUG;_LOGGING;RESOURCE_ANNOTATION_WORK;INTERACTIVE_AUTH; Full False @@ -82,6 +82,11 @@ Pdbonly True + + + $(DefineConstants);INTERACTIVE_AUTH; + + @@ -1070,6 +1075,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs new file mode 100644 index 0000000000..1941560072 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +internal partial class Interop +{ + internal partial class Kernel32 + { + [DllImport("kernel32.dll")] + internal static extern IntPtr GetConsoleWindow(); + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs new file mode 100644 index 0000000000..1b8804e6d6 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + + +internal partial class Interop +{ + internal partial class User32 + { + internal enum GetAncestorFlags + { + GetParent = 1, + GetRoot = 2, + /// + /// Retrieves the owned root window by walking the chain of parent and owner windows returned by GetParent. + /// + GetRootOwner = 3 + } + + /// + /// Retrieves the handle to the ancestor of the specified window. + /// + /// A handle to the window whose ancestor is to be retrieved. + /// If this parameter is the desktop window, the function returns NULL. + /// The ancestor to be retrieved. + /// The return value is the handle to the ancestor window. + [DllImport("user32.dll", ExactSpelling = true)] + internal static extern IntPtr GetAncestor(IntPtr hwnd, GetAncestorFlags flags); + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs new file mode 100644 index 0000000000..4d5cf6c3fa --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; + +namespace Microsoft.Data.SqlClient +{ + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + private Func _parentActivityOrWindowFunc = null; + + private Func ParentActivityOrWindow => _parentActivityOrWindowFunc; + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs new file mode 100644 index 0000000000..3073124b6a --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Data.SqlClient +{ + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + private Func _parentActivityOrWindowFunc = null; + + private Func ParentActivityOrWindow + { + get + { + return _parentActivityOrWindowFunc != null ? _parentActivityOrWindowFunc : GetConsoleOrTerminalWindow; + } + + set + { + _parentActivityOrWindowFunc = value; + } + } + +#if NETFRAMEWORK + /// + public void SetIWin32WindowFunc(Func iWin32WindowFunc) => SetParentActivityOrWindow(iWin32WindowFunc); +#endif + + /// + /// + /// + /// + public void SetParentActivityOrWindow(Func parentActivityOrWindowFunc) => this.ParentActivityOrWindow = parentActivityOrWindowFunc; + + private object GetConsoleOrTerminalWindow() + { + IntPtr consoleHandle = Interop.Kernel32.GetConsoleWindow(); + IntPtr handle = Interop.User32.GetAncestor(consoleHandle, Interop.User32.GetAncestorFlags.GetRootOwner); + + return handle; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 0393de9beb..35926c5b67 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -13,12 +13,15 @@ using Microsoft.Data.Common.ConnectionString; using Microsoft.Extensions.Caching.Memory; using Microsoft.Identity.Client; +#if INTERACTIVE_AUTH +using Microsoft.Identity.Client.Broker; +#endif using Microsoft.Identity.Client.Extensibility; namespace Microsoft.Data.SqlClient { /// - public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider { /// /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey. @@ -38,6 +41,31 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro private Func _deviceCodeFlowCallback; private ICustomWebUi _customWebUI = null; private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId; + private SynchronizationContext _synchronizationContext = null; + + private SynchronizationContext SynchronizationContext + { + get + { + if (_synchronizationContext != null) + { + return _synchronizationContext; + } + else if (SynchronizationContext.Current != null) + { + return SynchronizationContext.Current; + } + else + { + return new SynchronizationContext(); + } + } + + set + { + _synchronizationContext = value; + } + } /// public ActiveDirectoryAuthenticationProvider() @@ -56,7 +84,13 @@ public ActiveDirectoryAuthenticationProvider(Func device { if (applicationClientId != null) { - _applicationClientId = applicationClientId; + summary> + /// TODO + /// + /// + public void SetSynchronizationContext(SynchronizationContext synchronizationContext) => this.SynchronizationContext = synchronizationContext; + + /// < _applicationClientId = applicationClientId; } SetDeviceCodeFlowCallback(deviceCodeFlowCallbackMethod); } @@ -77,13 +111,6 @@ public static void ClearUserTokenCache() /// public void SetDeviceCodeFlowCallback(Func deviceCodeFlowCallbackMethod) => _deviceCodeFlowCallback = deviceCodeFlowCallbackMethod; - - /// - public void SetAcquireAuthorizationCodeAsyncCallback(Func> acquireAuthorizationCodeAsyncCallback) => _customWebUI = new CustomWebUi(acquireAuthorizationCodeAsyncCallback); - - /// - public override bool IsSupported(SqlAuthenticationMethod authentication) - { return authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated #pragma warning disable 0618 // Type or member is obsolete || authentication == SqlAuthenticationMethod.ActiveDirectoryPassword @@ -153,46 +180,6 @@ public override async Task AcquireTokenAsync(SqlAuthenti string audience = parameters.Authority.Substring(separatorIndex + 1); string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault) - { - // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) }; - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) - { - // Cache ManagedIdentityCredential based on scope, authority, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) - { - // Cache ClientSecretCredential based on scope, authority, audience, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryWorkloadIdentity) - { - // Cache WorkloadIdentityCredential based on authority and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(WorkloadIdentityCredential), authority, string.Empty, string.Empty, clientId); - // If either tenant id, client id, or the token file path are not specified when fetching the token, - // a CredentialUnavailableException will be thrown instead - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Workload Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - /* * Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows * (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend @@ -208,112 +195,229 @@ public override async Task AcquireTokenAsync(SqlAuthenti redirectUri = "http://localhost"; } #endif - PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId -#if NETFRAMEWORK - , _iWin32WindowFunc -#endif - ); - - AuthenticationResult result = null; - IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) + switch (parameters.AuthenticationMethod) { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + case SqlAuthenticationMethod.ActiveDirectoryDefault: + { + // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } - if (result == null) - { - if (!string.IsNullOrEmpty(parameters.UserId)) + case SqlAuthenticationMethod.ActiveDirectoryManagedIdentity: + case SqlAuthenticationMethod.ActiveDirectoryMSI: { - // The AcquireTokenByIntegratedWindowsAuth method is marked as obsolete in MSAL.NET - // but it is still a supported way to acquire tokens for Active Directory Integrated authentication. -#pragma warning disable CS0618 // Type or member is obsolete - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) -#pragma warning restore CS0618 // Type or member is obsolete - .WithCorrelationId(parameters.ConnectionId) - .WithUsername(parameters.UserId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); + // Cache ManagedIdentityCredential based on scope, authority, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } - else + case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal: { -#pragma warning disable CS0618 // Type or member is obsolete - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) -#pragma warning restore CS0618 // Type or member is obsolete - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); + // Cache ClientSecretCredential based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - } - #pragma warning disable 0618 // Type or member is obsolete - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) - #pragma warning restore 0618 // Type or member is obsolete - { - string pwCacheKey = GetAccountPwCacheKey(parameters); - object previousPw = s_accountPwCache.Get(pwCacheKey); - byte[] currPwHash = GetHash(parameters.Password); - - if (previousPw != null && - previousPw is byte[] previousPwBytes && - // Only get the cached token if the current password hash matches the previously used password hash - AreEqual(currPwHash, previousPwBytes)) - { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - } - if (result == null) - { + case SqlAuthenticationMethod.ActiveDirectoryWorkloadIdentity: + { + // Cache WorkloadIdentityCredential based on authority and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(WorkloadIdentityCredential), authority, string.Empty, string.Empty, clientId); + // If either tenant id, client id, or the token file path are not specified when fetching the token, + // a CredentialUnavailableException will be thrown instead + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Workload Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + + //Public client auth methods + case SqlAuthenticationMethod.ActiveDirectoryPassword: + { + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); + + AuthenticationResult result = null; + + string pwCacheKey = GetAccountPwCacheKey(parameters); + object previousPw = s_accountPwCache.Get(pwCacheKey); + byte[] currPwHash = GetHash(parameters.Password); + + if (previousPw != null && + previousPw is byte[] previousPwBytes && + // Only get the cached token if the current password hash matches the previously used password hash + AreEqual(currPwHash, previousPwBytes)) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + } + + if (result == null) + { + //TODO: need to use broker here too? #pragma warning disable CS0618 // Type or member is obsolete - result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) #pragma warning restore CS0618 // Type or member is obsolete - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - - // We cache the password hash to ensure future connection requests include a validated password - // when we check for a cached MSAL account. Otherwise, a connection request with the same username - // against the same tenant could succeed with an invalid password when we re-use the cached token. - using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey)) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + + // We cache the password hash to ensure future connection requests include a validated password + // when we check for a cached MSAL account. Otherwise, a connection request with the same username + // against the same tenant could succeed with an invalid password when we re-use the cached token. + using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey)) + { + entry.Value = GetHash(parameters.Password); + entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours); + }; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } +#if INTERACTIVE_AUTH + case SqlAuthenticationMethod.ActiveDirectoryInteractive: + case SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow: + case SqlAuthenticationMethod.ActiveDirectoryIntegrated: { - entry.Value = GetHash(parameters.Password); - entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours); + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); + + AuthenticationResult result = null; + + try + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + + SynchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + result = await state._taskCompletionSource.Task; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + if (result == null) + { + // If no existing 'account' is found, we request user to sign in interactively. + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + + SynchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + result = await state._taskCompletionSource.Task; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) - { - try - { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } - catch (MsalUiRequiredException) - { - // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, - // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), - // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } +#endif + default: + { + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); + AuthenticationResult result = null; - if (result == null) - { - // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } - } - else - { - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | {0} authentication mode not supported by ActiveDirectoryAuthenticationProvider class.", parameters.AuthenticationMethod); - throw SQL.UnsupportedAuthenticationSpecified(parameters.AuthenticationMethod); - } + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + + if (result == null) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + // The AcquireTokenByIntegratedWindowsAuth method is marked as obsolete in MSAL.NET + // but it is still a supported way to acquire tokens for Active Directory Integrated authentication. +#pragma warning disable CS0618 // Type or member is obsolete + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) +#pragma warning restore CS0618 // Type or member is obsolete + .WithCorrelationId(parameters.ConnectionId) + .WithUsername(parameters.UserId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + } + else + { +#pragma warning disable CS0618 // Type or member is obsolete + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) +#pragma warning restore CS0618 // Type or member is obsolete + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + } + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + { + try + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + if (result == null) + { + // If no existing 'account' is found, we request user to sign in interactively. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + } + else + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | {0} authentication mode not supported by ActiveDirectoryAuthenticationProvider class.", parameters.AuthenticationMethod); + throw SQL.UnsupportedAuthenticationSpecified(parameters.AuthenticationMethod); + } - return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } + } } private static async Task TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters, @@ -357,6 +461,110 @@ private static async Task TryAcquireTokenSilent(IPublicCli return result; } +#if INTERACTIVE_AUTH + private class InteractiveAuthStateObject + { + internal IPublicClientApplication app; + internal string[] scopes; + internal Guid connectionId; + internal string userId; + internal SqlAuthenticationMethod authenticationMethod; + internal CancellationTokenSource cts; + internal ICustomWebUi customWebUI; + internal Func deviceCodeFlowCallback; + internal TaskCompletionSource _taskCompletionSource; + } + + private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) + { + InteractiveAuthStateObject interactiveAuthStateObject = (InteractiveAuthStateObject)state; + + try + { + if (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) + { + CancellationTokenSource ctsInteractive = new(); +#if NET + /* + * On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser, + * but once the user finishes authentication, the web page is redirected in such a way that MSAL can intercept the Uri. + * MSAL cannot detect if the user navigates away or simply closes the browser. Apps using this technique are encouraged + * to define a timeout (via CancellationToken). We recommend a timeout of at least a few minutes, to take into account + * cases where the user is prompted to change password or perform 2FA. + * + * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/System-Browser-on-.Net-Core#system-browser-experience + */ + ctsInteractive.CancelAfter(180000); +#endif + + if (interactiveAuthStateObject.customWebUI != null) + { + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithCustomWebUi(interactiveAuthStateObject.customWebUI) + .WithLoginHint(interactiveAuthStateObject.userId) + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); + interactiveAuthStateObject._taskCompletionSource.SetResult(result); + return; + } + else + { + /* + * We will use the MSAL Embedded or System web browser which changes by Default in MSAL according to this table: + * + * Framework Embedded System Default + * ------------------------------------------- + * .NET Classic Yes Yes^ Embedded + * .NET Core No Yes^ System + * .NET Standard No No NONE + * UWP Yes No Embedded + * Xamarin.Android Yes Yes System + * Xamarin.iOS Yes Yes System + * Xamarin.Mac Yes No Embedded + * + * ^ Requires "http://localhost" redirect URI + * + * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/MSAL.NET-uses-web-browser#at-a-glance + */ + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithLoginHint(interactiveAuthStateObject.userId) + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); + interactiveAuthStateObject._taskCompletionSource.SetResult(result); + return; + } + } + else + { + AuthenticationResult result = await interactiveAuthStateObject.app.AcquireTokenWithDeviceCode(interactiveAuthStateObject.scopes, + deviceCodeResult => interactiveAuthStateObject.deviceCodeFlowCallback(deviceCodeResult)) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .ExecuteAsync(cancellationToken: interactiveAuthStateObject.cts.Token) + .ConfigureAwait(false); + interactiveAuthStateObject._taskCompletionSource.SetResult(result); + return; + } + } + catch (OperationCanceledException) + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation timed out while acquiring access token."); + + var error = (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? + SQL.ActiveDirectoryInteractiveTimeout() : + SQL.ActiveDirectoryDeviceFlowTimeout(); + + interactiveAuthStateObject._taskCompletionSource.SetException(error); + } + catch (Exception e) + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation failed while acquiring access token."); + interactiveAuthStateObject._taskCompletionSource.SetException(e); + } + } +#endif + private static async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback) { @@ -561,15 +769,16 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ ClientName = DbConnectionStringDefaults.ApplicationName, ClientVersion = Common.ADP.GetAssemblyVersion().ToString(), RedirectUri = publicClientAppKey._redirectUri, +#if INTERACTIVE_AUTH + BrokerOptions = new BrokerOptions(BrokerOptions.OperatingSystems.Windows) +#endif }) .WithAuthority(publicClientAppKey._authority); - #if NETFRAMEWORK - if (_iWin32WindowFunc is not null) + if (publicClientAppKey._parentActivityOrWindowFunc != null) { - builder.WithParentActivityOrWindow(_iWin32WindowFunc); + builder.WithParentActivityOrWindow(publicClientAppKey._parentActivityOrWindowFunc); } - #endif return builder.Build(); } @@ -652,23 +861,16 @@ internal class PublicClientAppKey { public readonly string _authority; public readonly string _redirectUri; - public readonly string _applicationClientId; -#if NETFRAMEWORK - public readonly Func _iWin32WindowFunc; -#endif + public readonly Func _parentActivityOrWindowFunc; public PublicClientAppKey(string authority, string redirectUri, string applicationClientId -#if NETFRAMEWORK - , Func iWin32WindowFunc -#endif + , Func parentActivityOrWindowFunc ) { _authority = authority; _redirectUri = redirectUri; _applicationClientId = applicationClientId; -#if NETFRAMEWORK - _iWin32WindowFunc = iWin32WindowFunc; -#endif + _parentActivityOrWindowFunc = parentActivityOrWindowFunc; } public override bool Equals(object obj) @@ -678,17 +880,14 @@ public override bool Equals(object obj) return (string.CompareOrdinal(_authority, pcaKey._authority) == 0 && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0 && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0 -#if NETFRAMEWORK - && pcaKey._iWin32WindowFunc == _iWin32WindowFunc -#endif + && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc ); } return false; } public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId -#if NETFRAMEWORK - , _iWin32WindowFunc + , _parentActivityOrWindowFunc , _iWin32WindowFunc #endif ).GetHashCode(); }