From d5621d5d0d4e82662b7c82d63acd39d7efb3c34f Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Tue, 15 Jul 2025 09:44:19 -0700 Subject: [PATCH 1/2] Expose SSPI context provider as public --- .../SqlConnection.xml | 13 +++++ .../SspiAuthenticationParameters.xml | 28 ++++++++++ .../SspiContextProvider.xml | 20 +++++++ src/Microsoft.Data.SqlClient.sln | 2 + .../netcore/ref/Microsoft.Data.SqlClient.cs | 33 ++++++++++++ .../Microsoft/Data/SqlClient/SqlConnection.cs | 26 ++++++--- .../SqlClient/SqlInternalConnectionTds.cs | 6 ++- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 2 +- .../netfx/ref/Microsoft.Data.SqlClient.cs | 33 ++++++++++++ .../Microsoft/Data/SqlClient/SqlConnection.cs | 27 +++++++--- .../SqlClient/SqlInternalConnectionTds.cs | 6 ++- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 2 +- .../ConnectionPool/SqlConnectionPoolKey.cs | 20 ++++++- .../SSPI/NativeSspiContextProvider.cs | 4 +- .../SSPI/NegotiateSspiContextProvider.cs | 2 +- .../SSPI/SspiAuthenticationParameters.cs | 9 +++- .../SqlClient/SSPI/SspiContextProvider.cs | 17 ++++-- .../Data/SqlClient/SqlConnectionFactory.cs | 5 +- .../IntegratedAuthenticationTest.cs | 53 +++++++++++++++++++ 19 files changed, 275 insertions(+), 33 deletions(-) create mode 100644 doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml create mode 100644 doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml index 11602bd078..161326f6d1 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml @@ -2133,6 +2133,19 @@ The following sample tries to open a connection to an invalid database to simula Returns 0 if the connection is inactive on the client side. + + + Gets or sets the instance for customizing the SSPI context. If not set, the default for the platform will be used. + + + An instance. + + + + The SspiContextProvider is a part of the connection pool key. Care should be taken when using this property to ensure the implementation returns a stable identity per resource. + + + Indicates the state of the during the most recent network operation performed on the connection. diff --git a/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml b/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml new file mode 100644 index 0000000000..4623d86052 --- /dev/null +++ b/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml @@ -0,0 +1,28 @@ + + + + + Provides parameters used during SSPI authentication. + + + Creates an instance of the SspiAuthenticationParameters. + The name of the server. + The resource (often the server service principal name). + + + Gets the resource (often the server service principal name). + + + Gets the server name. + + + Gets or sets the user id if available. + + + Gets or sets the database name if available. + + + Gets or sets the password if available. + + + diff --git a/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml b/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml new file mode 100644 index 0000000000..2ea58cb80b --- /dev/null +++ b/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml @@ -0,0 +1,20 @@ + + + + + Provides the ability to customize SSPI context generation. + + + Creates an instance of the SSPIContextProvider. + + + Generates an SSPI outgoing blob given the incoming blob. + Incoming blob + Outgoing blob + Gets the authentication parameters associated with this connection. + + true if the context was generated, otherwise false. + + + + diff --git a/src/Microsoft.Data.SqlClient.sln b/src/Microsoft.Data.SqlClient.sln index e4d29d999c..bf0cbe8cc2 100644 --- a/src/Microsoft.Data.SqlClient.sln +++ b/src/Microsoft.Data.SqlClient.sln @@ -149,6 +149,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlClient", ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventArgs.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventArgs.xml ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventHandler.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventHandler.xml ..\doc\snippets\Microsoft.Data.SqlClient\SqlTransaction.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlTransaction.xml + ..\doc\snippets\Microsoft.Data.SqlClient\SspiAuthenticationParameters.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SspiAuthenticationParameters.xml + ..\doc\snippets\Microsoft.Data.SqlClient\SspiContextProvider.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SspiContextProvider.xml EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlClient.DataClassification", "Microsoft.Data.SqlClient.DataClassification", "{5D1F0032-7B0D-4FB6-A969-FCFB25C9EA1D}" diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index 3e26aa251a..0409ac83ad 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -929,6 +929,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect [System.ComponentModel.BrowsableAttribute(false)] [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } } + /// + public SspiContextProvider SspiContextProvider { get { throw null; } set { } } /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public override string Database { get { throw null; } } @@ -1976,6 +1978,37 @@ public sealed class SqlConfigurableRetryFactory /// public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; } } + /// + public abstract class SspiContextProvider + { + /// + protected abstract bool GenerateContext(System.ReadOnlySpan incomingBlob, System.Buffers.IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); + } + /// + public sealed class SspiAuthenticationParameters + { + /// + public SspiAuthenticationParameters(string serverName, string resource) + { + ServerName = serverName; + Resource = resource; + } + + /// + public string Resource { get; } + + /// + public string ServerName { get; } + + /// + public string UserId { get; set; } + + /// + public string DatabaseName { get; set; } + + /// + public string Password { get; set; } + } } namespace Microsoft.Data.SqlClient.Diagnostics { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index ce5e960f8d..d6d5334597 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -91,6 +91,7 @@ private static readonly Dictionary private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders; private Func> _accessTokenCallback; + private SspiContextProvider _sspiContextProvider; internal bool HasColumnEncryptionKeyStoreProvidersRegistered => _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0; @@ -650,7 +651,7 @@ public override string ConnectionString CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions); } } - ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback)); + ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider)); _connectionString = value; // Change _connectionString value only after value is validated CacheConnectionStringProperties(); } @@ -710,7 +711,7 @@ public string AccessToken } // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null, sspiContextProvider: null)); _accessToken = value; } } @@ -733,11 +734,22 @@ public Func + public SspiContextProvider SspiContextProvider + { + get { return _sspiContextProvider; } + set + { + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value)); + _sspiContextProvider = value; + } + } + /// [ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)] [ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)] @@ -1035,7 +1047,7 @@ public SqlCredential Credential _credential = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback, sspiContextProvider: null)); } } @@ -2265,7 +2277,7 @@ public static void ChangePassword(string connectionString, string newPassword) throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProvider: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); if (connectionOptions.IntegratedSecurity) @@ -2314,7 +2326,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); @@ -2352,7 +2364,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString { con?.Dispose(); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null); SqlConnectionFactory.SingletonInstance.ClearPool(key); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index c75eb35a0e..8db440a5ac 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -135,6 +135,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa SqlFedAuthToken _fedAuthToken = null; internal byte[] _accessTokenInBytes; internal readonly Func> _accessTokenCallback; + internal readonly SspiContextProvider _sspiContextProvider; private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; @@ -460,8 +461,8 @@ internal SqlInternalConnectionTds( bool applyTransientFaultHandling = false, string accessToken = null, IDbConnectionPool pool = null, - Func> accessTokenCallback = null) : base(connectionOptions) + Func> accessTokenCallback = null, + SspiContextProvider sspiContextProvider = null) : base(connectionOptions) { #if DEBUG if (reconnectSessionData != null) @@ -514,6 +515,7 @@ internal SqlInternalConnectionTds( } _accessTokenCallback = accessTokenCallback; + _sspiContextProvider = sspiContextProvider; _activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper(); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 4b6a82b7ba..8e3868e3f2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -409,7 +409,7 @@ internal void Connect(ServerInfo serverInfo, // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - _authenticationProvider = _physicalStateObj.CreateSspiContextProvider(); + _authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider(); SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | SSPI or Active Directory Authentication Library loaded for SQL Server based integrated authentication"); } diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index a3bb3b46c2..50ddd52f0e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -810,6 +810,8 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden [System.ComponentModel.BrowsableAttribute(false)] [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } } + /// + public SspiContextProvider SspiContextProvider { get { throw null; } set { } } /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public override string Database { get { throw null; } } @@ -1959,6 +1961,37 @@ public sealed class SqlConfigurableRetryFactory /// public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; } } + /// + public abstract class SspiContextProvider + { + /// + protected abstract bool GenerateContext(System.ReadOnlySpan incomingBlob, System.Buffers.IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); + } + /// + public sealed class SspiAuthenticationParameters + { + /// + public SspiAuthenticationParameters(string serverName, string resource) + { + ServerName = serverName; + Resource = resource; + } + + /// + public string Resource { get; } + + /// + public string ServerName { get; } + + /// + public string UserId { get; set; } + + /// + public string DatabaseName { get; set; } + + /// + public string Password { get; set; } + } } namespace Microsoft.Data.SqlClient.Server { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index 3eac0d16ff..53ebe6fd17 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -89,6 +89,8 @@ private static readonly Dictionary private Func> _accessTokenCallback; + private SspiContextProvider _sspiContextProvider; + internal bool HasColumnEncryptionKeyStoreProvidersRegistered => _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0; @@ -577,6 +579,17 @@ internal int ConnectRetryInterval get => ((SqlConnectionString)ConnectionOptions).ConnectRetryInterval; } + /// + public SspiContextProvider SspiContextProvider + { + get { return _sspiContextProvider; } + set + { + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value)); + _sspiContextProvider = value; + } + } + /// [DefaultValue("")] #pragma warning disable 618 // ignore obsolete warning about RecommendedAsConfigurable to use SettingsBindableAttribute @@ -645,7 +658,7 @@ public override string ConnectionString CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions); } } - ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback)); + ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider)); _connectionString = value; // Change _connectionString value only after value is validated CacheConnectionStringProperties(); } @@ -705,7 +718,7 @@ public string AccessToken _accessToken = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null, sspiContextProvider: null)); } } @@ -727,7 +740,7 @@ public Func> _accessTokenCallback; + internal readonly SspiContextProvider _sspiContextProvider; private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; @@ -472,8 +473,8 @@ internal SqlInternalConnectionTds( bool applyTransientFaultHandling = false, string accessToken = null, IDbConnectionPool pool = null, - Func> accessTokenCallback = null) + Func> accessTokenCallback = null, + SspiContextProvider sspiContextProvider = null) : base(connectionOptions) { #if DEBUG @@ -525,6 +526,7 @@ internal SqlInternalConnectionTds( } _accessTokenCallback = accessTokenCallback; + _sspiContextProvider = sspiContextProvider; _activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper(); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index f89613204f..406789d67d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -411,7 +411,7 @@ internal void Connect(ServerInfo serverInfo, // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - _authenticationProvider = _physicalStateObj.CreateSspiContextProvider(); + _authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider(); if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs index 207c0a8e1a..31da3521df 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs @@ -18,10 +18,12 @@ internal class SqlConnectionPoolKey : DbConnectionPoolKey private readonly SqlCredential _credential; private readonly string _accessToken; private Func> _accessTokenCallback; + private SspiContextProvider _sspiContextProvider; internal SqlCredential Credential => _credential; internal string AccessToken => _accessToken; internal Func> AccessTokenCallback => _accessTokenCallback; + internal SspiContextProvider SspiContextProvider => _sspiContextProvider; internal override string ConnectionString { @@ -33,12 +35,18 @@ internal override string ConnectionString } } - internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken, Func> accessTokenCallback) : base(connectionString) + internal SqlConnectionPoolKey( + string connectionString, + SqlCredential credential, + string accessToken, + Func> accessTokenCallback, + SspiContextProvider sspiContextProvider) : base(connectionString) { Debug.Assert(credential == null || accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time."); _credential = credential; _accessToken = accessToken; _accessTokenCallback = accessTokenCallback; + _sspiContextProvider = sspiContextProvider; CalculateHashCode(); } @@ -47,6 +55,8 @@ private SqlConnectionPoolKey(SqlConnectionPoolKey key) : base(key) _credential = key.Credential; _accessToken = key.AccessToken; _accessTokenCallback = key._accessTokenCallback; + _sspiContextProvider = key._sspiContextProvider; + CalculateHashCode(); } @@ -61,7 +71,8 @@ public override bool Equals(object obj) && _credential == key._credential && ConnectionString == key.ConnectionString && _accessTokenCallback == key._accessTokenCallback - && string.CompareOrdinal(_accessToken, key._accessToken) == 0); + && string.CompareOrdinal(_accessToken, key._accessToken) == 0 + && _sspiContextProvider == key._sspiContextProvider); } public override int GetHashCode() @@ -94,6 +105,11 @@ private void CalculateHashCode() _hashValue = _hashValue * 17 + _accessTokenCallback.GetHashCode(); } } + + if (_sspiContextProvider != null) + { + _hashValue = _hashValue * 17 + _sspiContextProvider.GetHashCode(); + } } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs index 5935b149c8..1cc4af3e9c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs @@ -49,7 +49,7 @@ private void LoadSSPILibrary() } } - protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) + protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; @@ -62,7 +62,7 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlo var sendLength = s_maxSSPILength; var outBuff = outgoingBlobWriter.GetSpan((int)sendLength); - if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource)) + if (SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource) != 0) { return false; } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs index a74651cf2d..9bf5f97e83 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs @@ -13,7 +13,7 @@ internal sealed class NegotiateSspiContextProvider : SspiContextProvider, IDispo { private NegotiateAuthentication? _negotiateAuth; - protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) + protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) { var negotiateAuth = GetNegotiateAuthenticationForParams(authParams); var sendBuff = negotiateAuth.GetOutgoingBlob(incomingBlob, out var statusCode)!; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs index dce0858360..ad6c92853f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs @@ -2,22 +2,29 @@ namespace Microsoft.Data.SqlClient { - internal sealed class SspiAuthenticationParameters + /// + public sealed class SspiAuthenticationParameters { + /// public SspiAuthenticationParameters(string serverName, string resource) { ServerName = serverName; Resource = resource; } + /// public string Resource { get; } + /// public string ServerName { get; } + /// public string? UserId { get; set; } + /// public string? DatabaseName { get; set; } + /// public string? Password { get; set; } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs index f45ccee4fd..0cec692cee 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs @@ -6,7 +6,8 @@ namespace Microsoft.Data.SqlClient { - internal abstract class SspiContextProvider + /// + public abstract class SspiContextProvider { private TdsParser _parser = null!; private ServerInfo _serverInfo = null!; @@ -16,6 +17,11 @@ internal abstract class SspiContextProvider private protected TdsParserStateObject _physicalStateObj = null!; + /// + protected SspiContextProvider() + { + } + #if NET /// /// for details as to what and means and why there are two. @@ -58,7 +64,8 @@ private protected virtual void Initialize() { } - protected abstract bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); + /// + protected abstract bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); internal void WriteSSPIContext(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter) { @@ -94,9 +101,9 @@ private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBuff { try { - SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateSspiClientContext), authParams.Resource); + SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateContext), authParams.Resource); - return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams); + return GenerateContext(incomingBlob, outgoingBlobWriter, authParams); } catch (Exception e) { @@ -105,7 +112,7 @@ private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBuff } } - protected void SSPIError(string error, string procedure) + private protected void SSPIError(string error, string procedure) { Debug.Assert(!string.IsNullOrEmpty(procedure), "TdsParser.SSPIError called with an empty or null procedure string"); Debug.Assert(!string.IsNullOrEmpty(error), "TdsParser.SSPIError called with an empty or null error string"); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs index ae4c7bc98d..00a86bdeac 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs @@ -103,7 +103,7 @@ protected override DbConnectionInternal CreateConnection( // This first connection is established to SqlExpress to get the instance name // of the UserInstance. SqlConnectionString sseopt = new SqlConnectionString(opt, opt.DataSource, userInstance: true, setEnlistValue: false); - sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling); + sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling, sspiContextProvider: key.SspiContextProvider); // NOTE: Retrieve here. This user instance name will be used below to connect to the Sql Express User Instance. instanceName = sseConnection.InstanceName; @@ -157,7 +157,8 @@ protected override DbConnectionInternal CreateConnection( applyTransientFaultHandling, key.AccessToken, pool, - key.AccessTokenCallback); + key.AccessTokenCallback, + key.SspiContextProvider); } protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs index e043b2253c..af650a408e 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Buffers; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests @@ -56,6 +58,39 @@ public static void IntegratedAuthenticationTest_ServerSPN() TryOpenConnectionWithIntegratedAuthentication(builder.ConnectionString); } + [ConditionalFact(nameof(IsIntegratedSecurityEnvironmentSet), nameof(AreConnectionStringsSetup))] + public static void CustomSspiContextGeneratorTest() + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + builder.IntegratedSecurity = true; + Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out int port, out string instanceName)); + // Build the SPN for the server we are connecting to + builder.ServerSPN = $"MSSQLSvc/{DataTestUtility.GetMachineFQDN(hostname)}"; + if (!string.IsNullOrWhiteSpace(instanceName)) + { + builder.ServerSPN += ":" + instanceName; + } + + using SqlConnection conn = new(builder.ConnectionString) + { + SspiContextProvider = new TestSspiContextProvider(), + }; + + try + { + conn.Open(); + + Assert.Fail("Expected to use custom SSPI context provider"); + } + catch (SspiTestException sspi) + { + Assert.Equal(sspi.AuthParams.ServerName, builder.DataSource); + Assert.Equal(sspi.AuthParams.DatabaseName, builder.InitialCatalog); + Assert.Equal(sspi.AuthParams.UserId, builder.UserID); + Assert.Equal(sspi.AuthParams.Password, builder.Password); + } + } + private static void TryOpenConnectionWithIntegratedAuthentication(string connectionString) { using (SqlConnection connection = new SqlConnection(connectionString)) @@ -63,5 +98,23 @@ private static void TryOpenConnectionWithIntegratedAuthentication(string connect connection.Open(); } } + + private sealed class TestSspiContextProvider : SspiContextProvider + { + protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) + { + throw new SspiTestException(authParams); + } + } + + private sealed class SspiTestException : Exception + { + public SspiTestException(SspiAuthenticationParameters authParams) + { + AuthParams = authParams; + } + + public SspiAuthenticationParameters AuthParams { get; } + } } } From e894bb19737cc8165d248e23448b128a0c729950 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Tue, 15 Jul 2025 11:38:03 -0700 Subject: [PATCH 2/2] fix merge issue --- .../Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs index 9bf5f97e83..b9b6cbf36e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs @@ -20,7 +20,7 @@ protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBuffer // Log session id, status code and the actual SPN used in the negotiation SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSspiContextProvider), - nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, negotiateAuth.TargetName); + nameof(GenerateContext), _physicalStateObj.SessionId, statusCode, negotiateAuth.TargetName); if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded) {