diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs index cf0a296b30..20f29cfd2c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs @@ -1354,6 +1354,11 @@ internal void OnFeatureExtAck(int featureId, byte[] data) len = bLen; } + if (len < 0 || len > data.Length - i) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, len); + } + byte[] stateData = new byte[len]; Buffer.BlockCopy(data, i, stateData, 0, len); i += len; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs index 8f0d7fcba0..273f7b2233 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -746,6 +746,10 @@ internal static Exception ParsingErrorLength(ParsingErrorState state, int length { return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ParsingErrorLength, ((int)state).ToString(CultureInfo.InvariantCulture), length)); } + internal static Exception ParsingErrorLength(ParsingErrorState state, uint length) + { + return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ParsingErrorLength, ((int)state).ToString(CultureInfo.InvariantCulture), length)); + } internal static Exception ParsingErrorStatus(ParsingErrorState state, int status) { return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ParsingErrorStatus, ((int)state).ToString(CultureInfo.InvariantCulture), status)); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs index cf8f4b3893..fbdf2dbb53 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -84,6 +84,17 @@ internal static class TdsEnums public const int MAX_PACKET_SIZE = 32768; public const int MAX_SERVER_USER_NAME = 256; // obtained from luxor + // Maximum allowed data length for token payloads (feature ext ack, + // session state, fedauth info). Prevents a malicious server from causing + // unbounded memory allocation via spoofed token length fields. + internal const int MaxTokenDataLength = 1 << 20; // 1 MB + + // Maximum allowed data length for a DTC promote transaction propagation token. + internal const int MaxPromoteTransactionLength = 1 << 16; // 64 KB + + // Maximum valid wire size for datetime types (DateTimeOffset = 5 time + 3 date + 2 offset). + internal const int MaxDateTimeLength = 10; + // Severity 0 - 10 indicates informational (non-error) messages // Severity 11 - 16 indicates errors that can be corrected by user (syntax errors, etc...) // Severity 17 - 19 indicates failure due to insufficient resources in the server diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 088d41d94d..c2b17b22d7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -208,7 +208,7 @@ static TdsParser() { // For CoreCLR, we need to register the ANSI Code Page encoding provider before attempting to get an Encoding from a CodePage // For a default installation of SqlServer the encoding exchanged during Login is 1252. This encoding is not loaded by default - // See Remarks at https://msdn.microsoft.com/en-us/library/system.text.encodingprovider(v=vs.110).aspx + // See Remarks at https://msdn.microsoft.com/en-us/library/system.text.encodingprovider(v=vs.110).aspx // SqlClient needs to register the encoding providers to make sure that even basic scenarios work with Sql Server. Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); } @@ -683,7 +683,7 @@ internal void RemoveEncryption() // create a new packet encryption changes the internal packet size Bug# 228403 _physicalStateObj.ClearAllWritePackets(); - } + } internal void EnableMars() { @@ -1376,11 +1376,11 @@ internal void TdsLogin( int feOffset = length; // calculate and reserve the required bytes for the featureEx length = ApplyFeatureExData( - requestedFeatures, - recoverySessionData, + requestedFeatures, + recoverySessionData, fedAuthFeatureExtensionData, UserAgent.Ucs2Bytes, - useFeatureExt, + useFeatureExt, length ); @@ -2792,7 +2792,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle { _connHandler._federatedAuthenticationInfoReceived = true; SqlFedAuthInfo info; - + result = TryProcessFedAuthInfo(stateObj, tokenLength, out info); if (result != TdsOperationStatus.Done) { @@ -3348,6 +3348,10 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb // new value has 4 byte length return result; } + if (env._newLength < 0 || env._newLength > TdsEnums.MaxPromoteTransactionLength) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, env._newLength); + } // read new value with 4 byte length env._newBinValue = new byte[env._newLength]; result = stateObj.TryReadByteArray(env._newBinValue, env._newLength); @@ -3846,10 +3850,15 @@ private TdsOperationStatus TryProcessFeatureExtAck(TdsParserStateObject stateObj { return result; } - byte[] data = new byte[dataLen]; - if (dataLen > 0) + if (dataLen > (uint)TdsEnums.MaxTokenDataLength) { - result = stateObj.TryReadByteArray(data, checked((int)dataLen)); + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, dataLen); + } + int dataLength = (int)dataLen; + byte[] data = new byte[dataLength]; + if (dataLength > 0) + { + result = stateObj.TryReadByteArray(data, dataLength); if (result != TdsOperationStatus.Done) { return result; @@ -4169,6 +4178,10 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj, { throw SQL.ParsingErrorLength(ParsingErrorState.SessionStateLengthTooShort, length); } + if (length > TdsEnums.MaxTokenDataLength) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, length); + } uint seqNum; TdsOperationStatus result = stateObj.TryReadUInt32(out seqNum); if (result != TdsOperationStatus.Done) @@ -4218,6 +4231,10 @@ private TdsOperationStatus TryProcessSessionState(TdsParserStateObject stateObj, return result; } } + if (stateLen < 0 || stateLen > TdsEnums.MaxTokenDataLength) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, stateLen); + } byte[] buffer = null; lock (sdata._delta) { @@ -4435,6 +4452,10 @@ private TdsOperationStatus TryProcessFedAuthInfo(TdsParserStateObject stateObj, SqlClientEventSource.Log.TryTraceEvent(" FEDAUTHINFO token stream length too short for CountOfInfoIDs."); throw SQL.ParsingErrorLength(ParsingErrorState.FedAuthInfoLengthTooShortForCountOfInfoIds, tokenLen); } + if (tokenLen > TdsEnums.MaxTokenDataLength) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, tokenLen); + } // read how many FedAuthInfo options there are uint optionsCount; @@ -4912,14 +4933,20 @@ internal TdsOperationStatus TryProcessReturnValue(int length, } // always read as sql types - Debug.Assert(valLen < (ulong)(int.MaxValue), "ProcessReturnValue received data size > 2Gb"); - - int intlen = valLen > (ulong)(int.MaxValue) ? int.MaxValue : (int)valLen; + int intlen; if (rec.metaType.IsPlp) { intlen = int.MaxValue; // If plp data, read it all } + else if (valLen > (ulong)int.MaxValue) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, unchecked((int)valLen)); + } + else + { + intlen = (int)valLen; + } if (rec.type == SqlDbTypeExtensions.Vector) { @@ -5790,7 +5817,7 @@ private TdsOperationStatus TryCommonProcessMetaData(TdsParserStateObject stateOb { return result; } - + // read flags and set appropriate flags in structure byte flags; result = stateObj.TryReadByte(out flags); @@ -7119,7 +7146,7 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, return result; } - // Internally, we use Sqlbinary to deal with varbinary data and store it in + // Internally, we use Sqlbinary to deal with varbinary data and store it in // SqlBuffer as SqlBinary value. #if NET value.SqlBinary = SqlBinary.WrapBytes(b); @@ -7188,9 +7215,20 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, return TdsOperationStatus.Done; } + // length originates as a single byte on the wire (nullable datetime length prefix), + // but is kept as int to match the TDS parsing API surface where all lengths are int. + // Using byte here would require casts at all call sites and silently truncate values + // from the sql_variant path where lenData is computed arithmetically. private TdsOperationStatus TryReadSqlDateTime(SqlBuffer value, byte tdsType, int length, byte scale, TdsParserStateObject stateObj) { - Span datetimeBuffer = ((uint)length <= 16) ? stackalloc byte[16] : new byte[length]; + // DateTimeOffset is the largest datetime type at 10 bytes (5 time + 3 date + 2 offset). + // Reject anything larger to prevent heap allocation from spoofed metadata. + if (length < 0 || length > TdsEnums.MaxDateTimeLength) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, length); + } + + Span datetimeBuffer = stackalloc byte[TdsEnums.MaxDateTimeLength]; TdsOperationStatus result = stateObj.TryReadByteArray(datetimeBuffer, length); if (result != TdsOperationStatus.Done) @@ -7446,9 +7484,11 @@ internal TdsOperationStatus TryReadSqlValueInternal(SqlBuffer value, byte tdsTyp case TdsEnums.SQLVECTOR: { // Note: Better not come here with plp data!! - Debug.Assert(length <= TdsEnums.MAXSIZE); - byte[] b = new byte[length]; - result = stateObj.TryReadByteArrayWithContinue(length, isPlp: false, out b); + if (length < 0 || length > TdsEnums.MAXSIZE) + { + throw SQL.ParsingErrorLength(ParsingErrorState.CorruptedTdsStream, length); + } + result = stateObj.TryReadByteArrayWithContinue(length, isPlp: false, out byte[] b); if (result != TdsOperationStatus.Done) { return result; @@ -9278,7 +9318,7 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD /// internal int WriteVectorSupportFeatureRequest(bool write) { - const int len = 6; + const int len = 6; if (write) { @@ -10476,7 +10516,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet { isSqlVal = param.ParameterIsSqlType; // We have to forward the TYPE info, we need to know what type we are returning. Once we null the parameter we will no longer be able to distinguish what type were seeing. - // Output parameter of SqlDbType vector are defined through SqlParameter.Value. + // Output parameter of SqlDbType vector are defined through SqlParameter.Value. // This check ensures that we do not discard the parameter value when SqlDbType is vector. if (mt.SqlDbType != SqlDbTypeExtensions.Vector) { @@ -10761,7 +10801,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet Debug.Assert(udtVal != null, "GetBytes returned null instance. Make sure that it always returns non-null value"); size = udtVal.Length; - + if (size >= maxSupportedSize && maxsize != -1) { throw SQL.UDTInvalidSize(maxsize, maxSupportedSize); @@ -13263,7 +13303,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { if (type.NullableType == TdsEnums.SQLJSON) { - // TODO : Performance and BOM check. Saurabh + // TODO : Performance and BOM check. Saurabh byte[] jsonAsBytes = Encoding.UTF8.GetBytes((string)value); WriteInt(jsonAsBytes.Length, stateObj); return stateObj.WriteByteArray(jsonAsBytes, jsonAsBytes.Length, 0, canAccumulate: false); @@ -13921,13 +13961,13 @@ internal TdsOperationStatus TryReadPlpUnicodeCharsWithContinue(TdsParserStateObj } TdsOperationStatus result = TryReadPlpUnicodeChars( - ref temp, - 0, - length >> 1, - stateObj, - out length, + ref temp, + 0, + length >> 1, + stateObj, + out length, supportRentedBuff: !canContinue, // do not use the arraypool if we are going to keep the buffer in the snapshot - rentedBuff: ref buffIsRented, + rentedBuff: ref buffIsRented, startOffset, canContinue ); @@ -14137,7 +14177,7 @@ bool writeDataSizeToSnapshot stateObj._longlenleft--; if (writeDataSizeToSnapshot) { - // we need to write the single b1 byte to the array because we may run out of data + // we need to write the single b1 byte to the array because we may run out of data // and need to wait for another packet buff[offst] = (char)((b1 & 0xff)); currentPacketId = IncrementSnapshotDataSize(stateObj, restartingDataSizeCount, currentPacketId, 1); diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs index 7d8941a07d..c85365d98e 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs @@ -14,6 +14,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; /// /// Tests connection routing using the enhanced routing feature extension and envchange token /// +// TODO: Do we need this collection? It serializes all tests within it, which we probably don't +// need since each test uses its own TDS Server with ephemeral listen port. [Collection("SimulatedServerTests")] public class ConnectionEnhancedRoutingTests { diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs index a8734f9b8a..8a41477f26 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs @@ -12,6 +12,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests { [Trait("Category", "flaky")] + // TODO: Do we need this collection? It serializes all tests within it, which we probably don't + // need since each test uses its own TDS Server with ephemeral listen port. [Collection("SimulatedServerTests")] public class ConnectionFailoverTests { @@ -173,7 +175,7 @@ public void NetworkTimeout_ShouldFail() InitialCatalog = "master",// Required for failover partner to work ConnectTimeout = 1, ConnectRetryInterval = 1, - ConnectRetryCount = 0, // Disable retry + ConnectRetryCount = 0, // Disable retry Encrypt = false, MultiSubnetFailover = false, #if NETFRAMEWORK @@ -460,7 +462,7 @@ public void TransientFault_WithUserProvidedPartner_ShouldConnectToPrimary(uint e FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", // User provided failover partner }; using SqlConnection connection = new(builder.ConnectionString); - + // Act connection.Open(); diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs index f0618ac269..8220871f22 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs @@ -11,6 +11,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests { + // TODO: Do we need this collection? It serializes all tests within it, which we probably don't + // need since each test uses its own TDS Server with ephemeral listen port. [Collection("SimulatedServerTests")] public class ConnectionReadOnlyRoutingTests { @@ -71,7 +73,7 @@ public void RecursivelyRoutedConnection(int layers) router.Start(); routingLayers.Push(router); lastEndpoint = router.EndPoint; - lastConnectionString = (new SqlConnectionStringBuilder() { + lastConnectionString = (new SqlConnectionStringBuilder() { DataSource = $"localhost,{lastEndpoint.Port}", ApplicationIntent = ApplicationIntent.ReadOnly, Encrypt = false @@ -114,8 +116,8 @@ public async Task RecursivelyRoutedAsyncConnection(int layers) router.Start(); routingLayers.Push(router); lastEndpoint = router.EndPoint; - lastConnectionString = (new SqlConnectionStringBuilder() { - DataSource = $"localhost,{lastEndpoint.Port}", + lastConnectionString = (new SqlConnectionStringBuilder() { + DataSource = $"localhost,{lastEndpoint.Port}", ApplicationIntent = ApplicationIntent.ReadOnly, Encrypt = false }).ConnectionString; diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs index 6d89246776..4f45e49782 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs @@ -11,6 +11,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests { [Trait("Category", "flaky")] + // TODO: Do we need this collection? It serializes all tests within it, which we probably don't + // need since each test uses its own TDS Server with ephemeral listen port. [Collection("SimulatedServerTests")] public class ConnectionRoutingTests { @@ -195,7 +197,7 @@ public void NetworkTimeoutAtRoutedLocation_RetryDisabled_ShouldFail() // Act var e = Assert.Throws(connection.Open); - // Assert + // Assert Assert.Equal(ConnectionState.Closed, connection.State); Assert.Contains("Connection Timeout Expired", e.Message); } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs index 31dfc366fc..74190e847b 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs @@ -11,6 +11,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests { [Trait("Category", "flaky")] + // TODO: Do we need this collection? It serializes all tests within it, which we probably don't + // need since each test uses its own TDS Server with ephemeral listen port. [Collection("SimulatedServerTests")] public class ConnectionRoutingTestsAzure : IDisposable { @@ -22,8 +24,8 @@ public ConnectionRoutingTestsAzure() adpHelper.AddAzureSqlServerEndpoint("localhost"); } - public void Dispose() - { + public void Dispose() + { adpHelper.Dispose(); } @@ -193,7 +195,7 @@ public void NetworkTimeoutAtRoutedLocation_RetryDisabled_ShouldFail() // Act var e = Assert.Throws(connection.Open); - // Assert + // Assert Assert.Equal(ConnectionState.Closed, connection.State); Assert.Contains("Connection Timeout Expired", e.Message); } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtAckBoundsTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtAckBoundsTests.cs new file mode 100644 index 0000000000..20c94c0361 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtAckBoundsTests.cs @@ -0,0 +1,163 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.SqlServer.TDS; +using Microsoft.SqlServer.TDS.FeatureExtAck; +using Microsoft.SqlServer.TDS.Servers; +using TDSDoneToken = global::Microsoft.SqlServer.TDS.Done.TDSDoneToken; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; + +/// +/// Tests that the TDS parser rejects feature extension acknowledgment tokens +/// with data lengths exceeding protocol-reasonable bounds. This prevents a +/// malicious server from causing unbounded memory allocation on the client. +/// +// TODO: Do we need this collection? It serializes all tests within it, which we probably don't +// need since each test uses its own TDS Server with ephemeral listen port. +[Collection("SimulatedServerTests")] +public class FeatureExtAckBoundsTests : IDisposable +{ + private readonly TdsServerFixture _fixture; + private readonly TdsServer _server; + private readonly string _connectionString; + + public FeatureExtAckBoundsTests() + { + _fixture = new TdsServerFixture(); + _server = _fixture.TdsServer; + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{_server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + Pooling = false + }; + _connectionString = builder.ConnectionString; + } + + public void Dispose() => _fixture.Dispose(); + + /// + /// Verifies that the TDS parser rejects a FeatureExtAck token whose data length + /// field exceeds (1 MB), throwing a + /// parsing error instead of attempting an unbounded heap allocation. + /// This guards against CVE denial-of-service via pre-auth memory exhaustion. + /// + [Fact] + public void FeatureExtAck_OversizedDataLength_ThrowsParsingError() + { + // Arrange: inject a malicious FeatureExtAck token with an absurdly large data length + _server.OnAuthenticationResponseCompleted = responseMessage => + { + // Remove any existing FeatureExtAck token + var existing = responseMessage.OfType().FirstOrDefault(); + if (existing != null) + { + responseMessage.Remove(existing); + } + + // Insert a malicious token with oversized data length before the DONE token + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + responseMessage.Insert(doneIndex, new MaliciousFeatureExtAckToken( + featureId: (TDSFeatureID)TdsEnums.FEATUREEXT_GLOBALTRANSACTIONS, + claimedDataLen: (uint)(TdsEnums.MaxTokenDataLength + 1))); + }; + + // Act & Assert: connection should fail with a parsing error, NOT an OutOfMemoryException + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + + // The exception message should indicate a corrupted TDS stream + // with the oversized length value, not an OOM from attempting the allocation. + Assert.Contains("Error state: 18", ex.Message); // ParsingErrorState.CorruptedTdsStream = 18 + Assert.Contains($"Length: {TdsEnums.MaxTokenDataLength + 1}", ex.Message); + } + + /// + /// Verifies that a FeatureExtAck token with a data length at exactly the + /// allowed maximum is accepted without error, confirming there is no + /// off-by-one in the bounds check. + /// + [Fact] + public void FeatureExtAck_MaxAllowedDataLength_DoesNotThrow() + { + // Arrange: inject a FeatureExtAck token whose declared data length equals + // MaxTokenDataLength exactly. The bounds check should NOT fire for this + // value. The connection will fail for other reasons (not enough data on + // the wire), but the error must NOT be state 18 (CorruptedTdsStream). + _server.OnAuthenticationResponseCompleted = responseMessage => + { + var existing = responseMessage.OfType().FirstOrDefault(); + if (existing != null) + { + responseMessage.Remove(existing); + } + + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Insert token with dataLen = MaxTokenDataLength (at boundary, not over) + responseMessage.Insert(doneIndex, new MaliciousFeatureExtAckToken( + featureId: (TDSFeatureID)TdsEnums.FEATUREEXT_GLOBALTRANSACTIONS, + claimedDataLen: (uint)TdsEnums.MaxTokenDataLength)); + }; + + using SqlConnection connection = new(_connectionString); + // The connection will fail (insufficient data for the claimed length), + // but the failure must NOT be the bounds-check error. + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.DoesNotContain("Error state: 18", ex.Message); + } + + /// + /// A custom TDS packet token that writes a FEATUREEXTACK token with a fraudulently + /// large data length field. This simulates a malicious server attempting to cause + /// the client to allocate an unbounded byte array. + /// + private sealed class MaliciousFeatureExtAckToken : TDSPacketToken + { + private readonly TDSFeatureID _featureId; + private readonly uint _claimedDataLen; + + public MaliciousFeatureExtAckToken(TDSFeatureID featureId, uint claimedDataLen) + { + _featureId = featureId; + _claimedDataLen = claimedDataLen; + } + + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // Write the FEATUREEXTACK token type (0xAE) + destination.WriteByte((byte)TDSTokenType.FeatureExtAck); + + // Write the feature ID byte + destination.WriteByte((byte)_featureId); + + // Write the claimed data length (uint32, little-endian) — this is the lie + byte[] lenBytes = BitConverter.GetBytes(_claimedDataLen); + destination.Write(lenBytes, 0, 4); + + // Write only 1 byte of actual data (the client will try to read _claimedDataLen bytes + // but we only provide 1 — the bounds check should fire before the read attempt) + destination.WriteByte(0x01); + + // Write terminator + destination.WriteByte((byte)TDSFeatureID.Terminator); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs index e6342c4fa8..0141d44000 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs @@ -13,6 +13,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; +// Serializes execution with other SimulatedServerTests classes to avoid port/resource conflicts. +// Required here because IClassFixture shares a single TdsServer instance across all tests in this class. [Collection("SimulatedServerTests")] public class FeatureExtensionNegotiationTests : IClassFixture { diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/TdsTokenBoundsTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/TdsTokenBoundsTests.cs new file mode 100644 index 0000000000..ef975dccb3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/TdsTokenBoundsTests.cs @@ -0,0 +1,1051 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.SqlServer.TDS; +using Microsoft.SqlServer.TDS.ColMetadata; +using Microsoft.SqlServer.TDS.Done; +using Microsoft.SqlServer.TDS.FeatureExtAck; +using Microsoft.SqlServer.TDS.Servers; +using Microsoft.SqlServer.TDS.SessionState; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; + +/// +/// Tests that the TDS parser rejects various token types with data lengths +/// exceeding protocol-reasonable bounds, preventing unbounded memory allocation +/// from a malicious server. +/// +// Serializes execution with other SimulatedServerTests classes. Required here because +// DebugAssertSuppressor mutates the global Trace.Listeners collection, which is not +// safe to do concurrently with other tests that may trigger Debug.Assert. +[Collection("SimulatedServerTests")] +public class TdsTokenBoundsTests : IDisposable +{ + private readonly TdsServerFixture _fixture; + private readonly TdsServer _server; + private readonly string _connectionString; + + public TdsTokenBoundsTests() + { + _fixture = new TdsServerFixture(); + _server = _fixture.TdsServer; + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{_server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + Pooling = false + }; + _connectionString = builder.ConnectionString; + } + + public void Dispose() => _fixture.Dispose(); + + // ────────────────────────────────────────────────────────────────────────── + // Test 1: SessionState token with oversized total length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessSessionState rejects a SessionState token (0xE4) + /// whose total length field exceeds (1 MB), + /// preventing unbounded memory allocation from a spoofed token length. + /// + [Fact] + public void SessionState_OversizedTotalLength_ThrowsParsingError() + { + _server.OnAuthenticationResponseCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Inject a SessionState token claiming a total length exceeding MaxTokenDataLength + responseMessage.Insert(doneIndex, new MaliciousSessionStateToken( + claimedTotalLength: (uint)(TdsEnums.MaxTokenDataLength + 1))); + }; + + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains($"Length: {TdsEnums.MaxTokenDataLength + 1}", ex.Message); + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 2: SessionState token with oversized inner option length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessSessionState rejects an individual session state + /// option whose inner data length (encoded via the 0xFF + DWORD path) exceeds + /// , even when the outer token length is valid. + /// + [Fact] + public void SessionState_OversizedInnerOptionLength_ThrowsParsingError() + { + _server.OnAuthenticationResponseCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Inject a SessionState token with a valid outer length but an inner + // state option claiming a huge data length + responseMessage.Insert(doneIndex, new MaliciousSessionStateInnerLenToken( + innerClaimedLength: TdsEnums.MaxTokenDataLength + 1)); + }; + + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains($"Length: {TdsEnums.MaxTokenDataLength + 1}", ex.Message); + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 2b: SessionState token with negative inner option length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessSessionState rejects an individual session state + /// option whose inner data length (encoded via the 0xFF + DWORD path) is negative, + /// which would be interpreted as a huge unsigned value if not bounds-checked. + /// + [Fact] + public void SessionState_NegativeInnerOptionLength_ThrowsParsingError() + { + _server.OnAuthenticationResponseCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Inject a SessionState token with an inner state option claiming + // a negative data length (-1 = 0xFFFFFFFF as uint32) + responseMessage.Insert(doneIndex, new MaliciousSessionStateInnerLenToken( + innerClaimedLength: -1)); + }; + + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains("Length: -1", ex.Message); + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 3: SRECOVERY feature ack with malformed inner state data + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that the secondary parse of FEATUREEXT_SRECOVERY data in + /// SqlConnectionInternal.OnFeatureExtAck rejects inner state options + /// whose claimed length exceeds the remaining buffer, preventing an + /// out-of-bounds read or over-allocation. + /// + [Fact] + public void SRecovery_MalformedInnerStateLength_ThrowsParsingError() + { + _server.OnAuthenticationResponseCompleted = responseMessage => + { + // Remove existing FeatureExtAck if present + var existing = responseMessage.OfType().FirstOrDefault(); + if (existing != null) + { + responseMessage.Remove(existing); + } + + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Inject a FeatureExtAck with SessionRecovery feature containing + // inner state data where a state option claims a length exceeding the buffer + responseMessage.Insert(doneIndex, new MaliciousSRecoveryFeatureExtAckToken()); + }; + + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains("Length: 999", ex.Message); + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 4: FedAuthInfo token with oversized total length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessFedAuthInfo rejects a FedAuthInfo token (0xEE) + /// whose total length exceeds (1 MB). + /// The token type is dispatched unconditionally by TryRun, so this check + /// fires regardless of whether federated authentication was negotiated. + /// + [Fact] + public void FedAuthInfo_OversizedTokenLength_ThrowsParsingError() + { + _server.OnAuthenticationResponseCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Inject a FedAuthInfo token with a total length exceeding MaxTokenDataLength. + // The parser dispatches on token type regardless of whether FedAuth was negotiated. + responseMessage.Insert(doneIndex, new MaliciousFedAuthInfoToken( + claimedLength: TdsEnums.MaxTokenDataLength + 1)); + }; + + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains($"Length: {TdsEnums.MaxTokenDataLength + 1}", ex.Message); + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 5: ENV_PROMOTETRANSACTION with oversized newLength + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessEnvChange rejects a PromoteTransaction + /// environment change token (type 15) whose inner newLength field exceeds + /// (64 KB). A malicious server + /// can set the outer uint16 token length to a small value while writing an + /// int32 inner length claiming gigabytes, causing unbounded allocation. + /// + [Fact] + public void EnvChange_PromoteTransaction_OversizedLength_ThrowsParsingError() + { + _server.OnAuthenticationResponseCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + // Inject a PromoteTransaction env change with a fraudulently large newLength + responseMessage.Insert(doneIndex, new MaliciousPromoteTransactionEnvChangeToken( + claimedNewLength: TdsEnums.MaxPromoteTransactionLength + 1)); + }; + + using SqlConnection connection = new(_connectionString); + Exception ex = Assert.ThrowsAny(connection.Open); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains($"Length: {TdsEnums.MaxPromoteTransactionLength + 1}", ex.Message); + } + + // ══════════════════════════════════════════════════════════════════════════ + // Malicious token helpers + // ══════════════════════════════════════════════════════════════════════════ + + /// + /// Writes a SessionState token (0xE4) with a fraudulently large total length. + /// Wire format: [0xE4][uint32 totalLen][uint32 seqNum][byte status][...] + /// The bounds check fires on the totalLen value before any data is read. + /// + private sealed class MaliciousSessionStateToken : TDSPacketToken + { + private readonly uint _claimedTotalLength; + + public MaliciousSessionStateToken(uint claimedTotalLength) + { + _claimedTotalLength = claimedTotalLength; + } + + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // Token type + destination.WriteByte((byte)TDSTokenType.SessionState); // 0xE4 + + // Total token length (uint32) — this is the fraudulent value + byte[] lenBytes = BitConverter.GetBytes(_claimedTotalLength); + destination.Write(lenBytes, 0, 4); + + // Write minimal valid-looking data: seqNum (4 bytes) + status (1 byte) + // The bounds check should fire before trying to process option data. + destination.Write(new byte[] { 0x01, 0x00, 0x00, 0x00 }, 0, 4); // seqNum = 1 + destination.WriteByte(0x01); // status = recoverable + } + } + + /// + /// Writes a SessionState token (0xE4) with a valid outer length but an inner + /// state option that claims a huge data length (using the 0xFF + DWORD encoding). + /// Wire format: [0xE4][uint32 totalLen][uint32 seqNum][byte status] + /// [byte stateId][0xFF][int32 innerLen][...data...] + /// The inner bounds check fires on the innerLen value. + /// + private sealed class MaliciousSessionStateInnerLenToken : TDSPacketToken + { + private readonly int _innerClaimedLength; + + public MaliciousSessionStateInnerLenToken(int innerClaimedLength) + { + _innerClaimedLength = innerClaimedLength; + } + + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // Token type + destination.WriteByte((byte)TDSTokenType.SessionState); // 0xE4 + + // Calculate total token length: + // seqNum(4) + status(1) + stateId(1) + lenMarker(1) + innerLen(4) + minimal data(1) + uint totalLength = 4 + 1 + 1 + 1 + 4 + 1; + + byte[] lenBytes = BitConverter.GetBytes(totalLength); + destination.Write(lenBytes, 0, 4); + + // Sequence number + destination.Write(new byte[] { 0x01, 0x00, 0x00, 0x00 }, 0, 4); + + // Status (recoverable) + destination.WriteByte(0x01); + + // State option: stateId + destination.WriteByte(0x00); // UserOptions state ID + + // Length marker: 0xFF means next 4 bytes are the DWORD length + destination.WriteByte(0xFF); + + // Inner claimed length — fraudulently large + byte[] innerLenBytes = BitConverter.GetBytes(_innerClaimedLength); + destination.Write(innerLenBytes, 0, 4); + + // Write only 1 byte of actual data + destination.WriteByte(0x42); + } + } + + /// + /// Writes a FeatureExtAck token (0xAE) with a SessionRecovery feature (ID=1) + /// that carries inner state data where a state option claims a length exceeding + /// the remaining buffer. This exercises the bounds check in + /// SqlConnectionInternal.OnFeatureExtAck for FEATUREEXT_SRECOVERY. + /// + private sealed class MaliciousSRecoveryFeatureExtAckToken : TDSPacketToken + { + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // Token type: FEATUREEXTACK + destination.WriteByte((byte)TDSTokenType.FeatureExtAck); // 0xAE + + // Feature ID: SessionRecovery = 0x01 + destination.WriteByte(0x01); + + // The feature data for SRECOVERY is parsed by SqlConnectionInternal.OnFeatureExtAck: + // It reads pairs of [stateId(1)][lenByte(1)][data(len)] or [stateId(1)][0xFF][int32 len][data(len)] + // We'll craft inner data where one option claims length > remaining buffer. + + // Inner data layout: + // stateId(1) + 0xFF marker(1) + int32 len(4) + 1 byte actual data + byte[] innerData; + using (var ms = new MemoryStream()) + { + ms.WriteByte(0x00); // stateId = 0 + + // Use 0xFF marker to indicate DWORD length + ms.WriteByte(0xFF); + + // Claim a length that exceeds the remaining buffer + // The remaining buffer after reading stateId + 0xFF + 4-byte-len will be 1 byte, + // but we claim 999 bytes + byte[] claimedLen = BitConverter.GetBytes(999); + ms.Write(claimedLen, 0, 4); + + // Only provide 1 byte of actual data + ms.WriteByte(0x42); + + innerData = ms.ToArray(); + } + + // Feature data length (uint32) + byte[] featureDataLen = BitConverter.GetBytes((uint)innerData.Length); + destination.Write(featureDataLen, 0, 4); + + // Feature data + destination.Write(innerData, 0, innerData.Length); + + // Terminator + destination.WriteByte((byte)TDSFeatureID.Terminator); // 0xFF + } + } + + /// + /// Writes a FedAuthInfo token (0xEE) with a fraudulently large total length. + /// Wire format: [0xEE][int32 tokenLen][...data...] + /// The bounds check fires on tokenLen before any data is read. + /// + private sealed class MaliciousFedAuthInfoToken : TDSPacketToken + { + private readonly int _claimedLength; + + public MaliciousFedAuthInfoToken(int claimedLength) + { + _claimedLength = claimedLength; + } + + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // Token type + destination.WriteByte(0xEE); // SQLFEDAUTHINFO + + // Token length (int32) — fraudulently large + byte[] lenBytes = BitConverter.GetBytes(_claimedLength); + destination.Write(lenBytes, 0, 4); + + // Write minimal data (at least sizeof(uint) to pass the lower bound check, + // but the upper bound check should fire first) + destination.Write(new byte[] { 0x01, 0x00, 0x00, 0x00 }, 0, 4); // optionsCount = 1 + } + } + + /// + /// Writes an EnvChange token (0xE3) with type PromoteTransaction (15) whose + /// inner int32 newLength exceeds . + /// Wire format: [0xE3][uint16 tokenLen][byte type=15][int32 newLen][data...][byte oldLen=0] + /// The outer uint16 tokenLen is set to accommodate the header but the int32 + /// newLength claims far more data than actually follows, triggering the bounds check. + /// + private sealed class MaliciousPromoteTransactionEnvChangeToken : TDSPacketToken + { + private readonly int _claimedNewLength; + + public MaliciousPromoteTransactionEnvChangeToken(int claimedNewLength) + { + _claimedNewLength = claimedNewLength; + } + + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // Token type: ENVCHANGE + destination.WriteByte((byte)TDSTokenType.EnvironmentChange); // 0xE3 + + // Outer token length (uint16): type(1) + newLength(4) + 1 byte data + oldLength(1) + // We write just enough to contain the header fields the parser reads before + // hitting the bounds check. The parser reads: type(1) + int32 newLen(4) = 5 bytes min. + ushort outerLength = 1 + 4 + 1 + 1; // type + newLen + 1 fake byte + oldLen + byte[] outerLenBytes = BitConverter.GetBytes(outerLength); + destination.Write(outerLenBytes, 0, 2); + + // Env change type: PromoteTransaction = 15 + destination.WriteByte(15); + + // newLength (int32) — fraudulently large + byte[] newLenBytes = BitConverter.GetBytes(_claimedNewLength); + destination.Write(newLenBytes, 0, 4); + + // Write 1 byte of fake data (the bounds check fires before attempting to read this much) + destination.WriteByte(0x00); + + // Old value length (byte) = 0 + destination.WriteByte(0x00); + } + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 6: Post-login batch response injection — PromoteTransaction + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that bounds checks fire during command execution (post-login) + /// by injecting a malicious PromoteTransaction env change token into the + /// SQL batch response. This exercises the OnSQLBatchCompleted hook + /// on the simulated server and proves the same bounds check fires regardless + /// of whether the token arrives during login or command execution. + /// + [Fact] + public void BatchResponse_PromoteTransaction_OversizedLength_ThrowsParsingError() + { + _server.OnSQLBatchCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + responseMessage.Insert(doneIndex, new MaliciousPromoteTransactionEnvChangeToken( + claimedNewLength: TdsEnums.MaxPromoteTransactionLength + 1)); + }; + + using SqlConnection connection = new(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "SELECT 1"; + + Exception ex = Assert.ThrowsAny( + () => command.ExecuteNonQuery()); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains($"Length: {TdsEnums.MaxPromoteTransactionLength + 1}", ex.Message); + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 7: Post-login batch response — DateTime oversized length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryReadSqlDateTime rejects a TIME column value whose + /// data length byte exceeds the maximum datetime wire size (10 bytes). A + /// malicious server could set this length to a large value, causing the + /// parser to attempt an unbounded heap allocation. + /// + [Fact] + public void BatchResponse_DateTime_OversizedLength_ThrowsParsingError() + { + _server.OnSQLBatchCompleted = responseMessage => + { + // Replace the entire response with a crafted result set containing + // a TIME column whose ROW data length exceeds the maximum. + responseMessage.Clear(); + + // Use proper library tokens for COLMETADATA so framing is correct + var metadata = new TDSColMetadataToken(); + var col = new TDSColumnData(); + col.DataType = TDSDataType.TimeN; + col.DataTypeSpecific = (byte)7; // scale = 7 + col.Flags.IsNullable = true; + col.Name = string.Empty; + metadata.Columns.Add(col); + responseMessage.Add(metadata); + + // Add a malicious ROW token with oversized data length + responseMessage.Add(new MaliciousTimeRowToken()); + + // Add DONE token + responseMessage.Add(new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Count, TDSDoneTokenCommandType.Select, 1)); + }; + + using SqlConnection connection = new(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "MALICIOUS_QUERY_NOT_RECOGNIZED"; + + SqlDataReader reader = command.ExecuteReader(); + try + { + // Read() returns true (data is ready) but doesn't actually parse the + // column values — that happens on GetValue/GetTimeSpan. Force the read. + Assert.True(reader.Read()); + Exception ex = Assert.ThrowsAny( + () => reader.GetValue(0)); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains("Length: 11", ex.Message); + } + finally + { + // Disposing the reader after a corrupted stream causes the driver to + // attempt further TDS parsing during teardown, which can trip unrelated + // Debug.Assert calls in TdsParser. + using (new DebugAssertSuppressor()) + { + try { reader.Dispose(); } catch { } + } + } + } + + /// + /// Writes a ROW token (0xD1) with a single TimeN column whose data length + /// byte is set to 11 (exceeding the maximum datetime wire size of 10 bytes). + /// + private sealed class MaliciousTimeRowToken : TDSPacketToken + { + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // ROW token type + destination.WriteByte(0xD1); + + // TimeN data: length prefix (1 byte) = 11 (INVALID — max for time is 5, max datetime overall is 10) + destination.WriteByte(11); + + // Write 11 bytes of dummy data + destination.Write(new byte[11], 0, 11); + } + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 8: Post-login batch response — ReturnValue with oversized data length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessReturnValue rejects a RETURNVALUE token (0xAC) + /// whose inner data length (for a non-PLP IMAGE type) exceeds int.MaxValue. + /// A malicious server can craft a TEXT/IMAGE return value with a spoofed int32 + /// data length that becomes a huge value when cast to ulong, triggering unbounded + /// allocation. + /// + [Fact] + public void BatchResponse_ReturnValue_OversizedLength_ThrowsParsingError() + { + _server.OnSQLBatchCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + responseMessage.Insert(doneIndex, new MaliciousReturnValueToken()); + }; + + using SqlConnection connection = new(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "SELECT 1"; + + // The malicious RETURNVALUE token corrupts parser state before the + // exception propagates. Suppress Debug.Assert calls that fire in + // TdsParser during error handling and connection teardown. + using (new DebugAssertSuppressor()) + { + Exception ex = Assert.ThrowsAny( + () => command.ExecuteNonQuery()); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains("Length: -1", ex.Message); + } + } + + /// + /// Writes a RETURNVALUE token (0xAC) with an IMAGE (0x22) type whose data + /// length field is set to -1 (0xFFFFFFFF). When cast to ulong this exceeds + /// int.MaxValue, triggering the bounds check in TryProcessReturnValue. + /// Wire layout: + /// [0xAC] token + /// [uint16] ordinal + /// [byte] param name length (0) + /// [byte] status + /// [uint32] user type + /// [byte] flags1 + /// [byte] flags2 + /// [byte] tds type = 0x22 (IMAGE) + /// [int32] max length + /// [byte] textPtrLen = 16 + /// [16 bytes] textPtr + /// [8 bytes] timestamp + /// [int32] data length = -1 (INVALID) + /// + private sealed class MaliciousReturnValueToken : TDSPacketToken + { + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // RETURNVALUE token + destination.WriteByte(0xAC); + + // Ordinal (uint16 LE) + destination.WriteByte(0x00); + destination.WriteByte(0x00); + + // Param name length (byte) = 0 + destination.WriteByte(0x00); + + // Status (byte) = 0x01 (output parameter) + destination.WriteByte(0x01); + + // UserType (uint32 LE) = 0 + destination.Write(new byte[4], 0, 4); + + // Flags byte 1 (ignored) + destination.WriteByte(0x00); + + // Flags byte 2 + destination.WriteByte(0x00); + + // TDS type = SQLIMAGE (0x22) + destination.WriteByte(0x22); + + // MaxLen (int32 LE) — for IMAGE this is read via TryGetTokenLength + // which for 0x22 reads int32. Value doesn't matter much, just needs + // to be valid for MetaType lookup. + destination.Write(new byte[] { 0x10, 0x00, 0x00, 0x00 }, 0, 4); // 16 + + // -- TryProcessColumnHeaderNoNBC: IsLong && !IsPlp path -- + // TextPtr length (byte) = 16 + destination.WriteByte(0x10); + + // TextPtr data (16 bytes) + destination.Write(new byte[16], 0, 16); + + // Timestamp (8 bytes) + destination.Write(new byte[8], 0, 8); + + // Data length (int32 LE) = -1 (0xFFFFFFFF) + // (ulong)(-1) = 0xFFFFFFFFFFFFFFFF > int.MaxValue → triggers bounds check + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + } + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 9: Post-login batch response — PLP ReturnValue (regression guard) + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryProcessReturnValue correctly handles PLP + /// (Partially Length-Prefixed) return values without triggering the bounds + /// check. PLP types use the unknown-length sentinel (0xFFFFFFFFFFFFFFFE) + /// which must NOT be rejected by the non-PLP bounds check. + /// + [Fact] + public void BatchResponse_ReturnValue_PlpUnknownLength_Succeeds() + { + _server.OnSQLBatchCompleted = responseMessage => + { + int doneIndex = responseMessage.FindIndex(t => t is TDSDoneToken); + if (doneIndex < 0) + { + doneIndex = responseMessage.Count; + } + + responseMessage.Insert(doneIndex, new ValidPlpReturnValueToken()); + }; + + using SqlConnection connection = new(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "SELECT 1"; + + // Should NOT throw — PLP unknown-length sentinel is valid + command.ExecuteNonQuery(); + } + + /// + /// Writes a valid RETURNVALUE token (0xAC) with a PLP VARBINARY(MAX) type + /// using the unknown-length sentinel (0xFFFFFFFFFFFFFFFE) followed by an + /// immediate PLP terminator (chunk length = 0). This exercises the IsPlp + /// branch in TryProcessReturnValue and must NOT trigger the bounds check. + /// Wire layout: + /// [0xAC] token + /// [uint16] ordinal + /// [byte] param name length (0) + /// [byte] status + /// [uint32] user type + /// [byte] flags1 + /// [byte] flags2 + /// [byte] tds type = 0xA5 (BIGVARBINARY) + /// [uint16] max length = 0xFFFF (PLP marker) + /// [uint64] PLP length = 0xFFFFFFFFFFFFFFFE (unknown) + /// [uint32] chunk length = 0 (terminator) + /// + private sealed class ValidPlpReturnValueToken : TDSPacketToken + { + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // RETURNVALUE token + destination.WriteByte(0xAC); + + // Ordinal (uint16 LE) + destination.WriteByte(0x00); + destination.WriteByte(0x00); + + // Param name length (byte) = 0 + destination.WriteByte(0x00); + + // Status (byte) = 0x01 (output parameter) + destination.WriteByte(0x01); + + // UserType (uint32 LE) = 0 + destination.Write(new byte[4], 0, 4); + + // Flags byte 1 (ignored) + destination.WriteByte(0x00); + + // Flags byte 2 + destination.WriteByte(0x00); + + // TDS type = SQLBIGVARBINARY (0xA5) + destination.WriteByte(0xA5); + + // MaxLen (uint16 LE) = 0xFFFF — PLP marker + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + + // -- TryProcessColumnHeaderNoNBC: non-IsLong path -- + // TryGetDataLength → TryReadPlpLength: + // reads uint64 = 0xFFFFFFFFFFFFFFFE (unknown length sentinel) + destination.WriteByte(0xFE); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + destination.WriteByte(0xFF); + + // PLP chunk terminator (uint32 = 0) — empty data + destination.Write(new byte[4], 0, 4); + } + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 10: Post-login batch response — sql_variant with oversized binary + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryReadSqlValueInternal rejects a binary value inside + /// a sql_variant column whose inner data length exceeds + /// (8000 bytes). The bounds check in the + /// sql_variant deserialization path prevents unbounded heap allocation. + /// + [Fact] + public void BatchResponse_SqlVariantBinary_OversizedLength_ThrowsParsingError() + { + _server.OnSQLBatchCompleted = responseMessage => + { + responseMessage.Clear(); + + // COLMETADATA: one SSVariant column + var metadata = new TDSColMetadataToken(); + var col = new TDSColumnData(); + col.DataType = TDSDataType.SSVariant; + col.DataTypeSpecific = (uint)8009; // max length for SSVariant + col.Flags.IsNullable = true; + col.Name = string.Empty; + metadata.Columns.Add(col); + responseMessage.Add(metadata); + + // ROW with a sql_variant containing oversized binary data + responseMessage.Add(new MaliciousSqlVariantBinaryRowToken()); + + // DONE + responseMessage.Add(new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Count, TDSDoneTokenCommandType.Select, 1)); + }; + + using SqlConnection connection = new(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "MALICIOUS_QUERY_NOT_RECOGNIZED"; + + SqlDataReader reader = command.ExecuteReader(); + try + { + Assert.True(reader.Read()); + Exception ex = Assert.ThrowsAny( + () => reader.GetValue(0)); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains("Length: 8001", ex.Message); + } + finally + { + // Disposing the reader after a corrupted stream causes the driver to + // attempt further TDS parsing during teardown, which can trip unrelated + // Debug.Assert calls in TdsParser. + using (new DebugAssertSuppressor()) + { + try { reader.Dispose(); } catch { } + } + } + } + + // ────────────────────────────────────────────────────────────────────────── + // Test 11: Post-login batch response — sql_variant with negative inner length + // ────────────────────────────────────────────────────────────────────────── + + /// + /// Verifies that TryReadSqlValueInternal rejects a sql_variant binary + /// value whose declared total length is too small for the type overhead, + /// causing the computed inner data length to be negative. The bounds check + /// if (length < 0 || length > TdsEnums.MAXSIZE) catches this. + /// + [Fact] + public void BatchResponse_SqlVariantBinary_NegativeLength_ThrowsParsingError() + { + _server.OnSQLBatchCompleted = responseMessage => + { + responseMessage.Clear(); + + // COLMETADATA: one SSVariant column + var metadata = new TDSColMetadataToken(); + var col = new TDSColumnData(); + col.DataType = TDSDataType.SSVariant; + col.DataTypeSpecific = (uint)8009; // max length for SSVariant + col.Flags.IsNullable = true; + col.Name = string.Empty; + metadata.Columns.Add(col); + responseMessage.Add(metadata); + + // ROW with a sql_variant claiming a total length too small for its overhead + responseMessage.Add(new NegativeLenSqlVariantBinaryRowToken()); + + // DONE + responseMessage.Add(new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Count, TDSDoneTokenCommandType.Select, 1)); + }; + + using SqlConnection connection = new(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "MALICIOUS_QUERY_NOT_RECOGNIZED"; + + SqlDataReader reader = command.ExecuteReader(); + try + { + Assert.True(reader.Read()); + Exception ex = Assert.ThrowsAny( + () => reader.GetValue(0)); + Assert.Contains("Error state: 18", ex.Message); // CorruptedTdsStream + Assert.Contains("Length: -1", ex.Message); + } + finally + { + // Disposing the reader after a corrupted stream causes the driver to + // attempt further TDS parsing during teardown, which can trip unrelated + // Debug.Assert calls in TdsParser. + using (new DebugAssertSuppressor()) + { + try { reader.Dispose(); } catch { } + } + } + } + + /// + /// Writes a ROW token (0xD1) with a single SSVariant column containing a + /// BigVarBinary variant whose inner data length exceeds MAXSIZE (8000). + /// Wire layout for the variant: + /// [int32] total variant length = 8005 + /// [byte] inner type = 0xA5 (BigVarBinary) + /// [byte] cbPropBytes = 2 + /// [ushort] maxLen (property) = 8001 + /// [8001 bytes would be data, but we only write 4 to trigger the check] + /// lenData = 8005 - 2(SQLVARIANT_SIZE) - 2(cbProps) = 8001 > MAXSIZE → throws + /// + private sealed class MaliciousSqlVariantBinaryRowToken : TDSPacketToken + { + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // ROW token type + destination.WriteByte(0xD1); + + // SSVariant column data: total length (int32 LE) + // lenData = totalLength - SQLVARIANT_SIZE(2) - cbPropBytes(2) = totalLength - 4 + // We want lenData = 8001, so totalLength = 8005 + int totalLength = 8005; + byte[] lenBytes = BitConverter.GetBytes(totalLength); + destination.Write(lenBytes, 0, 4); + + // Inner type: BigVarBinary = 0xA5 + destination.WriteByte(0xA5); + + // cbPropBytes = 2 + destination.WriteByte(0x02); + + // Properties: maxLen (ushort) = 8001 + destination.WriteByte(0x41); // 8001 & 0xFF = 0x41 + destination.WriteByte(0x1F); // 8001 >> 8 = 0x1F + + // Write 4 bytes of dummy data (bounds check fires before trying to read 8001) + destination.Write(new byte[4], 0, 4); + } + } + + /// + /// Writes a ROW token (0xD1) with a single SSVariant column containing a + /// BigVarBinary variant whose total length (3) is too small to cover the + /// type overhead (SQLVARIANT_SIZE=2 + cbPropBytes=2 = 4), producing a + /// negative computed inner data length: lenData = 3 - 4 = -1. + /// + private sealed class NegativeLenSqlVariantBinaryRowToken : TDSPacketToken + { + public override bool Inflate(Stream source) => throw new NotSupportedException(); + + public override void Deflate(Stream destination) + { + // ROW token type + destination.WriteByte(0xD1); + + // SSVariant column data: total length (int32 LE) + // lenConsumed = SQLVARIANT_SIZE(2) + cbPropBytes(2) = 4 + // lenData = totalLength - lenConsumed = 3 - 4 = -1 → triggers < 0 check + int totalLength = 3; + byte[] lenBytes = BitConverter.GetBytes(totalLength); + destination.Write(lenBytes, 0, 4); + + // Inner type: BigVarBinary = 0xA5 + destination.WriteByte(0xA5); + + // cbPropBytes = 2 + destination.WriteByte(0x02); + + // Properties: maxLen (ushort) = 100 (arbitrary, just need 2 bytes) + destination.WriteByte(0x64); // 100 & 0xFF + destination.WriteByte(0x00); // 100 >> 8 + + // No data bytes — the bounds check fires before attempting to read + } + } + + /// + /// Temporarily suppresses Debug.Assert failures by clearing trace listeners. Used when + /// disposing resources after intentionally corrupting a TDS stream. A static lock serializes + /// access for the lifetime of the instance because Trace.Listeners is a global collection. + /// + private sealed class DebugAssertSuppressor : IDisposable + { + private static readonly object s_listenerLock = new(); + private readonly System.Diagnostics.TraceListener[] _listeners; + + public DebugAssertSuppressor() + { + System.Threading.Monitor.Enter(s_listenerLock); + try + { + _listeners = new System.Diagnostics.TraceListener[System.Diagnostics.Trace.Listeners.Count]; + System.Diagnostics.Trace.Listeners.CopyTo(_listeners, 0); + System.Diagnostics.Trace.Listeners.Clear(); + } + catch + { + System.Threading.Monitor.Exit(s_listenerLock); + throw; + } + } + + public void Dispose() + { + try + { + System.Diagnostics.Trace.Listeners.Clear(); + System.Diagnostics.Trace.Listeners.AddRange(_listeners); + } + finally + { + System.Threading.Monitor.Exit(s_listenerLock); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs index d2c81e6ca1..3ec8f6efaf 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs @@ -139,6 +139,13 @@ public GenericTdsServer(T arguments, QueryEngine queryEngine) public OnAuthenticationCompletedDelegate OnAuthenticationResponseCompleted { private get; set; } + /// + /// Delegate invoked after a SQL batch response is prepared but before it is + /// sent to the client. Tests can use this to inject or replace tokens in the + /// response message. + /// + public Action OnSQLBatchCompleted { get; set; } + public OnLogin7ValidatedDelegate OnLogin7Validated { private get; set; } @@ -451,6 +458,9 @@ public virtual TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session, session.PacketSize = (uint)Arguments.PacketSize; } + // Allow tests to modify or inject tokens into the response + OnSQLBatchCompleted?.Invoke(responseMessage[0]); + return responseMessage; }