From bbb5636eee6434d44ccaed2741dfe48833e3963a Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Fri, 31 Jan 2025 16:23:01 -0800 Subject: [PATCH 01/12] Use string for SNI instead of byte[] --- .../src/Microsoft.Data.SqlClient.csproj | 3 + .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 13 +-- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 21 +++-- .../SqlClient/TdsParserStateObject.netcore.cs | 2 +- .../SqlClient/TdsParserStateObjectManaged.cs | 4 +- .../SqlClient/TdsParserStateObjectNative.cs | 13 ++- .../netfx/src/Microsoft.Data.SqlClient.csproj | 12 ++- .../Data/SqlClient/BufferWriterExtensions.cs | 25 ++++++ .../src/Microsoft/Data/SqlClient/TdsParser.cs | 31 +++---- .../SqlClient/TdsParserStateObject.netfx.cs | 4 +- .../Interop/Windows/Sni/SniNativeWrapper.cs | 86 +++++++++++++++---- .../SSPI/NativeSSPIContextProvider.cs | 2 +- .../SSPI/NegotiateSSPIContextProvider.cs | 8 +- .../SqlClient/SSPI/SSPIContextProvider.cs | 8 +- .../Microsoft/Data/SqlClient/SqlObjectPool.cs | 2 + .../Data/SqlClient/SqlObjectPools.cs | 34 ++++++++ .../src/Microsoft/Data/SqlClient/TdsParser.cs | 16 ++-- .../SqlClient/TdsParserSafeHandles.Windows.cs | 6 +- .../SQL/InstanceNameTest/InstanceNameTest.cs | 4 +- 19 files changed, 202 insertions(+), 92 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPools.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index a4531966aa..50ae9d9daa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -515,6 +515,9 @@ Microsoft\Data\SqlClient\SqlObjectPool.cs + + Microsoft\Data\SqlClient\SqlObjectPools.cs + Microsoft\Data\SqlClient\SqlParameter.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 463283cc73..f45dbff874 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Net; @@ -50,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle( string fullServerName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spnBuffer, string serverSPN, bool flushCache, bool async, @@ -114,12 +115,12 @@ internal static SNIHandle CreateConnectionHandle( return sniHandle; } - private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN) + private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN) { Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName)); if (!string.IsNullOrWhiteSpace(serverSPN)) { - return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) }; + return new[] { serverSPN }; } string hostName = dataSource.ServerName; @@ -137,7 +138,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol); } - private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) + private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) { Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress)); IPHostEntry hostEntry = null; @@ -168,12 +169,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}"; // Set both SPNs with and without Port as Port is optional for default instance SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort); - return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) }; + return new[] { serverSpn, serverSpnWithDefaultPort }; } // else Named Pipes do not need to valid port SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn); - return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) }; + return new[] { serverSpn }; } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index b7bc73d620..a486f92be9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -112,8 +112,7 @@ internal sealed partial class TdsParser private bool _is2022 = false; - private byte[][] _sniSpnBuffer = null; - // UNDONE - need to have some for both instances - both command and default??? + private string[] _sniSpn = null; // SqlStatistics private SqlStatistics _statistics = null; @@ -390,7 +389,7 @@ internal void Connect( } else { - _sniSpnBuffer = null; + _sniSpn = null; SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID, authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString()); } @@ -402,7 +401,7 @@ internal void Connect( SqlClientEventSource.Log.TryTraceEvent(" Encryption will be disabled as target server is a SQL Local DB instance."); } - _sniSpnBuffer = null; + _sniSpn = null; _authenticationProvider = null; // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server @@ -441,7 +440,7 @@ internal void Connect( serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpnBuffer, + ref _sniSpn, false, true, fParallel, @@ -454,8 +453,6 @@ internal void Connect( hostNameInCertificate, serverCertificateFilename); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -470,6 +467,8 @@ internal void Connect( Debug.Fail("SNI returned status != success, but no error thrown?"); } + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + _server = serverInfo.ResolvedServerName; if (connHandler.PoolGroupProviderInfo != null) @@ -540,7 +539,7 @@ internal void Connect( _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpnBuffer, + ref _sniSpn, true, true, fParallel, @@ -553,8 +552,6 @@ internal void Connect( hostNameInCertificate, serverCertificateFilename); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -562,6 +559,8 @@ internal void Connect( ThrowExceptionAndWarning(_physicalStateObj); } + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId); Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); @@ -13317,7 +13316,7 @@ internal string TraceString() _fMARS ? bool.TrueString : bool.FalseString, _sessionPool == null ? "(null)" : _sessionPool.TraceString(), _is2005 ? bool.TrueString : bool.FalseString, - _sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null), + _sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index c7b9e3dc3b..01f1759530 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spn, bool flushCache, bool async, bool fParallel, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 28617dd69c..f9f211ca4a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spn, bool flushCache, bool async, bool parallel, @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle( string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 3534b61740..adada6abaa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -144,7 +144,7 @@ internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref byte[][] spnBuffer, + ref string[] spn, bool flushCache, bool async, bool fParallel, @@ -157,22 +157,18 @@ internal override void CreatePhysicalSNIHandle( string hostNameInCertificate, string serverCertificateFilename) { - // We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer - spnBuffer = new byte[1][]; if (isIntegratedSecurity) { // now allocate proper length of buffer if (!string.IsNullOrEmpty(serverSPN)) { // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. - byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN); - Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); - spnBuffer[0] = srvSPN; SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN); } else { - spnBuffer[0] = new byte[SniNativeWrapper.SniMaxComposedSpnLength]; + // This will signal to the interop layer that we need to retrieve the SPN + serverSPN = string.Empty; } } @@ -180,8 +176,9 @@ internal override void CreatePhysicalSNIHandle( SQLDNSInfo cachedDNSInfo; bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName, + _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); + spn = new[] { serverSPN.TrimEnd() }; } protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 0c3f1f747a..7b89fc6330 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -337,10 +337,13 @@ Microsoft\Data\Sql\SqlDataSourceEnumerator.cs - - Microsoft\Data\SqlClient\SqlObjectPool.cs - - + + Microsoft\Data\SqlClient\SqlObjectPool.cs + + + Microsoft\Data\SqlClient\SqlObjectPools.cs + + Microsoft\Data\SqlClient\AAsyncCallContext.cs @@ -854,6 +857,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs new file mode 100644 index 0000000000..efd15b1602 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs @@ -0,0 +1,25 @@ +using System.Buffers; +using System.Text; + +namespace Microsoft.Data.SqlClient +{ + internal static class BufferWriterExtensions + { + internal static long GetBytes(this Encoding encoding, string str, IBufferWriter bufferWriter) + { + var count = encoding.GetByteCount(str); + var array = ArrayPool.Shared.Rent(count); + + try + { + encoding.GetBytes(str, 0, str.Length, array, 0); + bufferWriter.Write(array); + return count; + } + finally + { + ArrayPool.Shared.Return(array); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 61b9de2451..3773e5f23c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -128,7 +128,7 @@ internal int ObjectID private bool _is2022 = false; - private byte[] _sniSpnBuffer = null; + private string _sniSpn = null; // UNDONE - need to have some for both instances - both command and default??? @@ -430,27 +430,24 @@ internal void Connect(ServerInfo serverInfo, // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); - if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { - // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. - byte[] srvSPN = Encoding.Unicode.GetBytes(serverInfo.ServerSPN); - Trace.Assert(srvSPN.Length <= SniNativeWrapper.SniMaxComposedSpnLength, "The provided SPN length exceeded the buffer size."); - _sniSpnBuffer = srvSPN; + _sniSpn = serverInfo.ServerSPN; SqlClientEventSource.Log.TryTraceEvent(" Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN); } else { - // now allocate proper length of buffer - _sniSpnBuffer = new byte[SniNativeWrapper.SniMaxComposedSpnLength]; + // Empty signifies to interop layer that SNI needs to be generated + _sniSpn = string.Empty; } + + _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); SqlClientEventSource.Log.TryTraceEvent(" SSPI or Active Directory Authentication Library for SQL Server based integrated authentication"); } else { _authenticationProvider = null; - _sniSpnBuffer = null; + _sniSpn = null; switch (authType) { @@ -529,7 +526,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - _sniSpnBuffer, + ref _sniSpn, false, true, fParallel, @@ -539,8 +536,6 @@ internal void Connect(ServerInfo serverInfo, FQDNforDNSCache, hostNameInCertificate); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -555,6 +550,8 @@ internal void Connect(ServerInfo serverInfo, Debug.Fail("SNI returned status != success, but no error thrown?"); } + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + _server = serverInfo.ResolvedServerName; if (connHandler.PoolGroupProviderInfo != null) @@ -629,7 +626,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - _sniSpnBuffer, + ref _sniSpn, true, true, fParallel, @@ -639,8 +636,6 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ResolvedServerName, hostNameInCertificate); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -648,6 +643,8 @@ internal void Connect(ServerInfo serverInfo, ThrowExceptionAndWarning(_physicalStateObj); } + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + uint retCode = SniNativeWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId); Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); SqlClientEventSource.Log.TryTraceEvent(" Sending prelogin handshake"); @@ -13785,7 +13782,7 @@ internal string TraceString() _is2000 ? bool.TrueString : bool.FalseString, _is2000SP1 ? bool.TrueString : bool.FalseString, _is2005 ? bool.TrueString : bool.FalseString, - _sniSpnBuffer == null ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null), + _sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 2f8d39678b..024d698822 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -242,7 +242,7 @@ internal void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - byte[] spnBuffer, + ref string spn, bool flushCache, bool async, bool fParallel, @@ -259,7 +259,7 @@ internal void CreatePhysicalSNIHandle( _ = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out SQLDNSInfo cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, timeout.MillisecondsRemainingInt, + _sessionHandle = new SNIHandle(myInfo, serverName, ref spn, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout, ipPreference, cachedDNSInfo, hostNameInCertificate); } diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 7f7eedf41c..5e48c43f8c 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -199,8 +200,7 @@ private static unsafe uint SNISecGenClientContextWrapper( [In] ref uint pcbOut, [MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone, - byte* szServerInfo, - uint cbServerInfo, + ReadOnlySpan serverInfo, [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName, [MarshalAsAttribute(UnmanagedType.LPWStr)] @@ -208,6 +208,7 @@ private static unsafe uint SNISecGenClientContextWrapper( { fixed (byte* pInPtr = pIn) fixed (byte* pOutPtr = pOut) + fixed (byte* pServerInfo = serverInfo) { return NativeMethods.SniSecGenClientContextWrapper( pConn, @@ -216,8 +217,8 @@ private static unsafe uint SNISecGenClientContextWrapper( pOutPtr, ref pcbOut, out pfDone, - szServerInfo, - cbServerInfo, + pServerInfo, + (uint)serverInfo.Length, pwszUserName, pwszPassword); } @@ -298,7 +299,7 @@ internal static unsafe uint SNIOpenSyncEx( ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, - byte[] spnBuffer, + ref string spn, byte[] instanceName, bool fOverrideCache, bool fSync, @@ -358,13 +359,59 @@ internal static unsafe uint SNIOpenSyncEx( clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; - if (spnBuffer != null) + if (spn != null) { - fixed (byte* pin_spnBuffer = &spnBuffer[0]) + // An empty string implies we need to find the SPN so we supply a buffer for the max size + if (spn.Length == 0) { - clientConsumerInfo.szSPN = pin_spnBuffer; - clientConsumerInfo.cchSPN = (uint)spnBuffer.Length; - return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + var array = ArrayPool.Shared.Rent(SniMaxComposedSpnLength); + array.AsSpan().Clear(); + + try + { + fixed (byte* pin_spnBuffer = array) + { + clientConsumerInfo.szSPN = pin_spnBuffer; + clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength; + + var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + + if (result == 0) + { + spn = Encoding.Unicode.GetString(array).TrimEnd('\0'); + } + + return result; + } + } + finally + { + ArrayPool.Shared.Return(array); + } + } + + // We have a value of the SPN, so we marshal that and send it to the native layer + else + { + var writer = SqlObjectPools.BufferWriter.Rent(); + + // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. + Encoding.Unicode.GetBytes(spn, writer); + Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); + + try + { + fixed (byte* pin_spnBuffer = writer.WrittenSpan) + { + clientConsumerInfo.szSPN = pin_spnBuffer; + clientConsumerInfo.cchSPN = (uint)writer.WrittenCount; + return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + } + } + finally + { + SqlObjectPools.BufferWriter.Return(writer); + } } } else @@ -544,23 +591,28 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w } #endif - - internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span OutBuff, ref uint sendLength, byte[] serverUserName) + internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, ref uint sendLength, string serverUserName) { - fixed (byte* pin_serverUserName = &serverUserName[0]) - //netcore fixed (byte* pInBuff = inBuff) + var serverWriter = SqlObjectPools.BufferWriter.Rent(); + + try { + Encoding.Unicode.GetBytes(serverUserName, serverWriter); + return SNISecGenClientContextWrapper( pConnectionObject, inBuff, - OutBuff, + outBuff, ref sendLength, out _, - pin_serverUserName, - (uint)serverUserName.Length, + serverWriter.WrittenSpan, null, null); } + finally + { + SqlObjectPools.BufferWriter.Return(serverWriter); + } } internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index 438dd44ad5..fa4b07dd94 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -49,7 +49,7 @@ private void LoadSSPILibrary() } } - protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer) + protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan _sniSpnBuffer) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs index bddf7802ed..719d1d681d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs @@ -12,14 +12,14 @@ namespace Microsoft.Data.SqlClient internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider { private NegotiateAuthentication? _negotiateAuth = null; - - protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer) + + protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverNames) { NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials; - for (int i = 0; i < _sniSpnBuffer.Length; i++) + for (int i = 0; i < serverNames.Length; i++) { - _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = Encoding.Unicode.GetString(_sniSpnBuffer[i]) }); + _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverNames[i] }); var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!; // Log session id, status code and the actual SPN used in the negotiation diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index 027ee5146e..1b446290ce 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -26,18 +26,18 @@ private protected virtual void Initialize() { } - protected abstract void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, byte[][] _sniSpnBuffer); + protected abstract void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan _sniSpnBuffer); - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, byte[] sniSpnBuffer) + internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string sniSpnBuffer) => SSPIData(receivedBuff, outgoingBlobWriter, new[] { sniSpnBuffer }); - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, byte[][] sniSpnBuffer) + internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string[] sniSpnBuffer) { using (TrySNIEventScope.Create(nameof(SSPIContextProvider))) { try { - GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, sniSpnBuffer); + GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, sniSpnBuffer); } catch (Exception e) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs index d5cf2398ec..94b9be977a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs @@ -3,7 +3,9 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Diagnostics; +using System.Text; using System.Threading; namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPools.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPools.cs new file mode 100644 index 0000000000..98b980180c --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPools.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Text; +using System.Threading; + +namespace Microsoft.Data.SqlClient +{ + // This is a collection of general object pools that can be reused as needed. + internal static class SqlObjectPools + { + private static SqlObjectPool> _bufferWriter; + + internal static SqlObjectPool> BufferWriter + { + get + { + if (_bufferWriter is null) + { + // This is a shared pool that will retain the last 20 writers to be reused. If more than 20 are requested at a time, + // they will not be retained when returned to the pool. + Interlocked.CompareExchange(ref _bufferWriter, new(20, () => new(), a => a.Clear()), null); + } + + return _bufferWriter; + } + } + } + +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index d56a0679d4..aba63db1c0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -8,10 +8,6 @@ namespace Microsoft.Data.SqlClient { internal partial class TdsParser { - // This is a shared pool that will retain the last 20 writers to be reused. If more than 20 are requested at a time, - // they will not be retained when returned to the pool. - private static readonly SqlObjectPool> _writers = new(20, () => new(), a => a.Clear()); - internal void ProcessSSPI(int receivedLength) { Debug.Assert(_authenticationProvider is not null); @@ -30,15 +26,15 @@ internal void ProcessSSPI(int receivedLength) } // allocate send buffer and initialize length - var writer = _writers.Rent(); + var writer = SqlObjectPools.BufferWriter.Rent(); // make call for SSPI data - _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _sniSpnBuffer); + _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _sniSpn); // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! _physicalStateObj.WriteByteSpan(writer.WrittenSpan); - _writers.Return(writer); + SqlObjectPools.BufferWriter.Return(writer); ArrayPool.Shared.Return(receivedBuff, clearArray: true); // set message type so server knows its a SSPI response @@ -156,14 +152,14 @@ internal void TdsLogin( { if (rec.useSSPI) { - sspiWriter = _writers.Rent(); + sspiWriter = SqlObjectPools.BufferWriter.Rent(); // Call helper function for SSPI data and actual length. // Since we don't have SSPI data from the server, send null for the // byte[] buffer and 0 for the int length. Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; - _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _sniSpnBuffer); + _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _sniSpn); _physicalStateObj.SniContext = SniContext.Snix_Login; @@ -196,7 +192,7 @@ internal void TdsLogin( if (sspiWriter is not null) { - _writers.Return(sspiWriter); + SqlObjectPools.BufferWriter.Return(sspiWriter); } _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index bf5871c57e..6d8c62631c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -150,7 +150,7 @@ internal sealed class SNIHandle : SafeHandle internal SNIHandle( ConsumerInfo myInfo, string serverName, - byte[] spnBuffer, + ref string spn, int timeout, out byte[] instanceName, bool flushCache, @@ -189,7 +189,7 @@ internal SNIHandle( myInfo, serverName, ref base.handle, - spnBuffer, + ref spn, instanceName, flushCache, fSync, @@ -205,7 +205,7 @@ internal SNIHandle( myInfo, serverName, ref base.handle, - spnBuffer, + ref spn, instanceName, flushCache, fSync, diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index ae4074c1e5..9f6673332c 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -212,9 +212,9 @@ private static string GetSPNInfo(string dataSource, string inInstanceName) string serverSPN = ""; MethodInfo getSqlServerSPNs = sniProxyObj.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null); - byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN }); + string[] result = (string[])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN }); - string spnInfo = Encoding.Unicode.GetString(result[0]); + string spnInfo = result[0]; return spnInfo; } From 66e55f753dfc3ca156f4f01a599f36851f705361 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Wed, 26 Feb 2025 16:25:30 -0800 Subject: [PATCH 02/12] spn naming feedback --- .../src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 6 +++--- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 12 ++++++------ .../src/Microsoft/Data/SqlClient/TdsParser.cs | 16 ++++++++-------- .../SqlClient/SSPI/NativeSSPIContextProvider.cs | 4 ++-- .../Data/SqlClient/SSPI/SSPIContextProvider.cs | 10 +++++----- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 4 ++-- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index f45dbff874..265f80246c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -34,7 +34,7 @@ internal class SNIProxy /// Full server name from connection string /// Timer expiration /// Instance name - /// SPN + /// SPNs /// pre-defined SPN /// Flush packet cache /// Asynchronous connection @@ -51,7 +51,7 @@ internal static SNIHandle CreateConnectionHandle( string fullServerName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spnBuffer, + ref string[] spns, string serverSPN, bool flushCache, bool async, @@ -103,7 +103,7 @@ internal static SNIHandle CreateConnectionHandle( { try { - spnBuffer = GetSqlServerSPNs(details, serverSPN); + spns = GetSqlServerSPNs(details, serverSPN); } catch (Exception e) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index a486f92be9..945fad1c5a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -112,7 +112,7 @@ internal sealed partial class TdsParser private bool _is2022 = false; - private string[] _sniSpn = null; + private string[] _serverSpn = null; // SqlStatistics private SqlStatistics _statistics = null; @@ -389,7 +389,7 @@ internal void Connect( } else { - _sniSpn = null; + _serverSpn = null; SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID, authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString()); } @@ -401,7 +401,7 @@ internal void Connect( SqlClientEventSource.Log.TryTraceEvent(" Encryption will be disabled as target server is a SQL Local DB instance."); } - _sniSpn = null; + _serverSpn = null; _authenticationProvider = null; // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server @@ -440,7 +440,7 @@ internal void Connect( serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpn, + ref _serverSpn, false, true, fParallel, @@ -539,7 +539,7 @@ internal void Connect( _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpn, + ref _serverSpn, true, true, fParallel, @@ -13316,7 +13316,7 @@ internal string TraceString() _fMARS ? bool.TrueString : bool.FalseString, _sessionPool == null ? "(null)" : _sessionPool.TraceString(), _is2005 ? bool.TrueString : bool.FalseString, - _sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null), + _serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 3773e5f23c..d766716965 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -128,7 +128,7 @@ internal int ObjectID private bool _is2022 = false; - private string _sniSpn = null; + private string _serverSpn = null; // UNDONE - need to have some for both instances - both command and default??? @@ -432,13 +432,13 @@ internal void Connect(ServerInfo serverInfo, { if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { - _sniSpn = serverInfo.ServerSPN; + _serverSpn = serverInfo.ServerSPN; SqlClientEventSource.Log.TryTraceEvent(" Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN); } else { - // Empty signifies to interop layer that SNI needs to be generated - _sniSpn = string.Empty; + // Empty signifies to interop layer that SPN needs to be generated + _serverSpn = string.Empty; } _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); @@ -447,7 +447,7 @@ internal void Connect(ServerInfo serverInfo, else { _authenticationProvider = null; - _sniSpn = null; + _serverSpn = null; switch (authType) { @@ -526,7 +526,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpn, + ref _serverSpn, false, true, fParallel, @@ -626,7 +626,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - ref _sniSpn, + ref _serverSpn, true, true, fParallel, @@ -13782,7 +13782,7 @@ internal string TraceString() _is2000 ? bool.TrueString : bool.FalseString, _is2000SP1 ? bool.TrueString : bool.FalseString, _is2005 ? bool.TrueString : bool.FalseString, - _sniSpn == null ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null), + _serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index fa4b07dd94..220df6d568 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -49,7 +49,7 @@ private void LoadSSPILibrary() } } - protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan _sniSpnBuffer) + protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpn) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; @@ -62,7 +62,7 @@ protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlo var sendLength = s_maxSSPILength; var outBuff = outgoingBlobWriter.GetSpan((int)sendLength); - if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, _sniSpnBuffer[0])) + if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpn[0])) { throw new InvalidOperationException(SQLMessage.SSPIGenerateError()); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index 1b446290ce..587ad230cc 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -26,18 +26,18 @@ private protected virtual void Initialize() { } - protected abstract void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan _sniSpnBuffer); + protected abstract void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpn); - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string sniSpnBuffer) - => SSPIData(receivedBuff, outgoingBlobWriter, new[] { sniSpnBuffer }); + internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string serverSpn) + => SSPIData(receivedBuff, outgoingBlobWriter, new[] { serverSpn }); - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string[] sniSpnBuffer) + internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string[] serverSpn) { using (TrySNIEventScope.Create(nameof(SSPIContextProvider))) { try { - GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, sniSpnBuffer); + GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn); } catch (Exception e) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index aba63db1c0..51e6e459c6 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -29,7 +29,7 @@ internal void ProcessSSPI(int receivedLength) var writer = SqlObjectPools.BufferWriter.Rent(); // make call for SSPI data - _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _sniSpn); + _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _serverSpn); // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! _physicalStateObj.WriteByteSpan(writer.WrittenSpan); @@ -159,7 +159,7 @@ internal void TdsLogin( // byte[] buffer and 0 for the int length. Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; - _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _sniSpn); + _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _serverSpn); _physicalStateObj.SniContext = SniContext.Snix_Login; From c56f20ebd9e5aecd56129db6b7e5b5ac84c8953b Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Wed, 26 Feb 2025 16:27:23 -0800 Subject: [PATCH 03/12] try/finally --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 109 ++++++++++-------- 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 51e6e459c6..0b2c7b5e3f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -28,14 +28,20 @@ internal void ProcessSSPI(int receivedLength) // allocate send buffer and initialize length var writer = SqlObjectPools.BufferWriter.Rent(); - // make call for SSPI data - _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _serverSpn); + try + { + // make call for SSPI data + _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _serverSpn); - // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! - _physicalStateObj.WriteByteSpan(writer.WrittenSpan); + // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! + _physicalStateObj.WriteByteSpan(writer.WrittenSpan); - SqlObjectPools.BufferWriter.Return(writer); - ArrayPool.Shared.Return(receivedBuff, clearArray: true); + } + finally + { + SqlObjectPools.BufferWriter.Return(writer); + ArrayPool.Shared.Return(receivedBuff, clearArray: true); + } // set message type so server knows its a SSPI response _physicalStateObj._outputMessageType = TdsEnums.MT_SSPI; @@ -139,60 +145,65 @@ internal void TdsLogin( // allocate memory for SSPI variables ArrayBufferWriter sspiWriter = null; - // only add lengths of password and username if not using SSPI or requesting federated authentication info - if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested)) + try { - checked + // only add lengths of password and username if not using SSPI or requesting federated authentication info + if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested)) { - length += (userName.Length * 2) + encryptedPasswordLengthInBytes - + encryptedChangePasswordLengthInBytes; + checked + { + length += (userName.Length * 2) + encryptedPasswordLengthInBytes + + encryptedChangePasswordLengthInBytes; + } } - } - else - { - if (rec.useSSPI) + else { - sspiWriter = SqlObjectPools.BufferWriter.Rent(); + if (rec.useSSPI) + { + sspiWriter = SqlObjectPools.BufferWriter.Rent(); - // Call helper function for SSPI data and actual length. - // Since we don't have SSPI data from the server, send null for the - // byte[] buffer and 0 for the int length. - Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); - _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; - _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _serverSpn); + // Call helper function for SSPI data and actual length. + // Since we don't have SSPI data from the server, send null for the + // byte[] buffer and 0 for the int length. + Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); + _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; + _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _serverSpn); - _physicalStateObj.SniContext = SniContext.Snix_Login; + _physicalStateObj.SniContext = SniContext.Snix_Login; - checked - { - length += (int)sspiWriter.WrittenCount; + checked + { + length += (int)sspiWriter.WrittenCount; + } } } - } - int feOffset = length; - // calculate and reserve the required bytes for the featureEx - length = ApplyFeatureExData(requestedFeatures, recoverySessionData, fedAuthFeatureExtensionData, useFeatureExt, length); - - WriteLoginData(rec, - requestedFeatures, - recoverySessionData, - fedAuthFeatureExtensionData, - encrypt, - encryptedPassword, - encryptedChangePassword, - encryptedPasswordLengthInBytes, - encryptedChangePasswordLengthInBytes, - useFeatureExt, - userName, - length, - feOffset, - clientInterfaceName, - sspiWriter is { } ? sspiWriter.WrittenSpan : ReadOnlySpan.Empty); - - if (sspiWriter is not null) + int feOffset = length; + // calculate and reserve the required bytes for the featureEx + length = ApplyFeatureExData(requestedFeatures, recoverySessionData, fedAuthFeatureExtensionData, useFeatureExt, length); + + WriteLoginData(rec, + requestedFeatures, + recoverySessionData, + fedAuthFeatureExtensionData, + encrypt, + encryptedPassword, + encryptedChangePassword, + encryptedPasswordLengthInBytes, + encryptedChangePasswordLengthInBytes, + useFeatureExt, + userName, + length, + feOffset, + clientInterfaceName, + sspiWriter is { } ? sspiWriter.WrittenSpan : ReadOnlySpan.Empty); + } + finally { - SqlObjectPools.BufferWriter.Return(sspiWriter); + if (sspiWriter is not null) + { + SqlObjectPools.BufferWriter.Return(sspiWriter); + } } _physicalStateObj.WritePacket(TdsEnums.HARDFLUSH); From 657b1b29cb5c6ff4d1f6b73114d34239f96b1579 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Wed, 26 Feb 2025 16:30:26 -0800 Subject: [PATCH 04/12] undo initialization change --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 8 +-- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 53 ++++++++++--------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 945fad1c5a..24e08f6794 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -453,6 +453,8 @@ internal void Connect( hostNameInCertificate, serverCertificateFilename); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -467,8 +469,6 @@ internal void Connect( Debug.Fail("SNI returned status != success, but no error thrown?"); } - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - _server = serverInfo.ResolvedServerName; if (connHandler.PoolGroupProviderInfo != null) @@ -552,6 +552,8 @@ internal void Connect( hostNameInCertificate, serverCertificateFilename); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -559,8 +561,6 @@ internal void Connect( ThrowExceptionAndWarning(_physicalStateObj); } - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId); Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index d766716965..f1d250396e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -163,7 +163,7 @@ internal int ObjectID // now data length is 1 byte // First bit is 1 indicating client support failover partner with readonly intent private static readonly byte[] s_FeatureExtDataAzureSQLSupportFeatureRequest = { 0x01 }; - + // NOTE: You must take the internal connection's _parserLock before modifying this internal bool _asyncWrite = false; @@ -430,6 +430,8 @@ internal void Connect(ServerInfo serverInfo, // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { + _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); + if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { _serverSpn = serverInfo.ServerSPN; @@ -441,7 +443,6 @@ internal void Connect(ServerInfo serverInfo, _serverSpn = string.Empty; } - _authenticationProvider = _physicalStateObj.CreateSSPIContextProvider(); SqlClientEventSource.Log.TryTraceEvent(" SSPI or Active Directory Authentication Library for SQL Server based integrated authentication"); } else @@ -536,6 +537,8 @@ internal void Connect(ServerInfo serverInfo, FQDNforDNSCache, hostNameInCertificate); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -550,8 +553,6 @@ internal void Connect(ServerInfo serverInfo, Debug.Fail("SNI returned status != success, but no error thrown?"); } - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - _server = serverInfo.ResolvedServerName; if (connHandler.PoolGroupProviderInfo != null) @@ -636,6 +637,8 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ResolvedServerName, hostNameInCertificate); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); + if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -643,8 +646,6 @@ internal void Connect(ServerInfo serverInfo, ThrowExceptionAndWarning(_physicalStateObj); } - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - uint retCode = SniNativeWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId); Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); SqlClientEventSource.Log.TryTraceEvent(" Sending prelogin handshake"); @@ -3390,7 +3391,7 @@ private TdsOperationStatus TryProcessDone(SqlCommand cmd, SqlDataReader reader, Debug.Assert(!((sqlTransaction != null && _distributedTransaction != null) || (_userStartedLocalTransaction != null && _distributedTransaction != null)) , "ProcessDone - have both distributed and local transactions not null!"); - */ + */ // WebData 112722 stateObj.DecrementOpenResultCount(); @@ -3877,8 +3878,8 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj, if (!recoverable) { checked - { - sdata._unrecoverableStatesCount++; + { + sdata._unrecoverableStatesCount++; } } } @@ -3899,8 +3900,8 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj, else { checked - { - sdata._unrecoverableStatesCount++; + { + sdata._unrecoverableStatesCount++; } } sv._recoverable = recoverable; @@ -3979,29 +3980,29 @@ private TdsOperationStatus TryProcessLoginAck(TdsParserStateObject stateObj, out { case TdsEnums.SQL2005_MAJOR << 24 | TdsEnums.SQL2005_RTM_MINOR: // 2005 if (increment != TdsEnums.SQL2005_INCREMENT) - { - throw SQL.InvalidTDSVersion(); + { + throw SQL.InvalidTDSVersion(); } _is2005 = true; break; case TdsEnums.SQL2008_MAJOR << 24 | TdsEnums.SQL2008_MINOR: if (increment != TdsEnums.SQL2008_INCREMENT) - { - throw SQL.InvalidTDSVersion(); + { + throw SQL.InvalidTDSVersion(); } _is2008 = true; break; case TdsEnums.SQL2012_MAJOR << 24 | TdsEnums.SQL2012_MINOR: if (increment != TdsEnums.SQL2012_INCREMENT) - { - throw SQL.InvalidTDSVersion(); + { + throw SQL.InvalidTDSVersion(); } _is2012 = true; break; case TdsEnums.TDS8_MAJOR << 24 | TdsEnums.TDS8_MINOR: if (increment != TdsEnums.TDS8_INCREMENT) - { - throw SQL.InvalidTDSVersion(); + { + throw SQL.InvalidTDSVersion(); } _is2022 = true; break; @@ -5934,7 +5935,7 @@ private TdsOperationStatus TryProcessColInfo(_SqlMetaDataSet columns, SqlDataRea for (int i = 0; i < columns.Length; i++) { _SqlMetaData col = columns[i]; - + result = stateObj.TryReadByte(out _); if (result != TdsOperationStatus.Done) { @@ -6471,7 +6472,7 @@ private TdsOperationStatus TryReadSqlStringValue(SqlBuffer value, byte type, int char[] cc = null; bool buffIsRented = false; result = TryReadPlpUnicodeChars(ref cc, 0, length >> 1, stateObj, out length, supportRentedBuff: true, rentedBuff: ref buffIsRented); - + if (result == TdsOperationStatus.Done) { if (length > 0) @@ -11370,7 +11371,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) - { + { // see comments above actualLengthInBytes = metadata.baseTI.length; } @@ -12278,7 +12279,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati _parser.WriteInt(count, _stateObj); // write length of chunk task = _stateObj.WriteByteArray(buffer, count, offset, canAccumulate: false); } - + return task ?? Task.CompletedTask; } catch (System.OutOfMemoryException) @@ -12511,7 +12512,7 @@ private async Task WriteTextFeed(TextDataFeed feed, Encoding encoding, bool need char[] inBuff = ArrayPool.Shared.Rent(constTextBufferSize); encoding = encoding ?? new UnicodeEncoding(false, false); - + using (ConstrainedTextWriter writer = new ConstrainedTextWriter(new StreamWriter(new TdsOutputStream(this, stateObj, null), encoding), size)) { if (needBom) @@ -13429,7 +13430,7 @@ internal TdsOperationStatus TryReadPlpUnicodeChars( int charsRead = 0; int charsLeft = 0; char[] newbuf; - + if (stateObj._longlen == 0) { Debug.Assert(stateObj._longlenleft == 0); @@ -13546,7 +13547,7 @@ internal TdsOperationStatus TryReadPlpUnicodeChars( totalCharsRead++; } if (stateObj._longlenleft == 0) - { + { // Read the next chunk or cleanup state if hit the end result = stateObj.TryReadPlpLength(false, out _); if (result != TdsOperationStatus.Done) From 2881ce4af34d660254de997d309e5904eefeeebc Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 08:42:48 -0800 Subject: [PATCH 05/12] more try/finally --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 8 ++-- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 37 +++++++++++-------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 5e48c43f8c..cd28e2f162 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -395,12 +395,12 @@ internal static unsafe uint SNIOpenSyncEx( { var writer = SqlObjectPools.BufferWriter.Rent(); - // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. - Encoding.Unicode.GetBytes(spn, writer); - Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); - try { + // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. + Encoding.Unicode.GetBytes(spn, writer); + Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); + fixed (byte* pin_spnBuffer = writer.WrittenSpan) { clientConsumerInfo.szSPN = pin_spnBuffer; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 0b2c7b5e3f..8721b05658 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -14,32 +14,39 @@ internal void ProcessSSPI(int receivedLength) SniContext outerContext = _physicalStateObj.SniContext; _physicalStateObj.SniContext = SniContext.Snix_ProcessSspi; + // allocate received buffer based on length from SSPI message byte[] receivedBuff = ArrayPool.Shared.Rent(receivedLength); - // read SSPI data received from server - Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); - TdsOperationStatus result = _physicalStateObj.TryReadByteArray(receivedBuff, receivedLength); - if (result != TdsOperationStatus.Done) + try { - throw SQL.SynchronousCallMayNotPend(); - } + // read SSPI data received from server + Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); + TdsOperationStatus result = _physicalStateObj.TryReadByteArray(receivedBuff, receivedLength); + if (result != TdsOperationStatus.Done) + { + throw SQL.SynchronousCallMayNotPend(); + } - // allocate send buffer and initialize length - var writer = SqlObjectPools.BufferWriter.Rent(); + // allocate send buffer and initialize length + var writer = SqlObjectPools.BufferWriter.Rent(); - try - { - // make call for SSPI data - _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _serverSpn); + try + { + // make call for SSPI data + _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _serverSpn); - // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! - _physicalStateObj.WriteByteSpan(writer.WrittenSpan); + // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! + _physicalStateObj.WriteByteSpan(writer.WrittenSpan); + } + finally + { + SqlObjectPools.BufferWriter.Return(writer); + } } finally { - SqlObjectPools.BufferWriter.Return(writer); ArrayPool.Shared.Return(receivedBuff, clearArray: true); } From 644f031d3c657bc073f93063fb2cc390d84cb6e0 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:18:11 -0800 Subject: [PATCH 06/12] Update src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs Co-authored-by: Malcolm Daigle --- .../src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index adada6abaa..b07ac26b7b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -144,7 +144,7 @@ internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spn, + ref string[] spns, bool flushCache, bool async, bool fParallel, From c0c2df7a7908da90ed231ce12a2080a39073a5ee Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:18:20 -0800 Subject: [PATCH 07/12] Update src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs Co-authored-by: Malcolm Daigle --- .../Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 01f1759530..b2763a786f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spn, + ref string[] spns, bool flushCache, bool async, bool fParallel, From 696991cc8d2040e17e61c8d1b626de5bda02d3fb Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:18:30 -0800 Subject: [PATCH 08/12] Update src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs Co-authored-by: Malcolm Daigle --- .../src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index f9f211ca4a..083027c36a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spn, + ref string[] spns, bool flushCache, bool async, bool parallel, From e1c46ec39b76e300fb3ce81f030fe9ff804f94a0 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:18:57 -0800 Subject: [PATCH 09/12] Update src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs Co-authored-by: Malcolm Daigle --- .../Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index 220df6d568..facafc35cd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -49,7 +49,7 @@ private void LoadSSPILibrary() } } - protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpn) + protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpns) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; From a232ac8a3ec822517664c994a22b80cf4097f80a Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:19:12 -0800 Subject: [PATCH 10/12] Update src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs Co-authored-by: Malcolm Daigle --- .../src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index 587ad230cc..e779224fe4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -26,7 +26,7 @@ private protected virtual void Initialize() { } - protected abstract void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpn); + protected abstract void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpns); internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string serverSpn) => SSPIData(receivedBuff, outgoingBlobWriter, new[] { serverSpn }); From 13ab3b44b91655f57a3f3d3da5322a2324e2a0f9 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:19:23 -0800 Subject: [PATCH 11/12] Update src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs Co-authored-by: Malcolm Daigle --- .../src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index e779224fe4..c0f98a8c5c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -31,7 +31,7 @@ private protected virtual void Initialize() internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string serverSpn) => SSPIData(receivedBuff, outgoingBlobWriter, new[] { serverSpn }); - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string[] serverSpn) + internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string[] serverSpns) { using (TrySNIEventScope.Create(nameof(SSPIContextProvider))) { From c6c95c25dce3b5741131f401abbf3582175f3d13 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Thu, 27 Feb 2025 14:23:47 -0800 Subject: [PATCH 12/12] naming consistency and other feedback --- .../Data/SqlClient/TdsParserStateObjectManaged.cs | 2 +- .../Data/SqlClient/TdsParserStateObjectNative.cs | 2 +- .../Microsoft/Data/SqlClient/BufferWriterExtensions.cs | 9 +++++---- .../Data/SqlClient/SSPI/NativeSSPIContextProvider.cs | 2 +- .../Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs | 9 ++++----- .../Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs | 2 +- .../src/Microsoft/Data/SqlClient/SqlObjectPool.cs | 2 -- 7 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 083027c36a..66606e26b8 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle( string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spns, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index b07ac26b7b..49192f21e4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -178,7 +178,7 @@ internal override void CreatePhysicalSNIHandle( _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); - spn = new[] { serverSPN.TrimEnd() }; + spns = new[] { serverSPN.TrimEnd() }; } protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs index efd15b1602..e23cff2dcc 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/BufferWriterExtensions.cs @@ -1,4 +1,5 @@ -using System.Buffers; +using System; +using System.Buffers; using System.Text; namespace Microsoft.Data.SqlClient @@ -12,9 +13,9 @@ internal static long GetBytes(this Encoding encoding, string str, IBufferWriter< try { - encoding.GetBytes(str, 0, str.Length, array, 0); - bufferWriter.Write(array); - return count; + var length = encoding.GetBytes(str, 0, str.Length, array, 0); + bufferWriter.Write(array.AsSpan(0, length)); + return length; } finally { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index facafc35cd..0a2fa8aeb7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -62,7 +62,7 @@ protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlo var sendLength = s_maxSSPILength; var outBuff = outgoingBlobWriter.GetSpan((int)sendLength); - if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpn[0])) + if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpns[0])) { throw new InvalidOperationException(SQLMessage.SSPIGenerateError()); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs index 719d1d681d..9a4eb457a4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs @@ -1,7 +1,6 @@ #if NET using System; -using System.Text; using System.Net.Security; using System.Buffers; @@ -12,14 +11,14 @@ namespace Microsoft.Data.SqlClient internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider { private NegotiateAuthentication? _negotiateAuth = null; - - protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverNames) + + protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpns) { NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials; - for (int i = 0; i < serverNames.Length; i++) + for (int i = 0; i < serverSpns.Length; i++) { - _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverNames[i] }); + _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverSpns[i] }); var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!; // Log session id, status code and the actual SPN used in the negotiation diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index c0f98a8c5c..6aef7bfbff 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -37,7 +37,7 @@ internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outg { try { - GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn); + GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpns); } catch (Exception e) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs index 94b9be977a..d5cf2398ec 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlObjectPool.cs @@ -3,9 +3,7 @@ // See the LICENSE file in the project root for more information. using System; -using System.Buffers; using System.Diagnostics; -using System.Text; using System.Threading; namespace Microsoft.Data.SqlClient