diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 6246fc590d..1022ff06ac 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -188,6 +188,9 @@ Microsoft\Data\SqlClient\OnChangedEventHandler.cs + + Microsoft\Data\SqlClient\Packet.cs + Microsoft\Data\SqlClient\ParameterPeekAheadValue.cs @@ -581,6 +584,9 @@ Microsoft\Data\SqlClient\TdsParserStateObject.cs + + Microsoft\Data\SqlClient\TdsParserStateObject.Multiplexer.cs + Microsoft\Data\SqlClient\TdsParserStaticMethods.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index f0ae13e6cd..05fb82f7e4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -3724,7 +3724,7 @@ private TdsOperationStatus TryNextResult(out bool more) /// // user must call Read() to position on the first row - override public bool Read() + public override bool Read() { if (_currentTask != null) { @@ -4564,9 +4564,10 @@ internal TdsOperationStatus TrySetMetaData(_SqlMetaDataSet metaData, bool moreIn _metaDataConsumed = true; if (_parser != null) - { // There is a valid case where parser is null - // Peek, and if row token present, set _hasRows true since there is a - // row in the result + { + // There is a valid case where parser is null + // Peek, and if row token present, set _hasRows true since there is a + // row in the result byte b; TdsOperationStatus result = _stateObj.TryPeekByte(out b); if (result != TdsOperationStatus.Done) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs index e2d86dc210..7c10f0aa4f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs @@ -25,6 +25,10 @@ internal void PostReadAsyncForMars() _pMarsPhysicalConObj.IncrementPendingCallbacks(); SessionHandle handle = _pMarsPhysicalConObj.SessionHandle; + // we do not need to consider partial packets when making this read because we + // expect this read to pend. a partial packet should not exist at setup of the + // parser + Debug.Assert(_physicalStateObj.PartialPacket==null); temp = _pMarsPhysicalConObj.ReadAsync(handle, out error); Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index af920ed2a7..3f34ec4ded 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2047,11 +2047,19 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle if (!IsValidTdsToken(token)) { - Debug.Fail($"unexpected token; token = {token,-2:X2}"); +#if DEBUG + string message = stateObj.DumpBuffer(); + Debug.Fail(message); +#endif _state = TdsParserState.Broken; _connHandler.BreakConnection(); SqlClientEventSource.Log.TryTraceEvent(" Potential multi-threaded misuse of connection, unexpected TDS token found {0}", ObjectID); +#if DEBUG + throw new InvalidOperationException(message); +#else throw SQL.ParsingError(); +#endif + } int tokenLength; @@ -4143,6 +4151,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length, TdsParserStateObje { return result; } + byte len; result = stateObj.TryReadByte(out len); if (result != TdsOperationStatus.Done) @@ -4540,7 +4549,6 @@ internal TdsOperationStatus TryProcessCollation(TdsParserStateObject stateObj, o collation = null; return result; } - if (SqlCollation.Equals(_cachedCollation, info, sortId)) { collation = _cachedCollation; @@ -5263,7 +5271,7 @@ private TdsOperationStatus TryCommonProcessMetaData(TdsParserStateObject stateOb { // If the column is encrypted, we should have a valid cipherTable if (cipherTable != null) - { + { result = TryProcessTceCryptoMetadata(stateObj, col, cipherTable, columnEncryptionSetting, isReturnValue: false); if (result != TdsOperationStatus.Done) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 4f655dc403..aa6986ca2d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -322,14 +322,22 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) stateObj.SendAttention(mustTakeWriteLock: true); PacketHandle syncReadPacket = default; + bool readFromNetwork = true; RuntimeHelpers.PrepareConstrainedRegions(); bool shouldDecrement = false; try { Interlocked.Increment(ref _readingCount); shouldDecrement = true; - - syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) + { + syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + } + else + { + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); shouldDecrement = false; @@ -342,7 +350,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) } else { - Debug.Assert(!IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); + Debug.Assert(!readFromNetwork || !IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); fail = true; // Subsequent read failed, time to give up. } } @@ -353,7 +361,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) Interlocked.Decrement(ref _readingCount); } - if (!IsPacketEmpty(syncReadPacket)) + if (readFromNetwork && !IsPacketEmpty(syncReadPacket)) { ReleasePacket(syncReadPacket); } @@ -393,60 +401,9 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(PacketHandle packet, uint error) + private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { - if (error != 0) - { - if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) - { - // Do nothing with callback if closed or broken and error not 0 - callback can occur - // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. - return; - } - - AddError(_parser.ProcessSNIError(this)); - AssertValidState(); - } - else - { - uint dataSize = 0; - - uint getDataError = SNIPacketGetData(packet, _inBuff, ref dataSize); - - if (getDataError == TdsEnums.SNI_SUCCESS) - { - if (_inBuff.Length < dataSize) - { - Debug.Assert(true, "Unexpected dataSize on Read"); - throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); - } - - _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; - _inBytesRead = (int)dataSize; - _inBytesUsed = 0; - - if (_snapshot != null) - { - _snapshot.AppendPacketData(_inBuff, _inBytesRead); - if (_snapshotReplay) - { - _snapshot.MoveNext(); -#if DEBUG - _snapshot.AssertCurrent(); -#endif - } - } - - SniReadStatisticsAndTracing(); - SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer: {1}, In Bytes Read: {2}", ObjectID, _inBuff, _inBytesRead); - - AssertValidState(); - } - else - { - throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); - } - } + return SNIPacketGetData(packet, _inBuff, ref dataSize); } private void ChangeNetworkPacketTimeout(int dueTime, int period) @@ -535,7 +492,7 @@ public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) bool processFinallyBlock = true; try { - Debug.Assert(CheckPacket(packet, source) && source != null, "AsyncResult null on callback"); + Debug.Assert((packet.Type == 0 && PartialPacketContainsCompletePacket()) || (CheckPacket(packet, source) && source != null), "AsyncResult null on callback"); if (_parser.MARSOn) { @@ -547,7 +504,7 @@ public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) // The timer thread may be unreliable under high contention scenarios. It cannot be // assumed that the timeout has happened on the timer thread callback. Check the timeout - // synchrnously and then call OnTimeoutSync to force an atomic change of state. + // synchronously and then call OnTimeoutSync to force an atomic change of state. if (TimeoutHasExpired) { OnTimeoutSync(asyncClose: true); @@ -1633,7 +1590,7 @@ internal void AssertStateIsClean() if ((parser != null) && (parser.State != TdsParserState.Closed) && (parser.State != TdsParserState.Broken)) { // Async reads - Debug.Assert(_snapshot == null && !_snapshotReplay, "StateObj has leftover snapshot state"); + Debug.Assert(_snapshot == null && _snapshotStatus == SnapshotStatus.NotActive, "StateObj has leftover snapshot state"); Debug.Assert(!_asyncReadWithoutSnapshot, "StateObj has AsyncReadWithoutSnapshot still enabled"); Debug.Assert(_executionContext == null, "StateObj has a stored execution context from an async read"); // Async writes diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 591977ef2d..feecd6f44f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -391,6 +391,9 @@ Microsoft\Data\SqlClient\OnChangedEventHandler.cs + + Microsoft\Data\SqlClient\Packet.cs + Microsoft\Data\SqlClient\ParameterPeekAheadValue.cs @@ -772,6 +775,9 @@ Microsoft\Data\SqlClient\TdsParserStateObject.cs + + Microsoft\Data\SqlClient\TdsParserStateObject.Multiplexer.cs + Microsoft\Data\SqlClient\TdsParserStaticMethods.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index e45765c0a4..a44cabd134 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -4364,6 +4364,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length, return result; } } + byte len; result = stateObj.TryReadByte(out len); if (result != TdsOperationStatus.Done) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 0ce58d120a..343778cb4e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -453,14 +453,22 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) stateObj.SendAttention(mustTakeWriteLock: true); PacketHandle syncReadPacket = default; + bool readFromNetwork = true; RuntimeHelpers.PrepareConstrainedRegions(); bool shouldDecrement = false; try { Interlocked.Increment(ref _readingCount); shouldDecrement = true; - - syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) + { + syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + } + else + { + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); shouldDecrement = false; @@ -473,7 +481,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) } else { - Debug.Assert(!IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); + Debug.Assert(!readFromNetwork || !IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); fail = true; // Subsequent read failed, time to give up. } } @@ -484,7 +492,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) Interlocked.Decrement(ref _readingCount); } - if (!IsPacketEmpty(syncReadPacket)) + if (readFromNetwork && !IsPacketEmpty(syncReadPacket)) { // Be sure to release packet, otherwise it will be leaked by native. ReleasePacket(syncReadPacket); @@ -525,60 +533,9 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(PacketHandle packet, uint error) + private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { - if (error != 0) - { - if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) - { - // Do nothing with callback if closed or broken and error not 0 - callback can occur - // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. - return; - } - - AddError(_parser.ProcessSNIError(this)); - AssertValidState(); - } - else - { - uint dataSize = 0; - - uint getDataError = SniNativeWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); - - if (getDataError == TdsEnums.SNI_SUCCESS) - { - if (_inBuff.Length < dataSize) - { - Debug.Assert(true, "Unexpected dataSize on Read"); - throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); - } - - _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; - _inBytesRead = (int)dataSize; - _inBytesUsed = 0; - - if (_snapshot != null) - { - _snapshot.AppendPacketData(_inBuff, _inBytesRead); - if (_snapshotReplay) - { - _snapshot.MoveNext(); -#if DEBUG - _snapshot.AssertCurrent(); -#endif - } - } - - SniReadStatisticsAndTracing(); - SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer: {1}, In Bytes Read: {2}", ObjectID, _inBuff, _inBytesRead); - - AssertValidState(); - } - else - { - throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); - } - } + return SniNativeWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); } private void ChangeNetworkPacketTimeout(int dueTime, int period) @@ -1774,7 +1731,7 @@ internal void AssertStateIsClean() if ((parser != null) && (parser.State != TdsParserState.Closed) && (parser.State != TdsParserState.Broken)) { // Async reads - Debug.Assert(_snapshot == null && !_snapshotReplay, "StateObj has leftover snapshot state"); + Debug.Assert(_snapshot == null && _snapshotStatus == SnapshotStatus.NotActive, "StateObj has leftover snapshot state"); Debug.Assert(!_asyncReadWithoutSnapshot, "StateObj has AsyncReadWithoutSnapshot still enabled"); Debug.Assert(_executionContext == null, "StateObj has a stored execution context from an async read"); // Async writes diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs index 9ea289f5b7..b66154a2ae 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs @@ -20,6 +20,7 @@ private enum Tristate : byte internal const string SuppressInsecureTLSWarningString = @"Switch.Microsoft.Data.SqlClient.SuppressInsecureTLSWarning"; internal const string UseMinimumLoginTimeoutString = @"Switch.Microsoft.Data.SqlClient.UseOneSecFloorInTimeoutCalculationDuringLogin"; internal const string LegacyVarTimeZeroScaleBehaviourString = @"Switch.Microsoft.Data.SqlClient.LegacyVarTimeZeroScaleBehaviour"; + internal const string UseCompatibilityProcessSniString = @"Switch.Microsoft.Data.SqlClient.UseCompatibilityProcessSni"; // this field is accessed through reflection in tests and should not be renamed or have the type changed without refactoring NullRow related tests private static Tristate s_legacyRowVersionNullBehavior; @@ -28,6 +29,7 @@ private enum Tristate : byte private static Tristate s_useMinimumLoginTimeout; // this field is accessed through reflection in Microsoft.Data.SqlClient.Tests.SqlParameterTests and should not be renamed or have the type changed without refactoring related tests private static Tristate s_legacyVarTimeZeroScaleBehaviour; + private static Tristate s_useCompatProcessSni; #if NET static LocalAppContextSwitches() @@ -83,6 +85,24 @@ public static bool DisableTNIRByDefault } } #endif + public static bool UseCompatibilityProcessSni + { + get + { + if (s_useCompatProcessSni == Tristate.NotInitialized) + { + if (AppContext.TryGetSwitch(UseCompatibilityProcessSniString, out bool returnedValue) && returnedValue) + { + s_useCompatProcessSni = Tristate.True; + } + else + { + s_useCompatProcessSni = Tristate.False; + } + } + return s_useCompatProcessSni == Tristate.True; + } + } /// /// When using Encrypt=false in the connection string, a security warning is output to the console if the TLS version is 1.2 or lower. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs new file mode 100644 index 0000000000..b81270bf08 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs @@ -0,0 +1,189 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace Microsoft.Data.SqlClient +{ + /// + /// Contains a buffer for a partial or full packet and methods to get information about the status of + /// the packet that the buffer represents.
+ /// This class is used to contain partial packet data and helps ensure that the packet data is completely + /// received before a full packet is made available to the rest of the library + ///
+ internal sealed partial class Packet + { + public const int UnknownDataLength = -1; + + private int _dataLength; + private int _totalLength; + private byte[] _buffer; + + public Packet() + { + _dataLength = UnknownDataLength; + } + + /// + /// If the packet data has enough bytes available to determine the amount of data that should be present + /// in the packet then this property will be set to the count of data bytes in the packet,
+ /// otherwise this will be -1 + ///
+ public int DataLength + { + get + { + CheckDisposed(); + return _dataLength; + } + set + { + CheckDisposed(); + _dataLength = value; + } + } + + /// + /// A byte array containing bytes of data + /// + public byte[] Buffer + { + get + { + CheckDisposed(); + return _buffer; + } + set + { + CheckDisposed(); + _buffer = value; + } + } + + /// + /// The total count of bytes currently in the array including the tds header bytes + /// + public int CurrentLength + { + get + { + CheckDisposed(); + return _totalLength; + } + set + { + CheckDisposed(); + _totalLength = value; + } + } + + /// + /// If the packet data has enough bytes available to determine the length amount of data that should be present + /// in the packet then this property will return the count of data bytes that are expected to be in the packet.
+ /// If there are not enough bytes to determine the data byte count then this property will throw an exception.
+ ///
+ /// Call to check if there will be a value before using this property. + ///
+ public int RequiredLength + { + get + { + CheckDisposed(); + if (!HasDataLength) + { + throw new InvalidOperationException($"cannot get {nameof(RequiredLength)} when {nameof(HasDataLength)} is false"); + } + return TdsEnums.HEADER_LEN + _dataLength; + } + } + + /// + /// returns a boolean value indicating if there are enough total bytes available in the to read the tds header + /// + public bool HasHeader => _totalLength >= TdsEnums.HEADER_LEN; + + /// + /// returns a boolean value indicating if the value has been set. + /// + public bool HasDataLength => _dataLength >= 0; + + /// + /// returns a boolean value indicating whether the contains enough + /// data for a valid tds header, has a set and that the + /// is greater than or equal to the + tds header length.
+ ///
+ public bool ContainsCompletePacket => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) <= _totalLength; + + /// + /// returns a containing the first 8 bytes of the array which will + /// contain the TDS header bytes. This can be passed to static functions on to extract information from the + /// tds packet header.
+ /// Call before using this function. + ///
+ /// + public ReadOnlySpan GetHeaderSpan() => _buffer.AsSpan(0, TdsEnums.HEADER_LEN); + + [Conditional("DEBUG")] + internal void CheckDisposed() => CheckDisposedImpl(); + + [Conditional("DEBUG")] + internal void SetCreatedBy(int creator) => SetCreatedByImpl(creator); + + partial void SetCreatedByImpl(int creator); + + partial void CheckDisposedImpl(); + + public static void ThrowDisposed() + { + throw new ObjectDisposedException(nameof(Packet)); + } + + internal static byte GetStatusFromHeader(ReadOnlySpan header) => header[1]; + + internal static int GetDataLengthFromHeader(ReadOnlySpan header) + { + return (header[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + } + internal static int GetSpidFromHeader(ReadOnlySpan header) + { + return (header[TdsEnums.SPID_OFFSET] << 8 | header[TdsEnums.SPID_OFFSET + 1]); + } + internal static int GetIDFromHeader(ReadOnlySpan header) + { + return header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4]; + } + + internal static int GetDataLengthFromHeader(Packet packet) => GetDataLengthFromHeader(packet.GetHeaderSpan()); + + internal static bool GetIsEOMFromHeader(ReadOnlySpan header) => GetStatusFromHeader(header) == 1; + } + +#if DEBUG + internal sealed partial class Packet + { + private int _createdBy; + private bool _disposed; + + public int CreatedBy => _createdBy; + + [Conditional("DEBUG")] + public void Dispose() + { + _disposed = true; + } + + partial void SetCreatedByImpl(int creator) => _createdBy = creator; + + partial void CheckDisposedImpl() + { + if (_disposed) + { + ThrowDisposed(); + } + } + } +#endif +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs new file mode 100644 index 0000000000..2bb72e9bf2 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -0,0 +1,559 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; + +namespace Microsoft.Data.SqlClient +{ +#if NETFRAMEWORK + using PacketHandle = IntPtr; +#endif + partial class TdsParserStateObject + { + private Packet _partialPacket; + internal Packet PartialPacket => _partialPacket; + + public void ProcessSniPacket(PacketHandle packet, uint error) + { + if (LocalAppContextSwitches.UseCompatibilityProcessSni) + { + ProcessSniPacketCompat(packet, error); + return; + } + + if (error != 0) + { + if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) + { + // Do nothing with callback if closed or broken and error not 0 - callback can occur + // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. + return; + } + + AddError(_parser.ProcessSNIError(this)); + AssertValidState(); + } + else + { + uint dataSize = 0; + bool usedPartialPacket = false; + uint getDataError = 0; + + if (PartialPacketContainsCompletePacket()) + { + Packet partialPacket = _partialPacket; + SetBuffer(partialPacket.Buffer, 0, partialPacket.CurrentLength); + ClearPartialPacket(); + getDataError = TdsEnums.SNI_SUCCESS; + usedPartialPacket = true; + } + else + { + if (_inBytesRead != 0) + { + SetBuffer(new byte[_inBuff.Length], 0, 0); + } + getDataError = GetSniPacket(packet, ref dataSize); + } + + if (getDataError == TdsEnums.SNI_SUCCESS) + { + if (_inBuff.Length < dataSize) + { + Debug.Assert(true, "Unexpected dataSize on Read"); + throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); + } + + if (!usedPartialPacket) + { + _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; + + SetBuffer(_inBuff, 0, (int)dataSize); + } + + bool recurse = false; + bool appended = false; + do + { + if (recurse && appended) + { + SetBuffer(new byte[_inBuff.Length], 0, 0); + appended = false; + } + MultiplexPackets( + _inBuff, _inBytesUsed, _inBytesRead, + PartialPacket, + out int newDataOffset, + out int newDataLength, + out Packet remainderPacket, + out bool consumeInputDirectly, + out bool consumePartialPacket, + out bool remainderPacketProduced, + out recurse + ); + bool bufferIsPartialCompleted = false; + + // if a partial packet was reconstructed it must be handled first + if (consumePartialPacket) + { + if (_snapshot != null) + { + _snapshot.AppendPacketData(PartialPacket.Buffer, PartialPacket.CurrentLength); + SetBuffer(new byte[_inBuff.Length], 0, 0); + appended = true; + } + else + { + SetBuffer(PartialPacket.Buffer, 0, PartialPacket.CurrentLength); + + } + bufferIsPartialCompleted = true; + ClearPartialPacket(); + } + + // if the remaining data can be processed directly it must be second + if (consumeInputDirectly) + { + // if some data was taken from the new packet adjust the counters + if (dataSize != newDataLength || 0 != newDataOffset) + { + SetBuffer(_inBuff, newDataOffset, newDataLength); + } + + if (_snapshot != null) + { + _snapshot.AppendPacketData(_inBuff, _inBytesRead); + // if we SetBuffer here to clear the packet buffer we will break the attention handling which relies + // on the attention containing packet remaining in the active buffer even if we're appending to the + // snapshot so we will have to use the appended variable to prevent the same buffer being added again + //// SetBuffer(new byte[_inBuff.Length], 0, 0); + appended = true; + } + else + { + SetBuffer(_inBuff, 0, _inBytesRead); + } + bufferIsPartialCompleted = true; + } + else + { + // whatever is in the input buffer should not be directly consumed + // and is contained in the partial or remainder packets so make sure + // we don't process it + if (!bufferIsPartialCompleted) + { + SetBuffer(_inBuff, 0, 0); + } + } + + // if there is a remainder it must be last + if (remainderPacketProduced) + { + SetPartialPacket(remainderPacket); + if (!bufferIsPartialCompleted) + { + // we are keeping the partial packet buffer so replace it with a new one + // unless we have already set the buffer to the partial packet buffer + SetBuffer(new byte[_inBuff.Length], 0, 0); + } + } + + } while (recurse && _snapshot != null); + + if (_snapshot != null) + { + if (_snapshotStatus != SnapshotStatus.NotActive && appended) + { + _snapshot.MoveNext(); + } + } + + SniReadStatisticsAndTracing(); + SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer {1}, In Bytes Read: {2}", ObjectID, _inBuff, (ushort)_inBytesRead); + + AssertValidState(); + } + else + { + throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); + } + } + } + + private void SetPartialPacket(Packet packet) + { + if (_partialPacket != null && packet != null) + { + throw new InvalidOperationException("partial packet cannot be non-null when setting to non=null"); + } + _partialPacket = packet; + } + + private void ClearPartialPacket() + { + Packet partialPacket = _partialPacket; + _partialPacket = null; +#if DEBUG + if (partialPacket != null) + { + partialPacket.Dispose(); + } +#endif + } + + // this check is used in two places that must be identical so it is + // extracted into a method, do not inline this method + internal bool PartialPacketContainsCompletePacket() + { + Packet partialPacket = _partialPacket; + return partialPacket != null && partialPacket.ContainsCompletePacket; + } + + private static void MultiplexPackets( + byte[] dataBuffer, int dataOffset, int dataLength, + Packet partialPacket, + out int newDataOffset, + out int newDataLength, + out Packet remainderPacket, + out bool consumeInputDirectly, + out bool consumePartialPacket, + out bool createdRemainderPacket, + out bool recurse + ) + { + Debug.Assert(dataBuffer != null); + + ReadOnlySpan data = dataBuffer.AsSpan(dataOffset, dataLength); + remainderPacket = null; + consumeInputDirectly = false; + consumePartialPacket = false; + createdRemainderPacket = false; + recurse = false; + + newDataLength = dataLength; + newDataOffset = dataOffset; + + int bytesConsumed = 0; + + if (partialPacket != null) + { + if (!partialPacket.HasDataLength) + { + // we need to get enough bytes to read the packet header + int headerBytesNeeded = Math.Max(0, TdsEnums.HEADER_LEN - partialPacket.CurrentLength); + if (headerBytesNeeded > 0) + { + int headerBytesAvailable = Math.Min(data.Length, headerBytesNeeded); + + Span headerTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, headerBytesAvailable); + ReadOnlySpan headerSource = data.Slice(0, headerBytesAvailable); + headerSource.CopyTo(headerTarget); + + partialPacket.CurrentLength = partialPacket.CurrentLength + headerBytesAvailable; + bytesConsumed += headerBytesAvailable; + data = data.Slice(headerBytesAvailable); + } + if (partialPacket.HasHeader) + { + partialPacket.DataLength = Packet.GetDataLengthFromHeader(partialPacket); + } + } + + if (partialPacket.HasDataLength) + { + if (partialPacket.CurrentLength < partialPacket.RequiredLength) + { + // the packet length is known so take as much data as possible from the incoming + // data to try and complete the packet + + int payloadBytesNeeded = partialPacket.DataLength - (partialPacket.CurrentLength - TdsEnums.HEADER_LEN); + int payloadBytesAvailable = Math.Min(data.Length, payloadBytesNeeded); + + ReadOnlySpan payloadSource = data.Slice(0, payloadBytesAvailable); + Span payloadTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, payloadBytesAvailable); + payloadSource.CopyTo(payloadTarget); + + partialPacket.CurrentLength = partialPacket.CurrentLength + payloadBytesAvailable; + bytesConsumed += payloadBytesAvailable; + data = data.Slice(payloadBytesAvailable); + } + else if (partialPacket.CurrentLength > partialPacket.RequiredLength) + { + // the partial packet contains a complete packet of data and also contains + // data from a following packet + + // the TDS spec requires that all packets be of the defined packet size apart from + // the last packet of a response. This means that is should not possible to have more than + // 2 packet fragments in a single buffer like this: + // - first packet caused the partial + // - second packet is the one we have just unpacked + // - third packet is the extra data we have found + // however, due to the timing of cancellation it is possible that a response token stream + // has ended before an attention message response is sent leaving us with a short final + // packet and an additional short cancel packet following it + + // this should only happen when the caller is trying to consume the partial packet + // and does not have new input data + + int remainderLength = partialPacket.CurrentLength - partialPacket.RequiredLength; + + partialPacket.CurrentLength = partialPacket.RequiredLength; + + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = remainderLength, + }; + remainderPacket.SetCreatedBy(1); + + ReadOnlySpan remainderSource = partialPacket.Buffer.AsSpan(TdsEnums.HEADER_LEN + partialPacket.DataLength, remainderLength); + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); + remainderSource.CopyTo(remainderTarget); + + createdRemainderPacket = true; + + recurse = SetupRemainderPacket(remainderPacket); + } + + if (partialPacket.CurrentLength == partialPacket.RequiredLength) + { + // partial packet has been completed + consumePartialPacket = true; + } + } + + if (bytesConsumed > 0) + { + if (data.Length > 0) + { + // some data has been taken from the buffer and put into the partial + // packet buffer. We have data left so move the data we have to the + // start of the buffer so we can pass the buffer back as zero based + // to the caller avoiding offset calculations in the rest of this method + Buffer.BlockCopy( + dataBuffer, dataOffset + bytesConsumed, // from + dataBuffer, dataOffset, // to + dataLength - bytesConsumed // for + ); +#if DEBUG + // for debugging purposes fill the removed data area with an easily + // recognisable pattern so we can see if it is misused + Span removed = dataBuffer.AsSpan(dataOffset + (dataLength - bytesConsumed), bytesConsumed); + removed.Fill(0xFF); +#endif + + // then recreate the data span so that we're looking at the data + // that has been moved + data = dataBuffer.AsSpan(dataOffset, dataLength - bytesConsumed); + } + + newDataLength = dataLength - bytesConsumed; + } + } + + // partial packet handling should not make decisions about consuming the input buffer + Debug.Assert(!consumeInputDirectly); + // partial packet handling may only create a remainder packet when it is trying to consume the partial packet and has no incoming data + Debug.Assert(!createdRemainderPacket || data.Length == 0); + + if (data.Length > 0) + { + if (data.Length >= TdsEnums.HEADER_LEN) + { + // we have enough bytes to read the packet header and see how + // much data we are expecting it to contain + int packetDataLength = Packet.GetDataLengthFromHeader(data); + + if (data.Length == TdsEnums.HEADER_LEN + packetDataLength) + { + if (!consumePartialPacket) + { + // we can tell the caller that they should directly consume the data in + // the input buffer, this is the happy path + consumeInputDirectly = true; + } + else + { + // we took some data from the input to reconstruct the partial packet + // so we can't tell the caller to directly consume the packet in the + // input buffer, we need to construct a new remainder packet and then + // tell them to consume it + remainderPacket = new Packet + { + Buffer = dataBuffer, + CurrentLength = data.Length + }; + remainderPacket.SetCreatedBy(2); + createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); + } + } + else if (data.Length < TdsEnums.HEADER_LEN + packetDataLength) + { + // an incomplete packet so create a remainder packet to pass back + remainderPacket = new Packet + { + Buffer = dataBuffer, + DataLength = packetDataLength, + CurrentLength = data.Length + }; + remainderPacket.SetCreatedBy(3); + createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); + } + else // implied: current length > required length + { + // more data than required so need to split it out, but we can't do that + // here so we need to tell the caller to take the remainder packet and then + // call this function again + if (consumePartialPacket) + { + // we are already telling the caller to consume the partial packet so we + // can't tell them it to also consume the data in the buffer directly + // so create a remainder packet and pass it back. + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = data.Length + }; + remainderPacket.SetCreatedBy(4); + ReadOnlySpan remainderSource = data; + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderPacket.CurrentLength); + remainderSource.CopyTo(remainderTarget); + + createdRemainderPacket = true; + + recurse = SetupRemainderPacket(remainderPacket); + } + else + { + newDataLength = TdsEnums.HEADER_LEN + packetDataLength; + int remainderLength = data.Length - (TdsEnums.HEADER_LEN + packetDataLength); + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = remainderLength + }; + remainderPacket.SetCreatedBy(5); + + ReadOnlySpan remainderSource = data.Slice(TdsEnums.HEADER_LEN + packetDataLength); + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); + remainderSource.CopyTo(remainderTarget); +#if DEBUG + // for debugging purposes fill the removed data area with an easily + // recognisable pattern so we can see if it is misused + Span removed = dataBuffer.AsSpan(TdsEnums.HEADER_LEN + packetDataLength, remainderLength); + removed.Fill(0xFF); +#endif + createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); + + consumeInputDirectly = true; + } + } + } + else + { + // either: + // 1) we took some data from the input to reconstruct the partial packet + // 2) there was less than a single packet header of data received + // in both cases we can't tell the caller to directly consume the packet + // in the input buffer, we need to construct a new remainder packet with + // the incomplete data and let the caller deal with it + remainderPacket = new Packet + { + Buffer = dataBuffer, + CurrentLength = data.Length + }; + remainderPacket.SetCreatedBy(6); + createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); + } + } + + if (consumePartialPacket && consumeInputDirectly) + { + throw new InvalidOperationException($"MultiplexPackets cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"); + } + } + + private static bool SetupRemainderPacket(Packet packet) + { + Debug.Assert(packet != null); + bool containsFullPacket = false; + if (packet.HasHeader) + { + packet.DataLength = Packet.GetDataLengthFromHeader(packet); + if (packet.HasDataLength && packet.CurrentLength >= packet.RequiredLength) + { + containsFullPacket = true; + } + } + + return containsFullPacket; + } + + + public void ProcessSniPacketCompat(PacketHandle packet, uint error) + { + if (error != 0) + { + if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) + { + // Do nothing with callback if closed or broken and error not 0 - callback can occur + // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. + return; + } + + AddError(_parser.ProcessSNIError(this)); + AssertValidState(); + } + else + { + uint dataSize = 0; + + uint getDataError = +#if NETFRAMEWORK + SniNativeWrapper. +#endif + SNIPacketGetData(packet, _inBuff, ref dataSize); + + if (getDataError == TdsEnums.SNI_SUCCESS) + { + if (_inBuff.Length < dataSize) + { + Debug.Assert(true, "Unexpected dataSize on Read"); + throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); + } + + _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; + _inBytesRead = (int)dataSize; + _inBytesUsed = 0; + + if (_snapshot != null) + { + _snapshot.AppendPacketData(_inBuff, _inBytesRead); + if (_snapshotStatus != SnapshotStatus.NotActive) + { + _snapshot.MoveNext(); +#if DEBUG + _snapshot.AssertCurrent(); +#endif + } + } + + SniReadStatisticsAndTracing(); + SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer: {1}, In Bytes Read: {2}", ObjectID, _inBuff, _inBytesRead); + + AssertValidState(); + } + else + { + throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); + } + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index b3cf628594..fe5ce76614 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Security; +using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -64,6 +65,13 @@ public TimeoutState(int value) public int IdentityValue => _value; } + private enum SnapshotStatus + { + NotActive, + ReplayStarting, + ReplayRunning + } + private const int AttentionTimeoutSeconds = 5; // Ticks to consider a connection "good" after a successful I/O (10,000 ticks = 1 ms) @@ -215,7 +223,7 @@ public TimeoutState(int value) internal TaskCompletionSource _networkPacketTaskSource; private Timer _networkPacketTimeout; internal bool _syncOverAsync = true; - private bool _snapshotReplay; + private SnapshotStatus _snapshotStatus; private StateSnapshot _snapshot; private StateSnapshot _cachedSnapshot; internal ExecutionContext _executionContext; @@ -939,13 +947,11 @@ internal TdsOperationStatus TryProcessHeader() if (_partialHeaderBytesRead == _inputHeaderLen) { // All read + ReadOnlySpan header = _partialHeaderBuffer.AsSpan(0, TdsEnums.HEADER_LEN); _partialHeaderBytesRead = 0; - _inBytesPacket = ((int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - (int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; - - _messageStatus = _partialHeaderBuffer[1]; - _spid = _partialHeaderBuffer[TdsEnums.SPID_OFFSET] << 8 | - _partialHeaderBuffer[TdsEnums.SPID_OFFSET + 1]; + _messageStatus = Packet.GetStatusFromHeader(header); + _inBytesPacket = Packet.GetDataLengthFromHeader(header); + _spid = Packet.GetSpidFromHeader(header); SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); } @@ -981,11 +987,10 @@ internal TdsOperationStatus TryProcessHeader() else { // normal header processing... - _messageStatus = _inBuff[_inBytesUsed + 1]; - _inBytesPacket = (_inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; - _spid = _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET + 1]; + ReadOnlySpan header = _inBuff.AsSpan(_inBytesUsed, TdsEnums.HEADER_LEN); + _messageStatus = Packet.GetStatusFromHeader(header); + _inBytesPacket = Packet.GetDataLengthFromHeader(header); + _spid = Packet.GetSpidFromHeader(header); #if NET SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); #endif @@ -1127,9 +1132,7 @@ internal bool SetPacketSize(int size) // Allocate or re-allocate _inBuff. if (_inBuff == null) { - _inBuff = new byte[size]; - _inBytesRead = 0; - _inBytesUsed = 0; + SetBuffer(new byte[size], 0, 0); } else if (size != _inBuff.Length) { @@ -1139,28 +1142,24 @@ internal bool SetPacketSize(int size) // if we still have data left in the buffer we must keep that array reference and then copy into new one byte[] temp = _inBuff; - _inBuff = new byte[size]; - // copy remainder of unused data int remainingData = _inBytesRead - _inBytesUsed; - if ((temp.Length < _inBytesUsed + remainingData) || (_inBuff.Length < remainingData)) + if ((temp.Length < _inBytesUsed + remainingData) || (size < remainingData)) { - string errormessage = StringsHelper.GetString(Strings.SQL_InvalidInternalPacketSize) + ' ' + temp.Length + ", " + _inBytesUsed + ", " + remainingData + ", " + _inBuff.Length; + string errormessage = StringsHelper.GetString(Strings.SQL_InvalidInternalPacketSize) + ' ' + temp.Length + ", " + _inBytesUsed + ", " + remainingData + ", " + size; throw SQL.InvalidInternalPacketSize(errormessage); } - Buffer.BlockCopy(temp, _inBytesUsed, _inBuff, 0, remainingData); - _inBytesRead = _inBytesRead - _inBytesUsed; - _inBytesUsed = 0; + byte[] inBuff = new byte[size]; + Buffer.BlockCopy(temp, _inBytesUsed, inBuff, 0, remainingData); + SetBuffer(inBuff, 0, remainingData); AssertValidState(); } else { // buffer is empty - just create the new one that is double the size of the old one - _inBuff = new byte[size]; - _inBytesRead = 0; - _inBytesUsed = 0; + SetBuffer(new byte[size], 0, 0); } } @@ -1385,7 +1384,7 @@ internal TdsOperationStatus TryReadInt32(out int value) TdsOperationStatus result = TryReadByteArray(buffer, 4); if (result != TdsOperationStatus.Done) { - value = default; + value = 0; return result; } } @@ -1832,6 +1831,7 @@ internal int ReadPlpBytesChunk(byte[] buff, int offset, int len) // Every time you call this method increment the offset and decrease len by the value of totalBytesRead internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len, out int totalBytesRead) { + totalBytesRead = 0; int bytesRead; int bytesLeft; byte[] newbuf; @@ -1857,7 +1857,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len // If total length is known up front, allocate the whole buffer in one shot instead of realloc'ing and copying over each time if (buff == null && _longlen != TdsEnums.SQL_PLP_UNKNOWNLEN) { - if (_snapshot != null) + if (_snapshot != null && _snapshotStatus != SnapshotStatus.NotActive) { // if there is a snapshot and it contains a stored plp buffer take it // and try to use it if it is the right length @@ -1892,9 +1892,6 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len { buff = new byte[_longlenleft]; } - - totalBytesRead = 0; - while (bytesLeft > 0) { int bytesToRead = (int)Math.Min(_longlenleft, (ulong)bytesLeft); @@ -2000,13 +1997,17 @@ internal TdsOperationStatus TryReadNetworkPacket() TdsOperationStatus result = TdsOperationStatus.InvalidData; if (_snapshot != null) { - if (_snapshotReplay) + if (_snapshotStatus != SnapshotStatus.NotActive) { #if DEBUG - // in debug builds stack traces contain line numbers so if we want to be - // able to compare the stack traces they must all be created in the same - // location in the code - string stackTrace = Environment.StackTrace; + string stackTrace = null; + if (s_checkNetworkPacketRetryStacks) + { + // in debug builds stack traces contain line numbers so if we want to be + // able to compare the stack traces they must all be created in the same + // location in the code + stackTrace = Environment.StackTrace; + } #endif if (_snapshot.MoveNext()) { @@ -2018,24 +2019,39 @@ internal TdsOperationStatus TryReadNetworkPacket() #endif return TdsOperationStatus.Done; } -#if DEBUG else { +#if DEBUG if (s_checkNetworkPacketRetryStacks) { _lastStack = stackTrace; } - } #endif + } } // previous buffer is in snapshot _inBuff = new byte[_inBuff.Length]; + result = TdsOperationStatus.NeedMoreData; + } + + if (result == TdsOperationStatus.InvalidData && PartialPacket != null && !PartialPacket.ContainsCompletePacket) + { + result = TdsOperationStatus.NeedMoreData; } if (_syncOverAsync) { ReadSniSyncOverAsync(); + while (_inBytesRead == 0) + { + // a partial packet must have taken the packet data so we + // need to read more data to complete the packet, but we + // can't return NeedMoreData in sync mode so we have to + // spin fetching more data here until we have something + // that the caller can read + ReadSniSyncOverAsync(); + } return TdsOperationStatus.Done; } @@ -2070,7 +2086,7 @@ internal void ReadSniSyncOverAsync() } PacketHandle readPacket = default; - + bool readFromNetwork = !PartialPacketContainsCompletePacket(); uint error; RuntimeHelpers.PrepareConstrainedRegions(); @@ -2082,7 +2098,14 @@ internal void ReadSniSyncOverAsync() Interlocked.Increment(ref _readingCount); shouldDecrement = true; - readPacket = ReadSyncOverAsync(GetTimeoutRemaining(), out error); + if (readFromNetwork) + { + readPacket = ReadSyncOverAsync(GetTimeoutRemaining(), out error); + } + else + { + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); shouldDecrement = false; @@ -2093,11 +2116,15 @@ internal void ReadSniSyncOverAsync() } if (TdsEnums.SNI_SUCCESS == error) - { // Success - process results! + { + // Success - process results! - Debug.Assert(!IsPacketEmpty(readPacket), "ReadNetworkPacket cannot be null in synchronous operation!"); + if (readFromNetwork) + { + Debug.Assert(!IsPacketEmpty(readPacket), "ReadNetworkPacket cannot be null in synchronous operation!"); + } - ProcessSniPacket(readPacket, 0); + ProcessSniPacket(readPacket, TdsEnums.SNI_SUCCESS); #if DEBUG if (s_forcePendingReadsToWaitForUser) { @@ -2109,9 +2136,12 @@ internal void ReadSniSyncOverAsync() #endif } else - { // Failure! - - Debug.Assert(!IsValidPacket(readPacket), "unexpected readPacket without corresponding SNIPacketRelease"); + { + // Failure! + if (readFromNetwork) + { + Debug.Assert(!IsValidPacket(readPacket), "unexpected readPacket without corresponding SNIPacketRelease"); + } ReadSniError(this, error); } @@ -2123,9 +2153,12 @@ internal void ReadSniSyncOverAsync() Interlocked.Decrement(ref _readingCount); } - if (!IsPacketEmpty(readPacket)) + if (readFromNetwork) { - ReleasePacket(readPacket); + if (!IsPacketEmpty(readPacket)) + { + ReleasePacket(readPacket); + } } AssertValidState(); @@ -2349,6 +2382,7 @@ internal void ReadSni(TaskCompletionSource completion) PacketHandle readPacket = default; uint error = 0; + bool readFromNetwork = true; RuntimeHelpers.PrepareConstrainedRegions(); try @@ -2392,21 +2426,35 @@ internal void ReadSni(TaskCompletionSource completion) finally { Interlocked.Increment(ref _readingCount); - - handle = SessionHandle; - if (!handle.IsNull) + try { - IncrementPendingCallbacks(); + handle = SessionHandle; + + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) + { + if (!handle.IsNull) + { + IncrementPendingCallbacks(); - readPacket = ReadAsync(handle, out error); + readPacket = ReadAsync(handle, out error); - if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) + if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) + { + DecrementPendingCallbacks(false); // Failure - we won't receive callback! + } + } + } + else { - DecrementPendingCallbacks(false); // Failure - we won't receive callback! + readPacket = default; + error = TdsEnums.SNI_SUCCESS; } } - - Interlocked.Decrement(ref _readingCount); + finally + { + Interlocked.Decrement(ref _readingCount); + } } if (handle.IsNull) @@ -2416,12 +2464,12 @@ internal void ReadSni(TaskCompletionSource completion) if (TdsEnums.SNI_SUCCESS == error) { // Success - process results! - Debug.Assert(IsValidPacket(readPacket), "ReadNetworkPacket should not have been null on this async operation!"); + Debug.Assert(!readFromNetwork || IsValidPacket(readPacket) , "ReadNetworkPacket should not have been null on this async operation!"); // Evaluate this condition for MANAGED_SNI. This may not be needed because the network call is happening Async and only the callback can receive a success. ReadAsyncCallback(IntPtr.Zero, readPacket, 0); // Only release packet for Managed SNI as for Native SNI packet is released in finally block. - if (TdsParserStateObjectFactory.UseManagedSNI && !IsPacketEmpty(readPacket)) + if (TdsParserStateObjectFactory.UseManagedSNI && readFromNetwork && !IsPacketEmpty(readPacket)) { ReleasePacket(readPacket); } @@ -2459,7 +2507,7 @@ internal void ReadSni(TaskCompletionSource completion) { if (!TdsParserStateObjectFactory.UseManagedSNI) { - if (!IsPacketEmpty(readPacket)) + if (readFromNetwork && !IsPacketEmpty(readPacket)) { // Be sure to release packet, otherwise it will be leaked by native. ReleasePacket(readPacket); @@ -2526,44 +2574,54 @@ internal bool IsConnectionAlive(bool throwOnException) return isAlive; } - /* - - // leave this in. comes handy if you have to do Console.WriteLine style debugging ;) - private void DumpBuffer() { - Console.WriteLine("dumping buffer"); - Console.WriteLine("_inBytesRead = {0}", _inBytesRead); - Console.WriteLine("_inBytesUsed = {0}", _inBytesUsed); + /// + /// Creates a human-readable message containing the _inBytesRead, _inBytesUsed counters + /// and the used and unused portions of the _inBuff array to help diagnosing problems with + /// packet parsing. + /// + /// + internal string DumpBuffer() + { + StringBuilder buffer = new StringBuilder(128); + buffer.AppendLine("dumping buffer"); + buffer.AppendFormat("_inBytesRead = {0}", _inBytesRead).AppendLine(); + buffer.AppendFormat("_inBytesUsed = {0}", _inBytesUsed).AppendLine(); int cc = 0; // character counter int i; - Console.WriteLine("used buffer:"); - for (i=0; i< _inBytesUsed; i++) { + buffer.AppendLine("used buffer:"); + for (i=0; i< _inBytesUsed; i++) + { if (cc==16) { - Console.WriteLine(); + buffer.AppendLine(); cc = 0; } - Console.Write("{0,-2:X2} ", _inBuff[i]); + buffer.AppendFormat("{0,-2:X2} ", _inBuff[i]); cc++; } - if (cc>0) { - Console.WriteLine(); + if (cc>0) + { + buffer.AppendLine(); } cc = 0; - Console.WriteLine("unused buffer:"); - for (i=_inBytesUsed; i<_inBytesRead; i++) { - if (cc==16) { - Console.WriteLine(); + buffer.AppendLine("unused buffer:"); + for (i=_inBytesUsed; i<_inBytesRead; i++) + { + if (cc==16) + { + buffer.AppendLine(); cc = 0; } - Console.Write("{0,-2:X2} ", _inBuff[i]); + buffer.AppendFormat("{0,-2:X2} ", _inBuff[i]); cc++; } - if (cc>0) { - Console.WriteLine(); + if (cc>0) + { + buffer.AppendLine(); } + return buffer.ToString(); } - */ - + internal void SetSnapshot() { StateSnapshot snapshot = _snapshot; @@ -2577,7 +2635,7 @@ internal void SetSnapshot() } _snapshot = snapshot; _snapshot.CaptureAsStart(this); - _snapshotReplay = false; + _snapshotStatus = SnapshotStatus.NotActive; } internal void ResetSnapshot() @@ -2589,7 +2647,7 @@ internal void ResetSnapshot() snapshot.Clear(); Interlocked.CompareExchange(ref _cachedSnapshot, snapshot, null); } - _snapshotReplay = false; + _snapshotStatus = SnapshotStatus.NotActive; } sealed partial class StateSnapshot @@ -2611,15 +2669,21 @@ internal void Clear() PrevPacket.NextPacket = null; PrevPacket = null; } - SetDebugStackInternal(null); - SetDebugPacketIdInternal(0); + SetDebugStackImpl(null); + SetDebugPacketId(0); + SetDebugDataHash(); } - internal void SetDebugStack(string value) => SetDebugStackInternal(value); - internal void SetDebugPacketId(int value) => SetDebugPacketIdInternal(value); + internal void SetDebugStack(string value) => SetDebugStackImpl(value); + internal void SetDebugPacketId(int value) => SetDebugPacketIdImpl(value); + internal void SetDebugDataHash() => SetDebugDataHashImpl(); + + internal void CheckDebugDataHash() => CheckDebugDataHashImpl(); - partial void SetDebugStackInternal(string value); - partial void SetDebugPacketIdInternal(int value); + partial void SetDebugStackImpl(string value); + partial void SetDebugPacketIdImpl(int value); + partial void SetDebugDataHashImpl(); + partial void CheckDebugDataHashImpl(); } #if DEBUG @@ -2641,84 +2705,184 @@ public PacketDataDebugView(PacketData data) _data = data; } - [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] - public PacketData[] Items + public string Type { + + get + { + if (_data != null && _data.Buffer!=null) + { + switch (_data.Buffer[0]) + { + case 1: return nameof(TdsEnums.MT_SQL); + case 2: return nameof(TdsEnums.MT_LOGIN); + case 3: return nameof(TdsEnums.MT_RPC); + case 4: return nameof(TdsEnums.MT_TOKENS); + case 5: return nameof(TdsEnums.MT_BINARY); + case 6: return nameof(TdsEnums.MT_ATTN); + case 7: return nameof(TdsEnums.MT_BULK); + case 8: return nameof(TdsEnums.MT_FEDAUTH); + case 9: return nameof(TdsEnums.MT_CLOSE); + case 10: return nameof(TdsEnums.MT_ERROR); + case 11: return nameof(TdsEnums.MT_ACK); + case 12: return nameof(TdsEnums.MT_ECHO); + case 13: return nameof(TdsEnums.MT_LOGOUT); + case 14: return nameof(TdsEnums.MT_TRANS); + case 15: return nameof(TdsEnums.MT_OLEDB); + case 16: return nameof(TdsEnums.MT_LOGIN7); + case 17: return nameof(TdsEnums.MT_SSPI); + case 18: return nameof(TdsEnums.MT_PRELOGIN); + default: return _data.Buffer[0].ToString("X2"); + } + } + return ""; + } + } + + public string Status { get { - PacketData[] items = Array.Empty(); - if (_data != null) + if (_data != null && _data.Buffer != null && _data.Buffer.Length > 1) { - int count = 0; - for (PacketData current = _data; current != null; current = current?.NextPacket) + int status = Packet.GetStatusFromHeader(_data.Buffer); + StringBuilder buffer = new StringBuilder(10); + + if ((status & TdsEnums.ST_EOM) == TdsEnums.ST_EOM) + { + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_EOM)); + } + if ((status & TdsEnums.ST_AACK) == TdsEnums.ST_AACK) { - count++; + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_AACK)); } - items = new PacketData[count]; - int index = 0; - for (PacketData current = _data; current != null; current = current?.NextPacket, index++) + if ((status & TdsEnums.ST_BATCH) == TdsEnums.ST_BATCH) { - items[index] = current; + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_BATCH)); } + if ((status & TdsEnums.ST_RESET_CONNECTION) == TdsEnums.ST_RESET_CONNECTION) + { + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_RESET_CONNECTION)); + } + if ((status & TdsEnums.ST_RESET_CONNECTION_PRESERVE_TRANSACTION) == TdsEnums.ST_RESET_CONNECTION_PRESERVE_TRANSACTION) + { + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_RESET_CONNECTION_PRESERVE_TRANSACTION)); + } + + return buffer.ToString(); } - return items; + + return ""; } } + + public int Length => _data.DataLength; + + public int Spid => _data.SPID; + + public int PacketID => _data.PacketID; + + public ReadOnlySpan HeaderBytes => _data.GetHeaderSpan(); + + public ReadOnlySpan Data => _data.Buffer.AsSpan(TdsEnums.HEADER_LEN); + + public PacketData NextPacket => _data.NextPacket; + public PacketData PrevPacket => _data.PrevPacket; } - public int PacketId; + public int DebugPacketId; public string Stack; + public byte[] Hash; + + public int PacketID => Packet.GetIDFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + + public int SPID => Packet.GetSpidFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + + public bool IsEOM => Packet.GetIsEOMFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + + public int DataLength => Packet.GetDataLengthFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); - partial void SetDebugStackInternal(string value) => Stack = value; + public ReadOnlySpan GetHeaderSpan() => Buffer.AsSpan(0, TdsEnums.HEADER_LEN); - partial void SetDebugPacketIdInternal(int value) => PacketId = value; + partial void SetDebugStackImpl(string value) => Stack = value; + partial void SetDebugPacketIdImpl(int value) => DebugPacketId = value; - public override string ToString() + partial void SetDebugDataHashImpl() { - //return $"{PacketId}: [{Buffer.Length}] ( {GetPacketDataOffset():D4}, {GetPacketTotalSize():D4} ) {(NextPacket != null ? @"->" : string.Empty)}"; - string byteString = null; - if (Buffer != null && Buffer.Length >= 12) + if (Buffer != null) { - ReadOnlySpan bytes = Buffer.AsSpan(0, 12); - StringBuilder buffer = new StringBuilder(12 * 3 + 10); - buffer.Append('{'); - for (int index = 0; index < bytes.Length; index++) + using (MD5 hasher = MD5.Create()) { - buffer.AppendFormat("{0:X2}", bytes[index]); - buffer.Append(", "); + Hash = hasher.ComputeHash(Buffer, 0, Read); } - buffer.Append("..."); - buffer.Append('}'); - byteString = buffer.ToString(); } - return $"{PacketId}: [{Read}] {byteString} {(NextPacket != null ? @"->" : string.Empty)}"; + else + { + Hash = null; + } + } - } -#endif - private sealed class PLPData - { - public readonly ulong SnapshotLongLen; - public readonly ulong SnapshotLongLenLeft; - - public PLPData(ulong snapshotLongLen, ulong snapshotLongLenLeft) + partial void CheckDebugDataHashImpl() { - SnapshotLongLen = snapshotLongLen; - SnapshotLongLenLeft = snapshotLongLenLeft; + if (Hash == null) + { + if (Buffer != null && Read > 0) + { + throw new InvalidOperationException("Packet modification detected. Hash is null but packet contains non-null buffer"); + } + } + else + { + byte[] checkHash = null; + using (MD5 hasher = MD5.Create()) + { + checkHash = hasher.ComputeHash(Buffer, 0, Read); + } + + for (int index = 0; index < Hash.Length; index++) + { + if (Hash[index] != checkHash[index]) + { + throw new InvalidOperationException("Packet modification detected. Hash from packet creation does not match hash from packet check"); + } + } + } } } +#endif private sealed class StateObjectData { private int _inBytesUsed; private int _inBytesPacket; - private PLPData _plpData; private byte _messageStatus; internal NullBitmap _nullBitmapInfo; private _SqlMetaDataSet _cleanupMetaData; internal _SqlMetaDataSetCollection _cleanupAltMetaDataSetArray; private SnapshottedStateFlags _state; + private ulong _longLen; + private ulong _longLenLeft; internal void Capture(TdsParserStateObject stateObj, bool trackStack = true) { @@ -2726,10 +2890,8 @@ internal void Capture(TdsParserStateObject stateObj, bool trackStack = true) _inBytesPacket = stateObj._inBytesPacket; _messageStatus = stateObj._messageStatus; _nullBitmapInfo = stateObj._nullBitmapInfo; // _nullBitmapInfo must be cloned before it is updated - if (stateObj._longlen != 0 || stateObj._longlenleft != 0) - { - _plpData = new PLPData(stateObj._longlen, stateObj._longlenleft); - } + _longLen = stateObj._longlen; + _longLenLeft = stateObj._longlenleft; _cleanupMetaData = stateObj._cleanupMetaData; _cleanupAltMetaDataSetArray = stateObj._cleanupAltMetaDataSetArray; // _cleanupAltMetaDataSetArray must be cloned before it is updated _state = stateObj._snapshottedState; @@ -2749,7 +2911,8 @@ internal void Clear(TdsParserStateObject stateObj, bool trackStack = true) _inBytesPacket = 0; _messageStatus = 0; _nullBitmapInfo = default; - _plpData = null; + _longLen = 0; + _longLenLeft = 0; _cleanupMetaData = null; _cleanupAltMetaDataSetArray = null; _state = SnapshottedStateFlags.None; @@ -2782,13 +2945,13 @@ internal void Restore(TdsParserStateObject stateObj) //else _stateObj._hasOpenResult is already == _snapshotHasOpenResult stateObj._snapshottedState = _state; + // reset plp state + stateObj._longlen = _longLen; + stateObj._longlenleft = _longLenLeft; + // Reset partially read state (these only need to be maintained if doing async without snapshot) stateObj._bTmpRead = 0; stateObj._partialHeaderBytesRead = 0; - - // reset plp state - stateObj._longlen = _plpData?.SnapshotLongLen ?? 0; - stateObj._longlenleft = _plpData?.SnapshotLongLenLeft ?? 0; } } @@ -2843,6 +3006,7 @@ internal void CheckStack(string trace) } } #endif + internal void CloneNullBitmapInfo() { if (_stateObj._nullBitmapInfo.ReferenceEquals(_replayStateData?._nullBitmapInfo ?? default)) @@ -2863,10 +3027,19 @@ internal void AppendPacketData(byte[] buffer, int read) { Debug.Assert(buffer != null, "packet data cannot be null"); Debug.Assert(read >= TdsEnums.HEADER_LEN, "minimum packet length is TdsEnums.HEADER_LEN"); + Debug.Assert(TdsEnums.HEADER_LEN + Packet.GetDataLengthFromHeader(buffer) == read, "partially read packets cannot be appended to the snapshot"); #if DEBUG for (PacketData current = _firstPacket; current != null; current = current.NextPacket) { - Debug.Assert(!ReferenceEquals(current.Buffer, buffer)); + if (ReferenceEquals(current.Buffer, buffer)) + { + // multiple packets are permitted to be in the same buffer because of partial packets + // but their contents cannot overlap + if ((current.Read + current.DataLength) > read) + { + Debug.Fail("duplicate or overlapping packet appended to snapshot"); + } + } } #endif PacketData packetData = _sparePacket; @@ -2883,6 +3056,7 @@ internal void AppendPacketData(byte[] buffer, int read) #if DEBUG packetData.SetDebugStack(_stateObj._lastStack); packetData.SetDebugPacketId(Interlocked.Increment(ref _packetCounter)); + packetData.SetDebugDataHash(); #endif if (_firstPacket is null) { @@ -2899,10 +3073,12 @@ internal void AppendPacketData(byte[] buffer, int read) internal bool MoveNext() { bool retval = false; + SnapshotStatus moveToMode = SnapshotStatus.ReplayRunning; bool moved = false; if (_current == null) { _current = _firstPacket; + moveToMode = SnapshotStatus.ReplayStarting; moved = true; } else if (_current.NextPacket != null) @@ -2913,10 +3089,9 @@ internal bool MoveNext() if (moved) { - _stateObj._inBuff = _current.Buffer; - _stateObj._inBytesUsed = 0; - _stateObj._inBytesRead = _current.Read; - _stateObj._snapshotReplay = true; + _stateObj.SetBuffer(_current.Buffer, 0, _current.Read); + _current.CheckDebugDataHash(); + _stateObj._snapshotStatus = moveToMode; retval = true; } @@ -2941,7 +3116,6 @@ internal void CaptureAsStart(TdsParserStateObject stateObj) _stateObj = stateObj; _replayStateData ??= new StateObjectData(); _replayStateData.Capture(stateObj); - #if DEBUG _rollingPend = 0; _rollingPendCount = 0; @@ -2976,6 +3150,7 @@ private void ClearState() _rollingPend = 0; _rollingPendCount = 0; _stateObj._lastStack = null; + _packetCounter = 0; #endif _stateObj = null; } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs index e6a6b9c73f..295a354349 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalAppContextSwitchesTests.cs @@ -14,6 +14,7 @@ public class LocalAppContextSwitchesTests [InlineData("LegacyRowVersionNullBehavior", false)] [InlineData("MakeReadAsyncBlocking", false)] [InlineData("UseMinimumLoginTimeout", true)] + [InlineData("UseCompatibilityProcessSni", false)] public void DefaultSwitchValue(string property, bool expectedDefaultValue) { var switchesType = typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index 6d8603beab..8352e4d640 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -28,6 +28,7 @@ + @@ -65,10 +66,13 @@ + + + @@ -95,6 +99,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs new file mode 100644 index 0000000000..12baf6f2e9 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -0,0 +1,724 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers.Binary; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using Xunit; + +namespace Microsoft.Data.SqlClient.Tests +{ + public class MultiplexerTests + { + public static bool IsUsingCompatibilityProcessSni + { + get + { + if (AppContext.TryGetSwitch(@"Switch.Microsoft.Data.SqlClient.UseCompatibilityProcessSni", out bool foundValue)) + { + return foundValue; + } + return false; + } + } + + public static bool IsUsingModernProcessSni => !IsUsingCompatibilityProcessSni; + + [ExcludeFromCodeCoverage] + public static IEnumerable IsAsync() + { + yield return new object[] { false }; + yield return new object[] { true }; + } + + [ExcludeFromCodeCoverage] + public static IEnumerable OnlyAsync() { yield return new object[] { true }; } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void PassThroughSinglePacket(bool isAsync) + { + int dataSize = 20; + var a = CreatePacket(dataSize, 0xF); + List input = new List { a }; + List expected = new List { a }; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void PassThroughMultiplePacket(bool isAsync) + { + int dataSize = 40; + List input = CreatePackets(dataSize, 5, 6, 7, 8); + List expected = input; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void PassThroughMultiplePacketWithShortEnd(bool isAsync) + { + int dataSize = 40; + List input = CreatePackets((dataSize, 20), 5, 6, 7, 8); + List expected = input; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void ReconstructSinglePacket(bool isAsync) + { + int dataSize = 4; + var a = CreatePacket(dataSize, 0xF); + List input = SplitPacket(a, 6); + List expected = new List { a }; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void Reconstruct2Packets_Part_PartFull(bool isAsync) + { + int dataSize = 4; + var expected = CreatePackets(dataSize, 0xAA, 0xBB); + + var input = SplitPackets(dataSize, expected, + 6, // partial first packet + (6 + 6), // end of packet 0, start of packet 1 + 6 // end of packet 1 + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void Reconstruct2Packets_Full_FullPart_Part(bool isAsync) + { + int dataSize = 30; + var expected = new List { CreatePacket(30, 5), CreatePacket(10, 6), CreatePacket(30, 7) }; + + var input = SplitPackets(38, expected, + (8 + 30), // full + (8 + 10) + (8 + 12), // full, part next + 18 // part end + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void ReconstructMultiplePacketSequence(bool isAsync) + { + int dataSize = 40; + List expected = CreatePackets(dataSize, 5, 6, 7, 8); + List input = SplitPackets(dataSize, expected, + (8 + 40), + (8 + 23), + (17) + (8 + 23), + (17) + (8 + 23), + (17) + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void ReconstructMultiplePacketSequenceWithShortEnd(bool isAsync) + { + int dataSize = 40; + List expected = CreatePackets((dataSize, 20), 5, 6, 7, 8); + List input = SplitPackets(dataSize, expected, + (8 + 40), + (8 + 23), + (17) + (8 + 23), + (17) + (8 + 20) + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalTheory(nameof(IsUsingModernProcessSni)), MemberData(nameof(IsAsync))] + public static void Reconstruct3Packets_PartPartPart(bool isAsync) + { + int dataSize = 62; + + var expected = new List { CreatePacket(26, 5), CreatePacket(10, 6), CreatePacket(10, 7) }; + + var input = SplitPackets(70, expected, + (8 + 26) + (8 + 10) + (8 + 10) // = 70: full, full, part + ); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalFact(nameof(IsUsingModernProcessSni))] + public static void TrailingPartialPacketInSnapshotNotDuplicated() + { + int dataSize = 120; + + var expected = new List { CreatePacket(120, 5), CreatePacket(90, 6), CreatePacket(13, 7), }; + + var input = SplitPackets(120, expected, + (8 + 120), + (8 + 90) + (8 + 13) + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(true, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ConditionalFact(nameof(IsUsingModernProcessSni))] + public static void BetweenAsyncAttentionPacket() + { + int dataSize = 120; + var normalPacket = CreatePacket(120, 5); + var attentionPacket = CreatePacket(13, 6); + var input = new List { normalPacket, attentionPacket }; + + var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync: true); + + for (int index = 0; index < input.Count; index++) + { + stateObject.Current = input[index]; + stateObject.ProcessSniPacket(default, 0); + } + + Assert.NotNull(stateObject._inBuff); + Assert.Equal(21, stateObject._inBytesRead); + Assert.Equal(0, stateObject._inBytesUsed); + Assert.NotNull(stateObject._snapshot); + Assert.NotNull(stateObject._snapshot.List); + Assert.Equal(2, stateObject._snapshot.List.Count); + + } + + [ConditionalFact(nameof(IsUsingModernProcessSni))] + public static void MultipleFullPacketsInRemainderAreSplitCorrectly() + { + int dataSize = 800 - TdsEnums.HEADER_LEN; + List expected = new List + { + CreatePacket(dataSize, 5), CreatePacket(80, 6), CreatePacket(21, 7) + }; + + + List input = SplitPacket(CombinePackets(expected), 700); + + var stateObject = new TdsParserStateObject(input, dataSize, isAsync: false); + + var output = MultiplexPacketList(false, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [ExcludeFromCodeCoverage] + private static List MultiplexPacketList(bool isAsync, int dataSize, List input) + { + var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync); + var output = new List(); + + for (int index = 0; index < input.Count; index++) + { + stateObject.Current = input[index]; + + stateObject.ProcessSniPacket(default, 0); + + if (stateObject._inBytesRead > 0) + { + if ( + stateObject._inBytesRead < TdsEnums.HEADER_LEN + || + stateObject._inBytesRead != (TdsEnums.HEADER_LEN + + Packet.GetDataLengthFromHeader( + stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) + ) + { + Assert.Fail("incomplete packet exposed after call to ProcessSniPacket"); + } + + if (!isAsync) + { + output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, + stateObject._inBytesRead)); + } + } + } + + + if (!isAsync) + { + while (stateObject.PartialPacket != null) + { + stateObject.Current = default; + + stateObject.ProcessSniPacket(default, 0); + + if (stateObject._inBytesRead > 0) + { + if ( + stateObject._inBytesRead < TdsEnums.HEADER_LEN + || + stateObject._inBytesRead != (TdsEnums.HEADER_LEN + + Packet.GetDataLengthFromHeader( + stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) + ) + { + Assert.Fail( + "incomplete packet exposed after call to ProcessSniPacket with usePartialPacket"); + } + + output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, + stateObject._inBytesRead)); + } + } + + } + else + { + output = stateObject._snapshot.List; + } + + return output; + } + + [ExcludeFromCodeCoverage] + private static void ComparePacketLists(int dataSize, List expected, List output) + { + Assert.NotNull(expected); + Assert.NotNull(output); + Assert.Equal(expected.Count, output.Count); + + for (int index = 0; index < expected.Count; index++) + { + var a = expected[index]; + var b = output[index]; + + var compare = a.AsSpan().SequenceCompareTo(b.AsSpan()); + + if (compare != 0) + { + Assert.Fail($"expected data does not match output data at packet index {index}"); + } + } + } + + [ExcludeFromCodeCoverage] + public static PacketData CreatePacket(int dataSize, byte dataValue, int startOffset = 0, int endPadding = 0) + { + byte[] buffer = new byte[startOffset + TdsEnums.HEADER_LEN + dataSize + endPadding]; + Span packet = buffer.AsSpan(startOffset, TdsEnums.HEADER_LEN + dataSize); + WritePacket(packet, dataSize, dataValue, 1); + return new PacketData(buffer, startOffset, buffer.Length - endPadding); + } + + [ExcludeFromCodeCoverage] + public static List CreatePackets(DataSize sizes, params byte[] dataValues) + { + int count = dataValues.Length; + List list = new List(count); + + for (byte index = 0; index < count; index++) + { + int dataSize = sizes.GetSize(index == dataValues.Length - 1); + int packetSize = TdsEnums.HEADER_LEN + dataSize; + byte[] array = new byte[packetSize]; + WritePacket(array, dataSize, dataValues[index], index); + list.Add(new PacketData(array, 0, packetSize)); + } + + return list; + } + + [ExcludeFromCodeCoverage] + private static void WritePacket(Span buffer, int dataSize, byte dataValue, byte id) + { + Span header = buffer.Slice(0, TdsEnums.HEADER_LEN); + header[0] = 4; // Type, 4 - Raw Data + header[1] = 0; // Status, 0 - normal message + BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.HEADER_LEN_FIELD_OFFSET, 2), + (short)(TdsEnums.HEADER_LEN + dataSize)); // total length + BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.SPID_OFFSET, 2), short.MaxValue); // SPID + header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4] = id; // PacketID + header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 5] = 0; // Window + + Span data = buffer.Slice(TdsEnums.HEADER_LEN, dataSize); + data.Fill(dataValue); + } + + [ExcludeFromCodeCoverage] + public static List SplitPacket(PacketData packet, int length) + { + List list = new List(2); + while (packet.Length > length) + { + list.Add(new PacketData(packet.Array, packet.Start, length)); + packet = new PacketData(packet.Array, packet.Start + length, packet.Length - length); + } + + if (packet.Length > 0) + { + list.Add(packet); + } + + return list; + } + + [ExcludeFromCodeCoverage] + public static List SplitPackets(int dataSize, List packets, params int[] lengths) + { + List list = new List(lengths.Length); + int packetSize = TdsEnums.HEADER_LEN + dataSize; + byte[][] arrays = new byte[lengths.Length][]; + for (int index = 0; index < lengths.Length; index++) + { + if (lengths[index] > packetSize) + { + throw new ArgumentOutOfRangeException( + $"segment size of an individual part cannot exceed the packet buffer size of the state object, max packet size: {packetSize}, supplied length: {lengths[index]}, at index: {index}"); + } + + arrays[index] = new byte[lengths[index]]; + } + + int targetOffset = 0; + int targetIndex = 0; + + int sourceOffset = 0; + int sourceIndex = 0; + + + do + { + Span targetSpan = Span.Empty; + if (targetOffset < arrays[targetIndex].Length) + { + targetSpan = arrays[targetIndex].AsSpan(targetOffset); + } + else + { + targetIndex += 1; + targetOffset = 0; + continue; + } + + Span sourceSpan = Span.Empty; + if (sourceOffset < packets[sourceIndex].Length) + { + sourceSpan = packets[sourceIndex].AsSpan(sourceOffset); + } + else + { + sourceIndex += 1; + sourceOffset = 0; + continue; + } + + int copy = Math.Min(targetSpan.Length, sourceSpan.Length); + if (copy > 0) + { + targetOffset += copy; + sourceOffset += copy; + sourceSpan.Slice(0, copy).CopyTo(targetSpan.Slice(0, copy)); + } + } while (sourceIndex < packets.Count && targetIndex < arrays.Length); + + foreach (var array in arrays) + { + list.Add(new PacketData(array, 0, array.Length)); + } + + return list; + } + + [ExcludeFromCodeCoverage] + public static PacketData CombinePackets(List packets) + { + int totalLength = SumPacketLengths(packets); + byte[] buffer = new byte[totalLength]; + int offset = 0; + for (int index = 0; index < packets.Count; index++) + { + PacketData packet = packets[index]; + Array.Copy(packet.Array, packet.Start, buffer, offset, packet.Length); + offset += packet.Length; + } + + return new PacketData(buffer, 0, totalLength); + } + + [ExcludeFromCodeCoverage] + public static int PacketSizeFromDataSize(int dataSize) => TdsEnums.HEADER_LEN + dataSize; + + [ExcludeFromCodeCoverage] + public static int DataSizeFromPacketSize(int packetSize) => packetSize - TdsEnums.HEADER_LEN; + + [ExcludeFromCodeCoverage] + public static int SumPacketLengths(List list) + { + int total = 0; + for (int index = 0; index < list.Count; index++) + { + total += list[index].Length; + } + return total; + } + + [ExcludeFromCodeCoverage] + public static List LoadPacketBinFiles(string directoryName) + { + // expects a set of files contained in a directory with the name + // formatted as packet_{number}_{dataSize}.bin each packet will be + // loaded into a byte[] + + string[] files = Directory.GetFiles(directoryName, "packet*.bin", SearchOption.TopDirectoryOnly); + SortedDictionary packets = new SortedDictionary(); + foreach (string file in files) + { + Match match = Regex.Match(file, @"packet_(?\d+)_(?\d+)\.bin"); + int number = int.Parse(match.Groups["number"].Value); + int size = int.Parse(match.Groups["size"].Value); + packets.Add( + number, + new PacketData( + System.IO.File.ReadAllBytes(file), + 0, + size + ) + ); + } + + return packets.Values.ToList(); + } + + [ExcludeFromCodeCoverage] + public static List NaiveReconstructPacketStream(List input) + { + int dataSize = input[0].Array.Length; + List output = new List(input.Count); + + byte[] currentBuffer = new byte[dataSize]; + int currentBufferOffset = 0; + + foreach (PacketData inputPacket in input) + { + int inputPacketOffset = 0; + while (inputPacketOffset < inputPacket.Length) + { + if (currentBufferOffset < dataSize) + { + int requiredCount = dataSize - currentBufferOffset; + int availableCount = inputPacket.Length - inputPacketOffset; + int copyCount = Math.Min(requiredCount, availableCount); + ReadOnlySpan copyFrom = inputPacket.Array.AsSpan(inputPacketOffset, copyCount); + Span copyTo = currentBuffer.AsSpan(currentBufferOffset, copyCount); + copyFrom.CopyTo(copyTo); + currentBufferOffset += copyCount; + inputPacketOffset += copyCount; + } + + if (currentBufferOffset == dataSize) + { + output.Add(new PacketData(currentBuffer, 0, dataSize)); + currentBufferOffset = 0; + currentBuffer = new byte[dataSize]; + } + } + } + + if (currentBufferOffset > 0) + { + output.Add(new PacketData(currentBuffer, 0, currentBufferOffset)); + } + + for (int index = 0; index < output.Count; index++) + { + PacketData packet = output[index]; + int expectedLength = 8 + Packet.GetDataLengthFromHeader(packet.Array); + if (expectedLength != packet.Length) + { + if (index != output.Count - 1) + { + throw new InvalidOperationException( + "non-terminal packet has a length mismatch between the packet header and amount of data available"); + } + else + { + byte[] remainder = new byte[dataSize]; + int remainderSize = packet.Length - expectedLength; + Span copyFrom = packet.Array.AsSpan(expectedLength, remainderSize); + Span copyTo = remainder.AsSpan(0, remainderSize); + copyFrom.CopyTo(copyTo); + copyFrom.Fill(0); + + PacketData replacementPacket = new PacketData(packet.Array, 0, expectedLength); + PacketData additionalPacket = new PacketData(remainder, 0, remainderSize); + output[index] = replacementPacket; + output.Add(additionalPacket); + } + } + } + + return output; + } + } + + [ExcludeFromCodeCoverage] + [DebuggerDisplay("{ToDebugString(),nq}")] + public readonly struct PacketData + { + public readonly byte[] Array; + public readonly int Start; + public readonly int Length; + + public PacketData(byte[] array, int start, int length) + { + Array = array; + Start = start; + Length = length; + } + + public Span AsSpan() + { + return Array == null ? Span.Empty : Array.AsSpan(Start, Length); + } + + public Span AsSpan(int start) + { + Span span = AsSpan(); + return span.Slice(start); + } + + public static PacketData Copy(byte[] array, int start, int length) + { + byte[] newArray = null; + if (array != null) + { + newArray = new byte[array.Length]; + Buffer.BlockCopy(array, start, newArray, start, length); + } + + return new PacketData(newArray, start, length); + } + + [ExcludeFromCodeCoverage] + public string ToDebugString() + { + StringBuilder buffer = new StringBuilder(128); + buffer.Append(Length); + + if (Array != null && Array.Length > 0) + { + if (Array.Length != Length) + { + buffer.AppendFormat(" (arr: {0})", Array.Length); + } + + buffer.Append(": {"); + buffer.AppendFormat("{0:D2}", Array[0]); + + int max = Math.Min(32, Array.Length); + for (int index = 1; index < max; index++) + { + buffer.Append(','); + buffer.AppendFormat("{0:D2}", Array[index]); + } + + if (Length > max) + { + buffer.Append(" ..."); + } + + buffer.Append('}'); + } + + return buffer.ToString(); + } + + } + + [ExcludeFromCodeCoverage] + [DebuggerStepThrough] + public struct DataSize + { + public DataSize(int commonSize) + { + CommonSize = commonSize; + LastSize = commonSize; + } + + public DataSize(int commonSize, int lastSize) + { + CommonSize = commonSize; + LastSize = lastSize; + } + + public int LastSize { get; set; } + public int CommonSize { get; set; } + + public int GetSize(bool isLast) + { + if (isLast) + { + return LastSize; + } + else + { + return CommonSize; + } + } + + public static implicit operator DataSize(int commonSize) + { + return new DataSize(commonSize, commonSize); + } + + public static implicit operator DataSize((int commonSize, int lastSize) values) + { + return new DataSize(values.commonSize, values.lastSize); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs new file mode 100644 index 0000000000..c512c3385b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Reflection; +using Microsoft.Data.SqlClient.Tests; + +namespace Microsoft.Data.SqlClient +{ +#if NETFRAMEWORK + using PacketHandle = IntPtr; +#elif NETCOREAPP + internal struct PacketHandle + { + } +#endif + internal partial class TdsParserStateObject + { + internal int ObjectID = 1; + + internal class SQL + { + internal static Exception InvalidInternalPacketSize(string v) => throw new Exception(v ?? nameof(InvalidInternalPacketSize)); + + internal static Exception ParsingError(ParsingErrorState state) => throw new Exception(state.ToString()); + } + + internal static class SqlClientEventSource + { + internal static class Log + { + internal static void TryAdvancedTraceBinEvent(string message, params object[] values) + { + } + } + } + + private enum SnapshotStatus + { + NotActive, + ReplayStarting, + ReplayRunning + } + + internal enum TdsParserState + { + Closed, + OpenNotLoggedIn, + OpenLoggedIn, + Broken, + } + + private uint GetSniPacket(PacketHandle packet, ref uint dataSize) + { + return SNIPacketGetData(packet, _inBuff, ref dataSize); + } + + private class StringsHelper + { + internal static string GetString(string sqlMisc_InvalidArraySizeMessage) => Strings.SqlMisc_InvalidArraySizeMessage; + } + + internal class Strings + { + internal static string SqlMisc_InvalidArraySizeMessage = nameof(SqlMisc_InvalidArraySizeMessage); + + } + + public class Parser + { + internal object ProcessSNIError(TdsParserStateObject tdsParserStateObject) => "ProcessSNIError"; + public TdsParserState State = TdsParserState.OpenLoggedIn; + } + + sealed internal class LastIOTimer + { + internal long _value; + } + + internal sealed class Snapshot + { + public List List; + + public Snapshot() => List = new List(); + [DebuggerStepThrough] + internal void AssertCurrent() { } + [DebuggerStepThrough] + internal void AppendPacketData(byte[] buffer, int read) => List.Add(new PacketData(buffer, 0, read)); + [DebuggerStepThrough] + internal void MoveNext() + { + + } + } + + public List Input; + public PacketData Current; + public bool IsAsync { get => _snapshot != null; } + + public int _packetSize; + + internal Snapshot _snapshot; + public int _inBytesRead; + public int _inBytesUsed; + public byte[] _inBuff; + [DebuggerStepThrough] + public TdsParserStateObject(List input, int packetSize, bool isAsync) + { + _packetSize = packetSize; + _inBuff = new byte[_packetSize]; + Input = input; + if (isAsync) + { + _snapshot = new Snapshot(); + } + } + [DebuggerStepThrough] + private uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) + { + Span target = inBuff.AsSpan(0, _packetSize); + Span source = Current.Array.AsSpan(Current.Start, Current.Length); + source.CopyTo(target); + dataSize = (uint)Current.Length; + return TdsEnums.SNI_SUCCESS; + } + + [DebuggerStepThrough] + void SetBuffer(byte[] buffer, int inBytesUsed, int inBytesRead) + { + _inBuff = buffer; + _inBytesUsed = inBytesUsed; + _inBytesRead = inBytesRead; + } + + // stubs + private LastIOTimer _lastSuccessfulIOTimer = new LastIOTimer(); + private Parser _parser = new Parser(); + private SnapshotStatus _snapshotStatus = SnapshotStatus.NotActive; + + [DebuggerStepThrough] + private void SniReadStatisticsAndTracing() { } + [DebuggerStepThrough] + private void AssertValidState() { } + + [DebuggerStepThrough] + private void AddError(object value) => throw new Exception(value as string ?? "AddError"); + + internal static class LocalAppContextSwitches + { + public static bool UseCompatibilityProcessSni + { + get + { + var switchesType = typeof(SqlCommand).Assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); + + return (bool)switchesType.GetProperty(nameof(UseCompatibilityProcessSni), BindingFlags.Public | BindingFlags.Static).GetValue(null); + } + } + } + +#if NETFRAMEWORK + private SniNativeWrapperImpl _native; + internal SniNativeWrapperImpl SniNativeWrapper + { + get + { + if (_native == null) + { + _native = new SniNativeWrapperImpl(this); + } + return _native; + } + } + + internal class SniNativeWrapperImpl + { + private readonly TdsParserStateObject _parent; + internal SniNativeWrapperImpl(TdsParserStateObject parent) => _parent = parent; + + internal uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) => _parent.SNIPacketGetData(packet, inBuff, ref dataSize); + } +#endif + } + + internal static class TdsEnums + { + public const uint SNI_SUCCESS = 0; // The operation completed successfully. + // header constants + public const int HEADER_LEN = 8; + public const int HEADER_LEN_FIELD_OFFSET = 2; + public const int SPID_OFFSET = 4; + } + + internal enum ParsingErrorState + { + CorruptedTdsStream = 18, + ProcessSniPacketFailed = 19, + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs index e71d6d62f6..0ae12be917 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Text; @@ -36,9 +37,13 @@ public void CancelAsyncConnections() private void RunCancelAsyncConnections(SqlConnectionStringBuilder connectionStringBuilder) { SqlConnection.ClearAllPools(); - _watch = Stopwatch.StartNew(); - _random = new Random(4); // chosen via fair dice role. + ParallelLoopResult results = new ParallelLoopResult(); + ConcurrentDictionary tracker = new ConcurrentDictionary(); + + _random = new Random(4); // chosen via fair dice roll. + _watch = Stopwatch.StartNew(); + try { // Setup a timer so that we can see what is going on while our tasks run @@ -47,7 +52,7 @@ private void RunCancelAsyncConnections(SqlConnectionStringBuilder connectionStri results = Parallel.For( fromInclusive: 0, toExclusive: NumberOfTasks, - (int i) => DoManyAsync(connectionStringBuilder).GetAwaiter().GetResult()); + (int i) => DoManyAsync(i, tracker, connectionStringBuilder).GetAwaiter().GetResult()); } } catch (Exception ex) @@ -82,15 +87,15 @@ private void DisplaySummary() { count = _exceptionDetails.Count; } - _output.WriteLine($"{_watch.Elapsed} {_continue} Started:{_start} Done:{_done} InFlight:{_inFlight} RowsRead:{_rowsRead} ResultRead:{_resultRead} PoisonedEnded:{_poisonedEnded} nonPoisonedExceptions:{_nonPoisonedExceptions} PoisonedCleanupExceptions:{_poisonCleanUpExceptions} Count:{count} Found:{_found}"); } // This is the the main body that our Tasks run - private async Task DoManyAsync(SqlConnectionStringBuilder connectionStringBuilder) + private async Task DoManyAsync(int index, ConcurrentDictionary tracker, SqlConnectionStringBuilder connectionStringBuilder) { Interlocked.Increment(ref _start); Interlocked.Increment(ref _inFlight); + tracker[index] = true; using (SqlConnection marsConnection = new SqlConnection(connectionStringBuilder.ToString())) { @@ -100,15 +105,15 @@ private async Task DoManyAsync(SqlConnectionStringBuilder connectionStringBuilde } // First poison - await DoOneAsync(marsConnection, connectionStringBuilder.ToString(), poison: true); + await DoOneAsync(marsConnection, connectionStringBuilder.ToString(), poison: true, index); for (int i = 0; i < NumberOfNonPoisoned && _continue; i++) { // now run some without poisoning - await DoOneAsync(marsConnection, connectionStringBuilder.ToString()); + await DoOneAsync(marsConnection, connectionStringBuilder.ToString(),false,index); } } - + tracker.TryRemove(index, out var _); Interlocked.Decrement(ref _inFlight); Interlocked.Increment(ref _done); } @@ -117,7 +122,7 @@ private async Task DoManyAsync(SqlConnectionStringBuilder connectionStringBuilde // if we are poisoning we will // 1 - Interject some sleeps in the sql statement so that it will run long enough that we can cancel it // 2 - Setup a time bomb task that will cancel the command a random amount of time later - private async Task DoOneAsync(SqlConnection marsConnection, string connectionString, bool poison = false) + private async Task DoOneAsync(SqlConnection marsConnection, string connectionString, bool poison, int parent) { try { @@ -135,12 +140,12 @@ private async Task DoOneAsync(SqlConnection marsConnection, string connectionStr { if (marsConnection != null && marsConnection.State == System.Data.ConnectionState.Open) { - await RunCommand(marsConnection, builder.ToString(), poison); + await RunCommand(marsConnection, builder.ToString(), poison, parent); } else { await connection.OpenAsync(); - await RunCommand(connection, builder.ToString(), poison); + await RunCommand(connection, builder.ToString(), poison, parent); } } } @@ -176,7 +181,7 @@ private async Task DoOneAsync(SqlConnection marsConnection, string connectionStr } } - private async Task RunCommand(SqlConnection connection, string commandText, bool poison) + private async Task RunCommand(SqlConnection connection, string commandText, bool poison, int parent) { int rowsRead = 0; int resultRead = 0; @@ -211,7 +216,7 @@ private async Task RunCommand(SqlConnection connection, string commandText, bool } while (await reader.NextResultAsync() && _continue); } - catch when (poison) + catch (SqlException) when (poison) { // This looks a little strange, we failed to read above so this should fail too // But consider the case where this code is elsewhere (in the Dispose method of a class holding this logic) @@ -228,6 +233,10 @@ private async Task RunCommand(SqlConnection connection, string commandText, bool throw; } + catch (Exception ex) + { + Assert.Fail("unexpected exception: " + ex.GetType().Name + " " +ex.Message); + } } } finally