From c4e5c4a139fbb557d5bb691cfe6bf2839d6e1631 Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Mon, 23 Sep 2024 16:15:39 -0700 Subject: [PATCH 01/11] WAM POC --- .../ref/Microsoft.Data.SqlClient.csproj | 1 + .../netfx/ref/Microsoft.Data.SqlClient.csproj | 1 + .../netfx/src/Microsoft.Data.SqlClient.csproj | 4 +- .../ActiveDirectoryAuthenticationProvider.cs | 107 +++++++++++++++--- 4 files changed, 98 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj index 41a471847e..0f186408b6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj @@ -28,5 +28,6 @@ + diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj index c3e68d9a8a..0ee39b5404 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj @@ -26,5 +26,6 @@ + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 730a747261..9c08dddee4 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -766,7 +766,9 @@ - + + + $(SystemTextEncodingsWebVersion) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 276efb441e..75de8815e8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -4,14 +4,19 @@ using System; using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Windows.Forms; using Azure.Core; using Azure.Identity; +using Microsoft.Data.Common; using Microsoft.Extensions.Caching.Memory; using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Broker; using Microsoft.Identity.Client.Extensibility; namespace Microsoft.Data.SqlClient @@ -107,10 +112,19 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) } #if NETFRAMEWORK - private Func _iWin32WindowFunc = null; + /// + /// + /// + public delegate Task InvokeDelegate(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, + SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback); + + private System.Windows.Forms.Control _control = null; - /// - public void SetIWin32WindowFunc(Func iWin32WindowFunc) => this._iWin32WindowFunc = iWin32WindowFunc; + /// + /// + /// + /// + public void SetIWin32WindowFunc(System.Windows.Forms.Control control) => this._control = control; #endif /// @@ -207,7 +221,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti #endif PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId #if NETFRAMEWORK - , _iWin32WindowFunc + , () => (IWin32Window) _control #endif ); @@ -284,14 +298,37 @@ previousPw is byte[] previousPwBytes && // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); + +#if NETFRAMEWORK + Func> func = async () => + { + return await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback + ).ConfigureAwait(false); + }; + + result = await (Task)_control.Invoke(func); +#else + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback + ).ConfigureAwait(false); +#endif SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } if (result == null) { // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); +#if NETFRAMEWORK + Func> func = async () => + { + return await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback + ).ConfigureAwait(false); + }; + + result = await (Task)_control.Invoke(func); +#else + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback + ).ConfigureAwait(false); +#endif SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } } @@ -365,6 +402,7 @@ private static async Task AcquireTokenInteractiveDeviceFlo */ ctsInteractive.CancelAfter(180000); #endif + if (customWebUI != null) { return await app.AcquireTokenInteractive(scopes) @@ -393,12 +431,12 @@ private static async Task AcquireTokenInteractiveDeviceFlo * * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/MSAL.NET-uses-web-browser#at-a-glance */ - return await app.AcquireTokenInteractive(scopes) - .WithCorrelationId(connectionId) - .WithLoginHint(userId) - .ExecuteAsync(ctsInteractive.Token) - .ConfigureAwait(false); - } + return await app.AcquireTokenInteractive(scopes) + .WithCorrelationId(connectionId) + .WithLoginHint(userId) + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); + } } else { @@ -545,14 +583,15 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ IPublicClientApplication publicClientApplication; #if NETFRAMEWORK - if (_iWin32WindowFunc != null) + if (_control != null) { publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) .WithAuthority(publicClientAppKey._authority) .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) .WithRedirectUri(publicClientAppKey._redirectUri) - .WithParentActivityOrWindow(_iWin32WindowFunc) + .WithParentActivityOrWindow(() => _control) + .WithBroker(new BrokerOptions(BrokerOptions.OperatingSystems.Windows)) .Build(); } else @@ -569,6 +608,46 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ return publicClientApplication; } + + // This is your window handle! + static IntPtr GetParent() + { + Process currentProcess = Process.GetCurrentProcess(); + return currentProcess.MainWindowHandle; + } + + enum GetAncestorFlags + { + GetParent = 1, + GetRoot = 2, + /// + /// Retrieves the owned root window by walking the chain of parent and owner windows returned by GetParent. + /// + GetRootOwner = 3 + } + + /// + /// Retrieves the handle to the ancestor of the specified window. + /// + /// A handle to the window whose ancestor is to be retrieved. + /// If this parameter is the desktop window, the function returns NULL. + /// The ancestor to be retrieved. + /// The return value is the handle to the ancestor window. + [DllImport("user32.dll", ExactSpelling = true)] + static extern IntPtr GetAncestor(IntPtr hwnd, GetAncestorFlags flags); + + [DllImport("kernel32.dll")] + static extern IntPtr GetConsoleWindow(); + + // This is your window handle! + IntPtr GetConsoleOrTerminalWindow() + { + IntPtr consoleHandle = GetConsoleWindow(); + IntPtr handle = GetAncestor(consoleHandle, GetAncestorFlags.GetRootOwner); + + return handle; + } + private static TokenCredentialData CreateTokenCredentialInstance(TokenCredentialKey tokenCredentialKey, string secret) { if (tokenCredentialKey._tokenCredentialType == typeof(DefaultAzureCredential)) From 6b0a7b9a4d6b0efcc5166633f2ff19fa40b892c2 Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Wed, 25 Sep 2024 13:55:58 -0700 Subject: [PATCH 02/11] Use synchronization context. Keep broker dependency for now to test progress. --- .../src/Microsoft.Data.SqlClient.csproj | 3 +- .../ActiveDirectoryAuthenticationProvider.cs | 174 +++++++++++------- 2 files changed, 108 insertions(+), 69 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index b9f6de4113..03a57430d3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -962,7 +962,8 @@ - + + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 75de8815e8..5eab0cecb2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -10,7 +10,6 @@ using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Windows.Forms; using Azure.Core; using Azure.Identity; using Microsoft.Data.Common; @@ -112,20 +111,31 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) } #if NETFRAMEWORK + + + /// + public void SetIWin32WindowFunc(Func iWin32WindowFunc) => SetParentActivityOrWindow(iWin32WindowFunc); +#endif + + private Func _parentActivityOrWindowFunc = null; + /// /// /// - public delegate Task InvokeDelegate(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback); + /// + public void SetParentActivityOrWindow(Func parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc; - private System.Windows.Forms.Control _control = null; + private delegate Task InvokeDelegate(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, + SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback); + + private SynchronizationContext _synchronizationContext = null; /// /// /// - /// - public void SetIWin32WindowFunc(System.Windows.Forms.Control control) => this._control = control; -#endif + /// + public void SetSynchronizationContext(SynchronizationContext synchronizationContext) => this._synchronizationContext = synchronizationContext; + /// @@ -219,11 +229,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti redirectUri = "http://localhost"; } #endif - PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId -#if NETFRAMEWORK - , () => (IWin32Window) _control -#endif - ); + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, _parentActivityOrWindowFunc); AuthenticationResult result = null; IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); @@ -299,36 +305,55 @@ previousPw is byte[] previousPwBytes && // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), // or the user needs to perform two factor authentication. -#if NETFRAMEWORK - Func> func = async () => + using (SemaphoreSlim waitHandle = new SemaphoreSlim(0)) { - return await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback - ).ConfigureAwait(false); - }; + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _waitHandle = waitHandle + }; + + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + await waitHandle.WaitAsync().ConfigureAwait(false); + result = state.result; + } - result = await (Task)_control.Invoke(func); -#else - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback - ).ConfigureAwait(false); -#endif SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } if (result == null) { // If no existing 'account' is found, we request user to sign in interactively. -#if NETFRAMEWORK - Func> func = async () => + using (SemaphoreSlim waitHandle = new SemaphoreSlim(0)) { - return await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback - ).ConfigureAwait(false); - }; + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _waitHandle = waitHandle + }; + + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + await waitHandle.WaitAsync().ConfigureAwait(false); + result = state.result; + } + - result = await (Task)_control.Invoke(func); -#else - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback - ).ConfigureAwait(false); -#endif SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } } @@ -382,12 +407,28 @@ private static async Task TryAcquireTokenSilent(IPublicCli return result; } - private static async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback) + private struct InteractiveAuthStateObject + { + internal IPublicClientApplication app; + internal string[] scopes; + internal Guid connectionId; + internal string userId; + internal SqlAuthenticationMethod authenticationMethod; + internal CancellationTokenSource cts; + internal ICustomWebUi customWebUI; + internal Func deviceCodeFlowCallback; + internal AuthenticationResult result; + internal SemaphoreSlim _waitHandle; + } + + + private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) { + InteractiveAuthStateObject interactiveAuthStateObject = (InteractiveAuthStateObject)state; + try { - if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) + if (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) { CancellationTokenSource ctsInteractive = new(); #if NET6_0_OR_GREATER @@ -403,14 +444,16 @@ private static async Task AcquireTokenInteractiveDeviceFlo ctsInteractive.CancelAfter(180000); #endif - if (customWebUI != null) + if (interactiveAuthStateObject.customWebUI != null) { - return await app.AcquireTokenInteractive(scopes) - .WithCorrelationId(connectionId) - .WithCustomWebUi(customWebUI) - .WithLoginHint(userId) + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithCustomWebUi(interactiveAuthStateObject.customWebUI) + .WithLoginHint(interactiveAuthStateObject.userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); + interactiveAuthStateObject.result = result; + return; } else { @@ -431,30 +474,37 @@ private static async Task AcquireTokenInteractiveDeviceFlo * * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/MSAL.NET-uses-web-browser#at-a-glance */ - return await app.AcquireTokenInteractive(scopes) - .WithCorrelationId(connectionId) - .WithLoginHint(userId) + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithLoginHint(interactiveAuthStateObject.userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); + interactiveAuthStateObject.result = result; + return; } } else { - AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, - deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)) - .WithCorrelationId(connectionId) - .ExecuteAsync(cancellationToken: cts.Token) + AuthenticationResult result = await interactiveAuthStateObject.app.AcquireTokenWithDeviceCode(interactiveAuthStateObject.scopes, + deviceCodeResult => interactiveAuthStateObject.deviceCodeFlowCallback(deviceCodeResult)) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .ExecuteAsync(cancellationToken: interactiveAuthStateObject.cts.Token) .ConfigureAwait(false); - return result; + interactiveAuthStateObject.result = result; + return; } } catch (OperationCanceledException) { SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation timed out while acquiring access token."); - throw (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? + throw (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? SQL.ActiveDirectoryInteractiveTimeout() : SQL.ActiveDirectoryDeviceFlowTimeout(); } + finally + { + interactiveAuthStateObject._waitHandle.Release(); + } } private static Task DefaultDeviceFlowCallback(DeviceCodeResult result) @@ -582,20 +632,18 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ { IPublicClientApplication publicClientApplication; -#if NETFRAMEWORK - if (_control != null) + if (_parentActivityOrWindowFunc != null) { publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) .WithAuthority(publicClientAppKey._authority) .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) .WithRedirectUri(publicClientAppKey._redirectUri) - .WithParentActivityOrWindow(() => _control) + .WithParentActivityOrWindow(_parentActivityOrWindowFunc) .WithBroker(new BrokerOptions(BrokerOptions.OperatingSystems.Windows)) .Build(); } else -#endif { publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) .WithAuthority(publicClientAppKey._authority) @@ -704,22 +752,16 @@ internal class PublicClientAppKey public readonly string _authority; public readonly string _redirectUri; public readonly string _applicationClientId; -#if NETFRAMEWORK - public readonly Func _iWin32WindowFunc; -#endif + public readonly Func _parentActivityOrWindowFunc; public PublicClientAppKey(string authority, string redirectUri, string applicationClientId -#if NETFRAMEWORK - , Func iWin32WindowFunc -#endif + , Func parentActivityOrWindowFunc ) { _authority = authority; _redirectUri = redirectUri; _applicationClientId = applicationClientId; -#if NETFRAMEWORK - _iWin32WindowFunc = iWin32WindowFunc; -#endif + _parentActivityOrWindowFunc = parentActivityOrWindowFunc; } public override bool Equals(object obj) @@ -729,18 +771,14 @@ public override bool Equals(object obj) return (string.CompareOrdinal(_authority, pcaKey._authority) == 0 && string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0 && string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0 -#if NETFRAMEWORK - && pcaKey._iWin32WindowFunc == _iWin32WindowFunc -#endif + && pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc ); } return false; } public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId -#if NETFRAMEWORK - , _iWin32WindowFunc -#endif + , _parentActivityOrWindowFunc ).GetHashCode(); } From f2022aeb72a853deef0b511e338494028d63217c Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Wed, 25 Sep 2024 14:34:16 -0700 Subject: [PATCH 03/11] Lower broker version. Use class for reference behavior. --- .../netcore/src/Microsoft.Data.SqlClient.csproj | 2 +- .../Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 03a57430d3..c5dbf29cab 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -963,7 +963,7 @@ - + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 5eab0cecb2..cb565ce4fb 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -407,7 +407,7 @@ private static async Task TryAcquireTokenSilent(IPublicCli return result; } - private struct InteractiveAuthStateObject + private class InteractiveAuthStateObject { internal IPublicClientApplication app; internal string[] scopes; @@ -419,6 +419,7 @@ private struct InteractiveAuthStateObject internal Func deviceCodeFlowCallback; internal AuthenticationResult result; internal SemaphoreSlim _waitHandle; + } From 14767bce367ce483b8ee00b600202bd4a52b390a Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Thu, 26 Sep 2024 13:46:52 -0700 Subject: [PATCH 04/11] switch to task completion source --- .../ActiveDirectoryAuthenticationProvider.cs | 97 ++++++++----------- 1 file changed, 41 insertions(+), 56 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index cb565ce4fb..cdae07fc64 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -305,26 +305,21 @@ previousPw is byte[] previousPwBytes && // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), // or the user needs to perform two factor authentication. - using (SemaphoreSlim waitHandle = new SemaphoreSlim(0)) + InteractiveAuthStateObject state = new InteractiveAuthStateObject() { - InteractiveAuthStateObject state = new InteractiveAuthStateObject() - { - app = app, - scopes = scopes, - connectionId = parameters.ConnectionId, - userId = parameters.UserId, - authenticationMethod = parameters.AuthenticationMethod, - cts = cts, - customWebUI = _customWebUI, - deviceCodeFlowCallback = _deviceCodeFlowCallback, - _waitHandle = waitHandle - }; - - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); - - await waitHandle.WaitAsync().ConfigureAwait(false); - result = state.result; - } + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + result = await state._taskCompletionSource.Task; SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } @@ -332,27 +327,21 @@ previousPw is byte[] previousPwBytes && if (result == null) { // If no existing 'account' is found, we request user to sign in interactively. - using (SemaphoreSlim waitHandle = new SemaphoreSlim(0)) + InteractiveAuthStateObject state = new InteractiveAuthStateObject() { - InteractiveAuthStateObject state = new InteractiveAuthStateObject() - { - app = app, - scopes = scopes, - connectionId = parameters.ConnectionId, - userId = parameters.UserId, - authenticationMethod = parameters.AuthenticationMethod, - cts = cts, - customWebUI = _customWebUI, - deviceCodeFlowCallback = _deviceCodeFlowCallback, - _waitHandle = waitHandle - }; - - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); - - await waitHandle.WaitAsync().ConfigureAwait(false); - result = state.result; - } + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + result = await state._taskCompletionSource.Task; SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } @@ -417,9 +406,7 @@ private class InteractiveAuthStateObject internal CancellationTokenSource cts; internal ICustomWebUi customWebUI; internal Func deviceCodeFlowCallback; - internal AuthenticationResult result; - internal SemaphoreSlim _waitHandle; - + internal TaskCompletionSource _taskCompletionSource; } @@ -453,7 +440,7 @@ private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) .WithLoginHint(interactiveAuthStateObject.userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); - interactiveAuthStateObject.result = result; + interactiveAuthStateObject._taskCompletionSource.SetResult(result); return; } else @@ -475,14 +462,14 @@ private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) * * https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/MSAL.NET-uses-web-browser#at-a-glance */ - var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) - .WithCorrelationId(interactiveAuthStateObject.connectionId) - .WithLoginHint(interactiveAuthStateObject.userId) - .ExecuteAsync(ctsInteractive.Token) - .ConfigureAwait(false); - interactiveAuthStateObject.result = result; + var result = await interactiveAuthStateObject.app.AcquireTokenInteractive(interactiveAuthStateObject.scopes) + .WithCorrelationId(interactiveAuthStateObject.connectionId) + .WithLoginHint(interactiveAuthStateObject.userId) + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); + interactiveAuthStateObject._taskCompletionSource.SetResult(result); return; - } + } } else { @@ -491,20 +478,19 @@ private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) .WithCorrelationId(interactiveAuthStateObject.connectionId) .ExecuteAsync(cancellationToken: interactiveAuthStateObject.cts.Token) .ConfigureAwait(false); - interactiveAuthStateObject.result = result; + interactiveAuthStateObject._taskCompletionSource.SetResult(result); return; } } catch (OperationCanceledException) { SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation timed out while acquiring access token."); - throw (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? + + var error = (interactiveAuthStateObject.authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) ? SQL.ActiveDirectoryInteractiveTimeout() : SQL.ActiveDirectoryDeviceFlowTimeout(); - } - finally - { - interactiveAuthStateObject._waitHandle.Release(); + + interactiveAuthStateObject._taskCompletionSource.SetException(error); } } @@ -828,6 +814,5 @@ public override bool Equals(object obj) public override int GetHashCode() => Tuple.Create(_tokenCredentialType, _authority, _scope, _audience, _clientId).GetHashCode(); } - } } From 37d77eea28797b1f524f503cc2ab6052004c2745 Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Thu, 26 Sep 2024 14:05:56 -0700 Subject: [PATCH 05/11] Route all exceptions back to task completion source to avoid crashing the ui thread. --- .../Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index cdae07fc64..09009bbd14 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -492,6 +492,11 @@ private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) interactiveAuthStateObject._taskCompletionSource.SetException(error); } + catch (Exception e) + { + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenInteractiveDeviceFlowAsync | Operation failed while acquiring access token."); + interactiveAuthStateObject._taskCompletionSource.SetException(e); + } } private static Task DefaultDeviceFlowCallback(DeviceCodeResult result) From 86d23d4c102a3eeaa383db4b611c04950125ecb8 Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Fri, 27 Sep 2024 12:57:55 -0700 Subject: [PATCH 06/11] Conditionally include broker dependency --- .../ref/Microsoft.Data.SqlClient.csproj | 4 ++- .../src/Microsoft.Data.SqlClient.csproj | 10 +++++- .../netfx/ref/Microsoft.Data.SqlClient.csproj | 4 ++- .../netfx/src/Microsoft.Data.SqlClient.csproj | 10 ++++-- .../ActiveDirectoryAuthenticationProvider.cs | 33 ++++--------------- tools/props/Versions.props | 2 ++ 6 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj index 0f186408b6..9e6cdaeb90 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.csproj @@ -28,6 +28,8 @@ - + + $(MicrosoftIdentityClientBrokerVersion) + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index c5dbf29cab..94463ccaf2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -20,6 +20,9 @@ $(NoWarn);IL2026;IL2057;IL2072;IL2075 + + $(DefineConstants);INTERACTIVE_AUTH; + true @@ -32,6 +35,9 @@ $([System.IO.Path]::Combine('$(IntermediateOutputPath)','$(TargetFramework)','$(TargetFrameworkMoniker).AssemblyAttributes$(DefaultLanguageSourceExtension)')) + + $(DefineConstants);INTERACTIVE_AUTH; + @@ -963,11 +969,13 @@ - + + $(MicrosoftIdentityClientBrokerVersion) + diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj index 0ee39b5404..2decbb7d8c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.csproj @@ -26,6 +26,8 @@ - + + $(MicrosoftIdentityClientBrokerVersion) + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 9c08dddee4..46feaf1a2d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -77,7 +77,7 @@ - $(DefineConstants);DEBUG;DBG;_DEBUG;_LOGGING;RESOURCE_ANNOTATION_WORK; + $(DefineConstants);DEBUG;DBG;_DEBUG;_LOGGING;RESOURCE_ANNOTATION_WORK;INTERACTIVE_AUTH; Full False @@ -89,6 +89,9 @@ $(DefineConstants);NETFRAMEWORK; + + $(DefineConstants);INTERACTIVE_AUTH; + @@ -767,8 +770,6 @@ - - $(SystemTextEncodingsWebVersion) @@ -780,6 +781,9 @@ $(AzureIdentityVersion) + + $(MicrosoftIdentityClientBrokerVersion) + $(MicrosoftIdentityModelProtocolsOpenIdConnectVersion) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 09009bbd14..712c8fc737 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -12,10 +12,13 @@ using System.Threading.Tasks; using Azure.Core; using Azure.Identity; -using Microsoft.Data.Common; using Microsoft.Extensions.Caching.Memory; using Microsoft.Identity.Client; + +#if INTERACTIVE_AUTH using Microsoft.Identity.Client.Broker; +#endif + using Microsoft.Identity.Client.Extensibility; namespace Microsoft.Data.SqlClient @@ -234,31 +237,8 @@ public override async Task AcquireTokenAsync(SqlAuthenti AuthenticationResult result = null; IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) - { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - if (result == null) - { - if (!string.IsNullOrEmpty(parameters.UserId)) - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .WithUsername(parameters.UserId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - } - else - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) { string pwCacheKey = GetAccountPwCacheKey(parameters); object previousPw = s_accountPwCache.Get(pwCacheKey); @@ -292,7 +272,8 @@ previousPw is byte[] previousPwBytes && } } else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { try { diff --git a/tools/props/Versions.props b/tools/props/Versions.props index 379cd30bd6..88ff9e00ce 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -30,6 +30,8 @@ 1.11.4 + 4.61.3 + 4.61.3 6.0.1 7.5.0 7.5.0 From b5ceede46ae924237b9c9b7a0bc7d489ddce7695 Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Fri, 27 Sep 2024 13:25:33 -0700 Subject: [PATCH 07/11] Refactor to switch statement. --- .../ActiveDirectoryAuthenticationProvider.cs | 271 +++++++++--------- 1 file changed, 140 insertions(+), 131 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 712c8fc737..7e07af5507 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -177,46 +177,6 @@ public override async Task AcquireTokenAsync(SqlAuthenti string audience = parameters.Authority.Substring(separatorIndex + 1); string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault) - { - // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) }; - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) - { - // Cache ManagedIdentityCredential based on scope, authority, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) - { - // Cache ClientSecretCredential based on scope, authority, audience, and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryWorkloadIdentity) - { - // Cache WorkloadIdentityCredential based on authority and clientId - TokenCredentialKey tokenCredentialKey = new(typeof(WorkloadIdentityCredential), authority, string.Empty, string.Empty, clientId); - // If either tenant id, client id, or the token file path are not specified when fetching the token, - // a CredentialUnavailableException will be thrown instead - AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Workload Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); - return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); - } - /* * Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows * (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend @@ -232,108 +192,157 @@ public override async Task AcquireTokenAsync(SqlAuthenti redirectUri = "http://localhost"; } #endif - PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, _parentActivityOrWindowFunc); - AuthenticationResult result = null; - IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); + switch (parameters.AuthenticationMethod) + { + case SqlAuthenticationMethod.ActiveDirectoryDefault: + { + // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + case SqlAuthenticationMethod.ActiveDirectoryManagedIdentity: + case SqlAuthenticationMethod.ActiveDirectoryMSI: + { + // Cache ManagedIdentityCredential based on scope, authority, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } + case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal: + { + // Cache ClientSecretCredential based on scope, authority, audience, and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId); + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) - { - string pwCacheKey = GetAccountPwCacheKey(parameters); - object previousPw = s_accountPwCache.Get(pwCacheKey); - byte[] currPwHash = GetHash(parameters.Password); - - if (previousPw != null && - previousPw is byte[] previousPwBytes && - // Only get the cached token if the current password hash matches the previously used password hash - AreEqual(currPwHash, previousPwBytes)) - { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - } + case SqlAuthenticationMethod.ActiveDirectoryWorkloadIdentity: + { + // Cache WorkloadIdentityCredential based on authority and clientId + TokenCredentialKey tokenCredentialKey = new(typeof(WorkloadIdentityCredential), authority, string.Empty, string.Empty, clientId); + // If either tenant id, client id, or the token file path are not specified when fetching the token, + // a CredentialUnavailableException will be thrown instead + AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Workload Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); + return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); + } - if (result == null) - { - result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - - // We cache the password hash to ensure future connection requests include a validated password - // when we check for a cached MSAL account. Otherwise, a connection request with the same username - // against the same tenant could succeed with an invalid password when we re-use the cached token. - using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey)) + //Public client auth methods + case SqlAuthenticationMethod.ActiveDirectoryPassword: { - entry.Value = GetHash(parameters.Password); - entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours); - }; + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, _parentActivityOrWindowFunc); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) - { - try - { - result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } - catch (MsalUiRequiredException) - { - // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, - // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), - // or the user needs to perform two factor authentication. + AuthenticationResult result = null; + + string pwCacheKey = GetAccountPwCacheKey(parameters); + object previousPw = s_accountPwCache.Get(pwCacheKey); + byte[] currPwHash = GetHash(parameters.Password); + + if (previousPw != null && + previousPw is byte[] previousPwBytes && + // Only get the cached token if the current password hash matches the previously used password hash + AreEqual(currPwHash, previousPwBytes)) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + } - InteractiveAuthStateObject state = new InteractiveAuthStateObject() + if (result == null) + { + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + + // We cache the password hash to ensure future connection requests include a validated password + // when we check for a cached MSAL account. Otherwise, a connection request with the same username + // against the same tenant could succeed with an invalid password when we re-use the cached token. + using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey)) + { + entry.Value = GetHash(parameters.Password); + entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours); + }; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } + case SqlAuthenticationMethod.ActiveDirectoryInteractive: + case SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow: + case SqlAuthenticationMethod.ActiveDirectoryIntegrated: { - app = app, - scopes = scopes, - connectionId = parameters.ConnectionId, - userId = parameters.UserId, - authenticationMethod = parameters.AuthenticationMethod, - cts = cts, - customWebUI = _customWebUI, - deviceCodeFlowCallback = _deviceCodeFlowCallback, - _taskCompletionSource = new TaskCompletionSource() - }; - - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); - result = await state._taskCompletionSource.Task; - - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, _parentActivityOrWindowFunc); + IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); - if (result == null) - { - // If no existing 'account' is found, we request user to sign in interactively. - InteractiveAuthStateObject state = new InteractiveAuthStateObject() + AuthenticationResult result = null; + + try + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + result = await state._taskCompletionSource.Task; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + if (result == null) + { + // If no existing 'account' is found, we request user to sign in interactively. + InteractiveAuthStateObject state = new InteractiveAuthStateObject() + { + app = app, + scopes = scopes, + connectionId = parameters.ConnectionId, + userId = parameters.UserId, + authenticationMethod = parameters.AuthenticationMethod, + cts = cts, + customWebUI = _customWebUI, + deviceCodeFlowCallback = _deviceCodeFlowCallback, + _taskCompletionSource = new TaskCompletionSource() + }; + + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + result = await state._taskCompletionSource.Task; + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } + default: { - app = app, - scopes = scopes, - connectionId = parameters.ConnectionId, - userId = parameters.UserId, - authenticationMethod = parameters.AuthenticationMethod, - cts = cts, - customWebUI = _customWebUI, - deviceCodeFlowCallback = _deviceCodeFlowCallback, - _taskCompletionSource = new TaskCompletionSource() - }; - - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); - result = await state._taskCompletionSource.Task; - - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } - } - else - { - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | {0} authentication mode not supported by ActiveDirectoryAuthenticationProvider class.", parameters.AuthenticationMethod); - throw SQL.UnsupportedAuthenticationSpecified(parameters.AuthenticationMethod); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | {0} authentication mode not supported by ActiveDirectoryAuthenticationProvider class.", parameters.AuthenticationMethod); + throw SQL.UnsupportedAuthenticationSpecified(parameters.AuthenticationMethod); + } } - - return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); } private static async Task TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters, From 4fbdd68d1a01f3f63021c3176df5cd1a6fabf368 Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Fri, 27 Sep 2024 15:30:54 -0700 Subject: [PATCH 08/11] Remove interactive auth option unless enabled at compile time. --- .../ActiveDirectoryAuthenticationProvider.cs | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 7e07af5507..f3d44e89f7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -255,6 +255,7 @@ previousPw is byte[] previousPwBytes && if (result == null) { + //TODO: need to use broker here too? result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) .WithCorrelationId(parameters.ConnectionId) .ExecuteAsync(cancellationToken: cts.Token) @@ -274,6 +275,7 @@ previousPw is byte[] previousPwBytes && return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); } +#if INTERACTIVE_AUTH case SqlAuthenticationMethod.ActiveDirectoryInteractive: case SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow: case SqlAuthenticationMethod.ActiveDirectoryIntegrated: @@ -337,6 +339,7 @@ previousPw is byte[] previousPwBytes && return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); } +#endif default: { SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | {0} authentication mode not supported by ActiveDirectoryAuthenticationProvider class.", parameters.AuthenticationMethod); @@ -386,6 +389,7 @@ private static async Task TryAcquireTokenSilent(IPublicCli return result; } +#if INTERACTIVE_AUTH private class InteractiveAuthStateObject { internal IPublicClientApplication app; @@ -399,7 +403,6 @@ private class InteractiveAuthStateObject internal TaskCompletionSource _taskCompletionSource; } - private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) { InteractiveAuthStateObject interactiveAuthStateObject = (InteractiveAuthStateObject)state; @@ -488,6 +491,7 @@ private static async void AcquireTokenInteractiveDeviceFlowAsync(object state) interactiveAuthStateObject._taskCompletionSource.SetException(e); } } +#endif private static Task DefaultDeviceFlowCallback(DeviceCodeResult result) { @@ -612,30 +616,16 @@ private static bool AreEqual(byte[] a1, byte[] a2) private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) { - IPublicClientApplication publicClientApplication; - - if (_parentActivityOrWindowFunc != null) - { - publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) - .WithAuthority(publicClientAppKey._authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(publicClientAppKey._redirectUri) - .WithParentActivityOrWindow(_parentActivityOrWindowFunc) - .WithBroker(new BrokerOptions(BrokerOptions.OperatingSystems.Windows)) - .Build(); - } - else - { - publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) - .WithAuthority(publicClientAppKey._authority) - .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) - .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) - .WithRedirectUri(publicClientAppKey._redirectUri) - .Build(); - } - - return publicClientApplication; + return PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId) + .WithAuthority(publicClientAppKey._authority) + .WithClientName(Common.DbConnectionStringDefaults.ApplicationName) + .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) + .WithRedirectUri(publicClientAppKey._redirectUri) +#if INTERACTIVE_AUTH + .WithParentActivityOrWindow(_parentActivityOrWindowFunc) + .WithBroker(new BrokerOptions(BrokerOptions.OperatingSystems.Windows)) +#endif + .Build(); } From c32327d04c13ae9425630515321a8bd6f2764a4f Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Fri, 18 Oct 2024 13:08:15 -0700 Subject: [PATCH 09/11] Add default synchronization context behavior. --- .../SqlClient/ActiveDirectoryAuthenticationProvider.cs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index f3d44e89f7..699e56e99c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -331,7 +331,15 @@ previousPw is byte[] previousPwBytes && _taskCompletionSource = new TaskCompletionSource() }; - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + if (_synchronizationContext == null) + { + var tempSC = new SynchronizationContext(); + tempSC.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + } + else + { + _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + } result = await state._taskCompletionSource.Task; SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); From 0ded77a9eb25acd1e7a378e6f1acc994c56ce40b Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Fri, 18 Oct 2024 15:22:49 -0700 Subject: [PATCH 10/11] Move window handle behavior to windows specific partial class. --- .../src/Microsoft.Data.SqlClient.csproj | 16 ++- .../netfx/src/Microsoft.Data.SqlClient.csproj | 11 ++ .../Kernel32/Interop.GetConsoleWindow.cs | 11 ++ .../Windows/User32/Interop.GetAncestor.cs | 30 +++++ ...iveDirectoryAuthenticationProvider.Unix.cs | 12 ++ ...DirectoryAuthenticationProvider.Windows.cs | 41 ++++++ .../ActiveDirectoryAuthenticationProvider.cs | 126 ++++++------------ tools/props/Versions.props | 2 +- tools/props/VersionsNet8OrLater.props | 2 +- 9 files changed, 162 insertions(+), 89 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 05e8d8bd04..744e8cbad5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -131,8 +131,8 @@ Microsoft\Data\SqlClient\AAsyncCallContext.cs - - Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.cs + + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProviderBase.cs Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs @@ -706,6 +706,9 @@ Common\Interop\Windows\Kernel32\Interop.FileTypes.cs + + Common\Interop\Windows\Kernel32\Interop.GetConsoleWindow.cs + Common\Interop\Windows\Kernel32\Interop.IoControlCodeAccess.cs @@ -727,6 +730,9 @@ Common\Interop\Windows\NtDll\Interop.RtlNtStatusToDosError.cs + + Common\Interop\Windows\NtDll\Interop.GetAncestor.cs + Microsoft\Data\Common\AdapterUtil.Windows.cs @@ -738,6 +744,9 @@ Microsoft\Data\Sql\SqlDataSourceEnumerator.Windows.cs + + + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.Windows.cs Microsoft\Data\SqlClient\SqlColumnEncryptionCngProvider.Windows.cs @@ -808,6 +817,9 @@ Microsoft\Data\ProviderBase\DbConnectionPoolIdentity.Unix.cs + + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProviderBase.Unix.cs + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 98d2c1ae31..35872e2f18 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -223,6 +223,9 @@ Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.cs + + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.Windows.cs + Microsoft\Data\SqlClient\AlwaysEncryptedEnclaveProviderUtils.cs @@ -689,6 +692,14 @@ Microsoft\Data\SqlDbTypeExtensions.cs + + + Common\Interop\Windows\Kernel32\Interop.GetConsoleWindow.cs + + + Common\Interop\Windows\NtDll\Interop.GetAncestor.cs + + diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs new file mode 100644 index 0000000000..54208ee3b5 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Kernel32/Interop.GetConsoleWindow.cs @@ -0,0 +1,11 @@ +using System; +using System.Runtime.InteropServices; + +internal partial class Interop +{ + internal partial class Kernel32 + { + [DllImport("kernel32.dll")] + internal static extern IntPtr GetConsoleWindow(); + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs new file mode 100644 index 0000000000..868b18aeaf --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/User32/Interop.GetAncestor.cs @@ -0,0 +1,30 @@ +using System; +using System.Runtime.InteropServices; + + +internal partial class Interop +{ + internal partial class User32 + { + internal enum GetAncestorFlags + { + GetParent = 1, + GetRoot = 2, + /// + /// Retrieves the owned root window by walking the chain of parent and owner windows returned by GetParent. + /// + GetRootOwner = 3 + } + + /// + /// Retrieves the handle to the ancestor of the specified window. + /// + /// A handle to the window whose ancestor is to be retrieved. + /// If this parameter is the desktop window, the function returns NULL. + /// The ancestor to be retrieved. + /// The return value is the handle to the ancestor window. + [DllImport("user32.dll", ExactSpelling = true)] + internal static extern IntPtr GetAncestor(IntPtr hwnd, GetAncestorFlags flags); + } +} + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs new file mode 100644 index 0000000000..53eb66e6e0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Unix.cs @@ -0,0 +1,12 @@ +using System; +using System.Threading; + +namespace Microsoft.Data.SqlClient +{ + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + private Func _parentActivityOrWindowFunc = null; + + private Func ParentActivityOrWindow => _parentActivityOrWindowFunc; + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs new file mode 100644 index 0000000000..c16edfc6b4 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.Windows.cs @@ -0,0 +1,41 @@ +using System; + +namespace Microsoft.Data.SqlClient +{ + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + { + private Func _parentActivityOrWindowFunc = null; + + private Func ParentActivityOrWindow + { + get + { + return _parentActivityOrWindowFunc != null ? _parentActivityOrWindowFunc : GetConsoleOrTerminalWindow; + } + + set + { + _parentActivityOrWindowFunc = value; + } + } + +#if NETFRAMEWORK + /// + public void SetIWin32WindowFunc(Func iWin32WindowFunc) => SetParentActivityOrWindow(iWin32WindowFunc); +#endif + + /// + /// + /// + /// + public void SetParentActivityOrWindow(Func parentActivityOrWindowFunc) => this.ParentActivityOrWindow = parentActivityOrWindowFunc; + + private object GetConsoleOrTerminalWindow() + { + IntPtr consoleHandle = Interop.Kernel32.GetConsoleWindow(); + IntPtr handle = Interop.User32.GetAncestor(consoleHandle, Interop.User32.GetAncestorFlags.GetRootOwner); + + return handle; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 699e56e99c..b5f0f4c97f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -4,8 +4,6 @@ using System; using System.Collections.Concurrent; -using System.Diagnostics; -using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; using System.Threading; @@ -14,17 +12,15 @@ using Azure.Identity; using Microsoft.Extensions.Caching.Memory; using Microsoft.Identity.Client; - #if INTERACTIVE_AUTH using Microsoft.Identity.Client.Broker; #endif - using Microsoft.Identity.Client.Extensibility; namespace Microsoft.Data.SqlClient { /// - public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider + public sealed partial class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider { /// /// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey. @@ -44,6 +40,31 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro private Func _deviceCodeFlowCallback; private ICustomWebUi _customWebUI = null; private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId; + private SynchronizationContext _synchronizationContext = null; + + private SynchronizationContext SynchronizationContext + { + get + { + if (_synchronizationContext != null) + { + return _synchronizationContext; + } + else if (SynchronizationContext.Current != null) + { + return SynchronizationContext.Current; + } + else + { + return new SynchronizationContext(); + } + } + + set + { + _synchronizationContext = value; + } + } /// public ActiveDirectoryAuthenticationProvider() @@ -87,6 +108,12 @@ public static void ClearUserTokenCache() /// public void SetAcquireAuthorizationCodeAsyncCallback(Func> acquireAuthorizationCodeAsyncCallback) => _customWebUI = new CustomWebUi(acquireAuthorizationCodeAsyncCallback); + /// + /// TODO + /// + /// + public void SetSynchronizationContext(SynchronizationContext synchronizationContext) => this.SynchronizationContext = synchronizationContext; + /// public override bool IsSupported(SqlAuthenticationMethod authentication) { @@ -113,33 +140,6 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) _logger.LogInfo(_type, "BeforeUnload", $"being unloaded from SqlAuthProviders for {authentication}."); } -#if NETFRAMEWORK - - - /// - public void SetIWin32WindowFunc(Func iWin32WindowFunc) => SetParentActivityOrWindow(iWin32WindowFunc); -#endif - - private Func _parentActivityOrWindowFunc = null; - - /// - /// - /// - /// - public void SetParentActivityOrWindow(Func parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc; - - private delegate Task InvokeDelegate(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback); - - private SynchronizationContext _synchronizationContext = null; - - /// - /// - /// - /// - public void SetSynchronizationContext(SynchronizationContext synchronizationContext) => this._synchronizationContext = synchronizationContext; - - /// public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters) @@ -236,7 +236,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti //Public client auth methods case SqlAuthenticationMethod.ActiveDirectoryPassword: { - PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, _parentActivityOrWindowFunc); + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); AuthenticationResult result = null; @@ -280,7 +280,7 @@ previousPw is byte[] previousPwBytes && case SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow: case SqlAuthenticationMethod.ActiveDirectoryIntegrated: { - PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, _parentActivityOrWindowFunc); + PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId, ParentActivityOrWindow); IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false); AuthenticationResult result = null; @@ -309,7 +309,9 @@ previousPw is byte[] previousPwBytes && _taskCompletionSource = new TaskCompletionSource() }; - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + + SynchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + result = await state._taskCompletionSource.Task; SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); @@ -331,15 +333,9 @@ previousPw is byte[] previousPwBytes && _taskCompletionSource = new TaskCompletionSource() }; - if (_synchronizationContext == null) - { - var tempSC = new SynchronizationContext(); - tempSC.Post(AcquireTokenInteractiveDeviceFlowAsync, state); - } - else - { - _synchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); - } + + SynchronizationContext.Post(AcquireTokenInteractiveDeviceFlowAsync, state); + result = await state._taskCompletionSource.Task; SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); @@ -630,52 +626,12 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ .WithClientVersion(Common.ADP.GetAssemblyVersion().ToString()) .WithRedirectUri(publicClientAppKey._redirectUri) #if INTERACTIVE_AUTH - .WithParentActivityOrWindow(_parentActivityOrWindowFunc) + .WithParentActivityOrWindow(ParentActivityOrWindow) .WithBroker(new BrokerOptions(BrokerOptions.OperatingSystems.Windows)) #endif .Build(); } - - // This is your window handle! - static IntPtr GetParent() - { - Process currentProcess = Process.GetCurrentProcess(); - return currentProcess.MainWindowHandle; - } - - enum GetAncestorFlags - { - GetParent = 1, - GetRoot = 2, - /// - /// Retrieves the owned root window by walking the chain of parent and owner windows returned by GetParent. - /// - GetRootOwner = 3 - } - - /// - /// Retrieves the handle to the ancestor of the specified window. - /// - /// A handle to the window whose ancestor is to be retrieved. - /// If this parameter is the desktop window, the function returns NULL. - /// The ancestor to be retrieved. - /// The return value is the handle to the ancestor window. - [DllImport("user32.dll", ExactSpelling = true)] - static extern IntPtr GetAncestor(IntPtr hwnd, GetAncestorFlags flags); - - [DllImport("kernel32.dll")] - static extern IntPtr GetConsoleWindow(); - - // This is your window handle! - IntPtr GetConsoleOrTerminalWindow() - { - IntPtr consoleHandle = GetConsoleWindow(); - IntPtr handle = GetAncestor(consoleHandle, GetAncestorFlags.GetRootOwner); - - return handle; - } - private static TokenCredentialData CreateTokenCredentialInstance(TokenCredentialKey tokenCredentialKey, string secret) { if (tokenCredentialKey._tokenCredentialType == typeof(DefaultAzureCredential)) diff --git a/tools/props/Versions.props b/tools/props/Versions.props index 88ff9e00ce..07cf0f9588 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -32,7 +32,7 @@ 1.11.4 4.61.3 4.61.3 - 6.0.1 + 6.0.2 7.5.0 7.5.0 diff --git a/tools/props/VersionsNet8OrLater.props b/tools/props/VersionsNet8OrLater.props index e67ec9ef6e..5ea4521bdc 100644 --- a/tools/props/VersionsNet8OrLater.props +++ b/tools/props/VersionsNet8OrLater.props @@ -3,7 +3,7 @@ 8.0.0 - 8.0.0 + 8.0.1 9.0.0-beta.24157.1 2.6.3 From 3fbce21d218ff2834451591e63097aa2165211bd Mon Sep 17 00:00:00 2001 From: Malcolm Daigle Date: Fri, 18 Oct 2024 15:30:49 -0700 Subject: [PATCH 11/11] Fix compile include --- .../netcore/src/Microsoft.Data.SqlClient.csproj | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 744e8cbad5..12a858c50b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -131,8 +131,8 @@ Microsoft\Data\SqlClient\AAsyncCallContext.cs - - Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProviderBase.cs + + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.cs Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs @@ -818,7 +818,7 @@ Microsoft\Data\ProviderBase\DbConnectionPoolIdentity.Unix.cs - Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProviderBase.Unix.cs + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.Unix.cs