diff --git a/cmake/vcpkg-ports/pybind11/portfile.cmake b/cmake/vcpkg-ports/pybind11/portfile.cmake index 4e4cd30a26df1..e0dd402cf186b 100644 --- a/cmake/vcpkg-ports/pybind11/portfile.cmake +++ b/cmake/vcpkg-ports/pybind11/portfile.cmake @@ -1,10 +1,15 @@ -vcpkg_from_github( - OUT_SOURCE_PATH SOURCE_PATH - REPO pybind/pybind11 - REF "v${VERSION}" - # SHA512 for the zip (not tar.gz) file. +# Manually define the download for the .zip archive (to be consistent with deps.txt) +# If we used vcpkg_from_github, it would download the .tar.gz archive, +# which has different SHA512: 19bee2c76320e25202ee078b5680ff8a7acfb33494dec29dad984ab04de8bcb01340d9fec37c8cc5ac9015dfc367e60312dcd8506e66ce8f0af4c49db562ddef +vcpkg_download_distfile(ARCHIVE + URLS "https://github.com/pybind/pybind11/archive/refs/tags/v${VERSION}.zip" + FILENAME "pybind11-${VERSION}.zip" SHA512 786b1bf534ac67a8d5669f8babf67bb13e48b3a3da1b6344e43ae10a84b80bbc8fea5f12a65fd18739c341fefef5622c5dc096db964dff33cc62ea4259b2e2c1 - HEAD_REF master +) + +vcpkg_extract_source_archive( + SOURCE_PATH + ARCHIVE "${ARCHIVE}" ) vcpkg_cmake_configure( diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index a6b267c6802cf..81d4f2589151b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -476,9 +476,25 @@ internal static class NativeMethods static NativeMethods() { #if !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__ - // Register a custom DllImportResolver to handle platform-specific library loading. - // Replaces default resolution specifically on Windows for case-sensitivity. - NativeLibrary.SetDllImportResolver(typeof(NativeMethods).Assembly, DllImportResolver); + if (!OrtEnv.DisableDllImportResolver) + { + try + { + // Register a custom DllImportResolver to handle platform-specific library loading. + // Replaces default resolution specifically on Windows for case-sensitivity. + NativeLibrary.SetDllImportResolver(typeof(NativeMethods).Assembly, DllImportResolver); + } + catch (InvalidOperationException) + { + // A resolver is already registered for this assembly (e.g., by the host application). + // This is not fatal — the host's resolver will handle library loading. + System.Diagnostics.Trace.WriteLine( + "[OnnxRuntime] A DllImportResolver is already registered for this assembly. " + + "OnnxRuntime's built-in resolver will not be used. " + + "To suppress this message, set OrtEnv.DisableDllImportResolver = true " + + "before using any OnnxRuntime APIs."); + } + } #endif #if NETSTANDARD2_0 diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 6fcff438c5cf3..22f541e2207fa 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.OnnxRuntime /// /// /// This enum is used to determine whether a pre-compiled model can be used with specific execution providers - /// and devices, or if recompilation is needed. + /// and devices, or if recompilation is needed. /// public enum OrtCompiledModelCompatibility { @@ -77,14 +77,14 @@ public struct EnvironmentCreationOptions /// /// The singleton class OrtEnv contains the process-global ONNX Runtime environment. /// It sets up logging, creates system wide thread-pools (if Thread Pool options are provided) - /// and other necessary things for OnnxRuntime to function. - /// + /// and other necessary things for OnnxRuntime to function. + /// /// Create or access OrtEnv by calling the Instance() method. Instance() can be called multiple times. /// It would return the same instance. - /// + /// /// CreateInstanceWithOptions() provides a way to create environment with options. /// It must be called once before Instance() is called, otherwise it would not have effect. - /// + /// /// If the environment is not explicitly created, it will be created as needed, e.g., /// when creating a SessionOptions instance. /// @@ -93,6 +93,28 @@ public sealed class OrtEnv : SafeHandle #region Static members private static readonly int ORT_PROJECTION_CSHARP = 2; + /// + /// Set this to true before accessing any OnnxRuntime type to prevent OnnxRuntime + /// from registering its own DllImportResolver via + /// NativeLibrary.SetDllImportResolver. + /// This is useful when the host application needs to register its own custom resolver + /// for the OnnxRuntime assembly. Must be set before any OnnxRuntime API is used + /// (i.e., before the internal NativeMethods static constructor runs). + /// + /// + /// + /// // Disable OnnxRuntime's built-in resolver before any ORT usage + /// OrtEnv.DisableDllImportResolver = true; + /// + /// // Register your own resolver + /// NativeLibrary.SetDllImportResolver(typeof(OrtEnv).Assembly, MyCustomResolver); + /// + /// // Now use OnnxRuntime normally + /// var env = OrtEnv.Instance(); + /// + /// + public static bool DisableDllImportResolver { get; set; } = false; + private static readonly byte[] _defaultLogId = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(@"CSharpOnnxRuntime"); // This must be static and set before the first creation call, otherwise, has no effect. @@ -274,7 +296,7 @@ private static void SetLanguageProjection(OrtEnv env) /// /// Instantiates (if not already done so) a new OrtEnv instance with the default logging level /// and no other options. Otherwise returns the existing instance. - /// + /// /// It returns the same instance on every call - `OrtEnv` is singleton /// /// Returns a singleton instance of OrtEnv that represents native OrtEnv object @@ -523,7 +545,7 @@ public OrtLoggingLevel EnvLogLevel /// A registered execution provider library can be used by all sessions created with the OrtEnv instance. /// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can /// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device. - /// + /// /// Coming: A selection policy can be specified and ORT will automatically select the best execution providers /// and devices for the model. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index aa1b683acd668..c298a95392317 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -532,4 +532,100 @@ public void TestDllImportResolverDoesNotThrow() } } } + +#if !NETSTANDARD2_0 + [Collection("Ort Inference Tests")] + public class OrtEnvExternalDllImportResolverTest + { + private System.Reflection.Assembly LoadIsolatedOnnxRuntimeAssembly(out System.Runtime.Loader.AssemblyLoadContext alc) + { + // Load a fresh copy of the ONNX Runtime assembly into a new AssemblyLoadContext. + // This guarantees we get a clean slate for static fields/constructors, avoiding + // interference from other xUnit tests that may have already initialized OrtEnv + // in the default context. + // + // Native library resolution (e.g., onnxruntime.dll) falls through to the default + // ALC when the isolated context cannot resolve it, so P/Invoke calls still work. + alc = new System.Runtime.Loader.AssemblyLoadContext("IsolatedORT_" + Guid.NewGuid(), isCollectible: true); + string asmPath = typeof(OrtEnv).Assembly.Location; + return alc.LoadFromAssemblyPath(asmPath); + } + + /// + /// Verifies the scenario where an external caller registers a DllImportResolver FIRST, + /// and then OrtEnv is initialized. ORT's try/catch should handle the conflict gracefully. + /// + [Fact(DisplayName = "TestExternalResolverRegisteredFirst")] + public void TestExternalResolverRegisteredFirst() + { + var asm = LoadIsolatedOnnxRuntimeAssembly(out var alc); + try + { + // 1. External application registers its own resolver FIRST. + // Returning IntPtr.Zero means "not handled" — the runtime falls back to + // its default resolution logic, so native libraries still load normally. + NativeLibrary.SetDllImportResolver(asm, (libraryName, a, searchPath) => IntPtr.Zero); + + // 2. ORT initializes (triggers NativeMethods static constructor). + // It will attempt to register its own resolver, which will throw + // InvalidOperationException internally, but the try/catch safety net + // prevents an unhandled TypeInitializationException. + var ortEnvType = asm.GetType("Microsoft.ML.OnnxRuntime.OrtEnv"); + var instanceMethod = ortEnvType.GetMethod("Instance", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static); + var ortEnvInstance = instanceMethod.Invoke(null, null); + Assert.NotNull(ortEnvInstance); + + // Verify ORT is fully functional despite the resolver conflict. + var getVersionMethod = ortEnvType.GetMethod("GetVersionString", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance); + var version = (string)getVersionMethod.Invoke(ortEnvInstance, null); + Assert.False(string.IsNullOrEmpty(version)); + } + finally + { + alc.Unload(); + } + } + + /// + /// Verifies that setting DisableDllImportResolver = true BEFORE ORT initializes + /// successfully prevents ORT from registering its own resolver, leaving the assembly + /// free for the external application to register theirs LATER without throwing. + /// + [Fact(DisplayName = "TestDisableDllImportResolverWorks")] + public void TestDisableDllImportResolverWorks() + { + var asm = LoadIsolatedOnnxRuntimeAssembly(out var alc); + try + { + var ortEnvType = asm.GetType("Microsoft.ML.OnnxRuntime.OrtEnv"); + + // 1. Set OrtEnv.DisableDllImportResolver = true FIRST. + var disableProp = ortEnvType.GetProperty("DisableDllImportResolver", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static); + Assert.NotNull(disableProp); + disableProp.SetValue(null, true); + + // 2. ORT initializes (triggers NativeMethods static constructor). + // It should respect the flag and SKIP calling NativeLibrary.SetDllImportResolver. + var instanceMethod = ortEnvType.GetMethod("Instance", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static); + var ortEnvInstance = instanceMethod.Invoke(null, null); + Assert.NotNull(ortEnvInstance); + + // 3. External application registers its own resolver AFTER ORT initialized. + // If the flag works correctly, ORT skipped its own SetDllImportResolver call, + // so this registration should succeed without throwing InvalidOperationException. + // Returning IntPtr.Zero means "not handled" — falls back to default resolution. + var ex = Record.Exception(() => + { + NativeLibrary.SetDllImportResolver(asm, (libraryName, a, searchPath) => IntPtr.Zero); + }); + + Assert.Null(ex); // No InvalidOperationException = ORT correctly skipped registration + } + finally + { + alc.Unload(); + } + } + } +#endif } diff --git a/include/onnxruntime/core/framework/int2.h b/include/onnxruntime/core/framework/int2.h index 0d406d6fcd8d3..40af5746c9273 100644 --- a/include/onnxruntime/core/framework/int2.h +++ b/include/onnxruntime/core/framework/int2.h @@ -7,6 +7,7 @@ #include #include "core/common/common.h" #include +#include "onnxruntime_config.h" namespace onnxruntime { @@ -137,8 +138,16 @@ struct Int2x4Base { const size_t full_quads = src.size() / 4; // Process complete groups of 4 elements + for (; dst_i < full_quads; dst_i++) { +#if defined(__GNUC__) && defined(HAS_ARRAY_BOUNDS) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif dst[dst_i] = Int2x4Base(src[src_i], src[src_i + 1], src[src_i + 2], src[src_i + 3]); +#if defined(__GNUC__) && defined(HAS_ARRAY_BOUNDS) +#pragma GCC diagnostic pop +#endif src_i += 4; } diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 1ea147a0079cc..abeb930d8ab8d 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -13,7 +13,7 @@ * The maximum length of the Config Key is 1024 * * The string format of a SessionOptions Config Value is defined individually for each Config. - * The maximum length of the Config Value is 2048 + * The maximum length of the Config Value is 8192 */ // Key for disable PrePacking, @@ -385,6 +385,11 @@ static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm"; // If not provided, default is 4. static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; +// Enable the DQ->MatMulNBits fusion graph transformer. +// "0": disabled (default). "1": enabled. +// This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered. +static const char* const kOrtSessionOptionsEnableDQMatMulNBitsFusion = "session.enable_dq_matmulnbits_fusion"; + // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME // Meant to be used with SetEpDynamicOptions // Specify the type of workload for this session. diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h index 1c356d8cfca56..6c0d6741dd92f 100644 --- a/onnxruntime/core/framework/config_options.h +++ b/onnxruntime/core/framework/config_options.h @@ -18,7 +18,7 @@ struct ConfigOptions { // Maximum key/value string lengths specified in // core/session/onnxruntime_session_options_config_keys.h static constexpr size_t kMaxKeyLength = 1024; - static constexpr size_t kMaxValueLength = 4096; + static constexpr size_t kMaxValueLength = 8192; std::unordered_map configurations; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index a745dd9f1376d..77aa460c72ef7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -1670,12 +1670,12 @@ MlasQ4Int8TileGemmKernelBlkLen32Avx2( // .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h:1531:13: note: array 'acc' declared here // 1531 | __m256 acc[NCols4]; // | ^ -#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS) +#ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Warray-bounds" #endif __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); -#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS) +#ifdef __clang__ #pragma clang diagnostic pop #endif if (BiasPtr != nullptr) { diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc new file mode 100644 index 0000000000000..f9ae13808cf2c --- /dev/null +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -0,0 +1,848 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/dq_matmulnbits_fusion.h" + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/constants.h" +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +#include +#include +#include +#include + +namespace onnxruntime { + +namespace { + +// --------------------------------------------------------------------------- +// Utility helpers +// --------------------------------------------------------------------------- + +bool IsUniformPackedUint4Value(const Initializer& init, uint8_t expected_nibble) { + if (init.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + return false; + } + + const size_t values_count = static_cast(init.size()); + if (values_count == 0) { + return false; + } + + const auto packed = init.DataAsByteSpan(); + const uint8_t expected = static_cast(expected_nibble & 0x0F); + for (size_t i = 0; i < values_count; ++i) { + const uint8_t byte = packed[i / 2]; + const uint8_t value = (i % 2 == 0) ? (byte & 0x0F) : ((byte >> 4) & 0x0F); + if (value != expected) { + return false; + } + } + + return true; +} + +bool HasRank2Shape(const ONNX_NAMESPACE::TensorProto& tp, int64_t dim0, int64_t dim1) { + return tp.dims_size() == 2 && tp.dims(0) == dim0 && tp.dims(1) == dim1; +} + +uint8_t GetPackedUint4Element(const uint8_t* packed, size_t index, size_t num_elements) { + ORT_ENFORCE(index < num_elements, "GetPackedUint4Element: index ", index, + " out of bounds (num_elements=", num_elements, ")"); + const uint8_t packed_byte = packed[index / 2]; + return (index % 2 == 0) ? static_cast(packed_byte & 0x0F) + : static_cast((packed_byte >> 4) & 0x0F); +} + +void PackUint4Rows(const Initializer& src, int64_t rows, int64_t cols, uint8_t* dst) { + const int64_t row_bytes = (cols + 1) / 2; + const size_t dst_bytes = SafeInt(rows) * row_bytes; + const size_t total_elements = SafeInt(rows) * cols; + memset(dst, 0, dst_bytes); + + const auto src_packed = src.DataAsByteSpan(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + const size_t src_index = SafeInt(r) * cols + c; + const uint8_t value = GetPackedUint4Element(src_packed.data(), src_index, total_elements); + + const size_t dst_index = SafeInt(r) * row_bytes + c / 2; + if ((c & 1) == 0) { + dst[dst_index] = value; + } else { + dst[dst_index] = static_cast(dst[dst_index] | (value << 4)); + } + } + } +} + +// Transpose and pack UINT4 weights from DQ axis=0 layout [K, N] to MatMulNBits layout [N, k_blocks, blob_size]. +// Source: row-major UINT4 with quantization along K (axis=0), shape [K, N]. +// The nibble ordering follows ONNX UINT4 convention: even indices in the low nibble, +// odd indices in the high nibble of each byte. +// Dest: UINT8 [N, k_blocks, block_size/2] where each byte packs two 4-bit weights. +void TransposePackWeightsAxis0( + const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size, + uint8_t* dst) { + const int64_t k_blocks = (K + block_size - 1) / block_size; + const int64_t blob_size = block_size / 2; + const size_t dst_bytes = SafeInt(N) * k_blocks * blob_size; + const size_t total_elements = SafeInt(K) * N; + memset(dst, 0, dst_bytes); + + for (int64_t n = 0; n < N; ++n) { + for (int64_t k = 0; k < K; ++k) { + const size_t src_index = SafeInt(k) * N + n; + const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); + + const int64_t kb = k / block_size; + const int64_t off = k % block_size; + const size_t dst_byte = SafeInt(n) * k_blocks * blob_size + kb * blob_size + off / 2; + if (off % 2 == 0) { + dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); + } else { + dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); + } + } + } +} + +// Transpose and pack UINT4 zero points from DQ axis=0 layout [k_blocks, N] to +// MatMulNBits layout UINT8 [N, ceil(k_blocks/2)]. +void TransposePackZPAxis0( + const uint8_t* src_packed, int64_t k_blocks, int64_t N, + uint8_t* dst) { + const int64_t zp_bytes_per_n = (k_blocks + 1) / 2; + const size_t dst_bytes = SafeInt(N) * zp_bytes_per_n; + const size_t total_elements = SafeInt(k_blocks) * N; + memset(dst, 0, dst_bytes); + + for (int64_t n = 0; n < N; ++n) { + for (int64_t kb = 0; kb < k_blocks; ++kb) { + const size_t src_index = SafeInt(kb) * N + n; + const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements); + + const size_t dst_byte = SafeInt(n) * zp_bytes_per_n + kb / 2; + if (kb % 2 == 0) { + dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val); + } else { + dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4)); + } + } + } +} + +// --------------------------------------------------------------------------- +// Match structs +// --------------------------------------------------------------------------- + +struct FusionMatch { + NodeIndex matmul_idx; + std::optional cast_idx; + NodeIndex transpose_idx; + NodeIndex reshape_idx; + NodeIndex dq_idx; +}; + +struct DirectDQMatch { + NodeIndex matmul_idx; + NodeIndex dq_idx; +}; + +// --------------------------------------------------------------------------- +// Shared Gemm validation (alpha=1, beta=1, transA=0, transB=0, bias 1-D [N]) +// --------------------------------------------------------------------------- + +bool ValidateGemmForFusion(const Node& gemm_node, int64_t N) { + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm_node, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm_node, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) + return false; + if (const auto* trans_a = graph_utils::GetNodeAttribute(gemm_node, "transA"); + trans_a && trans_a->i() != 0) + return false; + if (const auto* trans_b = graph_utils::GetNodeAttribute(gemm_node, "transB"); + trans_b && trans_b->i() != 0) + return false; + + const auto& inputs = gemm_node.InputDefs(); + if (inputs.size() > 2 && inputs[2] && inputs[2]->Exists()) { + const auto* bias_shape = inputs[2]->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1 || + !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != N) + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// Pattern 1 matching: DQ -> Reshape -> Transpose -> [Cast] -> MatMul/Gemm +// --------------------------------------------------------------------------- + +std::vector CollectReshapeTransposeMatches( + Graph& graph, + const std::vector& node_topology_list, + const logging::Logger& logger) { + std::vector matches; + + for (auto node_index : node_topology_list) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + + if (node->OpType() != "MatMul" && node->OpType() != "Gemm") continue; + + const auto& mm_inputs = node->InputDefs(); + if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; + + const Node* cast_node = nullptr; + const Node* transpose_node = graph.GetProducerNode(mm_inputs[1]->Name()); + if (transpose_node && transpose_node->OpType() == "Cast") { + cast_node = transpose_node; + if (cast_node->GetOutputEdgesCount() != 1) continue; + const auto& cast_inputs = cast_node->InputDefs(); + if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue; + transpose_node = graph.GetProducerNode(cast_inputs[0]->Name()); + } + + if (!transpose_node || transpose_node->OpType() != "Transpose") continue; + if (transpose_node->GetOutputEdgesCount() != 1) continue; + + const auto& tp_inputs = transpose_node->InputDefs(); + if (tp_inputs.empty() || !tp_inputs[0] || !tp_inputs[0]->Exists()) continue; + const Node* reshape_node = graph.GetProducerNode(tp_inputs[0]->Name()); + if (!reshape_node || reshape_node->OpType() != "Reshape") continue; + if (reshape_node->GetOutputEdgesCount() != 1) continue; + + const auto& reshape_inputs = reshape_node->InputDefs(); + if (reshape_inputs.empty() || !reshape_inputs[0] || !reshape_inputs[0]->Exists()) continue; + const Node* dq_node = graph.GetProducerNode(reshape_inputs[0]->Name()); + if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue; + if (dq_node->GetOutputEdgesCount() != 1) continue; + + const auto& dq_attrs = dq_node->GetAttributes(); + { + auto it = dq_attrs.find("axis"); + if (it == dq_attrs.end() || it->second.i() != 2) continue; + } + int64_t block_size = 0; + { + auto it = dq_attrs.find("block_size"); + if (it == dq_attrs.end()) continue; + block_size = it->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) continue; + } + + const auto* weight_arg = dq_node->InputDefs()[0]; + if (!weight_arg || !weight_arg->Exists()) continue; + const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); + if (!weight_const_tp) continue; + if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (weight_const_tp->dims_size() != 3) continue; + const int64_t N = weight_const_tp->dims(0); + const int64_t blocks = weight_const_tp->dims(1); + const int64_t bs_dim = weight_const_tp->dims(2); + if (N <= 0 || blocks <= 0 || bs_dim <= 0) continue; + if (bs_dim != block_size) continue; + const int64_t K = SafeInt(blocks) * bs_dim; + + const auto* scale_arg = dq_node->InputDefs()[1]; + if (!scale_arg || !scale_arg->Exists()) continue; + const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); + if (!scale_const_tp) continue; + int32_t dt_scale = scale_const_tp->data_type(); + if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + + const auto* a_arg = mm_inputs[0]; + if (!a_arg || !a_arg->TypeAsProto()) continue; + int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_a != dt_scale) continue; + + const auto* reshape_shape_arg = + reshape_node->InputDefs().size() > 1 ? reshape_node->InputDefs()[1] : nullptr; + if (!reshape_shape_arg || !reshape_shape_arg->Exists()) continue; + const auto* reshape_shape_tp = graph.GetConstantInitializer(reshape_shape_arg->Name(), true); + if (!reshape_shape_tp) continue; + + Initializer reshape_shape_init(graph, *reshape_shape_tp, graph.ModelPath()); + if (reshape_shape_init.size() != 2) continue; + + int64_t reshape_dim0 = 0; + int64_t reshape_dim1 = 0; + if (reshape_shape_init.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + const auto* shape_data = reshape_shape_init.data(); + reshape_dim0 = shape_data[0]; + reshape_dim1 = shape_data[1]; + } else if (reshape_shape_init.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) { + const auto* shape_data = reshape_shape_init.data(); + reshape_dim0 = shape_data[0]; + reshape_dim1 = shape_data[1]; + } else { + continue; + } + + auto resolve_reshape_dim = [](int64_t dim, int64_t expected) -> std::optional { + if (dim == expected || dim == 0 || dim == -1) { + return expected; + } + return std::nullopt; + }; + const auto resolved_reshape_dim0 = resolve_reshape_dim(reshape_dim0, N); + const auto resolved_reshape_dim1 = resolve_reshape_dim(reshape_dim1, K); + if (!resolved_reshape_dim0 || !resolved_reshape_dim1 || + *resolved_reshape_dim0 != N || *resolved_reshape_dim1 != K) { + continue; + } + + if (const auto* perm_attr = graph_utils::GetNodeAttribute(*transpose_node, "perm")) { + if (perm_attr->ints_size() != 2 || perm_attr->ints(0) != 1 || perm_attr->ints(1) != 0) { + continue; + } + } + + if (const auto* b_shape = mm_inputs[1]->Shape(); b_shape && b_shape->dim_size() == 2 && + utils::HasDimValue(b_shape->dim(0)) && utils::HasDimValue(b_shape->dim(1)) && + (b_shape->dim(0).dim_value() != K || b_shape->dim(1).dim_value() != N)) { + continue; + } + + if (const auto* a_shape = mm_inputs[0] ? mm_inputs[0]->Shape() : nullptr; + a_shape && a_shape->dim_size() >= 1) { + const int last_a_dim_idx = a_shape->dim_size() - 1; + if (utils::HasDimValue(a_shape->dim(last_a_dim_idx)) && + a_shape->dim(last_a_dim_idx).dim_value() != K) { + continue; + } + } + + const auto* y_shape = node->OutputDefs().empty() ? nullptr : node->OutputDefs()[0]->Shape(); + if (y_shape && y_shape->dim_size() >= 1) { + const int last_y_dim_idx = y_shape->dim_size() - 1; + if (utils::HasDimValue(y_shape->dim(last_y_dim_idx)) && + y_shape->dim(last_y_dim_idx).dim_value() != N) { + continue; + } + } + + if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; + + if (cast_node) { + const auto* cast_in = cast_node->InputDefs().empty() ? nullptr : cast_node->InputDefs()[0]; + const auto* cast_out = cast_node->OutputDefs().empty() ? nullptr : cast_node->OutputDefs()[0]; + if (!cast_in || !cast_out || !cast_in->TypeAsProto() || !cast_out->TypeAsProto()) continue; + if (cast_in->TypeAsProto()->tensor_type().elem_type() != + cast_out->TypeAsProto()->tensor_type().elem_type()) { + continue; + } + } + + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + if (has_zp) { + const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); + if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + } + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched pattern at MatMul node '" + << node->Name() << "'"; + + matches.push_back({node->Index(), + cast_node ? std::optional(cast_node->Index()) : std::nullopt, + transpose_node->Index(), + reshape_node->Index(), dq_node->Index()}); + } + + return matches; +} + +// --------------------------------------------------------------------------- +// Pattern 2 matching: direct DQ(axis=0, 2D UINT4) -> MatMul/Gemm +// --------------------------------------------------------------------------- + +std::vector CollectDirectDQMatches( + Graph& graph, + const std::vector& node_topology_list, + const std::unordered_set& skip_indices, + const logging::Logger& logger) { + std::vector direct_matches; + + for (auto node_index : node_topology_list) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + + if (node->OpType() != "MatMul" && node->OpType() != "Gemm") continue; + if (skip_indices.count(node->Index())) continue; + + const auto& mm_inputs = node->InputDefs(); + if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue; + + const Node* dq_node = graph.GetProducerNode(mm_inputs[1]->Name()); + if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue; + if (dq_node->GetOutputEdgesCount() != 1) continue; + + const auto& dq_attrs = dq_node->GetAttributes(); + { + auto it = dq_attrs.find("axis"); + if (it == dq_attrs.end() || it->second.i() != 0) continue; + } + int64_t block_size = 0; + { + auto it = dq_attrs.find("block_size"); + if (it == dq_attrs.end()) continue; + block_size = it->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) continue; + } + + const auto* weight_arg = dq_node->InputDefs()[0]; + if (!weight_arg || !weight_arg->Exists()) continue; + const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true); + if (!weight_const_tp) continue; + if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (weight_const_tp->dims_size() != 2) continue; + const int64_t K = weight_const_tp->dims(0); + const int64_t N = weight_const_tp->dims(1); + if (K <= 0 || N <= 0 || K % block_size != 0) continue; + const int64_t k_blocks = K / block_size; + + const auto* scale_arg = dq_node->InputDefs()[1]; + if (!scale_arg || !scale_arg->Exists()) continue; + const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true); + if (!scale_const_tp) continue; + int32_t dt_scale = scale_const_tp->data_type(); + if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + if (!HasRank2Shape(*scale_const_tp, k_blocks, N)) continue; + + const auto* a_arg = mm_inputs[0]; + if (!a_arg || !a_arg->TypeAsProto()) continue; + int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_a != dt_scale) continue; + + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + if (has_zp) { + const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true); + if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (!HasRank2Shape(*zp_const_tp, k_blocks, N)) continue; + } + + if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue; + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched direct DQ->MatMul pattern at node '" + << node->Name() << "' (K=" << K << ", N=" << N << ", block_size=" << block_size << ")"; + direct_matches.push_back({node->Index(), dq_node->Index()}); + } + + return direct_matches; +} + +// --------------------------------------------------------------------------- +// Pattern 1 rewriting: DQ+Reshape+Transpose+[Cast]+MatMul/Gemm -> MatMulNBits +// --------------------------------------------------------------------------- + +void ApplyReshapeTransposeFusions( + Graph& graph, + const std::vector& matches, + int64_t accuracy_level, + bool& modified, + const logging::Logger& logger) { + for (const auto& match : matches) { + const Node* mm_node = graph.GetNode(match.matmul_idx); + const Node* cast_node = match.cast_idx ? graph.GetNode(*match.cast_idx) : nullptr; + const Node* tp_node = graph.GetNode(match.transpose_idx); + const Node* dq_node = graph.GetNode(match.dq_idx); + const Node* reshape_node = graph.GetNode(match.reshape_idx); + if (!mm_node || !tp_node || !dq_node || !reshape_node || + (match.cast_idx && !cast_node)) { + continue; + } + + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + + const auto& dq_attrs = dq_node->GetAttributes(); + const int64_t block_size = dq_attrs.at("block_size").i(); + + const ONNX_NAMESPACE::TensorProto* weight_tp = nullptr; + if (!graph.GetInitializedTensor(weight_arg->Name(), weight_tp) || !weight_tp) continue; + const ONNX_NAMESPACE::TensorProto* scale_tp = nullptr; + if (!graph.GetInitializedTensor(scale_arg->Name(), scale_tp) || !scale_tp) continue; + const ONNX_NAMESPACE::TensorProto* zp_tp = nullptr; + if (has_zp) { + if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; + } + + if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || + weight_tp->dims_size() != 3) { + continue; + } + + const int64_t N = weight_tp->dims(0); + const int64_t quant_num = weight_tp->dims(1); + const int64_t bs_dim = weight_tp->dims(2); + if (N <= 0 || quant_num <= 0 || bs_dim <= 0 || bs_dim != block_size) continue; + const int64_t K = SafeInt(quant_num) * bs_dim; + const int64_t blob_bytes = (block_size + 1) / 2; + + Initializer weight_src(graph, *weight_tp, graph.ModelPath()); + Initializer scale_src(graph, *scale_tp, graph.ModelPath()); + if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum( + ONNX_NAMESPACE::TensorProto_DataType_UINT8) + ->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum( + scale_src.data_type()) + ->GetElementType(); + + auto cpu_allocator = CPUAllocator::DefaultInstance(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb"); + auto weight_dst = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator); + + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb"); + const int64_t scale_size = (TensorShape{N, quant_num}).Size(); + if (scale_src.size() != static_cast(scale_size)) continue; + auto scale_dst = Tensor(scale_type, TensorShape{scale_size}, cpu_allocator); + + std::string zp_dst_name; + std::optional zp_dst; + const int64_t zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); + + bool elide_default_uint4_zp8_input = false; + std::optional zp_src; + + const auto weight_bytes = weight_src.DataAsByteSpan(); + if (weight_bytes.size() != static_cast(weight_dst.SizeInBytes())) continue; + memcpy(weight_dst.MutableDataRaw(), weight_bytes.data(), weight_bytes.size()); + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + memcpy(scale_dst.MutableData(), scale_src.data(), + static_cast(scale_size) * sizeof(float)); + } else { + memcpy(scale_dst.MutableData(), scale_src.data(), + static_cast(scale_size) * sizeof(MLFloat16)); + } + + if (zp_tp) { + zp_src.emplace(graph, *zp_tp, graph.ModelPath()); + if (zp_src->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (zp_src->size() != static_cast(N * quant_num)) continue; + + const bool is_default_uint4_8 = + IsUniformPackedUint4Value(*zp_src, /*expected_nibble*/ 8); + if (is_default_uint4_8) { + elide_default_uint4_zp8_input = true; + } else { + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + PackUint4Rows(*zp_src, N, quant_num, zp_dst->MutableData()); + } + } else { + // DequantizeLinear default zero-point for uint4 is 0, while MatMulNBits + // default is 8. Emit explicit zeros to preserve semantics. + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_zp_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); + } + + auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); + std::optional zp_mnb_tp; + if (zp_dst && !elide_default_uint4_zp8_input) { + zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + } + + NodeAttributes mnb_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); + + std::vector mnb_inputs; + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); + if (zp_mnb_tp) { + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); + } + + // MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zero_points(opt), 4:g_idx(opt), 5:bias(opt) + bool fused_with_bias = false; + if (mm_node->OpType() == "Gemm" && + mm_node->InputDefs().size() > 2 && + mm_node->InputDefs()[2] && + mm_node->InputDefs()[2]->Exists()) { + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (mnb_inputs.size() < 5) { + mnb_inputs.push_back(&empty_arg); + } + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[2])); + fused_with_bias = true; + } + + std::vector mnb_outputs; + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + + auto& mnb_node = graph.AddNode( + graph.GenerateNodeName("DQFusedMatMulNBits"), + "MatMulNBits", + "Fused from DQ+Reshape+Transpose+MatMul", + mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); + mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); + graph.RemoveNode(match.matmul_idx); + + if (match.cast_idx && graph.GetNode(*match.cast_idx)) { + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.cast_idx)); + graph.RemoveNode(*match.cast_idx); + } + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.transpose_idx)); + graph.RemoveNode(match.transpose_idx); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.reshape_idx)); + graph.RemoveNode(match.reshape_idx); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx)); + graph.RemoveNode(match.dq_idx); + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused DQ+Reshape+Transpose" + << (match.cast_idx ? "+Cast" : "") + << "+MatMul/Gemm -> MatMulNBits" + << (fused_with_bias ? " (bias preserved)" : "") + << (elide_default_uint4_zp8_input ? " (default UINT4 zp8 elided)" : ""); + modified = true; + } +} + +// --------------------------------------------------------------------------- +// Pattern 2 rewriting: direct DQ(axis=0) + MatMul/Gemm -> MatMulNBits +// --------------------------------------------------------------------------- + +void ApplyDirectDQFusions( + Graph& graph, + const std::vector& matches, + int64_t accuracy_level, + bool& modified, + const logging::Logger& logger) { + for (const auto& match : matches) { + const Node* mm_node = graph.GetNode(match.matmul_idx); + const Node* dq_node = graph.GetNode(match.dq_idx); + if (!mm_node || !dq_node) continue; + + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + bool has_zp = zp_arg && zp_arg->Exists(); + + const auto& dq_attrs = dq_node->GetAttributes(); + const int64_t block_size = dq_attrs.at("block_size").i(); + + const ONNX_NAMESPACE::TensorProto* weight_tp = nullptr; + if (!graph.GetInitializedTensor(weight_arg->Name(), weight_tp) || !weight_tp) continue; + const ONNX_NAMESPACE::TensorProto* scale_tp = nullptr; + if (!graph.GetInitializedTensor(scale_arg->Name(), scale_tp) || !scale_tp) continue; + const ONNX_NAMESPACE::TensorProto* zp_tp = nullptr; + if (has_zp) { + if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue; + } + + if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 || + weight_tp->dims_size() != 2) continue; + + const int64_t K = weight_tp->dims(0); + const int64_t N = weight_tp->dims(1); + if (K <= 0 || N <= 0 || block_size <= 0 || K % block_size != 0) continue; + const int64_t k_blocks = K / block_size; + const int64_t blob_bytes = block_size / 2; + if (!HasRank2Shape(*scale_tp, k_blocks, N)) continue; + if (zp_tp && !HasRank2Shape(*zp_tp, k_blocks, N)) continue; + + Initializer weight_src(graph, *weight_tp, graph.ModelPath()); + const size_t required_weight_bytes = SafeInt(N) * k_blocks * blob_bytes; + if (weight_src.DataAsByteSpan().size() < required_weight_bytes) continue; + Initializer scale_src(graph, *scale_tp, graph.ModelPath()); + if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue; + + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum( + ONNX_NAMESPACE::TensorProto_DataType_UINT8) + ->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum( + scale_src.data_type()) + ->GetElementType(); + auto cpu_allocator = CPUAllocator::DefaultInstance(); + + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb"); + auto weight_dst = Tensor(uint8_type, TensorShape{N, k_blocks, blob_bytes}, cpu_allocator); + TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size, + weight_dst.MutableData()); + + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb"); + const int64_t scale_count = SafeInt(N) * k_blocks; + if (scale_src.size() != static_cast(scale_count)) continue; + auto scale_dst = Tensor(scale_type, TensorShape{scale_count}, cpu_allocator); + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + const float* src = scale_src.data(); + float* dst = scale_dst.MutableData(); + for (int64_t n = 0; n < N; ++n) + for (int64_t kb = 0; kb < k_blocks; ++kb) + dst[n * k_blocks + kb] = src[kb * N + n]; + } else { + const MLFloat16* src = scale_src.data(); + MLFloat16* dst = scale_dst.MutableData(); + for (int64_t n = 0; n < N; ++n) + for (int64_t kb = 0; kb < k_blocks; ++kb) + dst[n * k_blocks + kb] = src[kb * N + n]; + } + + std::string zp_dst_name; + std::optional zp_dst; + const int64_t zp_bytes_total = SafeInt(N) * ((k_blocks + 1) / 2); + + bool elide_zp = false; + + if (zp_tp) { + Initializer zp_src(graph, *zp_tp, graph.ModelPath()); + if (zp_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue; + if (zp_src.size() != static_cast(k_blocks * N)) continue; + + if (IsUniformPackedUint4Value(zp_src, 8)) { + elide_zp = true; + } else { + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); + TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N, + zp_dst->MutableData()); + } + } else { + // DQ default ZP for UINT4 is 0, MatMulNBits default is 8. Emit explicit zeros. + zp_dst_name = graph.GenerateNodeArgName("direct_DQ_zp_mnb"); + zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); + } + + auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); + std::optional zp_mnb_tp; + if (zp_dst && !elide_zp) { + zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + } + + NodeAttributes mnb_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs); + + std::vector mnb_inputs; + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); + if (zp_mnb_tp) { + mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); + } + + bool fused_with_bias = false; + if (mm_node->OpType() == "Gemm" && + mm_node->InputDefs().size() > 2 && + mm_node->InputDefs()[2] && + mm_node->InputDefs()[2]->Exists()) { + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + while (mnb_inputs.size() < 5) { + mnb_inputs.push_back(&empty_arg); + } + mnb_inputs.push_back(const_cast(mm_node->InputDefs()[2])); + fused_with_bias = true; + } + + std::vector mnb_outputs; + mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0])); + + auto& mnb_node = graph.AddNode( + graph.GenerateNodeName("DirectDQFusedMatMulNBits"), + "MatMulNBits", + "Fused from direct DQ(axis=0)+MatMul", + mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain); + mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType()); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx)); + graph.RemoveNode(match.matmul_idx); + + graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx)); + graph.RemoveNode(match.dq_idx); + + LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused direct DQ(axis=0)+MatMul/Gemm -> MatMulNBits" + << " (K=" << K << ", N=" << N << ", block_size=" << block_size << ")" + << (fused_with_bias ? " (bias preserved)" : "") + << (elide_zp ? " (default UINT4 zp8 elided)" : ""); + modified = true; + } +} + +} // namespace + +// --------------------------------------------------------------------------- +// DQMatMulNBitsFusion public interface +// --------------------------------------------------------------------------- + +DQMatMulNBitsFusion::DQMatMulNBitsFusion( + int64_t accuracy_level, + const InlinedHashSet& compatible_eps) + : GraphTransformer("DQMatMulNBitsFusion", compatible_eps), + accuracy_level_(accuracy_level) { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, + "MatMulNBits accuracy level must be between 0 and 4"); +} + +Status DQMatMulNBitsFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + } + + auto matches = CollectReshapeTransposeMatches(graph, node_topology_list, logger); + + std::unordered_set matched_matmul_indices; + for (const auto& m : matches) { + matched_matmul_indices.insert(m.matmul_idx); + } + + auto direct_matches = CollectDirectDQMatches(graph, node_topology_list, + matched_matmul_indices, logger); + + ApplyReshapeTransposeFusions(graph, matches, accuracy_level_, modified, logger); + ApplyDirectDQFusions(graph, direct_matches, accuracy_level_, modified, logger); + + return Status::OK(); +} + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h new file mode 100644 index 0000000000000..97c0debd760c0 --- /dev/null +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +// Fuses DequantizeLinear chains back into a single MatMulNBits contrib op. +// +// Supported patterns: +// Pattern 1: DQ(3D, UINT4, axis=2) -> Reshape(2D) -> Transpose([1,0]) +// -> [optional Cast] -> MatMul/Gemm => MatMulNBits +// Pattern 2: DQ(2D, UINT4, axis=0) -> MatMul/Gemm => MatMulNBits +// +// These patterns are produced when a quantized model goes through external +// toolchains that lower MatMulNBits to DQ + reshape/transpose + MatMul +// primitives, and then re-import the graph into ORT. +class DQMatMulNBitsFusion : public GraphTransformer { + public: + explicit DQMatMulNBitsFusion( + int64_t accuracy_level = 4, + const InlinedHashSet& compatible_eps = {}); + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const override; + + int64_t accuracy_level_; +}; + +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fdd4f5aa27862..4edabbe6058ab 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -18,6 +18,8 @@ #if !defined(ORT_MINIMAL_BUILD) +#include "core/optimizer/dq_matmulnbits_fusion.h" + #include "core/mlas/inc/mlas.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/bias_dropout_fusion.h" @@ -274,6 +276,26 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } +#if !defined(DISABLE_CONTRIB_OPS) + { + const bool enable_dq_matmulnbits_fusion = + session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "0") == "1"; + if (enable_dq_matmulnbits_fusion && !disable_quant_qdq) { + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); + transformers.emplace_back(std::make_unique( + qdq_matmulnbits_accuracy_level)); + } + } +#else + ORT_ENFORCE(session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "0") != "1", + "DQ->MatMulNBits fusion requires contrib ops but DISABLE_CONTRIB_OPS is defined"); +#endif + // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index bb9130add215a..63d985052f996 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -121,7 +121,7 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu } void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, - std::unordered_map duration_per_batch_size) const { + const std::unordered_map& duration_per_batch_size) const { ORT_UNUSED_PARAMETER(session_id); ORT_UNUSED_PARAMETER(total_runs_since_last); ORT_UNUSED_PARAMETER(total_run_duration_since_last); @@ -157,4 +157,35 @@ void Telemetry::LogProviderOptions(const std::string& provider_id, ORT_UNUSED_PARAMETER(captureState); } +void Telemetry::LogModelLoadStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); +} + +void Telemetry::LogModelLoadEnd(uint32_t session_id, const common::Status& status) const { + ORT_UNUSED_PARAMETER(session_id); + ORT_UNUSED_PARAMETER(status); +} + +void Telemetry::LogSessionCreationEnd(uint32_t session_id, + const common::Status& status) const { + ORT_UNUSED_PARAMETER(session_id); + ORT_UNUSED_PARAMETER(status); +} + +void Telemetry::LogRegisterEpLibraryWithLibPath(const std::string& registration_name, + const std::string& lib_path) const { + ORT_UNUSED_PARAMETER(registration_name); + ORT_UNUSED_PARAMETER(lib_path); +} + +void Telemetry::LogRegisterEpLibraryStart(const std::string& registration_name) const { + ORT_UNUSED_PARAMETER(registration_name); +} + +void Telemetry::LogRegisterEpLibraryEnd(const std::string& registration_name, + const common::Status& status) const { + ORT_UNUSED_PARAMETER(registration_name); + ORT_UNUSED_PARAMETER(status); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index e74d7ed0180fd..20a58e5a87184 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -85,7 +85,7 @@ class Telemetry { const char* function, uint32_t line) const; virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, - std::unordered_map duration_per_batch_size) const; + const std::unordered_map& duration_per_batch_size) const; virtual void LogExecutionProviderEvent(LUID* adapterLuid) const; @@ -101,6 +101,21 @@ class Telemetry { const std::string& provider_options_string, bool captureState) const; + virtual void LogModelLoadStart(uint32_t session_id) const; + + virtual void LogModelLoadEnd(uint32_t session_id, const common::Status& status) const; + + virtual void LogSessionCreationEnd(uint32_t session_id, + const common::Status& status) const; + + virtual void LogRegisterEpLibraryWithLibPath(const std::string& registration_name, + const std::string& lib_path) const; + + virtual void LogRegisterEpLibraryStart(const std::string& registration_name) const; + + virtual void LogRegisterEpLibraryEnd(const std::string& registration_name, + const common::Status& status) const; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry); }; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 6d5a400be703b..30c24de6a92ed 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -457,7 +457,8 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingInt32(graph_optimization_level, "graphOptimizationLevel"), TraceLoggingBool(embed_ep_context, "embedEpContext"), TraceLoggingBool(has_external_initializers_file, "hasExternalInitializersFile"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogCompileModelComplete(uint32_t session_id, @@ -480,7 +481,8 @@ void WindowsTelemetry::LogCompileModelComplete(uint32_t session_id, TraceLoggingBool(success, "success"), TraceLoggingUInt32(error_code, "errorCode"), TraceLoggingUInt32(error_category, "errorCategory"), - TraceLoggingString(error_message.c_str(), "errorMessage")); + TraceLoggingString(error_message.c_str(), "errorMessage"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, @@ -528,7 +530,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status } void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, - std::unordered_map duration_per_batch_size) const { + const std::unordered_map& duration_per_batch_size) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -668,4 +670,116 @@ void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const } } +void WindowsTelemetry::LogModelLoadStart(uint32_t session_id) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "ModelLoadStart", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + +void WindowsTelemetry::LogModelLoadEnd(uint32_t session_id, const common::Status& status) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "ModelLoadEnd", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingBool(status.IsOK(), "isSuccess"), + TraceLoggingUInt32(status.Code(), "errorCode"), + TraceLoggingUInt32(status.Category(), "errorCategory"), + TraceLoggingString(status.IsOK() ? "" : status.ErrorMessage().c_str(), "errorMessage"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + +void WindowsTelemetry::LogSessionCreationEnd(uint32_t session_id, + const common::Status& status) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "SessionCreationEnd", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingBool(status.IsOK(), "isSuccess"), + TraceLoggingUInt32(status.Code(), "errorCode"), + TraceLoggingUInt32(status.Category(), "errorCategory"), + TraceLoggingString(status.IsOK() ? "" : status.ErrorMessage().c_str(), "errorMessage"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + +void WindowsTelemetry::LogRegisterEpLibraryWithLibPath(const std::string& registration_name, + const std::string& lib_path) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "RegisterEpLibraryWithLibPath", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingString(registration_name.c_str(), "registrationName"), + TraceLoggingString(lib_path.c_str(), "libPath"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + +void WindowsTelemetry::LogRegisterEpLibraryStart(const std::string& registration_name) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "RegisterEpLibraryStart", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingString(registration_name.c_str(), "registrationName"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + +void WindowsTelemetry::LogRegisterEpLibraryEnd(const std::string& registration_name, + const common::Status& status) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "RegisterEpLibraryEnd", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingString(registration_name.c_str(), "registrationName"), + TraceLoggingBool(status.IsOK(), "isSuccess"), + TraceLoggingUInt32(status.Code(), "errorCode"), + TraceLoggingUInt32(status.Category(), "errorCategory"), + TraceLoggingString(status.IsOK() ? "" : status.ErrorMessage().c_str(), "errorMessage"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 591a248d70ab8..b46dcfbd3feb5 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -78,7 +78,7 @@ class WindowsTelemetry : public Telemetry { const char* function, uint32_t line) const override; void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, - std::unordered_map duration_per_batch_size) const override; + const std::unordered_map& duration_per_batch_size) const override; void LogExecutionProviderEvent(LUID* adapterLuid) const override; @@ -94,6 +94,21 @@ class WindowsTelemetry : public Telemetry { const std::string& provider_options_string, bool captureState) const override; + void LogModelLoadStart(uint32_t session_id) const override; + + void LogModelLoadEnd(uint32_t session_id, const common::Status& status) const override; + + void LogSessionCreationEnd(uint32_t session_id, + const common::Status& status) const override; + + void LogRegisterEpLibraryWithLibPath(const std::string& registration_name, + const std::string& lib_path) const override; + + void LogRegisterEpLibraryStart(const std::string& registration_name) const override; + + void LogRegisterEpLibraryEnd(const std::string& registration_name, + const common::Status& status) const override; + using EtwInternalCallback = std::function; diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc index d8c81e5cb63e5..6ecbfaa3993ca 100644 --- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc @@ -294,6 +294,41 @@ Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, cons return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "First dimension (num_rois) of batch_indices and rois don't match"); } + + // Validate batch_indices values are within [0, batch_size). + // Only check when the tensor data is accessible from the host (CPU). + // For GPU tensors (e.g. CUDA EP), Data() returns a device pointer + // that cannot be safely dereferenced on the host. A device-side bounds + // check for the CUDA path would require passing batch_size into the + // CUDA kernel — tracked as a follow-up. + if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) { + const int64_t batch_size = X_ptr->Shape()[0]; + const int64_t num_rois = batch_indices_dims[0]; + + auto check_bounds = [batch_size, num_rois](const auto* batch_indices_data) -> Status { + for (int64_t i = 0; i < num_rois; ++i) { + if (batch_indices_data[i] < 0 || batch_indices_data[i] >= batch_size) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "batch_indices value " + std::to_string(batch_indices_data[i]) + + " at index " + std::to_string(i) + + " is out of range [0, " + std::to_string(batch_size) + ")"); + } + } + return Status::OK(); + }; + + if (batch_indices_ptr->IsDataType()) { + auto status = check_bounds(batch_indices_ptr->Data()); + if (!status.IsOK()) return status; + } else if (batch_indices_ptr->IsDataType()) { + auto status = check_bounds(batch_indices_ptr->Data()); + if (!status.IsOK()) return status; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "batch_indices must be of type int64_t or int32_t"); + } + } + return Status::OK(); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index e415143a6ddd1..c567c220b3ce7 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -306,8 +306,8 @@ class NvExecutionProvider : public IExecutionProvider { const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const override; - int GetDeviceId() const { return device_id_; } - Status Sync() const; + int GetDeviceId() const override { return device_id_; } + Status Sync() const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 90e488a1eda18..a7c37cd481894 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -28,17 +28,32 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { + // Domain for TRT plugin custom ops (domain name: "trt.plugins"). Owns the OrtCustomOpDomain object. + // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession. static std::unique_ptr custom_op_domain = std::make_unique(); + + // Owns the TensorRTCustomOp objects for TRT plugins. Raw pointers are stored in custom_op_domain->custom_ops_. static std::vector> created_custom_op_list; + + // Domain for native custom ops (domain name: "trt"). Owns the OrtCustomOpDomain object. + // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession. static std::unique_ptr native_custom_op_domain = std::make_unique(); + + // Owns the TensorRTCustomOp objects for native custom ops. Raw pointers are stored in native_custom_op_domain->custom_ops_. + // Non-empty list indicates native custom ops have been registered (used to avoid re-registration on subsequent calls). static std::vector> native_custom_op_list; + + // Protects concurrent access to all the above static members. static std::mutex mutex; std::lock_guard lock(mutex); + + // Add already-initialized native ops to domain list + if (!native_custom_op_list.empty()) { + domain_list.push_back(native_custom_op_domain.get()); + } + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { domain_list.push_back(custom_op_domain.get()); - if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) { - domain_list.push_back(native_custom_op_domain.get()); - } return Status::OK(); } @@ -132,35 +147,36 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& } // Register native custom ops (register these independent of TRT plugin library availability) - const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"}; - int num_native_custom_ops = std::size(native_custom_ops_names); + if (native_custom_op_list.empty()) { + const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"}; + size_t num_native_custom_ops = std::size(native_custom_ops_names); + + for (size_t i = 0; i < num_native_custom_ops; i++) { + native_custom_op_list.push_back(std::make_unique(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr)); + native_custom_op_list.back()->SetName(native_custom_ops_names[i]); + native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get()); + } - for (int i = 0; i < num_native_custom_ops; i++) { - native_custom_op_list.push_back(std::make_unique(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr)); - native_custom_op_list.back()->SetName(native_custom_ops_names[i]); - native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get()); + native_custom_op_domain->domain_ = "trt"; + domain_list.push_back(native_custom_op_domain.get()); } - native_custom_op_domain->domain_ = "trt"; - domain_list.push_back(native_custom_op_domain.get()); return Status::OK(); } void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { - if (domain != nullptr) { - for (auto ptr : domain->custom_ops_) { - if (ptr != nullptr) { - delete ptr; - } - } - delete domain; - } + (void)domain; // Suppress unused parameter warning + // The domain and its custom ops are owned by static unique_ptrs in CreateTensorRTCustomOpDomainList(). + // Callers receive raw pointers via .get(). + // 1. Manually deleting them would cause a double-free when the static unique_ptrs are destroyed at program exit. + // 2. Resetting the static unique_ptrs is also unsafe because other EP instances or InferenceSession objects + // may still hold raw pointers to these same objects (handed out via domain_list). + // The static objects would be shared across EP instances and would persist for the program lifetime. } void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) { - for (auto ptr : custom_op_domain_list) { - ReleaseTensorRTCustomOpDomain(ptr); - } + // Only clear the reference vector, don't delete the static domain objects. + custom_op_domain_list.clear(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index d1e449eb58870..388387fae4b0a 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -61,7 +61,7 @@ struct NvProviderFactory : IExecutionProviderFactory { std::unique_ptr CreateProvider() override; std::unique_ptr CreateProvider(const OrtSessionOptions& session_options, - const OrtLogger& session_logger); + const OrtLogger& session_logger) override; private: NvExecutionProviderInfo info_; @@ -109,7 +109,7 @@ struct Nv_Provider : Provider { return std::make_shared(info); } - std::shared_ptr CreateExecutionProviderFactory(const void* param) { + std::shared_ptr CreateExecutionProviderFactory(const void* param) override { if (param == nullptr) { LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Passed NULL options to CreateExecutionProviderFactory()"; return nullptr; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 1e9fafe8aa323..2418c8424422b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -45,8 +45,14 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { + // Domain for TRT plugin custom ops (domain name: "trt.plugins"). Owns the OrtCustomOpDomain object. + // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession. static std::unique_ptr custom_op_domain = std::make_unique(); + + // Owns the TensorRTCustomOp objects for TRT plugins. Raw pointers are stored in custom_op_domain->custom_ops_. static std::vector> created_custom_op_list; + + // Protects concurrent access to all the above static members. static std::mutex mutex; std::lock_guard lock(mutex); if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { @@ -148,20 +154,18 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& } void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { - if (domain != nullptr) { - for (auto ptr : domain->custom_ops_) { - if (ptr != nullptr) { - delete ptr; - } - } - delete domain; - } + (void)domain; // Suppress unused parameter warning + // The domain and its custom ops are owned by static unique_ptrs in CreateTensorRTCustomOpDomainList(). + // Callers receive raw pointers via .get(). + // 1. Manually deleting them would cause a double-free when the static unique_ptrs are destroyed at program exit. + // 2. Resetting the static unique_ptrs is also unsafe because other EP instances or InferenceSession objects + // may still hold raw pointers to these same objects (handed out via domain_list). + // The static objects are shared across EP instances and persist for the program lifetime. } void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) { - for (auto ptr : custom_op_domain_list) { - ReleaseTensorRTCustomOpDomain(ptr); - } + // Only clear the reference vector, don't delete the static domain objects. + custom_op_domain_list.clear(); } } // namespace onnxruntime diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 523ff8eaf13b8..7cd02e5413407 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -17,6 +17,7 @@ #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" #include "core/session/onnxruntime_env_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/plugin_ep/ep_library_plugin.h" @@ -539,8 +540,13 @@ bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, std::unique_ptr ep_library, const std::vector& internal_factories) { + const Env& env = Env::Default(); + env.GetTelemetryProvider().LogRegisterEpLibraryStart(registration_name); + if (ep_libraries_.count(registration_name) > 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "library is already registered under ", registration_name); + auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "library is already registered under ", registration_name); + env.GetTelemetryProvider().LogRegisterEpLibraryEnd(registration_name, status); + return status; } auto status = Status::OK(); @@ -592,6 +598,7 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra }); } + env.GetTelemetryProvider().LogRegisterEpLibraryEnd(registration_name, status); return status; } @@ -611,6 +618,9 @@ Status Environment::CreateAndRegisterInternalEps() { Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) { std::lock_guard lock{mutex_}; + std::string lib_file_name = std::filesystem::path(lib_path).filename().string(); + Env::Default().GetTelemetryProvider().LogRegisterEpLibraryWithLibPath(registration_name, lib_file_name); + std::vector internal_factories = {}; std::unique_ptr ep_library; @@ -896,8 +906,15 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(), ep_devices.data(), ep_devices.size(), &num_ep_devices))); + const auto* library_path = instance.library->LibraryPath(); for (size_t i = 0; i < num_ep_devices; ++i) { - if (ep_devices[i] != nullptr) { // should never happen but just in case... + if (ep_devices[i] != nullptr) { // should never happen but just in case... + if (library_path != nullptr) { + // Add library path to EP metadata if available. + // This is used by GenAI for custom library loading so we want to consistently set it. + ep_devices[i]->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path->string()); + } + instance.execution_devices.emplace_back(ep_devices[i]); // take ownership } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 0c9b3c0663b5c..d24020424935a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -37,6 +37,7 @@ #include "core/framework/plugin_ep_stream.h" #include "core/framework/transform_layout_functions.h" #include "core/framework/utils.h" +#include "core/graph/constants.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/model_editor_api_types.h" @@ -730,6 +731,25 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const #endif // !defined(ORT_MINIMAL_BUILD) InferenceSession::~InferenceSession() { + // Flush any remaining RuntimePerf counters + ORT_TRY { + std::lock_guard telemetry_lock(telemetry_mutex_); + if (telemetry_.total_runs_since_last_ > 0) { + Env::Default().GetTelemetryProvider().LogRuntimePerf(session_id_, + telemetry_.total_runs_since_last_, + telemetry_.total_run_duration_since_last_, + telemetry_.duration_per_batch_size_); + } + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + LOGS(*session_logger_, ERROR) << "Error during telemetry flush: " << e.what(); + }); + } + ORT_CATCH(...) { + LOGS(*session_logger_, ERROR) << "Unknown error during telemetry flush"; + } + if (session_options_.enable_profiling) { ORT_TRY { EndProfiling(); @@ -969,7 +989,10 @@ common::Status InferenceSession::LoadWithLoader(std::function l(session_mutex_); if (is_model_loaded_) { // already loaded LOGS(*session_logger_, ERROR) << "This session already contains a loaded model."; @@ -1005,6 +1028,8 @@ common::Status InferenceSession::LoadWithLoader(std::function load_ort_format_model_bytes) { + const Env& env = Env::Default(); + env.GetTelemetryProvider().LogModelLoadStart(session_id_); + std::lock_guard l(session_mutex_); if (is_model_loaded_) { // already loaded @@ -1761,6 +1789,8 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort is_model_loaded_ = true; + env.GetTelemetryProvider().LogModelLoadEnd(session_id_, Status::OK()); + return Status::OK(); } @@ -2263,6 +2293,16 @@ common::Status InferenceSession::Initialize() { return Status::OK(); }; + // Enable DQ->MatMulNBits fusion if NvTensorRTRTX EP is registered. + if (execution_providers_.Get(onnxruntime::kNvTensorRTRTXExecutionProvider) != nullptr) { + if (session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "") == "") { + ORT_RETURN_IF_ERROR_SESSIONID_( + session_options_.config_options.AddConfigEntry( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1")); + } + } + // add predefined transformers ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_, session_options_.graph_optimization_level, @@ -2603,6 +2643,12 @@ common::Status InferenceSession::Initialize() { } } + // Log session creation end telemetry + { + const Env& init_env = Env::Default(); + init_env.GetTelemetryProvider().LogSessionCreationEnd(session_id_, status); + } + return status; } #if defined(_MSC_VER) && !defined(__clang__) @@ -3163,24 +3209,31 @@ Status InferenceSession::Run(const RunOptions& run_options, break; } - // time to send telemetry? - { - // Adding lock_guard here to ensure that telemetry updates are thread-safe. - std::lock_guard telemetry_lock(telemetry_mutex_); - ++telemetry_.total_runs_since_last_; - telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp); - telemetry_.duration_per_batch_size_[batch_size] += TimeDiffMicroSeconds(tp); - - if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > Telemetry::kDurationBetweenSending) { - // send the telemetry - env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_, - telemetry_.total_run_duration_since_last_, - telemetry_.duration_per_batch_size_); - // reset counters - telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now(); - telemetry_.total_runs_since_last_ = 0; - telemetry_.total_run_duration_since_last_ = 0; - telemetry_.duration_per_batch_size_.clear(); + // Only include successful inferences in batch since failed inferences can skew the metric + if (retval.IsOK()) { + // time to send telemetry? + { + // Adding lock_guard here to ensure that telemetry updates are thread-safe. + std::lock_guard telemetry_lock(telemetry_mutex_); + ++telemetry_.total_runs_since_last_; + telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp); + telemetry_.duration_per_batch_size_[batch_size] += TimeDiffMicroSeconds(tp); + + // Emit RuntimePerf on scheduled interval + if ((TimeDiffMicroSeconds(telemetry_.time_sent_last_) > telemetry_.runtime_perf_interval_)) { + env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_, + telemetry_.total_run_duration_since_last_, + telemetry_.duration_per_batch_size_); + // reset counters + telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now(); + telemetry_.total_runs_since_last_ = 0; + telemetry_.total_run_duration_since_last_ = 0; + telemetry_.duration_per_batch_size_.clear(); + + // Double the interval, capping at kRuntimePerfMaxInterval + telemetry_.runtime_perf_interval_ = std::min(telemetry_.runtime_perf_interval_ * 2, + Telemetry::kRuntimePerfMaxInterval); + } } } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 1dbf0318c988c..e51dc773f9761 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -976,8 +976,10 @@ class InferenceSession { std::unordered_map duration_per_batch_size_; // the duration (us) of Run() calls per batch size since the last report TimePoint time_sent_last_; // the TimePoint of the last report - // Event Rate per provider < 20 peak events per second - constexpr static long long kDurationBetweenSending = 1000 * 1000 * 60 * 10; // duration in (us). send a report every 10 mins + // RuntimePerf backoff interval: starts at 2s between emissions, doubles each emission, caps at 10 min + constexpr static int64_t kRuntimePerfInitialInterval = 2 * 1000 * 1000; // 2 seconds in (us) + constexpr static int64_t kRuntimePerfMaxInterval = 1000 * 1000 * 60 * 10; // 10 minutes in (us) + int64_t runtime_perf_interval_ = kRuntimePerfInitialInterval; } telemetry_; mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc index 42b65239de92c..0e2c4b4217702 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -22,11 +22,6 @@ OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_fa auto* ep_device = ep_devices[i]; if (ep_device) { ep_device->ep_factory = &ep_factory; - - // Add library path to EP metadata if available - if (library_path_.has_value()) { - ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string()); - } } } diff --git a/onnxruntime/core/session/plugin_ep/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h index af5bc23143e33..fed9eb072c704 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library.h +++ b/onnxruntime/core/session/plugin_ep/ep_library.h @@ -20,6 +20,7 @@ class EpLibrary { EpLibrary() = default; virtual const char* RegistrationName() const = 0; + virtual const std::filesystem::path* LibraryPath() const { return nullptr; } virtual Status Load() { return Status::OK(); } virtual const std::vector& GetFactories() = 0; // valid after Load() virtual Status Unload() { return Status::OK(); } diff --git a/onnxruntime/core/session/plugin_ep/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h index e044e91b61e37..cce52b3d5d282 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_plugin.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h @@ -25,6 +25,10 @@ class EpLibraryPlugin : public EpLibrary { return registration_name_.c_str(); } + const std::filesystem::path* LibraryPath() const override { + return &library_path_; + } + Status Load() override; const std::vector& GetFactories() override { diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 45277b2828f56..f3147794bc823 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -22,7 +22,7 @@ class EpLibraryProviderBridge : public EpLibrary { public: EpLibraryProviderBridge(std::unique_ptr provider_library, std::unique_ptr ep_library_plugin, - std::optional library_path = std::nullopt) + std::filesystem::path library_path) : provider_library_{std::move(provider_library)}, ep_library_plugin_{std::move(ep_library_plugin)}, library_path_{std::move(library_path)} { @@ -32,6 +32,10 @@ class EpLibraryProviderBridge : public EpLibrary { return ep_library_plugin_->RegistrationName(); } + const std::filesystem::path* LibraryPath() const override { + return &library_path_; + } + const std::vector& GetFactories() override { return factory_ptrs_; } @@ -56,7 +60,7 @@ class EpLibraryProviderBridge : public EpLibrary { std::unique_ptr ep_library_plugin_; // Library path for EP metadata - std::optional library_path_; + std::filesystem::path library_path_; std::vector> factories_; std::vector factory_ptrs_; // for convenience diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc index 2d061b6066a8a..6e2d204b04cf2 100644 --- a/onnxruntime/lora/adapter_format_utils.cc +++ b/onnxruntime/lora/adapter_format_utils.cc @@ -7,8 +7,9 @@ #include "core/framework/allocator.h" #include "core/common/common.h" #include "core/common/endian.h" -#include "core/framework/endian_utils.h" +#include "core/common/safeint.h" #include "core/common/span_utils.h" +#include "core/framework/endian_utils.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" #include "core/framework/ort_value.h" @@ -149,13 +150,36 @@ struct ReadDataForBigEndian { std::pair CreateOrtValueOverLoraParameter(const Parameter& param) { OrtValue result; + const auto* param_name = param.name(); + ORT_ENFORCE(param_name != nullptr, "Lora Parameter: name is missing"); + std::string name; - LoadStringFromLoraFormat(name, param.name()); + LoadStringFromLoraFormat(name, param_name); const auto data_type = param.data_type(); + ORT_ENFORCE(data_type != TensorDataType::UNDEFINED, + "Lora Param '", name, "': data_type is UNDEFINED"); + + const auto* dims = param.dims(); + ORT_ENFORCE(dims != nullptr && dims->size() > 0, + "Lora Param '", name, "': dims is missing or empty"); + + const auto* raw_data = param.raw_data(); + ORT_ENFORCE(raw_data != nullptr, + "Lora Param '", name, "': raw_data is missing"); + // Copying shape takes care of endianess using flatbuffers accessors - TensorShapeVector shape(param.dims()->begin(), param.dims()->end()); + TensorShapeVector shape(dims->begin(), dims->end()); + TensorShape tensor_shape(shape); const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); + const size_t expected_raw_data_size = SafeInt(tensor_shape.Size()) * elem_type->Size(); + if (raw_data->size() != expected_raw_data_size) { + ORT_THROW("Lora Param '", name, + "': raw_data size (", raw_data->size(), + ") does not match expected size (", expected_raw_data_size, + ") calculated from tensor shape and element type"); + } + static const OrtMemoryInfo cpu_meminfo(CPU, OrtAllocatorType::OrtDeviceAllocator); if constexpr (endian::native == endian::big) { @@ -166,16 +190,16 @@ std::pair CreateOrtValueOverLoraParameter(const Parameter // of raw data // const_cast is necessary due to Tensor class API Tensor::InitOrtValue(elem_type, - TensorShape(shape), - const_cast(param.raw_data()->data()), + tensor_shape, + const_cast(raw_data->data()), cpu_meminfo, result); } } else { // const_cast is necessary due to Tensor class API Tensor::InitOrtValue(elem_type, - TensorShape(shape), - const_cast(param.raw_data()->data()), + tensor_shape, + const_cast(raw_data->data()), cpu_meminfo, result); } diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc index 7b6679ffaf462..79bc34572a6f7 100644 --- a/onnxruntime/test/autoep/test_registration.cc +++ b/onnxruntime/test/autoep/test_registration.cc @@ -74,6 +74,15 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { auto options = test_ep_device->EpOptions(); ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); + // Verify the library path is present in the EP metadata + const char* metadata_library_path = metadata.GetValue(kOrtEpDevice_EpMetadataKey_LibraryPath); + ASSERT_NE(metadata_library_path, nullptr) << "Expected library_path to be present in EP metadata."; + + // Verify the library path matches the registered path + std::filesystem::path metadata_path{metadata_library_path}; + ASSERT_EQ(std::filesystem::canonical(metadata_path), std::filesystem::canonical(library_path)) + << "Expected library_path in EP metadata to match the registered library path."; + // the CPU device info will vary by machine so check for the lowest common denominator values Ort::ConstHardwareDevice device = test_ep_device->Device(); ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index 0c55cf45abcdf..791250a6e1364 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -173,7 +173,6 @@ struct TestDataType { verify_load(lora_adapter); } }; - } // namespace TEST(LoraAdapterTest, Load) { @@ -199,6 +198,226 @@ TEST(LoraAdapterTest, Load) { } } +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ValidParam) { + // Build a valid adapter with a single float parameter, then call + // CreateOrtValueOverLoraParameter on the deserialized Parameter. + constexpr std::array shape = {8, 4}; + InlinedVector data(32); + std::iota(data.begin(), data.end(), 0.f); + + adapters::utils::AdapterFormatBuilder adapter_builder; + adapter_builder.AddParameter("valid_param", adapters::TensorDataType::FLOAT, + shape, ReinterpretAsSpan(gsl::make_span(data))); + + auto buffer = adapter_builder.Finish(kAdapterVersion, kModelVersion); + + const auto* adapter = adapters::utils::ValidateAndGetAdapterFromBytes(buffer); + ASSERT_NE(adapter, nullptr); + ASSERT_NE(adapter->parameters(), nullptr); + ASSERT_EQ(adapter->parameters()->size(), 1u); + + const auto* param = adapter->parameters()->Get(0); + auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param); + + ASSERT_EQ(name, "valid_param"); + ASSERT_TRUE(ort_value.IsTensor()); + + const auto& tensor = ort_value.Get(); + ASSERT_EQ(tensor.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + auto dims = tensor.Shape().GetDims(); + ASSERT_EQ(dims.size(), 2u); + ASSERT_EQ(dims[0], 8); + ASSERT_EQ(dims[1], 4); + + auto result_span = tensor.DataAsSpan(); + ASSERT_EQ(result_span.size(), 32u); + for (size_t i = 0; i < result_span.size(); ++i) { + ASSERT_EQ(static_cast(i), result_span[i]); + } +} + +#ifndef ORT_NO_EXCEPTIONS + +namespace { +// Helper that wraps a single Parameter offset into a finished Adapter flatbuffer +// and returns a pointer to the deserialized Parameter. +// The FlatBufferBuilder must outlive the returned pointer. +const adapters::Parameter* BuildAdapterAndGetParam(flatbuffers::FlatBufferBuilder& fbb, + flatbuffers::Offset param_offset) { + auto params_offset = fbb.CreateVector(¶m_offset, 1); + auto adapter_offset = adapters::CreateAdapter( + fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset); + adapters::FinishAdapterBuffer(fbb, adapter_offset); + + const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer()); + return adapter->parameters()->Get(0); +} +} // namespace + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_RawDataSizeMismatch) { + // Craft a flatbuffer Parameter where raw_data has fewer bytes than + // shape (8 x 4) * sizeof(float) = 128 bytes. + // We supply only 64 bytes (half the expected amount) so the validation + // inside CreateOrtValueOverLoraParameter must throw. + flatbuffers::FlatBufferBuilder fbb; + + auto name_offset = fbb.CreateString("bad_param"); + std::vector dims = {8, 4}; + auto dims_offset = fbb.CreateVector(dims); + + // 8 * 4 floats = 32 elements = 128 bytes expected. + // Provide only 64 bytes (16 floats worth) to trigger the mismatch. + std::vector short_data(64, 0); + fbb.ForceVectorAlignment(short_data.size(), sizeof(uint8_t), 8); + auto data_offset = fbb.CreateVector(short_data); + + auto param_offset = adapters::CreateParameter( + fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset); + + // Wrap the single parameter inside an Adapter so the buffer is valid flatbuffers. + auto params_offset = fbb.CreateVector(¶m_offset, 1); + auto adapter_offset = adapters::CreateAdapter( + fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset); + adapters::FinishAdapterBuffer(fbb, adapter_offset); + + auto* buf = fbb.GetBufferPointer(); + + // Retrieve the Parameter from the Adapter + const auto* adapter = adapters::GetAdapter(buf); + ASSERT_NE(adapter, nullptr); + ASSERT_NE(adapter->parameters(), nullptr); + ASSERT_EQ(adapter->parameters()->size(), 1u); + + const auto* param = adapter->parameters()->Get(0); + ASSERT_NE(param, nullptr); + + // The raw_data is 64 bytes but shape says 8x4 floats = 128 bytes. + // CreateOrtValueOverLoraParameter must throw. + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ExcessRawData) { + // Craft a flatbuffer Parameter where raw_data has MORE bytes than expected. + // Shape (2, 2) with float => 4 elements => 16 bytes expected, but we supply 32. + flatbuffers::FlatBufferBuilder fbb; + + auto name_offset = fbb.CreateString("excess_param"); + std::vector dims = {2, 2}; + auto dims_offset = fbb.CreateVector(dims); + + // 2 * 2 floats = 4 elements = 16 bytes expected. Supply 32. + std::vector excess_data(32, 0); + fbb.ForceVectorAlignment(excess_data.size(), sizeof(uint8_t), 8); + auto data_offset = fbb.CreateVector(excess_data); + + auto param_offset = adapters::CreateParameter( + fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset); + + auto params_offset = fbb.CreateVector(¶m_offset, 1); + auto adapter_offset = adapters::CreateAdapter( + fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset); + adapters::FinishAdapterBuffer(fbb, adapter_offset); + + const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer()); + ASSERT_NE(adapter, nullptr); + + const auto* param = adapter->parameters()->Get(0); + ASSERT_NE(param, nullptr); + + // Excess data should also trigger the mismatch throw. + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingName) { + // Parameter with null name should throw gracefully. + flatbuffers::FlatBufferBuilder fbb; + + std::vector dims = {2, 2}; + std::vector raw_data(16, 0); // 2*2 floats = 16 bytes + + // name is nullptr, all other fields are valid + auto param_offset = adapters::CreateParameterDirect( + fbb, /*name=*/nullptr, &dims, adapters::TensorDataType::FLOAT, &raw_data); + + const auto* param = BuildAdapterAndGetParam(fbb, param_offset); + ASSERT_NE(param, nullptr); + ASSERT_EQ(param->name(), nullptr); + + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingDims) { + // Parameter with null dims should throw gracefully. + flatbuffers::FlatBufferBuilder fbb; + + std::vector raw_data(16, 0); + + // dims is nullptr + auto param_offset = adapters::CreateParameterDirect( + fbb, "no_dims_param", /*dims=*/nullptr, adapters::TensorDataType::FLOAT, &raw_data); + + const auto* param = BuildAdapterAndGetParam(fbb, param_offset); + ASSERT_NE(param, nullptr); + ASSERT_EQ(param->dims(), nullptr); + + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_EmptyDims) { + // Parameter with an empty dims vector should throw gracefully. + flatbuffers::FlatBufferBuilder fbb; + + std::vector empty_dims; + std::vector raw_data(16, 0); + + auto param_offset = adapters::CreateParameterDirect( + fbb, "empty_dims_param", &empty_dims, adapters::TensorDataType::FLOAT, &raw_data); + + const auto* param = BuildAdapterAndGetParam(fbb, param_offset); + ASSERT_NE(param, nullptr); + ASSERT_EQ(param->dims()->size(), 0u); + + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingRawData) { + // Parameter with null raw_data should throw gracefully. + flatbuffers::FlatBufferBuilder fbb; + + std::vector dims = {2, 2}; + + // raw_data is nullptr + auto param_offset = adapters::CreateParameterDirect( + fbb, "no_data_param", &dims, adapters::TensorDataType::FLOAT, /*raw_data=*/nullptr); + + const auto* param = BuildAdapterAndGetParam(fbb, param_offset); + ASSERT_NE(param, nullptr); + ASSERT_EQ(param->raw_data(), nullptr); + + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_UndefinedDataType) { + // Parameter with UNDEFINED data_type should throw gracefully. + flatbuffers::FlatBufferBuilder fbb; + + std::vector dims = {2, 2}; + std::vector raw_data(16, 0); + + // data_type defaults to UNDEFINED when not set + auto param_offset = adapters::CreateParameterDirect( + fbb, "undef_type_param", &dims, adapters::TensorDataType::UNDEFINED, &raw_data); + + const auto* param = BuildAdapterAndGetParam(fbb, param_offset); + ASSERT_NE(param, nullptr); + ASSERT_EQ(param->data_type(), adapters::TensorDataType::UNDEFINED); + + ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException); +} + +#endif // ORT_NO_EXCEPTIONS + #ifdef USE_CUDA TEST(LoraAdapterTest, VerifyDeviceCopy) { auto cpu_ep = DefaultCpuExecutionProvider(); diff --git a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc new file mode 100644 index 0000000000000..8aa4c88052742 --- /dev/null +++ b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc @@ -0,0 +1,595 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for the DQMatMulNBitsFusion graph transformer. +// Tests Pattern 1: DQ(3D,axis=2)->Reshape->Transpose([1,0])->[Cast]->MatMul/Gemm -> MatMulNBits +// Tests Pattern 2: DQ(2D,axis=0)->MatMul/Gemm -> MatMulNBits + +#include "core/common/span_utils.h" +#include "core/framework/int4.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/dq_matmulnbits_fusion.h" + +#include "test/test_environment.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" +#include "test/util/include/asserts.h" + +#include "gtest/gtest.h" + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace onnxruntime { +namespace test { + +static std::vector MakePackedUint4(const std::vector& values) { + const size_t num_pairs = UInt4x2::CalcNumInt4Pairs(values.size()); + std::vector packed(num_pairs); + for (size_t i = 0; i < values.size(); i += 2) { + uint8_t lo = values[i] & 0x0F; + uint8_t hi = (i + 1 < values.size()) ? (values[i + 1] & 0x0F) : 0; + packed[i / 2] = UInt4x2(lo, hi); + } + return packed; +} + +static void BuildPattern1Graph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp, + bool with_cast, + bool use_gemm, + const std::vector* weight_values = nullptr, + const std::vector* scale_values = nullptr, + const std::vector* zp_values = nullptr) { + const int64_t num_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + const int64_t weight_elems = N * num_blocks * block_size; + std::vector w_vals; + if (weight_values) { + w_vals = *weight_values; + } else { + w_vals.resize(static_cast(weight_elems)); + for (size_t i = 0; i < w_vals.size(); ++i) { + w_vals[i] = static_cast(i % 16); + } + } + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer( + {N, num_blocks, block_size}, w_packed); + + std::vector s_vals; + if (scale_values) { + s_vals = *scale_values; + } else { + s_vals.resize(static_cast(N * num_blocks)); + for (size_t i = 0; i < s_vals.size(); ++i) { + s_vals[i] = 0.1f + 0.01f * static_cast(i % 10); + } + } + auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(2)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + + auto* dq_output = builder.MakeIntermediate(); + if (with_zp) { + std::vector z_vals; + if (zp_values) { + z_vals = *zp_values; + } else { + z_vals.resize(static_cast(N * num_blocks)); + for (size_t i = 0; i < z_vals.size(); ++i) { + z_vals[i] = 8; + } + } + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({N, num_blocks, 1}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + auto* reshape_shape = builder.MakeInitializer({2}, {N, K}); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output}); + + NodeAttributes tp_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs); + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs); + + NodeArg* matmul_b = transpose_output; + + if (with_cast) { + auto* cast_output = builder.MakeIntermediate(); + NodeAttributes cast_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(1)), cast_attrs); + builder.AddNode("Cast", {transpose_output}, {cast_output}, "", &cast_attrs); + matmul_b = cast_output; + } + + if (use_gemm) { + builder.AddNode("Gemm", {input_a, matmul_b}, {output}); + } else { + builder.AddNode("MatMul", {input_a, matmul_b}, {output}); + } +} + +static void BuildPattern1GemmBiasGraph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp) { + const int64_t num_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + const int64_t weight_elems = N * num_blocks * block_size; + std::vector w_vals(static_cast(weight_elems)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({N, num_blocks, block_size}, w_packed); + + std::vector s_vals(static_cast(N * num_blocks)); + for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f; + auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(2)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + + if (with_zp) { + std::vector z_vals(static_cast(N * num_blocks), 8); + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({N, num_blocks, 1}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + auto* reshape_shape = builder.MakeInitializer({2}, {N, K}); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output}); + + NodeAttributes tp_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs); + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs); + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_a, transpose_output, bias_arg}, {output}); +} + +static void BuildPattern2Graph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp, + bool use_gemm) { + const int64_t k_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + std::vector w_vals(static_cast(K * N)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({K, N}, w_packed); + + std::vector s_vals(static_cast(k_blocks * N)); + for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f + 0.01f * static_cast(i % 10); + auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + + if (with_zp) { + std::vector z_vals(static_cast(k_blocks * N), 8); + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({k_blocks, N}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + if (use_gemm) { + builder.AddNode("Gemm", {input_a, dq_output}, {output}); + } else { + builder.AddNode("MatMul", {input_a, dq_output}, {output}); + } +} + +static void BuildPattern2GemmBiasGraph(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size, + bool with_zp) { + const int64_t k_blocks = K / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + std::vector w_vals(static_cast(K * N)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({K, N}, w_packed); + + std::vector s_vals(static_cast(k_blocks * N)); + for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f; + auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + + if (with_zp) { + std::vector z_vals(static_cast(k_blocks * N), 8); + auto zp_packed = MakePackedUint4(z_vals); + auto* zp_arg = builder.MakeInitializer({k_blocks, N}, zp_packed); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + } + + auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f)); + builder.AddNode("Gemm", {input_a, dq_output, bias_arg}, {output}); +} + +static void BuildPattern1WrongAxis(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size) { + const int64_t num_blocks = K / block_size; + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + std::vector w_vals(static_cast(N * num_blocks * block_size)); + for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16); + auto w_packed = MakePackedUint4(w_vals); + auto* weight_arg = builder.MakeInitializer({N, num_blocks, block_size}, w_packed); + + std::vector s_vals(static_cast(N * num_blocks), 0.1f); + auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + + auto* reshape_shape = builder.MakeInitializer({2}, {N, K}); + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output}); + + NodeAttributes tp_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs); + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs); + + builder.AddNode("MatMul", {input_a, transpose_output}, {output}); +} + +static void BuildPattern2NonConstWeight(ModelTestBuilder& builder, + int64_t M, int64_t N, int64_t K, + int64_t block_size) { + const int64_t k_blocks = K / block_size; + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInput({K, N}, + UInt4x2(UInt4x2::min_val, 0), + UInt4x2(UInt4x2::max_val, 0)); + + std::vector s_vals(static_cast(k_blocks * N), 0.1f); + auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs); + + builder.AddNode("MatMul", {input_a, dq_output}, {output}); +} + +static std::map CountOpsInGraphByDomain(const Graph& graph) { + std::map op_counts; + for (const auto& node : graph.Nodes()) { + std::string key = node.OpType(); + if (!node.Domain().empty() && node.Domain() != kOnnxDomain) { + key = node.Domain() + "." + key; + } + op_counts[key]++; + } + return op_counts; +} + +class DQMatMulNBitsFusionTest : public GraphTransformationTests {}; + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_NoZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, false, false, false); + }; + + auto pre_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["DequantizeLinear"], 1); + EXPECT_EQ(ops["Reshape"], 1); + EXPECT_EQ(ops["Transpose"], 1); + EXPECT_EQ(ops["MatMul"], 1); + return Status::OK(); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("DequantizeLinear"), 0); + EXPECT_EQ(ops.count("Reshape"), 0); + EXPECT_EQ(ops.count("Transpose"), 0); + EXPECT_EQ(ops.count("MatMul"), 0); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + const auto& attrs = node.GetAttributes(); + EXPECT_EQ(attrs.at("K").i(), K); + EXPECT_EQ(attrs.at("N").i(), N); + EXPECT_EQ(attrs.at("bits").i(), 4); + EXPECT_EQ(attrs.at("block_size").i(), block_size); + EXPECT_EQ(node.InputDefs().size(), static_cast(4)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_check, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithDefaultZP8) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, true, false, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("DequantizeLinear"), 0); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.InputDefs().size(), static_cast(3)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithNonDefaultZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector zp_vals(static_cast(N * num_blocks), 3); + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, true, false, false, + nullptr, nullptr, &zp_vals); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.InputDefs().size(), static_cast(4)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithCast) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, false, true, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Cast"), 0); + EXPECT_EQ(ops.count("MatMul"), 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_Gemm_WithBias) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1GemmBiasGraph(builder, M, N, K, block_size, true); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Gemm"), 0); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_GE(node.InputDefs().size(), static_cast(6)); + EXPECT_TRUE(node.InputDefs()[5] != nullptr); + EXPECT_TRUE(node.InputDefs()[5]->Exists()); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern1_Gemm_NoZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, false, false, true); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Gemm"), 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern2_MatMul_NoZP) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2Graph(builder, M, N, K, block_size, false, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("DequantizeLinear"), 0); + EXPECT_EQ(ops.count("MatMul"), 0); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.GetAttributes().at("K").i(), K); + EXPECT_EQ(node.GetAttributes().at("N").i(), N); + EXPECT_EQ(node.InputDefs().size(), static_cast(4)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern2_MatMul_WithDefaultZP8) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2Graph(builder, M, N, K, block_size, true, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_EQ(node.InputDefs().size(), static_cast(3)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Pattern2_Gemm_WithBias) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2GemmBiasGraph(builder, M, N, K, block_size, false); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops["com.microsoft.MatMulNBits"], 1); + EXPECT_EQ(ops.count("Gemm"), 0); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + EXPECT_GE(node.InputDefs().size(), static_cast(6)); + } + } + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Negative_Pattern1_WrongAxis) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1WrongAxis(builder, M, N, K, block_size); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("com.microsoft.MatMulNBits"), 0); + EXPECT_EQ(ops["MatMul"], 1); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +TEST_F(DQMatMulNBitsFusionTest, Negative_Pattern2_NonConstWeight) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + + auto build = [&](ModelTestBuilder& builder) { + BuildPattern2NonConstWeight(builder, M, N, K, block_size); + }; + + auto post_check = [&](Graph& graph) -> Status { + auto ops = CountOpsInGraphByDomain(graph); + EXPECT_EQ(ops.count("com.microsoft.MatMulNBits"), 0); + EXPECT_EQ(ops["DequantizeLinear"], 1); + EXPECT_EQ(ops["MatMul"], 1); + return Status::OK(); + }; + + auto transformer = std::make_unique(4); + ASSERT_STATUS_OK(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, nullptr, post_check)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index dc687714f07cd..f7bfa3055f96d 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -9,6 +9,7 @@ #include "gtest/gtest.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_session_options_config_keys.h" using namespace ONNX_NAMESPACE; @@ -69,5 +70,31 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } +TEST(GraphTransformerUtilsTests, TestDQMatMulNBitsFusionConfigWithContribGating) { + SessionOptions session_options; + const auto status = session_options.config_options.AddConfigEntry( + kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1"); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + +#if defined(DISABLE_CONTRIB_OPS) + EXPECT_ANY_THROW({ + std::ignore = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, session_options, cpu_ep, logger); + }); +#else + auto transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, session_options, cpu_ep, logger); + + const bool has_dq_matmulnbits_fusion = + std::any_of(transformers.begin(), transformers.end(), [](const auto& transformer) { + return transformer && transformer->Name() == "DQMatMulNBitsFusion"; + }); + + EXPECT_TRUE(has_dq_matmulnbits_fusion); +#endif +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index 58a616717316e..1a1c1b6cde3b5 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -812,5 +812,47 @@ TEST(RoiAlignTest, MismatchNumRois) { test.Run(OpTester::ExpectResult::kExpectFailure, "[ShapeInferenceError] Dimension mismatch in unification between 4 and 5"); } + +TEST(RoiAlignTest, BatchIndicesOutOfRange) { + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + test.AddInput("X", {1, 1, 4, 4}, + {0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f, + 12.f, 13.f, 14.f, 15.f}); + test.AddInput("rois", {1, 4}, {0.f, 0.f, 3.f, 3.f}); + test.AddInput("batch_indices", {1}, {1}); // <-- failure condition + test.AddOutput("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value 1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers); +} + +TEST(RoiAlignTest, BatchIndicesNegative) { + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 2); + test.AddAttribute("output_width", 2); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f); + + test.AddInput("X", {1, 1, 4, 4}, + {0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f, + 12.f, 13.f, 14.f, 15.f}); + test.AddInput("rois", {1, 4}, {0.f, 0.f, 3.f, 3.f}); + test.AddInput("batch_indices", {1}, {-1}); // <-- failure condition + test.AddOutput("Y", {1, 1, 2, 2}, {0.f, 0.f, 0.f, 0.f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, "batch_indices value -1 at index 0 is out of range [0, 1)", {}, nullptr, &execution_providers); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index d12a586f662ac..ba58344e1e3e2 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/common/common.h" +#include "core/framework/config_options.h" #include "core/graph/constants.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -25,7 +26,7 @@ TEST(CApiTest, session_options_deterministic_compute) { TEST(CApiTest, session_options_oversized_affinity_string) { Ort::SessionOptions options; - std::string long_affinity_str(onnxruntime::kMaxStrLen + 1, '0'); + std::string long_affinity_str(ConfigOptions::kMaxValueLength + 1, '0'); try { options.AddConfigEntry(kOrtSessionOptionsConfigIntraOpThreadAffinities, long_affinity_str.c_str()); ASSERT_TRUE(false) << "Creation of config should have thrown exception";