From a86441f0e96e4b475f5e244d5a34d1f05696ddaa Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Mon, 17 Mar 2025 17:15:05 -0700 Subject: [PATCH 1/4] Reduce unsafe usage in StreamUtils --- src/Microsoft.ML.Core/Utilities/Stream.cs | 258 +++++----------------- 1 file changed, 58 insertions(+), 200 deletions(-) diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 93af2a167b..e94ae91f76 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -3,9 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections; using System.Collections.Generic; using System.IO; +using System.Runtime.InteropServices; using System.Text; using System.Threading; using Microsoft.ML.Runtime; @@ -14,8 +16,6 @@ namespace Microsoft.ML.Internal.Utilities { internal static partial class Utils { - private const int _bulkReadThresholdInBytes = 4096; - public static void CloseEx(this Stream stream) { if (stream == null) @@ -468,7 +468,7 @@ public static float[] ReadFloatArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadFloatArray(reader, size); } @@ -482,22 +482,7 @@ public static float[] ReadFloatArray(this BinaryReader reader, int size) return null; var values = new float[size]; - long bufferSizeInBytes = (long)size * sizeof(float); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = reader.ReadFloat(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -509,67 +494,24 @@ public static void ReadFloatArray(this BinaryReader reader, float[] array, int s Contracts.Assert(0 <= start && start < array.Length); Contracts.Assert(0 < count && count <= array.Length - start); - long bufferReadLengthInBytes = (long)count * sizeof(float); - if (bufferReadLengthInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < count; i++) - array[start + i] = reader.ReadFloat(); - } - else - { - unsafe - { - fixed (void* dst = array) - { - long bufferBeginOffsetInBytes = (long)start * sizeof(float); - long bufferSizeInBytes = ((long)array.Length - start) * sizeof(float); - ReadBytes(reader, (byte*)dst + bufferBeginOffsetInBytes, bufferSizeInBytes, bufferReadLengthInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, array.AsSpan(start, count)); } public static float[] ReadSingleArray(this BinaryReader reader) { - Contracts.AssertValue(reader); - int size = reader.ReadInt32(); - Contracts.CheckDecode(size >= 0); - return ReadSingleArray(reader, size); + return reader.ReadFloatArray(); } public static float[] ReadSingleArray(this BinaryReader reader, int size) { - Contracts.AssertValue(reader); - Contracts.Assert(size >= 0); - if (size == 0) - return null; - var values = new float[size]; - - long bufferSizeInBytes = (long)size * sizeof(float); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = reader.ReadSingle(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } - - return values; + return reader.ReadFloatArray(size); } public static double[] ReadDoubleArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadDoubleArray(reader, size); } @@ -582,22 +524,7 @@ public static double[] ReadDoubleArray(this BinaryReader reader, int size) return null; var values = new double[size]; - long bufferSizeInBytes = (long)size * sizeof(double); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = reader.ReadDouble(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -606,7 +533,7 @@ public static int[] ReadIntArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadIntArray(reader, size); } @@ -620,22 +547,7 @@ public static int[] ReadIntArray(this BinaryReader reader, int size) return null; var values = new int[size]; - long bufferSizeInBytes = (long)size * sizeof(int); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = reader.ReadInt32(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -644,7 +556,7 @@ public static uint[] ReadUIntArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadUIntArray(reader, size); } @@ -658,22 +570,7 @@ public static uint[] ReadUIntArray(this BinaryReader reader, int size) return null; var values = new uint[size]; - long bufferSizeInBytes = (long)size * sizeof(uint); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = reader.ReadUInt32(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -682,7 +579,7 @@ public static long[] ReadLongArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadLongArray(reader, size); } @@ -696,22 +593,7 @@ public static long[] ReadLongArray(this BinaryReader reader, int size) return null; var values = new long[size]; - long bufferSizeInBytes = (long)size * sizeof(long); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = reader.ReadInt64(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -720,7 +602,7 @@ public static bool[] ReadBoolArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadBoolArray(reader, size); } @@ -734,28 +616,9 @@ public static bool[] ReadBoolArray(this BinaryReader reader, int size) return null; var values = new bool[size]; - long bufferSizeInBytes = (long)size * sizeof(bool); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - { - byte b = reader.ReadByte(); - Contracts.CheckDecode(b <= 1); - values[i] = b != 0; - } - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - for (long i = 0; i < size; i++) - Contracts.CheckDecode(*((byte*)dst + i) <= 1); - } - } - } + // It is in general not safe to populate a bool[] with untrusted input. + // The call below assumes the input stream is trusted. + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -764,7 +627,7 @@ public static char[] ReadCharArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadCharArray(reader, size); } @@ -778,22 +641,7 @@ public static char[] ReadCharArray(this BinaryReader reader, int size) return null; var values = new char[size]; - long bufferSizeInBytes = (long)size * sizeof(char); - if (bufferSizeInBytes < _bulkReadThresholdInBytes) - { - for (int i = 0; i < size; i++) - values[i] = (char)reader.ReadInt16(); - } - else - { - unsafe - { - fixed (void* dst = values) - { - ReadBytes(reader, dst, bufferSizeInBytes, bufferSizeInBytes); - } - } - } + ReadBinaryDataIntoSpan(reader, values.AsSpan()); return values; } @@ -802,7 +650,7 @@ public static byte[] ReadByteArray(this BinaryReader reader) { Contracts.AssertValue(reader); - int size = reader.ReadInt32(); + int size = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(size >= 0); return ReadByteArray(reader, size); } @@ -821,49 +669,59 @@ public static byte[] ReadByteArray(this BinaryReader reader, int size) public static BitArray ReadBitArray(this BinaryReader reader) { - int numBits = reader.ReadInt32(); + int numBits = reader.ReadInt32(); // reading trusted capacity value from data stream Contracts.CheckDecode(numBits >= 0); if (numBits == 0) return null; - var numBytes = (numBits + 7) / 8; + var numBytes = (numBits + 7) / 8; // trusted capacity value expected not to integer overflow var bytes = reader.ReadByteArray(numBytes); var returnArray = new BitArray(bytes); returnArray.Length = numBits; return returnArray; } - public static unsafe void ReadBytes(this BinaryReader reader, void* destination, long destinationSizeInBytes, long bytesToRead, ref byte[] work) + private static unsafe void ReadBinaryDataIntoSpan(BinaryReader reader, Span destination) where T : unmanaged { Contracts.AssertValue(reader); - Contracts.Assert(bytesToRead >= 0); - Contracts.Assert(destinationSizeInBytes >= bytesToRead); - Contracts.Assert(destination != null); - Contracts.AssertValueOrNull(work); + Contracts.Assert(!destination.IsEmpty); + + // There are two considerations here. First, we want to keep all temporary arrays (even pooled arrays) + // under some threshold size. Second, when we project the Span to bytes, we need to do it in chunks, + // as trying to project the entire span at once will lead to integer overflow if the byte length + // exceeds int.MaxLength. - // Size our read buffer to 70KB to stay off the LOH. - const int blockSize = 70 * 1024; - int desiredWorkSize = (int)Math.Min(blockSize, bytesToRead); - EnsureSize(ref work, desiredWorkSize); + const int maxChunkSizeInBytes = 70 * 1024; + int maxChunkSizeInElements = maxChunkSizeInBytes / sizeof(T); + Contracts.Assert(maxChunkSizeInElements > 0, "Unexpectedly large T."); - fixed (void* src = work) + // Rent a byte[] instead of a T[] to allow reuse of buffers across different types T. + byte[] rentedArray = ArrayPool.Shared.Rent(maxChunkSizeInElements * sizeof(T)); + try { - long offset = 0; - while (offset < bytesToRead) + while (!destination.IsEmpty) { - int toRead = (int)Math.Min(bytesToRead - offset, blockSize); - int read = reader.Read(work, 0, toRead); - Contracts.CheckDecode(read == toRead); - Buffer.MemoryCopy(src, (byte*)destination + offset, destinationSizeInBytes - offset, read); - offset += read; + int numElementsToReadThisChunk = Math.Min(maxChunkSizeInElements, destination.Length); + int rentedArrayOffset = 0; + int numBytesRemainingToReadThisChunk = numElementsToReadThisChunk * sizeof(T); // n.b. not necessarily populating the entire rented array + + do + { + int numBytesReadJustNow = reader.Read(rentedArray, rentedArrayOffset, numBytesRemainingToReadThisChunk); + rentedArrayOffset += numBytesReadJustNow; + numBytesRemainingToReadThisChunk -= numBytesReadJustNow; + } while (numBytesRemainingToReadThisChunk > 0); + + // Copy the rented array to the destination span (projected as bytes). + // This projection as bytes is safe as long as T is a primitive numeric type (integers, floats). + // Avoid projecting the Span as a Span to avoid potential alignment issues. + rentedArray.AsSpan(0, numElementsToReadThisChunk * sizeof(T)).CopyTo(MemoryMarshal.AsBytes(destination.Slice(0, numElementsToReadThisChunk))); + destination = destination.Slice(numElementsToReadThisChunk); } - Contracts.Assert(offset == bytesToRead); } - } - - public static unsafe void ReadBytes(this BinaryReader reader, void* destination, long destinationSizeInBytes, long bytesToRead) - { - byte[] work = null; - ReadBytes(reader, destination, destinationSizeInBytes, bytesToRead, ref work); + finally + { + ArrayPool.Shared.Return(rentedArray); + } } /// From f370efc3e9b6cbe31e987f71fd369a723aa996b0 Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Mon, 17 Mar 2025 17:58:25 -0700 Subject: [PATCH 2/4] Remove unsafe code from ToByteArrayExtensions --- .../Utils/ToByteArrayExtensions.cs | 365 +++++++----------- 1 file changed, 146 insertions(+), 219 deletions(-) diff --git a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs index 3855b3a125..82852cd745 100644 --- a/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs +++ b/src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs @@ -4,6 +4,7 @@ using System; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using Microsoft.ML.Internal.Utilities; @@ -47,13 +48,11 @@ public static int SizeInBytes(this short a) return sizeof(short); } - public static unsafe void ToByteArray(this short a, byte[] buffer, ref int position) + public static void ToByteArray(this short a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - short* pDest = (short*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = short. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(short); } @@ -71,13 +70,11 @@ public static int SizeInBytes(this ushort a) return sizeof(ushort); } - public static unsafe void ToByteArray(this ushort a, byte[] buffer, ref int position) + public static void ToByteArray(this ushort a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - ushort* pDest = (ushort*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = ushort. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(ushort); } @@ -95,24 +92,17 @@ public static int SizeInBytes(this int a) return sizeof(int); } - public static unsafe void ToByteArray(this int a, byte[] buffer, ref int position) + public static void ToByteArray(this int a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - int* pDest = (int*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = int. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(int); } - public static unsafe int ToInt(this byte[] buffer, ref int position) + public static int ToInt(this byte[] buffer, ref int position) { - int a; - fixed (byte* pBuffer = buffer) - { - int* pIntBuffer = (int*)(pBuffer + position); - a = *pIntBuffer; - } + int a = BitConverter.ToInt32(buffer, position); position += sizeof(int); return a; } @@ -124,24 +114,17 @@ public static int SizeInBytes(this uint a) return sizeof(uint); } - public static unsafe void ToByteArray(this uint a, byte[] buffer, ref int position) + public static void ToByteArray(this uint a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - uint* pDest = (uint*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = uint. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(uint); } - public static unsafe uint ToUInt(this byte[] buffer, ref int position) + public static uint ToUInt(this byte[] buffer, ref int position) { - uint a; - fixed (byte* pBuffer = buffer) - { - uint* pIntBuffer = (uint*)(pBuffer + position); - a = *pIntBuffer; - } + uint a = BitConverter.ToUInt32(buffer, position); position += sizeof(uint); return a; } @@ -153,13 +136,11 @@ public static int SizeInBytes(this long a) return sizeof(long); } - public static unsafe void ToByteArray(this long a, byte[] buffer, ref int position) + public static void ToByteArray(this long a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - long* pDest = (long*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = long. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(long); } @@ -177,13 +158,11 @@ public static int SizeInBytes(this ulong a) return sizeof(ulong); } - public static unsafe void ToByteArray(this ulong a, byte[] buffer, ref int position) + public static void ToByteArray(this ulong a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - ulong* pDest = (ulong*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = ulong. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(ulong); } @@ -201,13 +180,11 @@ public static int SizeInBytes(this float a) return sizeof(float); } - public static unsafe void ToByteArray(this float a, byte[] buffer, ref int position) + public static void ToByteArray(this float a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - float* pDest = (float*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = float. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(float); } @@ -225,13 +202,11 @@ public static int SizeInBytes(this double a) return sizeof(double); } - public static unsafe void ToByteArray(this double a, byte[] buffer, ref int position) + public static void ToByteArray(this double a, byte[] buffer, ref int position) { - fixed (byte* pBuffer = buffer) - { - double* pDest = (double*)(pBuffer + position); - *pDest = a; - } + // Per docs, MemoryMarshal.Write<...> is safe for T: = double. + // It writes machine-endian and handles unaligned byte buffers properly. + MemoryMarshal.Write(buffer.AsSpan(position), ref a); position += sizeof(double); } @@ -246,7 +221,7 @@ public static double ToDouble(this byte[] buffer, ref int position) public static int SizeInBytes(this string a) { - return sizeof(int) + Encoding.Unicode.GetByteCount(a); + return checked(sizeof(int) + Encoding.Unicode.GetByteCount(a)); } public static void ToByteArray(this string a, byte[] buffer, ref int position) @@ -279,7 +254,7 @@ public static string ToString(this byte[] buffer, ref int position) public static int SizeInBytes(this byte[] a) { - return sizeof(int) + Utils.Size(a) * sizeof(byte); + return checked(sizeof(int) + Utils.Size(a) * sizeof(byte)); } public static void ToByteArray(this byte[] a, byte[] buffer, ref int position) @@ -303,37 +278,31 @@ public static byte[] ToByteArray(this byte[] buffer, ref int position) public static int SizeInBytes(this short[] a) { - return sizeof(int) + Utils.Size(a) * sizeof(short); + return checked(sizeof(int) + Utils.Size(a) * sizeof(short)); } - public static unsafe void ToByteArray(this short[] a, byte[] buffer, ref int position) + public static void ToByteArray(this short[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (short* pA = a) - { - short* pBuffer = (short*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(short); } - public static unsafe short[] ToShortArray(this byte[] buffer, ref int position) + public static short[] ToShortArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(short)); // if this overflows, we couldn't have populated buffer anyway short[] a = new short[length]; - fixed (byte* tmpBuffer = buffer) - fixed (short* pA = a) - { - short* pBuffer = (short*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(short); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -342,37 +311,31 @@ public static unsafe short[] ToShortArray(this byte[] buffer, ref int position) public static int SizeInBytes(this ushort[] a) { - return sizeof(int) + Utils.Size(a) * sizeof(ushort); + return checked(sizeof(int) + Utils.Size(a) * sizeof(ushort)); } - public static unsafe void ToByteArray(this ushort[] a, byte[] buffer, ref int position) + public static void ToByteArray(this ushort[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (ushort* pA = a) - { - ushort* pBuffer = (ushort*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(ushort); } - public static unsafe ushort[] ToUShortArray(this byte[] buffer, ref int position) + public static ushort[] ToUShortArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(ushort)); // if this overflows, we couldn't have populated buffer anyway ushort[] a = new ushort[length]; - fixed (byte* tmpBuffer = buffer) - fixed (ushort* pA = a) - { - ushort* pBuffer = (ushort*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(ushort); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -381,42 +344,36 @@ public static unsafe ushort[] ToUShortArray(this byte[] buffer, ref int position public static int SizeInBytes(this int[] array) { - return sizeof(int) + Utils.Size(array) * sizeof(int); + return checked(sizeof(int) + Utils.Size(array) * sizeof(int)); } - public static unsafe void ToByteArray(this int[] a, byte[] buffer, ref int position) + public static void ToByteArray(this int[] a, byte[] buffer, ref int position) { int length = Utils.Size(a); length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (int* pA = a) - { - int* pBuffer = (int*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(int); } - public static unsafe int[] ToIntArray(this byte[] buffer, ref int position) + public static int[] ToIntArray(this byte[] buffer, ref int position) => buffer.ToIntArray(ref position, buffer.ToInt(ref position)); - public static unsafe int[] ToIntArray(this byte[] buffer, ref int position, int length) + public static int[] ToIntArray(this byte[] buffer, ref int position, int length) { if (length == 0) return null; + int byteLength = checked(length * sizeof(int)); // if this overflows, we couldn't have populated buffer anyway int[] a = new int[length]; - fixed (byte* tmpBuffer = buffer) - fixed (int* pA = a) - { - int* pBuffer = (int*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(int); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -425,37 +382,31 @@ public static unsafe int[] ToIntArray(this byte[] buffer, ref int position, int public static int SizeInBytes(this uint[] array) { - return sizeof(int) + Utils.Size(array) * sizeof(uint); + return checked(sizeof(int) + Utils.Size(array) * sizeof(uint)); } - public static unsafe void ToByteArray(this uint[] a, byte[] buffer, ref int position) + public static void ToByteArray(this uint[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (uint* pA = a) - { - uint* pBuffer = (uint*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(uint); } - public static unsafe uint[] ToUIntArray(this byte[] buffer, ref int position) + public static uint[] ToUIntArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(uint)); // if this overflows, we couldn't have populated buffer anyway uint[] a = new uint[length]; - fixed (byte* tmpBuffer = buffer) - fixed (uint* pA = a) - { - uint* pBuffer = (uint*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(uint); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -464,37 +415,31 @@ public static unsafe uint[] ToUIntArray(this byte[] buffer, ref int position) public static int SizeInBytes(this long[] array) { - return sizeof(int) + Utils.Size(array) * sizeof(long); + return checked(sizeof(int) + Utils.Size(array) * sizeof(long)); } - public static unsafe void ToByteArray(this long[] a, byte[] buffer, ref int position) + public static void ToByteArray(this long[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (long* pA = a) - { - long* pBuffer = (long*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(long); } - public static unsafe long[] ToLongArray(this byte[] buffer, ref int position) + public static long[] ToLongArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(long)); // if this overflows, we couldn't have populated buffer anyway long[] a = new long[length]; - fixed (byte* tmpBuffer = buffer) - fixed (long* pA = a) - { - long* pBuffer = (long*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(long); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -503,37 +448,31 @@ public static unsafe long[] ToLongArray(this byte[] buffer, ref int position) public static int SizeInBytes(this ulong[] array) { - return sizeof(int) + Utils.Size(array) * sizeof(ulong); + return checked(sizeof(int) + Utils.Size(array) * sizeof(ulong)); } - public static unsafe void ToByteArray(this ulong[] a, byte[] buffer, ref int position) + public static void ToByteArray(this ulong[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (ulong* pA = a) - { - ulong* pBuffer = (ulong*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(ulong); } - public static unsafe ulong[] ToULongArray(this byte[] buffer, ref int position) + public static ulong[] ToULongArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(ulong)); // if this overflows, we couldn't have populated buffer anyway ulong[] a = new ulong[length]; - fixed (byte* tmpBuffer = buffer) - fixed (ulong* pA = a) - { - ulong* pBuffer = (ulong*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(ulong); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -542,37 +481,31 @@ public static unsafe ulong[] ToULongArray(this byte[] buffer, ref int position) public static int SizeInBytes(this float[] array) { - return sizeof(int) + Utils.Size(array) * sizeof(float); + return checked(sizeof(int) + Utils.Size(array) * sizeof(float)); } - public static unsafe void ToByteArray(this float[] a, byte[] buffer, ref int position) + public static void ToByteArray(this float[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (float* pA = a) - { - float* pBuffer = (float*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(float); } - public static unsafe float[] ToFloatArray(this byte[] buffer, ref int position) + public static float[] ToFloatArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(float)); // if this overflows, we couldn't have populated buffer anyway float[] a = new float[length]; - fixed (byte* tmpBuffer = buffer) - fixed (float* pA = a) - { - float* pBuffer = (float*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(float); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -581,37 +514,31 @@ public static unsafe float[] ToFloatArray(this byte[] buffer, ref int position) public static int SizeInBytes(this double[] array) { - return sizeof(int) + Utils.Size(array) * sizeof(double); + return checked(sizeof(int) + Utils.Size(array) * sizeof(double)); } - public static unsafe void ToByteArray(this double[] a, byte[] buffer, ref int position) + public static void ToByteArray(this double[] a, byte[] buffer, ref int position) { int length = a.Length; length.ToByteArray(buffer, ref position); - fixed (byte* tmpBuffer = buffer) - fixed (double* pA = a) - { - double* pBuffer = (double*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pBuffer[i] = pA[i]; - } + // MemoryMarshal.AsBytes is type-safe but could fail if the source buffer is so long + // that its byte length can't be represented as an int32. In this case, we're ok with + // AsBytes throwing an exception early, since we know the length of our destination byte + // buffer is limited to an int32 length anyway. + MemoryMarshal.AsBytes(a.AsSpan()).CopyTo(buffer.AsSpan(position)); position += length * sizeof(double); } - public static unsafe double[] ToDoubleArray(this byte[] buffer, ref int position) + public static double[] ToDoubleArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream + int byteLength = checked(length * sizeof(double)); // if this overflows, we couldn't have populated buffer anyway double[] a = new double[length]; - fixed (byte* tmpBuffer = buffer) - fixed (double* pA = a) - { - double* pBuffer = (double*)(tmpBuffer + position); - for (int i = 0; i < length; ++i) - pA[i] = pBuffer[i]; - } - position += length * sizeof(double); + // MemoryMarshal.AsBytes is type-safe. The checked block above prevents failure here. + buffer.AsSpan(position, byteLength).CopyTo(MemoryMarshal.AsBytes(a.AsSpan())); + position += byteLength; return a; } @@ -622,7 +549,7 @@ public static int SizeInBytes(this double[][] array) { if (Utils.Size(array) == 0) return sizeof(int); - return sizeof(int) + array.Sum(x => x.SizeInBytes()); + return checked(sizeof(int) + array.Sum(x => x.SizeInBytes())); } public static void ToByteArray(this double[][] a, byte[] buffer, ref int position) @@ -636,7 +563,7 @@ public static void ToByteArray(this double[][] a, byte[] buffer, ref int positio public static double[][] ToDoubleJaggedArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream double[][] a = new double[length][]; for (int i = 0; i < a.Length; ++i) { @@ -668,7 +595,7 @@ public static void ToByteArray(this string[] a, byte[] buffer, ref int position) public static string[] ToStringArray(this byte[] buffer, ref int position) { - int length = buffer.ToInt(ref position); + int length = buffer.ToInt(ref position); // reading trusted length from input stream string[] a = new string[length]; for (int i = 0; i < a.Length; ++i) { From a88f36fcaabb63351b712eb1206412cd70234671 Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Mon, 17 Mar 2025 18:17:40 -0700 Subject: [PATCH 3/4] Remove unsafe code in VectorUtils - Only where we know for a fact the JIT will elide bounds checks --- .../Utils/VectorUtils.cs | 148 +++++------------- 1 file changed, 39 insertions(+), 109 deletions(-) diff --git a/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs b/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs index 6685f61f17..c192ee2e42 100644 --- a/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs +++ b/src/Microsoft.ML.FastTree/Utils/VectorUtils.cs @@ -18,41 +18,27 @@ public static double GetVectorSize(double[] vector) } // Normalizes the vector to have size of 1 - public static unsafe void NormalizeVectorSize(double[] vector) + public static void NormalizeVectorSize(double[] vector) { double size = GetVectorSize(vector); - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - pVector[i] /= size; - } - } + vector[i] /= size; } } // Center vector to have mean = 0 - public static unsafe void CenterVector(double[] vector) + public static void CenterVector(double[] vector) { double mean = GetMean(vector); - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - pVector[i] = (pVector[i] - mean); - } - } + vector[i] -= mean; } } // Normalizes the vector to have mean = 0 and std = 1 - public static unsafe void NormalizeVector(double[] vector) + public static void NormalizeVector(double[] vector) { double mean = GetMean(vector); double std = GetStandardDeviation(vector, mean); @@ -60,27 +46,20 @@ public static unsafe void NormalizeVector(double[] vector) } // Normalizes the vector to have mean = 0 and std = 1 - public static unsafe void NormalizeVector(double[] vector, double mean, double std) + public static void NormalizeVector(double[] vector, double mean, double std) { - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - pVector[i] = (pVector[i] - mean) / std; - } - } + vector[i] = (vector[i] - mean) / std; } } - public static unsafe double GetDotProduct(double[] vector1, double[] vector2) + public static double GetDotProduct(double[] vector1, double[] vector2) { return GetDotProduct(vector1, vector2, vector1.Length); } - public static unsafe double GetDotProduct(float[] vector1, float[] vector2) + public static double GetDotProduct(float[] vector1, float[] vector2) { return GetDotProduct(vector1, vector2, vector1.Length); } @@ -119,38 +98,24 @@ public static unsafe double GetDotProduct(float[] vector1, float[] vector2, int return product; } - public static unsafe double GetMean(double[] vector) + public static double GetMean(double[] vector) { double sum = 0; - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - sum += pVector[i]; - } - } + sum += vector[i]; } - return sum / length; + return sum / vector.Length; } - public static unsafe double GetMean(float[] vector) + public static double GetMean(float[] vector) { double sum = 0; - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (float* pVector = vector) - { - for (int i = 0; i < length; i++) - { - sum += pVector[i]; - } - } + sum += vector[i]; } - return sum / length; + return sum / vector.Length; } public static double GetStandardDeviation(double[] vector) @@ -158,42 +123,28 @@ public static double GetStandardDeviation(double[] vector) return GetStandardDeviation(vector, GetMean(vector)); } - public static unsafe double GetStandardDeviation(double[] vector, double mean) + public static double GetStandardDeviation(double[] vector, double mean) { double sum = 0; - int length = vector.Length; double tmp; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - tmp = pVector[i] - mean; - sum += tmp * tmp; - } - } + tmp = vector[i] - mean; + sum += tmp * tmp; } - return Math.Sqrt(sum / length); + return Math.Sqrt(sum / vector.Length); } - public static unsafe int GetIndexOfMax(double[] vector) + public static int GetIndexOfMax(double[] vector) { - int length = vector.Length; double max = vector[0]; int maxIdx = 0; - unsafe + for (int i = 1; i < vector.Length; i++) { - fixed (double* pVector = vector) + if (vector[i] > max) { - for (int i = 1; i < length; i++) - { - if (pVector[i] > max) - { - max = pVector[i]; - maxIdx = i; - } - } + max = vector[i]; + maxIdx = i; } } return maxIdx; @@ -253,50 +204,29 @@ public static unsafe void AddInPlace(double[] vector1, double[] vector2) } // Mutiplies the second vector from the first one (vector1[i] /= val) - public static unsafe void MutiplyInPlace(double[] vector, double val) + public static void MutiplyInPlace(double[] vector, double val) { - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - pVector[i] *= val; - } - } + vector[i] *= val; } } // Divides the second vector from the first one (vector1[i] /= val) - public static unsafe void DivideInPlace(double[] vector, double val) + public static void DivideInPlace(double[] vector, double val) { - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (double* pVector = vector) - { - for (int i = 0; i < length; i++) - { - pVector[i] /= val; - } - } + vector[i] /= val; } } // Divides the second vector from the first one (vector1[i] /= val) - public static unsafe void DivideInPlace(float[] vector, float val) + public static void DivideInPlace(float[] vector, float val) { - int length = vector.Length; - unsafe + for (int i = 0; i < vector.Length; i++) { - fixed (float* pVector = vector) - { - for (int i = 0; i < length; i++) - { - pVector[i] /= val; - } - } + vector[i] /= val; } } From 4c46e4f7004c774e49b55b01908e787c4ac99416 Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Wed, 19 Mar 2025 13:17:49 -0700 Subject: [PATCH 4/4] Simplify array rental logic in Stream.cs --- src/Microsoft.ML.Core/Utilities/Stream.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index e94ae91f76..51dd247bbe 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -695,7 +695,7 @@ private static unsafe void ReadBinaryDataIntoSpan(BinaryReader reader, Span 0, "Unexpectedly large T."); // Rent a byte[] instead of a T[] to allow reuse of buffers across different types T. - byte[] rentedArray = ArrayPool.Shared.Rent(maxChunkSizeInElements * sizeof(T)); + byte[] rentedArray = ArrayPool.Shared.Rent(maxChunkSizeInBytes); try { while (!destination.IsEmpty)