diff --git a/cmake/vcpkg-ports/pybind11/portfile.cmake b/cmake/vcpkg-ports/pybind11/portfile.cmake
index 4e4cd30a26df1..e0dd402cf186b 100644
--- a/cmake/vcpkg-ports/pybind11/portfile.cmake
+++ b/cmake/vcpkg-ports/pybind11/portfile.cmake
@@ -1,10 +1,15 @@
-vcpkg_from_github(
- OUT_SOURCE_PATH SOURCE_PATH
- REPO pybind/pybind11
- REF "v${VERSION}"
- # SHA512 for the zip (not tar.gz) file.
+# Manually define the download for the .zip archive (to be consistent with deps.txt)
+# If we used vcpkg_from_github, it would download the .tar.gz archive,
+# which has different SHA512: 19bee2c76320e25202ee078b5680ff8a7acfb33494dec29dad984ab04de8bcb01340d9fec37c8cc5ac9015dfc367e60312dcd8506e66ce8f0af4c49db562ddef
+vcpkg_download_distfile(ARCHIVE
+ URLS "https://github.com/pybind/pybind11/archive/refs/tags/v${VERSION}.zip"
+ FILENAME "pybind11-${VERSION}.zip"
SHA512 786b1bf534ac67a8d5669f8babf67bb13e48b3a3da1b6344e43ae10a84b80bbc8fea5f12a65fd18739c341fefef5622c5dc096db964dff33cc62ea4259b2e2c1
- HEAD_REF master
+)
+
+vcpkg_extract_source_archive(
+ SOURCE_PATH
+ ARCHIVE "${ARCHIVE}"
)
vcpkg_cmake_configure(
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index a6b267c6802cf..81d4f2589151b 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -476,9 +476,25 @@ internal static class NativeMethods
static NativeMethods()
{
#if !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__
- // Register a custom DllImportResolver to handle platform-specific library loading.
- // Replaces default resolution specifically on Windows for case-sensitivity.
- NativeLibrary.SetDllImportResolver(typeof(NativeMethods).Assembly, DllImportResolver);
+ if (!OrtEnv.DisableDllImportResolver)
+ {
+ try
+ {
+ // Register a custom DllImportResolver to handle platform-specific library loading.
+ // Replaces default resolution specifically on Windows for case-sensitivity.
+ NativeLibrary.SetDllImportResolver(typeof(NativeMethods).Assembly, DllImportResolver);
+ }
+ catch (InvalidOperationException)
+ {
+ // A resolver is already registered for this assembly (e.g., by the host application).
+ // This is not fatal — the host's resolver will handle library loading.
+ System.Diagnostics.Trace.WriteLine(
+ "[OnnxRuntime] A DllImportResolver is already registered for this assembly. "
+ + "OnnxRuntime's built-in resolver will not be used. "
+ + "To suppress this message, set OrtEnv.DisableDllImportResolver = true "
+ + "before using any OnnxRuntime APIs.");
+ }
+ }
#endif
#if NETSTANDARD2_0
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
index 6fcff438c5cf3..22f541e2207fa 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
@@ -12,7 +12,7 @@ namespace Microsoft.ML.OnnxRuntime
///
///
/// This enum is used to determine whether a pre-compiled model can be used with specific execution providers
- /// and devices, or if recompilation is needed.
+ /// and devices, or if recompilation is needed.
///
public enum OrtCompiledModelCompatibility
{
@@ -77,14 +77,14 @@ public struct EnvironmentCreationOptions
///
/// The singleton class OrtEnv contains the process-global ONNX Runtime environment.
/// It sets up logging, creates system wide thread-pools (if Thread Pool options are provided)
- /// and other necessary things for OnnxRuntime to function.
- ///
+ /// and other necessary things for OnnxRuntime to function.
+ ///
/// Create or access OrtEnv by calling the Instance() method. Instance() can be called multiple times.
/// It would return the same instance.
- ///
+ ///
/// CreateInstanceWithOptions() provides a way to create environment with options.
/// It must be called once before Instance() is called, otherwise it would not have effect.
- ///
+ ///
/// If the environment is not explicitly created, it will be created as needed, e.g.,
/// when creating a SessionOptions instance.
///
@@ -93,6 +93,28 @@ public sealed class OrtEnv : SafeHandle
#region Static members
private static readonly int ORT_PROJECTION_CSHARP = 2;
+ ///
+ /// Set this to true before accessing any OnnxRuntime type to prevent OnnxRuntime
+ /// from registering its own DllImportResolver via
+ /// NativeLibrary.SetDllImportResolver.
+ /// This is useful when the host application needs to register its own custom resolver
+ /// for the OnnxRuntime assembly. Must be set before any OnnxRuntime API is used
+ /// (i.e., before the internal NativeMethods static constructor runs).
+ ///
+ ///
+ ///
+ /// // Disable OnnxRuntime's built-in resolver before any ORT usage
+ /// OrtEnv.DisableDllImportResolver = true;
+ ///
+ /// // Register your own resolver
+ /// NativeLibrary.SetDllImportResolver(typeof(OrtEnv).Assembly, MyCustomResolver);
+ ///
+ /// // Now use OnnxRuntime normally
+ /// var env = OrtEnv.Instance();
+ ///
+ ///
+ public static bool DisableDllImportResolver { get; set; } = false;
+
private static readonly byte[] _defaultLogId = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(@"CSharpOnnxRuntime");
// This must be static and set before the first creation call, otherwise, has no effect.
@@ -274,7 +296,7 @@ private static void SetLanguageProjection(OrtEnv env)
///
/// Instantiates (if not already done so) a new OrtEnv instance with the default logging level
/// and no other options. Otherwise returns the existing instance.
- ///
+ ///
/// It returns the same instance on every call - `OrtEnv` is singleton
///
/// Returns a singleton instance of OrtEnv that represents native OrtEnv object
@@ -523,7 +545,7 @@ public OrtLoggingLevel EnvLogLevel
/// A registered execution provider library can be used by all sessions created with the OrtEnv instance.
/// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can
/// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device.
- ///
+ ///
/// Coming: A selection policy can be specified and ORT will automatically select the best execution providers
/// and devices for the model.
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
index aa1b683acd668..c298a95392317 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
@@ -532,4 +532,100 @@ public void TestDllImportResolverDoesNotThrow()
}
}
}
+
+#if !NETSTANDARD2_0
+ [Collection("Ort Inference Tests")]
+ public class OrtEnvExternalDllImportResolverTest
+ {
+ private System.Reflection.Assembly LoadIsolatedOnnxRuntimeAssembly(out System.Runtime.Loader.AssemblyLoadContext alc)
+ {
+ // Load a fresh copy of the ONNX Runtime assembly into a new AssemblyLoadContext.
+ // This guarantees we get a clean slate for static fields/constructors, avoiding
+ // interference from other xUnit tests that may have already initialized OrtEnv
+ // in the default context.
+ //
+ // Native library resolution (e.g., onnxruntime.dll) falls through to the default
+ // ALC when the isolated context cannot resolve it, so P/Invoke calls still work.
+ alc = new System.Runtime.Loader.AssemblyLoadContext("IsolatedORT_" + Guid.NewGuid(), isCollectible: true);
+ string asmPath = typeof(OrtEnv).Assembly.Location;
+ return alc.LoadFromAssemblyPath(asmPath);
+ }
+
+ ///
+ /// Verifies the scenario where an external caller registers a DllImportResolver FIRST,
+ /// and then OrtEnv is initialized. ORT's try/catch should handle the conflict gracefully.
+ ///
+ [Fact(DisplayName = "TestExternalResolverRegisteredFirst")]
+ public void TestExternalResolverRegisteredFirst()
+ {
+ var asm = LoadIsolatedOnnxRuntimeAssembly(out var alc);
+ try
+ {
+ // 1. External application registers its own resolver FIRST.
+ // Returning IntPtr.Zero means "not handled" — the runtime falls back to
+ // its default resolution logic, so native libraries still load normally.
+ NativeLibrary.SetDllImportResolver(asm, (libraryName, a, searchPath) => IntPtr.Zero);
+
+ // 2. ORT initializes (triggers NativeMethods static constructor).
+ // It will attempt to register its own resolver, which will throw
+ // InvalidOperationException internally, but the try/catch safety net
+ // prevents an unhandled TypeInitializationException.
+ var ortEnvType = asm.GetType("Microsoft.ML.OnnxRuntime.OrtEnv");
+ var instanceMethod = ortEnvType.GetMethod("Instance", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static);
+ var ortEnvInstance = instanceMethod.Invoke(null, null);
+ Assert.NotNull(ortEnvInstance);
+
+ // Verify ORT is fully functional despite the resolver conflict.
+ var getVersionMethod = ortEnvType.GetMethod("GetVersionString", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance);
+ var version = (string)getVersionMethod.Invoke(ortEnvInstance, null);
+ Assert.False(string.IsNullOrEmpty(version));
+ }
+ finally
+ {
+ alc.Unload();
+ }
+ }
+
+ ///
+ /// Verifies that setting DisableDllImportResolver = true BEFORE ORT initializes
+ /// successfully prevents ORT from registering its own resolver, leaving the assembly
+ /// free for the external application to register theirs LATER without throwing.
+ ///
+ [Fact(DisplayName = "TestDisableDllImportResolverWorks")]
+ public void TestDisableDllImportResolverWorks()
+ {
+ var asm = LoadIsolatedOnnxRuntimeAssembly(out var alc);
+ try
+ {
+ var ortEnvType = asm.GetType("Microsoft.ML.OnnxRuntime.OrtEnv");
+
+ // 1. Set OrtEnv.DisableDllImportResolver = true FIRST.
+ var disableProp = ortEnvType.GetProperty("DisableDllImportResolver", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static);
+ Assert.NotNull(disableProp);
+ disableProp.SetValue(null, true);
+
+ // 2. ORT initializes (triggers NativeMethods static constructor).
+ // It should respect the flag and SKIP calling NativeLibrary.SetDllImportResolver.
+ var instanceMethod = ortEnvType.GetMethod("Instance", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static);
+ var ortEnvInstance = instanceMethod.Invoke(null, null);
+ Assert.NotNull(ortEnvInstance);
+
+ // 3. External application registers its own resolver AFTER ORT initialized.
+ // If the flag works correctly, ORT skipped its own SetDllImportResolver call,
+ // so this registration should succeed without throwing InvalidOperationException.
+ // Returning IntPtr.Zero means "not handled" — falls back to default resolution.
+ var ex = Record.Exception(() =>
+ {
+ NativeLibrary.SetDllImportResolver(asm, (libraryName, a, searchPath) => IntPtr.Zero);
+ });
+
+ Assert.Null(ex); // No InvalidOperationException = ORT correctly skipped registration
+ }
+ finally
+ {
+ alc.Unload();
+ }
+ }
+ }
+#endif
}
diff --git a/include/onnxruntime/core/framework/int2.h b/include/onnxruntime/core/framework/int2.h
index 0d406d6fcd8d3..40af5746c9273 100644
--- a/include/onnxruntime/core/framework/int2.h
+++ b/include/onnxruntime/core/framework/int2.h
@@ -7,6 +7,7 @@
#include
#include "core/common/common.h"
#include
+#include "onnxruntime_config.h"
namespace onnxruntime {
@@ -137,8 +138,16 @@ struct Int2x4Base {
const size_t full_quads = src.size() / 4;
// Process complete groups of 4 elements
+
for (; dst_i < full_quads; dst_i++) {
+#if defined(__GNUC__) && defined(HAS_ARRAY_BOUNDS)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Warray-bounds"
+#endif
dst[dst_i] = Int2x4Base(src[src_i], src[src_i + 1], src[src_i + 2], src[src_i + 3]);
+#if defined(__GNUC__) && defined(HAS_ARRAY_BOUNDS)
+#pragma GCC diagnostic pop
+#endif
src_i += 4;
}
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 1ea147a0079cc..abeb930d8ab8d 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -13,7 +13,7 @@
* The maximum length of the Config Key is 1024
*
* The string format of a SessionOptions Config Value is defined individually for each Config.
- * The maximum length of the Config Value is 2048
+ * The maximum length of the Config Value is 8192
*/
// Key for disable PrePacking,
@@ -385,6 +385,11 @@ static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm";
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
+// Enable the DQ->MatMulNBits fusion graph transformer.
+// "0": disabled (default). "1": enabled.
+// This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered.
+static const char* const kOrtSessionOptionsEnableDQMatMulNBitsFusion = "session.enable_dq_matmulnbits_fusion";
+
// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
// Meant to be used with SetEpDynamicOptions
// Specify the type of workload for this session.
diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h
index 1c356d8cfca56..6c0d6741dd92f 100644
--- a/onnxruntime/core/framework/config_options.h
+++ b/onnxruntime/core/framework/config_options.h
@@ -18,7 +18,7 @@ struct ConfigOptions {
// Maximum key/value string lengths specified in
// core/session/onnxruntime_session_options_config_keys.h
static constexpr size_t kMaxKeyLength = 1024;
- static constexpr size_t kMaxValueLength = 4096;
+ static constexpr size_t kMaxValueLength = 8192;
std::unordered_map configurations;
diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h
index a745dd9f1376d..77aa460c72ef7 100644
--- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h
+++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h
@@ -1670,12 +1670,12 @@ MlasQ4Int8TileGemmKernelBlkLen32Avx2(
// .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h:1531:13: note: array 'acc' declared here
// 1531 | __m256 acc[NCols4];
// | ^
-#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS)
+#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Warray-bounds"
#endif
__m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]);
-#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS)
+#ifdef __clang__
#pragma clang diagnostic pop
#endif
if (BiasPtr != nullptr) {
diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc
new file mode 100644
index 0000000000000..f9ae13808cf2c
--- /dev/null
+++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc
@@ -0,0 +1,848 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/dq_matmulnbits_fusion.h"
+
+#if !defined(ORT_MINIMAL_BUILD)
+
+#include "core/common/common.h"
+#include "core/common/safeint.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/graph/constants.h"
+#include "core/graph/graph_utils.h"
+#include "core/graph/node_attr_utils.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/utils.h"
+
+#include
+#include
+#include
+#include
+
+namespace onnxruntime {
+
+namespace {
+
+// ---------------------------------------------------------------------------
+// Utility helpers
+// ---------------------------------------------------------------------------
+
+bool IsUniformPackedUint4Value(const Initializer& init, uint8_t expected_nibble) {
+ if (init.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
+ return false;
+ }
+
+ const size_t values_count = static_cast(init.size());
+ if (values_count == 0) {
+ return false;
+ }
+
+ const auto packed = init.DataAsByteSpan();
+ const uint8_t expected = static_cast(expected_nibble & 0x0F);
+ for (size_t i = 0; i < values_count; ++i) {
+ const uint8_t byte = packed[i / 2];
+ const uint8_t value = (i % 2 == 0) ? (byte & 0x0F) : ((byte >> 4) & 0x0F);
+ if (value != expected) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool HasRank2Shape(const ONNX_NAMESPACE::TensorProto& tp, int64_t dim0, int64_t dim1) {
+ return tp.dims_size() == 2 && tp.dims(0) == dim0 && tp.dims(1) == dim1;
+}
+
+uint8_t GetPackedUint4Element(const uint8_t* packed, size_t index, size_t num_elements) {
+ ORT_ENFORCE(index < num_elements, "GetPackedUint4Element: index ", index,
+ " out of bounds (num_elements=", num_elements, ")");
+ const uint8_t packed_byte = packed[index / 2];
+ return (index % 2 == 0) ? static_cast(packed_byte & 0x0F)
+ : static_cast((packed_byte >> 4) & 0x0F);
+}
+
+void PackUint4Rows(const Initializer& src, int64_t rows, int64_t cols, uint8_t* dst) {
+ const int64_t row_bytes = (cols + 1) / 2;
+ const size_t dst_bytes = SafeInt(rows) * row_bytes;
+ const size_t total_elements = SafeInt(rows) * cols;
+ memset(dst, 0, dst_bytes);
+
+ const auto src_packed = src.DataAsByteSpan();
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 0; c < cols; ++c) {
+ const size_t src_index = SafeInt(r) * cols + c;
+ const uint8_t value = GetPackedUint4Element(src_packed.data(), src_index, total_elements);
+
+ const size_t dst_index = SafeInt(r) * row_bytes + c / 2;
+ if ((c & 1) == 0) {
+ dst[dst_index] = value;
+ } else {
+ dst[dst_index] = static_cast(dst[dst_index] | (value << 4));
+ }
+ }
+ }
+}
+
+// Transpose and pack UINT4 weights from DQ axis=0 layout [K, N] to MatMulNBits layout [N, k_blocks, blob_size].
+// Source: row-major UINT4 with quantization along K (axis=0), shape [K, N].
+// The nibble ordering follows ONNX UINT4 convention: even indices in the low nibble,
+// odd indices in the high nibble of each byte.
+// Dest: UINT8 [N, k_blocks, block_size/2] where each byte packs two 4-bit weights.
+void TransposePackWeightsAxis0(
+ const uint8_t* src_packed, int64_t K, int64_t N, int64_t block_size,
+ uint8_t* dst) {
+ const int64_t k_blocks = (K + block_size - 1) / block_size;
+ const int64_t blob_size = block_size / 2;
+ const size_t dst_bytes = SafeInt(N) * k_blocks * blob_size;
+ const size_t total_elements = SafeInt(K) * N;
+ memset(dst, 0, dst_bytes);
+
+ for (int64_t n = 0; n < N; ++n) {
+ for (int64_t k = 0; k < K; ++k) {
+ const size_t src_index = SafeInt(k) * N + n;
+ const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements);
+
+ const int64_t kb = k / block_size;
+ const int64_t off = k % block_size;
+ const size_t dst_byte = SafeInt(n) * k_blocks * blob_size + kb * blob_size + off / 2;
+ if (off % 2 == 0) {
+ dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val);
+ } else {
+ dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4));
+ }
+ }
+ }
+}
+
+// Transpose and pack UINT4 zero points from DQ axis=0 layout [k_blocks, N] to
+// MatMulNBits layout UINT8 [N, ceil(k_blocks/2)].
+void TransposePackZPAxis0(
+ const uint8_t* src_packed, int64_t k_blocks, int64_t N,
+ uint8_t* dst) {
+ const int64_t zp_bytes_per_n = (k_blocks + 1) / 2;
+ const size_t dst_bytes = SafeInt(N) * zp_bytes_per_n;
+ const size_t total_elements = SafeInt(k_blocks) * N;
+ memset(dst, 0, dst_bytes);
+
+ for (int64_t n = 0; n < N; ++n) {
+ for (int64_t kb = 0; kb < k_blocks; ++kb) {
+ const size_t src_index = SafeInt(kb) * N + n;
+ const uint8_t val = GetPackedUint4Element(src_packed, src_index, total_elements);
+
+ const size_t dst_byte = SafeInt(n) * zp_bytes_per_n + kb / 2;
+ if (kb % 2 == 0) {
+ dst[dst_byte] = static_cast((dst[dst_byte] & 0xF0) | val);
+ } else {
+ dst[dst_byte] = static_cast((dst[dst_byte] & 0x0F) | (val << 4));
+ }
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Match structs
+// ---------------------------------------------------------------------------
+
+struct FusionMatch {
+ NodeIndex matmul_idx;
+ std::optional cast_idx;
+ NodeIndex transpose_idx;
+ NodeIndex reshape_idx;
+ NodeIndex dq_idx;
+};
+
+struct DirectDQMatch {
+ NodeIndex matmul_idx;
+ NodeIndex dq_idx;
+};
+
+// ---------------------------------------------------------------------------
+// Shared Gemm validation (alpha=1, beta=1, transA=0, transB=0, bias 1-D [N])
+// ---------------------------------------------------------------------------
+
+bool ValidateGemmForFusion(const Node& gemm_node, int64_t N) {
+ if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm_node, "alpha");
+ alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f)
+ return false;
+ if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm_node, "beta");
+ beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f)
+ return false;
+ if (const auto* trans_a = graph_utils::GetNodeAttribute(gemm_node, "transA");
+ trans_a && trans_a->i() != 0)
+ return false;
+ if (const auto* trans_b = graph_utils::GetNodeAttribute(gemm_node, "transB");
+ trans_b && trans_b->i() != 0)
+ return false;
+
+ const auto& inputs = gemm_node.InputDefs();
+ if (inputs.size() > 2 && inputs[2] && inputs[2]->Exists()) {
+ const auto* bias_shape = inputs[2]->Shape();
+ if (!bias_shape || bias_shape->dim_size() != 1 ||
+ !utils::HasDimValue(bias_shape->dim(0)) ||
+ bias_shape->dim(0).dim_value() != N)
+ return false;
+ }
+ return true;
+}
+
+// ---------------------------------------------------------------------------
+// Pattern 1 matching: DQ -> Reshape -> Transpose -> [Cast] -> MatMul/Gemm
+// ---------------------------------------------------------------------------
+
+std::vector CollectReshapeTransposeMatches(
+ Graph& graph,
+ const std::vector& node_topology_list,
+ const logging::Logger& logger) {
+ std::vector matches;
+
+ for (auto node_index : node_topology_list) {
+ auto* node = graph.GetNode(node_index);
+ if (!node) continue;
+
+ if (node->OpType() != "MatMul" && node->OpType() != "Gemm") continue;
+
+ const auto& mm_inputs = node->InputDefs();
+ if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue;
+
+ const Node* cast_node = nullptr;
+ const Node* transpose_node = graph.GetProducerNode(mm_inputs[1]->Name());
+ if (transpose_node && transpose_node->OpType() == "Cast") {
+ cast_node = transpose_node;
+ if (cast_node->GetOutputEdgesCount() != 1) continue;
+ const auto& cast_inputs = cast_node->InputDefs();
+ if (cast_inputs.empty() || !cast_inputs[0] || !cast_inputs[0]->Exists()) continue;
+ transpose_node = graph.GetProducerNode(cast_inputs[0]->Name());
+ }
+
+ if (!transpose_node || transpose_node->OpType() != "Transpose") continue;
+ if (transpose_node->GetOutputEdgesCount() != 1) continue;
+
+ const auto& tp_inputs = transpose_node->InputDefs();
+ if (tp_inputs.empty() || !tp_inputs[0] || !tp_inputs[0]->Exists()) continue;
+ const Node* reshape_node = graph.GetProducerNode(tp_inputs[0]->Name());
+ if (!reshape_node || reshape_node->OpType() != "Reshape") continue;
+ if (reshape_node->GetOutputEdgesCount() != 1) continue;
+
+ const auto& reshape_inputs = reshape_node->InputDefs();
+ if (reshape_inputs.empty() || !reshape_inputs[0] || !reshape_inputs[0]->Exists()) continue;
+ const Node* dq_node = graph.GetProducerNode(reshape_inputs[0]->Name());
+ if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue;
+ if (dq_node->GetOutputEdgesCount() != 1) continue;
+
+ const auto& dq_attrs = dq_node->GetAttributes();
+ {
+ auto it = dq_attrs.find("axis");
+ if (it == dq_attrs.end() || it->second.i() != 2) continue;
+ }
+ int64_t block_size = 0;
+ {
+ auto it = dq_attrs.find("block_size");
+ if (it == dq_attrs.end()) continue;
+ block_size = it->second.i();
+ if (block_size < 16 || ((block_size - 1) & block_size)) continue;
+ }
+
+ const auto* weight_arg = dq_node->InputDefs()[0];
+ if (!weight_arg || !weight_arg->Exists()) continue;
+ const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true);
+ if (!weight_const_tp) continue;
+ if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue;
+ if (weight_const_tp->dims_size() != 3) continue;
+ const int64_t N = weight_const_tp->dims(0);
+ const int64_t blocks = weight_const_tp->dims(1);
+ const int64_t bs_dim = weight_const_tp->dims(2);
+ if (N <= 0 || blocks <= 0 || bs_dim <= 0) continue;
+ if (bs_dim != block_size) continue;
+ const int64_t K = SafeInt(blocks) * bs_dim;
+
+ const auto* scale_arg = dq_node->InputDefs()[1];
+ if (!scale_arg || !scale_arg->Exists()) continue;
+ const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true);
+ if (!scale_const_tp) continue;
+ int32_t dt_scale = scale_const_tp->data_type();
+ if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
+ dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue;
+
+ const auto* a_arg = mm_inputs[0];
+ if (!a_arg || !a_arg->TypeAsProto()) continue;
+ int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type();
+ if (dt_a != dt_scale) continue;
+
+ const auto* reshape_shape_arg =
+ reshape_node->InputDefs().size() > 1 ? reshape_node->InputDefs()[1] : nullptr;
+ if (!reshape_shape_arg || !reshape_shape_arg->Exists()) continue;
+ const auto* reshape_shape_tp = graph.GetConstantInitializer(reshape_shape_arg->Name(), true);
+ if (!reshape_shape_tp) continue;
+
+ Initializer reshape_shape_init(graph, *reshape_shape_tp, graph.ModelPath());
+ if (reshape_shape_init.size() != 2) continue;
+
+ int64_t reshape_dim0 = 0;
+ int64_t reshape_dim1 = 0;
+ if (reshape_shape_init.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
+ const auto* shape_data = reshape_shape_init.data();
+ reshape_dim0 = shape_data[0];
+ reshape_dim1 = shape_data[1];
+ } else if (reshape_shape_init.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
+ const auto* shape_data = reshape_shape_init.data();
+ reshape_dim0 = shape_data[0];
+ reshape_dim1 = shape_data[1];
+ } else {
+ continue;
+ }
+
+ auto resolve_reshape_dim = [](int64_t dim, int64_t expected) -> std::optional {
+ if (dim == expected || dim == 0 || dim == -1) {
+ return expected;
+ }
+ return std::nullopt;
+ };
+ const auto resolved_reshape_dim0 = resolve_reshape_dim(reshape_dim0, N);
+ const auto resolved_reshape_dim1 = resolve_reshape_dim(reshape_dim1, K);
+ if (!resolved_reshape_dim0 || !resolved_reshape_dim1 ||
+ *resolved_reshape_dim0 != N || *resolved_reshape_dim1 != K) {
+ continue;
+ }
+
+ if (const auto* perm_attr = graph_utils::GetNodeAttribute(*transpose_node, "perm")) {
+ if (perm_attr->ints_size() != 2 || perm_attr->ints(0) != 1 || perm_attr->ints(1) != 0) {
+ continue;
+ }
+ }
+
+ if (const auto* b_shape = mm_inputs[1]->Shape(); b_shape && b_shape->dim_size() == 2 &&
+ utils::HasDimValue(b_shape->dim(0)) && utils::HasDimValue(b_shape->dim(1)) &&
+ (b_shape->dim(0).dim_value() != K || b_shape->dim(1).dim_value() != N)) {
+ continue;
+ }
+
+ if (const auto* a_shape = mm_inputs[0] ? mm_inputs[0]->Shape() : nullptr;
+ a_shape && a_shape->dim_size() >= 1) {
+ const int last_a_dim_idx = a_shape->dim_size() - 1;
+ if (utils::HasDimValue(a_shape->dim(last_a_dim_idx)) &&
+ a_shape->dim(last_a_dim_idx).dim_value() != K) {
+ continue;
+ }
+ }
+
+ const auto* y_shape = node->OutputDefs().empty() ? nullptr : node->OutputDefs()[0]->Shape();
+ if (y_shape && y_shape->dim_size() >= 1) {
+ const int last_y_dim_idx = y_shape->dim_size() - 1;
+ if (utils::HasDimValue(y_shape->dim(last_y_dim_idx)) &&
+ y_shape->dim(last_y_dim_idx).dim_value() != N) {
+ continue;
+ }
+ }
+
+ if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue;
+
+ if (cast_node) {
+ const auto* cast_in = cast_node->InputDefs().empty() ? nullptr : cast_node->InputDefs()[0];
+ const auto* cast_out = cast_node->OutputDefs().empty() ? nullptr : cast_node->OutputDefs()[0];
+ if (!cast_in || !cast_out || !cast_in->TypeAsProto() || !cast_out->TypeAsProto()) continue;
+ if (cast_in->TypeAsProto()->tensor_type().elem_type() !=
+ cast_out->TypeAsProto()->tensor_type().elem_type()) {
+ continue;
+ }
+ }
+
+ const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr;
+ bool has_zp = zp_arg && zp_arg->Exists();
+ if (has_zp) {
+ const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true);
+ if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue;
+ }
+
+ LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched pattern at MatMul node '"
+ << node->Name() << "'";
+
+ matches.push_back({node->Index(),
+ cast_node ? std::optional(cast_node->Index()) : std::nullopt,
+ transpose_node->Index(),
+ reshape_node->Index(), dq_node->Index()});
+ }
+
+ return matches;
+}
+
+// ---------------------------------------------------------------------------
+// Pattern 2 matching: direct DQ(axis=0, 2D UINT4) -> MatMul/Gemm
+// ---------------------------------------------------------------------------
+
+std::vector CollectDirectDQMatches(
+ Graph& graph,
+ const std::vector& node_topology_list,
+ const std::unordered_set& skip_indices,
+ const logging::Logger& logger) {
+ std::vector direct_matches;
+
+ for (auto node_index : node_topology_list) {
+ auto* node = graph.GetNode(node_index);
+ if (!node) continue;
+
+ if (node->OpType() != "MatMul" && node->OpType() != "Gemm") continue;
+ if (skip_indices.count(node->Index())) continue;
+
+ const auto& mm_inputs = node->InputDefs();
+ if (mm_inputs.size() < 2 || !mm_inputs[1] || !mm_inputs[1]->Exists()) continue;
+
+ const Node* dq_node = graph.GetProducerNode(mm_inputs[1]->Name());
+ if (!dq_node || dq_node->OpType() != "DequantizeLinear") continue;
+ if (dq_node->GetOutputEdgesCount() != 1) continue;
+
+ const auto& dq_attrs = dq_node->GetAttributes();
+ {
+ auto it = dq_attrs.find("axis");
+ if (it == dq_attrs.end() || it->second.i() != 0) continue;
+ }
+ int64_t block_size = 0;
+ {
+ auto it = dq_attrs.find("block_size");
+ if (it == dq_attrs.end()) continue;
+ block_size = it->second.i();
+ if (block_size < 16 || ((block_size - 1) & block_size)) continue;
+ }
+
+ const auto* weight_arg = dq_node->InputDefs()[0];
+ if (!weight_arg || !weight_arg->Exists()) continue;
+ const auto* weight_const_tp = graph.GetConstantInitializer(weight_arg->Name(), true);
+ if (!weight_const_tp) continue;
+ if (weight_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue;
+ if (weight_const_tp->dims_size() != 2) continue;
+ const int64_t K = weight_const_tp->dims(0);
+ const int64_t N = weight_const_tp->dims(1);
+ if (K <= 0 || N <= 0 || K % block_size != 0) continue;
+ const int64_t k_blocks = K / block_size;
+
+ const auto* scale_arg = dq_node->InputDefs()[1];
+ if (!scale_arg || !scale_arg->Exists()) continue;
+ const auto* scale_const_tp = graph.GetConstantInitializer(scale_arg->Name(), true);
+ if (!scale_const_tp) continue;
+ int32_t dt_scale = scale_const_tp->data_type();
+ if (dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
+ dt_scale != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue;
+ if (!HasRank2Shape(*scale_const_tp, k_blocks, N)) continue;
+
+ const auto* a_arg = mm_inputs[0];
+ if (!a_arg || !a_arg->TypeAsProto()) continue;
+ int32_t dt_a = a_arg->TypeAsProto()->tensor_type().elem_type();
+ if (dt_a != dt_scale) continue;
+
+ const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr;
+ bool has_zp = zp_arg && zp_arg->Exists();
+ if (has_zp) {
+ const auto* zp_const_tp = graph.GetConstantInitializer(zp_arg->Name(), true);
+ if (!zp_const_tp || zp_const_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue;
+ if (!HasRank2Shape(*zp_const_tp, k_blocks, N)) continue;
+ }
+
+ if (node->OpType() == "Gemm" && !ValidateGemmForFusion(*node, N)) continue;
+
+ LOGS(logger, INFO) << "DQMatMulNBitsFusion: matched direct DQ->MatMul pattern at node '"
+ << node->Name() << "' (K=" << K << ", N=" << N << ", block_size=" << block_size << ")";
+ direct_matches.push_back({node->Index(), dq_node->Index()});
+ }
+
+ return direct_matches;
+}
+
+// ---------------------------------------------------------------------------
+// Pattern 1 rewriting: DQ+Reshape+Transpose+[Cast]+MatMul/Gemm -> MatMulNBits
+// ---------------------------------------------------------------------------
+
+void ApplyReshapeTransposeFusions(
+ Graph& graph,
+ const std::vector& matches,
+ int64_t accuracy_level,
+ bool& modified,
+ const logging::Logger& logger) {
+ for (const auto& match : matches) {
+ const Node* mm_node = graph.GetNode(match.matmul_idx);
+ const Node* cast_node = match.cast_idx ? graph.GetNode(*match.cast_idx) : nullptr;
+ const Node* tp_node = graph.GetNode(match.transpose_idx);
+ const Node* dq_node = graph.GetNode(match.dq_idx);
+ const Node* reshape_node = graph.GetNode(match.reshape_idx);
+ if (!mm_node || !tp_node || !dq_node || !reshape_node ||
+ (match.cast_idx && !cast_node)) {
+ continue;
+ }
+
+ const auto* weight_arg = dq_node->InputDefs()[0];
+ const auto* scale_arg = dq_node->InputDefs()[1];
+ const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr;
+ bool has_zp = zp_arg && zp_arg->Exists();
+
+ const auto& dq_attrs = dq_node->GetAttributes();
+ const int64_t block_size = dq_attrs.at("block_size").i();
+
+ const ONNX_NAMESPACE::TensorProto* weight_tp = nullptr;
+ if (!graph.GetInitializedTensor(weight_arg->Name(), weight_tp) || !weight_tp) continue;
+ const ONNX_NAMESPACE::TensorProto* scale_tp = nullptr;
+ if (!graph.GetInitializedTensor(scale_arg->Name(), scale_tp) || !scale_tp) continue;
+ const ONNX_NAMESPACE::TensorProto* zp_tp = nullptr;
+ if (has_zp) {
+ if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue;
+ }
+
+ if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 ||
+ weight_tp->dims_size() != 3) {
+ continue;
+ }
+
+ const int64_t N = weight_tp->dims(0);
+ const int64_t quant_num = weight_tp->dims(1);
+ const int64_t bs_dim = weight_tp->dims(2);
+ if (N <= 0 || quant_num <= 0 || bs_dim <= 0 || bs_dim != block_size) continue;
+ const int64_t K = SafeInt(quant_num) * bs_dim;
+ const int64_t blob_bytes = (block_size + 1) / 2;
+
+ Initializer weight_src(graph, *weight_tp, graph.ModelPath());
+ Initializer scale_src(graph, *scale_tp, graph.ModelPath());
+ if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
+ scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
+ continue;
+ }
+
+ auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(
+ ONNX_NAMESPACE::TensorProto_DataType_UINT8)
+ ->GetElementType();
+ auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(
+ scale_src.data_type())
+ ->GetElementType();
+
+ auto cpu_allocator = CPUAllocator::DefaultInstance();
+
+ auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb");
+ auto weight_dst = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator);
+
+ auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb");
+ const int64_t scale_size = (TensorShape{N, quant_num}).Size();
+ if (scale_src.size() != static_cast(scale_size)) continue;
+ auto scale_dst = Tensor(scale_type, TensorShape{scale_size}, cpu_allocator);
+
+ std::string zp_dst_name;
+ std::optional zp_dst;
+ const int64_t zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size();
+
+ bool elide_default_uint4_zp8_input = false;
+ std::optional zp_src;
+
+ const auto weight_bytes = weight_src.DataAsByteSpan();
+ if (weight_bytes.size() != static_cast(weight_dst.SizeInBytes())) continue;
+ memcpy(weight_dst.MutableDataRaw(), weight_bytes.data(), weight_bytes.size());
+
+ if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
+ memcpy(scale_dst.MutableData(), scale_src.data(),
+ static_cast(scale_size) * sizeof(float));
+ } else {
+ memcpy(scale_dst.MutableData(), scale_src.data(),
+ static_cast(scale_size) * sizeof(MLFloat16));
+ }
+
+ if (zp_tp) {
+ zp_src.emplace(graph, *zp_tp, graph.ModelPath());
+ if (zp_src->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue;
+ if (zp_src->size() != static_cast(N * quant_num)) continue;
+
+ const bool is_default_uint4_8 =
+ IsUniformPackedUint4Value(*zp_src, /*expected_nibble*/ 8);
+ if (is_default_uint4_8) {
+ elide_default_uint4_zp8_input = true;
+ } else {
+ zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb");
+ zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator);
+ PackUint4Rows(*zp_src, N, quant_num, zp_dst->MutableData());
+ }
+ } else {
+ // DequantizeLinear default zero-point for uint4 is 0, while MatMulNBits
+ // default is 8. Emit explicit zeros to preserve semantics.
+ zp_dst_name = graph.GenerateNodeArgName("fused_DQ_zp_mnb");
+ zp_dst = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator);
+ memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes());
+ }
+
+ auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true);
+ auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true);
+ std::optional zp_mnb_tp;
+ if (zp_dst && !elide_default_uint4_zp8_input) {
+ zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true));
+ }
+
+ NodeAttributes mnb_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs);
+
+ std::vector mnb_inputs;
+ mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0]));
+ mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst)));
+ mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst)));
+ if (zp_mnb_tp) {
+ mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst)));
+ }
+
+ // MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zero_points(opt), 4:g_idx(opt), 5:bias(opt)
+ bool fused_with_bias = false;
+ if (mm_node->OpType() == "Gemm" &&
+ mm_node->InputDefs().size() > 2 &&
+ mm_node->InputDefs()[2] &&
+ mm_node->InputDefs()[2]->Exists()) {
+ NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr);
+ while (mnb_inputs.size() < 5) {
+ mnb_inputs.push_back(&empty_arg);
+ }
+ mnb_inputs.push_back(const_cast(mm_node->InputDefs()[2]));
+ fused_with_bias = true;
+ }
+
+ std::vector mnb_outputs;
+ mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0]));
+
+ auto& mnb_node = graph.AddNode(
+ graph.GenerateNodeName("DQFusedMatMulNBits"),
+ "MatMulNBits",
+ "Fused from DQ+Reshape+Transpose+MatMul",
+ mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain);
+ mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType());
+
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx));
+ graph.RemoveNode(match.matmul_idx);
+
+ if (match.cast_idx && graph.GetNode(*match.cast_idx)) {
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(*match.cast_idx));
+ graph.RemoveNode(*match.cast_idx);
+ }
+
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.transpose_idx));
+ graph.RemoveNode(match.transpose_idx);
+
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.reshape_idx));
+ graph.RemoveNode(match.reshape_idx);
+
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx));
+ graph.RemoveNode(match.dq_idx);
+
+ LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused DQ+Reshape+Transpose"
+ << (match.cast_idx ? "+Cast" : "")
+ << "+MatMul/Gemm -> MatMulNBits"
+ << (fused_with_bias ? " (bias preserved)" : "")
+ << (elide_default_uint4_zp8_input ? " (default UINT4 zp8 elided)" : "");
+ modified = true;
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Pattern 2 rewriting: direct DQ(axis=0) + MatMul/Gemm -> MatMulNBits
+// ---------------------------------------------------------------------------
+
+void ApplyDirectDQFusions(
+ Graph& graph,
+ const std::vector& matches,
+ int64_t accuracy_level,
+ bool& modified,
+ const logging::Logger& logger) {
+ for (const auto& match : matches) {
+ const Node* mm_node = graph.GetNode(match.matmul_idx);
+ const Node* dq_node = graph.GetNode(match.dq_idx);
+ if (!mm_node || !dq_node) continue;
+
+ const auto* weight_arg = dq_node->InputDefs()[0];
+ const auto* scale_arg = dq_node->InputDefs()[1];
+ const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr;
+ bool has_zp = zp_arg && zp_arg->Exists();
+
+ const auto& dq_attrs = dq_node->GetAttributes();
+ const int64_t block_size = dq_attrs.at("block_size").i();
+
+ const ONNX_NAMESPACE::TensorProto* weight_tp = nullptr;
+ if (!graph.GetInitializedTensor(weight_arg->Name(), weight_tp) || !weight_tp) continue;
+ const ONNX_NAMESPACE::TensorProto* scale_tp = nullptr;
+ if (!graph.GetInitializedTensor(scale_arg->Name(), scale_tp) || !scale_tp) continue;
+ const ONNX_NAMESPACE::TensorProto* zp_tp = nullptr;
+ if (has_zp) {
+ if (!graph.GetInitializedTensor(zp_arg->Name(), zp_tp) || !zp_tp) continue;
+ }
+
+ if (weight_tp->data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4 ||
+ weight_tp->dims_size() != 2) continue;
+
+ const int64_t K = weight_tp->dims(0);
+ const int64_t N = weight_tp->dims(1);
+ if (K <= 0 || N <= 0 || block_size <= 0 || K % block_size != 0) continue;
+ const int64_t k_blocks = K / block_size;
+ const int64_t blob_bytes = block_size / 2;
+ if (!HasRank2Shape(*scale_tp, k_blocks, N)) continue;
+ if (zp_tp && !HasRank2Shape(*zp_tp, k_blocks, N)) continue;
+
+ Initializer weight_src(graph, *weight_tp, graph.ModelPath());
+ const size_t required_weight_bytes = SafeInt(N) * k_blocks * blob_bytes;
+ if (weight_src.DataAsByteSpan().size() < required_weight_bytes) continue;
+ Initializer scale_src(graph, *scale_tp, graph.ModelPath());
+ if (scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
+ scale_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) continue;
+
+ auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(
+ ONNX_NAMESPACE::TensorProto_DataType_UINT8)
+ ->GetElementType();
+ auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(
+ scale_src.data_type())
+ ->GetElementType();
+ auto cpu_allocator = CPUAllocator::DefaultInstance();
+
+ auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_mnb");
+ auto weight_dst = Tensor(uint8_type, TensorShape{N, k_blocks, blob_bytes}, cpu_allocator);
+ TransposePackWeightsAxis0(weight_src.DataAsByteSpan().data(), K, N, block_size,
+ weight_dst.MutableData());
+
+ auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_mnb");
+ const int64_t scale_count = SafeInt(N) * k_blocks;
+ if (scale_src.size() != static_cast(scale_count)) continue;
+ auto scale_dst = Tensor(scale_type, TensorShape{scale_count}, cpu_allocator);
+
+ if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
+ const float* src = scale_src.data();
+ float* dst = scale_dst.MutableData();
+ for (int64_t n = 0; n < N; ++n)
+ for (int64_t kb = 0; kb < k_blocks; ++kb)
+ dst[n * k_blocks + kb] = src[kb * N + n];
+ } else {
+ const MLFloat16* src = scale_src.data();
+ MLFloat16* dst = scale_dst.MutableData();
+ for (int64_t n = 0; n < N; ++n)
+ for (int64_t kb = 0; kb < k_blocks; ++kb)
+ dst[n * k_blocks + kb] = src[kb * N + n];
+ }
+
+ std::string zp_dst_name;
+ std::optional zp_dst;
+ const int64_t zp_bytes_total = SafeInt(N) * ((k_blocks + 1) / 2);
+
+ bool elide_zp = false;
+
+ if (zp_tp) {
+ Initializer zp_src(graph, *zp_tp, graph.ModelPath());
+ if (zp_src.data_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT4) continue;
+ if (zp_src.size() != static_cast(k_blocks * N)) continue;
+
+ if (IsUniformPackedUint4Value(zp_src, 8)) {
+ elide_zp = true;
+ } else {
+ zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_mnb");
+ zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator);
+ TransposePackZPAxis0(zp_src.DataAsByteSpan().data(), k_blocks, N,
+ zp_dst->MutableData());
+ }
+ } else {
+ // DQ default ZP for UINT4 is 0, MatMulNBits default is 8. Emit explicit zeros.
+ zp_dst_name = graph.GenerateNodeArgName("direct_DQ_zp_mnb");
+ zp_dst = Tensor(uint8_type, TensorShape{zp_bytes_total}, cpu_allocator);
+ memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes());
+ }
+
+ auto weight_mnb_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true);
+ auto scale_mnb_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true);
+ std::optional zp_mnb_tp;
+ if (zp_dst && !elide_zp) {
+ zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true));
+ }
+
+ NodeAttributes mnb_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), mnb_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), mnb_attrs);
+
+ std::vector mnb_inputs;
+ mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0]));
+ mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst)));
+ mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst)));
+ if (zp_mnb_tp) {
+ mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst)));
+ }
+
+ bool fused_with_bias = false;
+ if (mm_node->OpType() == "Gemm" &&
+ mm_node->InputDefs().size() > 2 &&
+ mm_node->InputDefs()[2] &&
+ mm_node->InputDefs()[2]->Exists()) {
+ NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr);
+ while (mnb_inputs.size() < 5) {
+ mnb_inputs.push_back(&empty_arg);
+ }
+ mnb_inputs.push_back(const_cast(mm_node->InputDefs()[2]));
+ fused_with_bias = true;
+ }
+
+ std::vector mnb_outputs;
+ mnb_outputs.push_back(const_cast(mm_node->OutputDefs()[0]));
+
+ auto& mnb_node = graph.AddNode(
+ graph.GenerateNodeName("DirectDQFusedMatMulNBits"),
+ "MatMulNBits",
+ "Fused from direct DQ(axis=0)+MatMul",
+ mnb_inputs, mnb_outputs, &mnb_attrs, kMSDomain);
+ mnb_node.SetExecutionProviderType(mm_node->GetExecutionProviderType());
+
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.matmul_idx));
+ graph.RemoveNode(match.matmul_idx);
+
+ graph_utils::RemoveNodeOutputEdges(graph, *graph.GetNode(match.dq_idx));
+ graph.RemoveNode(match.dq_idx);
+
+ LOGS(logger, INFO) << "DQMatMulNBitsFusion: fused direct DQ(axis=0)+MatMul/Gemm -> MatMulNBits"
+ << " (K=" << K << ", N=" << N << ", block_size=" << block_size << ")"
+ << (fused_with_bias ? " (bias preserved)" : "")
+ << (elide_zp ? " (default UINT4 zp8 elided)" : "");
+ modified = true;
+ }
+}
+
+} // namespace
+
+// ---------------------------------------------------------------------------
+// DQMatMulNBitsFusion public interface
+// ---------------------------------------------------------------------------
+
+DQMatMulNBitsFusion::DQMatMulNBitsFusion(
+ int64_t accuracy_level,
+ const InlinedHashSet& compatible_eps)
+ : GraphTransformer("DQMatMulNBitsFusion", compatible_eps),
+ accuracy_level_(accuracy_level) {
+ ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4,
+ "MatMulNBits accuracy level must be between 0 and 4");
+}
+
+Status DQMatMulNBitsFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
+ const logging::Logger& logger) const {
+ GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+
+ for (auto node_index : node_topology_list) {
+ auto* node = graph.GetNode(node_index);
+ if (!node) continue;
+ ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
+ }
+
+ auto matches = CollectReshapeTransposeMatches(graph, node_topology_list, logger);
+
+ std::unordered_set matched_matmul_indices;
+ for (const auto& m : matches) {
+ matched_matmul_indices.insert(m.matmul_idx);
+ }
+
+ auto direct_matches = CollectDirectDQMatches(graph, node_topology_list,
+ matched_matmul_indices, logger);
+
+ ApplyReshapeTransposeFusions(graph, matches, accuracy_level_, modified, logger);
+ ApplyDirectDQFusions(graph, direct_matches, accuracy_level_, modified, logger);
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
+
+#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h
new file mode 100644
index 0000000000000..97c0debd760c0
--- /dev/null
+++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.h
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#if !defined(ORT_MINIMAL_BUILD)
+
+#include
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+// Fuses DequantizeLinear chains back into a single MatMulNBits contrib op.
+//
+// Supported patterns:
+// Pattern 1: DQ(3D, UINT4, axis=2) -> Reshape(2D) -> Transpose([1,0])
+// -> [optional Cast] -> MatMul/Gemm => MatMulNBits
+// Pattern 2: DQ(2D, UINT4, axis=0) -> MatMul/Gemm => MatMulNBits
+//
+// These patterns are produced when a quantized model goes through external
+// toolchains that lower MatMulNBits to DQ + reshape/transpose + MatMul
+// primitives, and then re-import the graph into ORT.
+class DQMatMulNBitsFusion : public GraphTransformer {
+ public:
+ explicit DQMatMulNBitsFusion(
+ int64_t accuracy_level = 4,
+ const InlinedHashSet& compatible_eps = {});
+
+ private:
+ Status ApplyImpl(Graph& graph, bool& modified, int graph_level,
+ const logging::Logger& logger) const override;
+
+ int64_t accuracy_level_;
+};
+
+} // namespace onnxruntime
+
+#endif // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index fdd4f5aa27862..4edabbe6058ab 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -18,6 +18,8 @@
#if !defined(ORT_MINIMAL_BUILD)
+#include "core/optimizer/dq_matmulnbits_fusion.h"
+
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/bias_dropout_fusion.h"
@@ -274,6 +276,26 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique());
}
+#if !defined(DISABLE_CONTRIB_OPS)
+ {
+ const bool enable_dq_matmulnbits_fusion =
+ session_options.config_options.GetConfigOrDefault(
+ kOrtSessionOptionsEnableDQMatMulNBitsFusion, "0") == "1";
+ if (enable_dq_matmulnbits_fusion && !disable_quant_qdq) {
+ const int64_t qdq_matmulnbits_accuracy_level =
+ ParseStringWithClassicLocale(
+ session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
+ "4"));
+ transformers.emplace_back(std::make_unique(
+ qdq_matmulnbits_accuracy_level));
+ }
+ }
+#else
+ ORT_ENFORCE(session_options.config_options.GetConfigOrDefault(
+ kOrtSessionOptionsEnableDQMatMulNBitsFusion, "0") != "1",
+ "DQ->MatMulNBits fusion requires contrib ops but DISABLE_CONTRIB_OPS is defined");
+#endif
+
// run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around.
// shouldn't affect the end result - just easier to debug any issue if it's last.
transformers.emplace_back(std::make_unique(std::move(cpu_allocator)));
diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc
index bb9130add215a..63d985052f996 100644
--- a/onnxruntime/core/platform/telemetry.cc
+++ b/onnxruntime/core/platform/telemetry.cc
@@ -121,7 +121,7 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu
}
void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
- std::unordered_map duration_per_batch_size) const {
+ const std::unordered_map& duration_per_batch_size) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(total_runs_since_last);
ORT_UNUSED_PARAMETER(total_run_duration_since_last);
@@ -157,4 +157,35 @@ void Telemetry::LogProviderOptions(const std::string& provider_id,
ORT_UNUSED_PARAMETER(captureState);
}
+void Telemetry::LogModelLoadStart(uint32_t session_id) const {
+ ORT_UNUSED_PARAMETER(session_id);
+}
+
+void Telemetry::LogModelLoadEnd(uint32_t session_id, const common::Status& status) const {
+ ORT_UNUSED_PARAMETER(session_id);
+ ORT_UNUSED_PARAMETER(status);
+}
+
+void Telemetry::LogSessionCreationEnd(uint32_t session_id,
+ const common::Status& status) const {
+ ORT_UNUSED_PARAMETER(session_id);
+ ORT_UNUSED_PARAMETER(status);
+}
+
+void Telemetry::LogRegisterEpLibraryWithLibPath(const std::string& registration_name,
+ const std::string& lib_path) const {
+ ORT_UNUSED_PARAMETER(registration_name);
+ ORT_UNUSED_PARAMETER(lib_path);
+}
+
+void Telemetry::LogRegisterEpLibraryStart(const std::string& registration_name) const {
+ ORT_UNUSED_PARAMETER(registration_name);
+}
+
+void Telemetry::LogRegisterEpLibraryEnd(const std::string& registration_name,
+ const common::Status& status) const {
+ ORT_UNUSED_PARAMETER(registration_name);
+ ORT_UNUSED_PARAMETER(status);
+}
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h
index e74d7ed0180fd..20a58e5a87184 100644
--- a/onnxruntime/core/platform/telemetry.h
+++ b/onnxruntime/core/platform/telemetry.h
@@ -85,7 +85,7 @@ class Telemetry {
const char* function, uint32_t line) const;
virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
- std::unordered_map duration_per_batch_size) const;
+ const std::unordered_map& duration_per_batch_size) const;
virtual void LogExecutionProviderEvent(LUID* adapterLuid) const;
@@ -101,6 +101,21 @@ class Telemetry {
const std::string& provider_options_string,
bool captureState) const;
+ virtual void LogModelLoadStart(uint32_t session_id) const;
+
+ virtual void LogModelLoadEnd(uint32_t session_id, const common::Status& status) const;
+
+ virtual void LogSessionCreationEnd(uint32_t session_id,
+ const common::Status& status) const;
+
+ virtual void LogRegisterEpLibraryWithLibPath(const std::string& registration_name,
+ const std::string& lib_path) const;
+
+ virtual void LogRegisterEpLibraryStart(const std::string& registration_name) const;
+
+ virtual void LogRegisterEpLibraryEnd(const std::string& registration_name,
+ const common::Status& status) const;
+
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry);
};
diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc
index 6d5a400be703b..30c24de6a92ed 100644
--- a/onnxruntime/core/platform/windows/telemetry.cc
+++ b/onnxruntime/core/platform/windows/telemetry.cc
@@ -457,7 +457,8 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id,
TraceLoggingInt32(graph_optimization_level, "graphOptimizationLevel"),
TraceLoggingBool(embed_ep_context, "embedEpContext"),
TraceLoggingBool(has_external_initializers_file, "hasExternalInitializersFile"),
- TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
+ TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
void WindowsTelemetry::LogCompileModelComplete(uint32_t session_id,
@@ -480,7 +481,8 @@ void WindowsTelemetry::LogCompileModelComplete(uint32_t session_id,
TraceLoggingBool(success, "success"),
TraceLoggingUInt32(error_code, "errorCode"),
TraceLoggingUInt32(error_category, "errorCategory"),
- TraceLoggingString(error_message.c_str(), "errorMessage"));
+ TraceLoggingString(error_message.c_str(), "errorMessage"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
@@ -528,7 +530,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
}
void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
- std::unordered_map duration_per_batch_size) const {
+ const std::unordered_map& duration_per_batch_size) const {
if (global_register_count_ == 0 || enabled_ == false)
return;
@@ -668,4 +670,116 @@ void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const
}
}
+void WindowsTelemetry::LogModelLoadStart(uint32_t session_id) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "ModelLoadStart",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt32(session_id, "sessionId"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
+void WindowsTelemetry::LogModelLoadEnd(uint32_t session_id, const common::Status& status) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "ModelLoadEnd",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt32(session_id, "sessionId"),
+ TraceLoggingBool(status.IsOK(), "isSuccess"),
+ TraceLoggingUInt32(status.Code(), "errorCode"),
+ TraceLoggingUInt32(status.Category(), "errorCategory"),
+ TraceLoggingString(status.IsOK() ? "" : status.ErrorMessage().c_str(), "errorMessage"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
+void WindowsTelemetry::LogSessionCreationEnd(uint32_t session_id,
+ const common::Status& status) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "SessionCreationEnd",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt32(session_id, "sessionId"),
+ TraceLoggingBool(status.IsOK(), "isSuccess"),
+ TraceLoggingUInt32(status.Code(), "errorCode"),
+ TraceLoggingUInt32(status.Category(), "errorCategory"),
+ TraceLoggingString(status.IsOK() ? "" : status.ErrorMessage().c_str(), "errorMessage"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
+void WindowsTelemetry::LogRegisterEpLibraryWithLibPath(const std::string& registration_name,
+ const std::string& lib_path) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "RegisterEpLibraryWithLibPath",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingString(registration_name.c_str(), "registrationName"),
+ TraceLoggingString(lib_path.c_str(), "libPath"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
+void WindowsTelemetry::LogRegisterEpLibraryStart(const std::string& registration_name) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "RegisterEpLibraryStart",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingString(registration_name.c_str(), "registrationName"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
+void WindowsTelemetry::LogRegisterEpLibraryEnd(const std::string& registration_name,
+ const common::Status& status) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "RegisterEpLibraryEnd",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingString(registration_name.c_str(), "registrationName"),
+ TraceLoggingBool(status.IsOK(), "isSuccess"),
+ TraceLoggingUInt32(status.Code(), "errorCode"),
+ TraceLoggingUInt32(status.Category(), "errorCategory"),
+ TraceLoggingString(status.IsOK() ? "" : status.ErrorMessage().c_str(), "errorMessage"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h
index 591a248d70ab8..b46dcfbd3feb5 100644
--- a/onnxruntime/core/platform/windows/telemetry.h
+++ b/onnxruntime/core/platform/windows/telemetry.h
@@ -78,7 +78,7 @@ class WindowsTelemetry : public Telemetry {
const char* function, uint32_t line) const override;
void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
- std::unordered_map duration_per_batch_size) const override;
+ const std::unordered_map& duration_per_batch_size) const override;
void LogExecutionProviderEvent(LUID* adapterLuid) const override;
@@ -94,6 +94,21 @@ class WindowsTelemetry : public Telemetry {
const std::string& provider_options_string,
bool captureState) const override;
+ void LogModelLoadStart(uint32_t session_id) const override;
+
+ void LogModelLoadEnd(uint32_t session_id, const common::Status& status) const override;
+
+ void LogSessionCreationEnd(uint32_t session_id,
+ const common::Status& status) const override;
+
+ void LogRegisterEpLibraryWithLibPath(const std::string& registration_name,
+ const std::string& lib_path) const override;
+
+ void LogRegisterEpLibraryStart(const std::string& registration_name) const override;
+
+ void LogRegisterEpLibraryEnd(const std::string& registration_name,
+ const common::Status& status) const override;
+
using EtwInternalCallback = std::function;
diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc
index d8c81e5cb63e5..6ecbfaa3993ca 100644
--- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc
+++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc
@@ -294,6 +294,41 @@ Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, cons
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"First dimension (num_rois) of batch_indices and rois don't match");
}
+
+ // Validate batch_indices values are within [0, batch_size).
+ // Only check when the tensor data is accessible from the host (CPU).
+ // For GPU tensors (e.g. CUDA EP), Data() returns a device pointer
+ // that cannot be safely dereferenced on the host. A device-side bounds
+ // check for the CUDA path would require passing batch_size into the
+ // CUDA kernel — tracked as a follow-up.
+ if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) {
+ const int64_t batch_size = X_ptr->Shape()[0];
+ const int64_t num_rois = batch_indices_dims[0];
+
+ auto check_bounds = [batch_size, num_rois](const auto* batch_indices_data) -> Status {
+ for (int64_t i = 0; i < num_rois; ++i) {
+ if (batch_indices_data[i] < 0 || batch_indices_data[i] >= batch_size) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "batch_indices value " + std::to_string(batch_indices_data[i]) +
+ " at index " + std::to_string(i) +
+ " is out of range [0, " + std::to_string(batch_size) + ")");
+ }
+ }
+ return Status::OK();
+ };
+
+ if (batch_indices_ptr->IsDataType()) {
+ auto status = check_bounds(batch_indices_ptr->Data());
+ if (!status.IsOK()) return status;
+ } else if (batch_indices_ptr->IsDataType()) {
+ auto status = check_bounds(batch_indices_ptr->Data());
+ if (!status.IsOK()) return status;
+ } else {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "batch_indices must be of type int64_t or int32_t");
+ }
+ }
+
return Status::OK();
}
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
index e415143a6ddd1..c567c220b3ce7 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
@@ -306,8 +306,8 @@ class NvExecutionProvider : public IExecutionProvider {
const GraphOptimizerRegistry& graph_optimizer_registry,
IResourceAccountant* /* resource_accountant */) const override;
- int GetDeviceId() const { return device_id_; }
- Status Sync() const;
+ int GetDeviceId() const override { return device_id_; }
+ Status Sync() const override;
common::Status Compile(const std::vector& fused_nodes_and_graphs,
std::vector& node_compute_funcs) override;
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc
index 90e488a1eda18..a7c37cd481894 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc
@@ -28,17 +28,32 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose);
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) {
+ // Domain for TRT plugin custom ops (domain name: "trt.plugins"). Owns the OrtCustomOpDomain object.
+ // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession.
static std::unique_ptr custom_op_domain = std::make_unique();
+
+ // Owns the TensorRTCustomOp objects for TRT plugins. Raw pointers are stored in custom_op_domain->custom_ops_.
static std::vector> created_custom_op_list;
+
+ // Domain for native custom ops (domain name: "trt"). Owns the OrtCustomOpDomain object.
+ // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession.
static std::unique_ptr native_custom_op_domain = std::make_unique();
+
+ // Owns the TensorRTCustomOp objects for native custom ops. Raw pointers are stored in native_custom_op_domain->custom_ops_.
+ // Non-empty list indicates native custom ops have been registered (used to avoid re-registration on subsequent calls).
static std::vector> native_custom_op_list;
+
+ // Protects concurrent access to all the above static members.
static std::mutex mutex;
std::lock_guard lock(mutex);
+
+ // Add already-initialized native ops to domain list
+ if (!native_custom_op_list.empty()) {
+ domain_list.push_back(native_custom_op_domain.get());
+ }
+
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain.get());
- if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) {
- domain_list.push_back(native_custom_op_domain.get());
- }
return Status::OK();
}
@@ -132,35 +147,36 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector&
}
// Register native custom ops (register these independent of TRT plugin library availability)
- const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
- int num_native_custom_ops = std::size(native_custom_ops_names);
+ if (native_custom_op_list.empty()) {
+ const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
+ size_t num_native_custom_ops = std::size(native_custom_ops_names);
+
+ for (size_t i = 0; i < num_native_custom_ops; i++) {
+ native_custom_op_list.push_back(std::make_unique(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
+ native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
+ native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
+ }
- for (int i = 0; i < num_native_custom_ops; i++) {
- native_custom_op_list.push_back(std::make_unique(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
- native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
- native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
+ native_custom_op_domain->domain_ = "trt";
+ domain_list.push_back(native_custom_op_domain.get());
}
- native_custom_op_domain->domain_ = "trt";
- domain_list.push_back(native_custom_op_domain.get());
return Status::OK();
}
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
- if (domain != nullptr) {
- for (auto ptr : domain->custom_ops_) {
- if (ptr != nullptr) {
- delete ptr;
- }
- }
- delete domain;
- }
+ (void)domain; // Suppress unused parameter warning
+ // The domain and its custom ops are owned by static unique_ptrs in CreateTensorRTCustomOpDomainList().
+ // Callers receive raw pointers via .get().
+ // 1. Manually deleting them would cause a double-free when the static unique_ptrs are destroyed at program exit.
+ // 2. Resetting the static unique_ptrs is also unsafe because other EP instances or InferenceSession objects
+ // may still hold raw pointers to these same objects (handed out via domain_list).
+ // The static objects would be shared across EP instances and would persist for the program lifetime.
}
void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) {
- for (auto ptr : custom_op_domain_list) {
- ReleaseTensorRTCustomOpDomain(ptr);
- }
+ // Only clear the reference vector, don't delete the static domain objects.
+ custom_op_domain_list.clear();
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
index d1e449eb58870..388387fae4b0a 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
@@ -61,7 +61,7 @@ struct NvProviderFactory : IExecutionProviderFactory {
std::unique_ptr CreateProvider() override;
std::unique_ptr CreateProvider(const OrtSessionOptions& session_options,
- const OrtLogger& session_logger);
+ const OrtLogger& session_logger) override;
private:
NvExecutionProviderInfo info_;
@@ -109,7 +109,7 @@ struct Nv_Provider : Provider {
return std::make_shared(info);
}
- std::shared_ptr CreateExecutionProviderFactory(const void* param) {
+ std::shared_ptr CreateExecutionProviderFactory(const void* param) override {
if (param == nullptr) {
LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Passed NULL options to CreateExecutionProviderFactory()";
return nullptr;
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc
index 1e9fafe8aa323..2418c8424422b 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc
@@ -45,8 +45,14 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose);
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) {
+ // Domain for TRT plugin custom ops (domain name: "trt.plugins"). Owns the OrtCustomOpDomain object.
+ // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession.
static std::unique_ptr custom_op_domain = std::make_unique();
+
+ // Owns the TensorRTCustomOp objects for TRT plugins. Raw pointers are stored in custom_op_domain->custom_ops_.
static std::vector> created_custom_op_list;
+
+ // Protects concurrent access to all the above static members.
static std::mutex mutex;
std::lock_guard lock(mutex);
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
@@ -148,20 +154,18 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector&
}
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
- if (domain != nullptr) {
- for (auto ptr : domain->custom_ops_) {
- if (ptr != nullptr) {
- delete ptr;
- }
- }
- delete domain;
- }
+ (void)domain; // Suppress unused parameter warning
+ // The domain and its custom ops are owned by static unique_ptrs in CreateTensorRTCustomOpDomainList().
+ // Callers receive raw pointers via .get().
+ // 1. Manually deleting them would cause a double-free when the static unique_ptrs are destroyed at program exit.
+ // 2. Resetting the static unique_ptrs is also unsafe because other EP instances or InferenceSession objects
+ // may still hold raw pointers to these same objects (handed out via domain_list).
+ // The static objects are shared across EP instances and persist for the program lifetime.
}
void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) {
- for (auto ptr : custom_op_domain_list) {
- ReleaseTensorRTCustomOpDomain(ptr);
- }
+ // Only clear the reference vector, don't delete the static domain objects.
+ custom_op_domain_list.clear();
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc
index 523ff8eaf13b8..7cd02e5413407 100644
--- a/onnxruntime/core/session/environment.cc
+++ b/onnxruntime/core/session/environment.cc
@@ -17,6 +17,7 @@
#include "core/session/allocator_adapters.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_env_config_keys.h"
+#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
#include "core/session/plugin_ep/ep_factory_internal.h"
#include "core/session/plugin_ep/ep_library_internal.h"
#include "core/session/plugin_ep/ep_library_plugin.h"
@@ -539,8 +540,13 @@ bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) {
Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name,
std::unique_ptr ep_library,
const std::vector& internal_factories) {
+ const Env& env = Env::Default();
+ env.GetTelemetryProvider().LogRegisterEpLibraryStart(registration_name);
+
if (ep_libraries_.count(registration_name) > 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "library is already registered under ", registration_name);
+ auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "library is already registered under ", registration_name);
+ env.GetTelemetryProvider().LogRegisterEpLibraryEnd(registration_name, status);
+ return status;
}
auto status = Status::OK();
@@ -592,6 +598,7 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra
});
}
+ env.GetTelemetryProvider().LogRegisterEpLibraryEnd(registration_name, status);
return status;
}
@@ -611,6 +618,9 @@ Status Environment::CreateAndRegisterInternalEps() {
Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) {
std::lock_guard lock{mutex_};
+ std::string lib_file_name = std::filesystem::path(lib_path).filename().string();
+ Env::Default().GetTelemetryProvider().LogRegisterEpLibraryWithLibPath(registration_name, lib_file_name);
+
std::vector internal_factories = {};
std::unique_ptr ep_library;
@@ -896,8 +906,15 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u
factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(),
ep_devices.data(), ep_devices.size(), &num_ep_devices)));
+ const auto* library_path = instance.library->LibraryPath();
for (size_t i = 0; i < num_ep_devices; ++i) {
- if (ep_devices[i] != nullptr) { // should never happen but just in case...
+ if (ep_devices[i] != nullptr) { // should never happen but just in case...
+ if (library_path != nullptr) {
+ // Add library path to EP metadata if available.
+ // This is used by GenAI for custom library loading so we want to consistently set it.
+ ep_devices[i]->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path->string());
+ }
+
instance.execution_devices.emplace_back(ep_devices[i]); // take ownership
}
}
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 0c9b3c0663b5c..d24020424935a 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -37,6 +37,7 @@
#include "core/framework/plugin_ep_stream.h"
#include "core/framework/transform_layout_functions.h"
#include "core/framework/utils.h"
+#include "core/graph/constants.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/graph/model_editor_api_types.h"
@@ -730,6 +731,25 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const
#endif // !defined(ORT_MINIMAL_BUILD)
InferenceSession::~InferenceSession() {
+ // Flush any remaining RuntimePerf counters
+ ORT_TRY {
+ std::lock_guard telemetry_lock(telemetry_mutex_);
+ if (telemetry_.total_runs_since_last_ > 0) {
+ Env::Default().GetTelemetryProvider().LogRuntimePerf(session_id_,
+ telemetry_.total_runs_since_last_,
+ telemetry_.total_run_duration_since_last_,
+ telemetry_.duration_per_batch_size_);
+ }
+ }
+ ORT_CATCH(const std::exception& e) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ LOGS(*session_logger_, ERROR) << "Error during telemetry flush: " << e.what();
+ });
+ }
+ ORT_CATCH(...) {
+ LOGS(*session_logger_, ERROR) << "Unknown error during telemetry flush";
+ }
+
if (session_options_.enable_profiling) {
ORT_TRY {
EndProfiling();
@@ -969,7 +989,10 @@ common::Status InferenceSession::LoadWithLoader(std::function l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
@@ -1005,6 +1028,8 @@ common::Status InferenceSession::LoadWithLoader(std::function load_ort_format_model_bytes) {
+ const Env& env = Env::Default();
+ env.GetTelemetryProvider().LogModelLoadStart(session_id_);
+
std::lock_guard l(session_mutex_);
if (is_model_loaded_) { // already loaded
@@ -1761,6 +1789,8 @@ Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort
is_model_loaded_ = true;
+ env.GetTelemetryProvider().LogModelLoadEnd(session_id_, Status::OK());
+
return Status::OK();
}
@@ -2263,6 +2293,16 @@ common::Status InferenceSession::Initialize() {
return Status::OK();
};
+ // Enable DQ->MatMulNBits fusion if NvTensorRTRTX EP is registered.
+ if (execution_providers_.Get(onnxruntime::kNvTensorRTRTXExecutionProvider) != nullptr) {
+ if (session_options_.config_options.GetConfigOrDefault(
+ kOrtSessionOptionsEnableDQMatMulNBitsFusion, "") == "") {
+ ORT_RETURN_IF_ERROR_SESSIONID_(
+ session_options_.config_options.AddConfigEntry(
+ kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1"));
+ }
+ }
+
// add predefined transformers
ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_,
session_options_.graph_optimization_level,
@@ -2603,6 +2643,12 @@ common::Status InferenceSession::Initialize() {
}
}
+ // Log session creation end telemetry
+ {
+ const Env& init_env = Env::Default();
+ init_env.GetTelemetryProvider().LogSessionCreationEnd(session_id_, status);
+ }
+
return status;
}
#if defined(_MSC_VER) && !defined(__clang__)
@@ -3163,24 +3209,31 @@ Status InferenceSession::Run(const RunOptions& run_options,
break;
}
- // time to send telemetry?
- {
- // Adding lock_guard here to ensure that telemetry updates are thread-safe.
- std::lock_guard telemetry_lock(telemetry_mutex_);
- ++telemetry_.total_runs_since_last_;
- telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp);
- telemetry_.duration_per_batch_size_[batch_size] += TimeDiffMicroSeconds(tp);
-
- if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > Telemetry::kDurationBetweenSending) {
- // send the telemetry
- env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_,
- telemetry_.total_run_duration_since_last_,
- telemetry_.duration_per_batch_size_);
- // reset counters
- telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now();
- telemetry_.total_runs_since_last_ = 0;
- telemetry_.total_run_duration_since_last_ = 0;
- telemetry_.duration_per_batch_size_.clear();
+ // Only include successful inferences in batch since failed inferences can skew the metric
+ if (retval.IsOK()) {
+ // time to send telemetry?
+ {
+ // Adding lock_guard here to ensure that telemetry updates are thread-safe.
+ std::lock_guard telemetry_lock(telemetry_mutex_);
+ ++telemetry_.total_runs_since_last_;
+ telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp);
+ telemetry_.duration_per_batch_size_[batch_size] += TimeDiffMicroSeconds(tp);
+
+ // Emit RuntimePerf on scheduled interval
+ if ((TimeDiffMicroSeconds(telemetry_.time_sent_last_) > telemetry_.runtime_perf_interval_)) {
+ env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_,
+ telemetry_.total_run_duration_since_last_,
+ telemetry_.duration_per_batch_size_);
+ // reset counters
+ telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now();
+ telemetry_.total_runs_since_last_ = 0;
+ telemetry_.total_run_duration_since_last_ = 0;
+ telemetry_.duration_per_batch_size_.clear();
+
+ // Double the interval, capping at kRuntimePerfMaxInterval
+ telemetry_.runtime_perf_interval_ = std::min(telemetry_.runtime_perf_interval_ * 2,
+ Telemetry::kRuntimePerfMaxInterval);
+ }
}
}
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index 1dbf0318c988c..e51dc773f9761 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -976,8 +976,10 @@ class InferenceSession {
std::unordered_map duration_per_batch_size_; // the duration (us) of Run() calls per batch size since the last report
TimePoint time_sent_last_; // the TimePoint of the last report
- // Event Rate per provider < 20 peak events per second
- constexpr static long long kDurationBetweenSending = 1000 * 1000 * 60 * 10; // duration in (us). send a report every 10 mins
+ // RuntimePerf backoff interval: starts at 2s between emissions, doubles each emission, caps at 10 min
+ constexpr static int64_t kRuntimePerfInitialInterval = 2 * 1000 * 1000; // 2 seconds in (us)
+ constexpr static int64_t kRuntimePerfMaxInterval = 1000 * 1000 * 60 * 10; // 10 minutes in (us)
+ int64_t runtime_perf_interval_ = kRuntimePerfInitialInterval;
} telemetry_;
mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data
diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc
index 42b65239de92c..0e2c4b4217702 100644
--- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc
+++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc
@@ -22,11 +22,6 @@ OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_fa
auto* ep_device = ep_devices[i];
if (ep_device) {
ep_device->ep_factory = &ep_factory;
-
- // Add library path to EP metadata if available
- if (library_path_.has_value()) {
- ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string());
- }
}
}
diff --git a/onnxruntime/core/session/plugin_ep/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h
index af5bc23143e33..fed9eb072c704 100644
--- a/onnxruntime/core/session/plugin_ep/ep_library.h
+++ b/onnxruntime/core/session/plugin_ep/ep_library.h
@@ -20,6 +20,7 @@ class EpLibrary {
EpLibrary() = default;
virtual const char* RegistrationName() const = 0;
+ virtual const std::filesystem::path* LibraryPath() const { return nullptr; }
virtual Status Load() { return Status::OK(); }
virtual const std::vector& GetFactories() = 0; // valid after Load()
virtual Status Unload() { return Status::OK(); }
diff --git a/onnxruntime/core/session/plugin_ep/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h
index e044e91b61e37..cce52b3d5d282 100644
--- a/onnxruntime/core/session/plugin_ep/ep_library_plugin.h
+++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h
@@ -25,6 +25,10 @@ class EpLibraryPlugin : public EpLibrary {
return registration_name_.c_str();
}
+ const std::filesystem::path* LibraryPath() const override {
+ return &library_path_;
+ }
+
Status Load() override;
const std::vector& GetFactories() override {
diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h
index 45277b2828f56..f3147794bc823 100644
--- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h
+++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h
@@ -22,7 +22,7 @@ class EpLibraryProviderBridge : public EpLibrary {
public:
EpLibraryProviderBridge(std::unique_ptr provider_library,
std::unique_ptr ep_library_plugin,
- std::optional library_path = std::nullopt)
+ std::filesystem::path library_path)
: provider_library_{std::move(provider_library)},
ep_library_plugin_{std::move(ep_library_plugin)},
library_path_{std::move(library_path)} {
@@ -32,6 +32,10 @@ class EpLibraryProviderBridge : public EpLibrary {
return ep_library_plugin_->RegistrationName();
}
+ const std::filesystem::path* LibraryPath() const override {
+ return &library_path_;
+ }
+
const std::vector& GetFactories() override {
return factory_ptrs_;
}
@@ -56,7 +60,7 @@ class EpLibraryProviderBridge : public EpLibrary {
std::unique_ptr ep_library_plugin_;
// Library path for EP metadata
- std::optional library_path_;
+ std::filesystem::path library_path_;
std::vector> factories_;
std::vector factory_ptrs_; // for convenience
diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc
index 2d061b6066a8a..6e2d204b04cf2 100644
--- a/onnxruntime/lora/adapter_format_utils.cc
+++ b/onnxruntime/lora/adapter_format_utils.cc
@@ -7,8 +7,9 @@
#include "core/framework/allocator.h"
#include "core/common/common.h"
#include "core/common/endian.h"
-#include "core/framework/endian_utils.h"
+#include "core/common/safeint.h"
#include "core/common/span_utils.h"
+#include "core/framework/endian_utils.h"
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"
#include "core/framework/ort_value.h"
@@ -149,13 +150,36 @@ struct ReadDataForBigEndian {
std::pair CreateOrtValueOverLoraParameter(const Parameter& param) {
OrtValue result;
+ const auto* param_name = param.name();
+ ORT_ENFORCE(param_name != nullptr, "Lora Parameter: name is missing");
+
std::string name;
- LoadStringFromLoraFormat(name, param.name());
+ LoadStringFromLoraFormat(name, param_name);
const auto data_type = param.data_type();
+ ORT_ENFORCE(data_type != TensorDataType::UNDEFINED,
+ "Lora Param '", name, "': data_type is UNDEFINED");
+
+ const auto* dims = param.dims();
+ ORT_ENFORCE(dims != nullptr && dims->size() > 0,
+ "Lora Param '", name, "': dims is missing or empty");
+
+ const auto* raw_data = param.raw_data();
+ ORT_ENFORCE(raw_data != nullptr,
+ "Lora Param '", name, "': raw_data is missing");
+
// Copying shape takes care of endianess using flatbuffers accessors
- TensorShapeVector shape(param.dims()->begin(), param.dims()->end());
+ TensorShapeVector shape(dims->begin(), dims->end());
+ TensorShape tensor_shape(shape);
const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType();
+ const size_t expected_raw_data_size = SafeInt(tensor_shape.Size()) * elem_type->Size();
+ if (raw_data->size() != expected_raw_data_size) {
+ ORT_THROW("Lora Param '", name,
+ "': raw_data size (", raw_data->size(),
+ ") does not match expected size (", expected_raw_data_size,
+ ") calculated from tensor shape and element type");
+ }
+
static const OrtMemoryInfo cpu_meminfo(CPU, OrtAllocatorType::OrtDeviceAllocator);
if constexpr (endian::native == endian::big) {
@@ -166,16 +190,16 @@ std::pair CreateOrtValueOverLoraParameter(const Parameter
// of raw data
// const_cast is necessary due to Tensor class API
Tensor::InitOrtValue(elem_type,
- TensorShape(shape),
- const_cast(param.raw_data()->data()),
+ tensor_shape,
+ const_cast(raw_data->data()),
cpu_meminfo,
result);
}
} else {
// const_cast is necessary due to Tensor class API
Tensor::InitOrtValue(elem_type,
- TensorShape(shape),
- const_cast(param.raw_data()->data()),
+ tensor_shape,
+ const_cast(raw_data->data()),
cpu_meminfo,
result);
}
diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc
index 7b6679ffaf462..79bc34572a6f7 100644
--- a/onnxruntime/test/autoep/test_registration.cc
+++ b/onnxruntime/test/autoep/test_registration.cc
@@ -74,6 +74,15 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) {
auto options = test_ep_device->EpOptions();
ASSERT_STREQ(options.GetValue("run_really_fast"), "true");
+ // Verify the library path is present in the EP metadata
+ const char* metadata_library_path = metadata.GetValue(kOrtEpDevice_EpMetadataKey_LibraryPath);
+ ASSERT_NE(metadata_library_path, nullptr) << "Expected library_path to be present in EP metadata.";
+
+ // Verify the library path matches the registered path
+ std::filesystem::path metadata_path{metadata_library_path};
+ ASSERT_EQ(std::filesystem::canonical(metadata_path), std::filesystem::canonical(library_path))
+ << "Expected library_path in EP metadata to match the registered library path.";
+
// the CPU device info will vary by machine so check for the lowest common denominator values
Ort::ConstHardwareDevice device = test_ep_device->Device();
ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU);
diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc
index 0c55cf45abcdf..791250a6e1364 100644
--- a/onnxruntime/test/lora/lora_test.cc
+++ b/onnxruntime/test/lora/lora_test.cc
@@ -173,7 +173,6 @@ struct TestDataType {
verify_load(lora_adapter);
}
};
-
} // namespace
TEST(LoraAdapterTest, Load) {
@@ -199,6 +198,226 @@ TEST(LoraAdapterTest, Load) {
}
}
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ValidParam) {
+ // Build a valid adapter with a single float parameter, then call
+ // CreateOrtValueOverLoraParameter on the deserialized Parameter.
+ constexpr std::array shape = {8, 4};
+ InlinedVector data(32);
+ std::iota(data.begin(), data.end(), 0.f);
+
+ adapters::utils::AdapterFormatBuilder adapter_builder;
+ adapter_builder.AddParameter("valid_param", adapters::TensorDataType::FLOAT,
+ shape, ReinterpretAsSpan(gsl::make_span(data)));
+
+ auto buffer = adapter_builder.Finish(kAdapterVersion, kModelVersion);
+
+ const auto* adapter = adapters::utils::ValidateAndGetAdapterFromBytes(buffer);
+ ASSERT_NE(adapter, nullptr);
+ ASSERT_NE(adapter->parameters(), nullptr);
+ ASSERT_EQ(adapter->parameters()->size(), 1u);
+
+ const auto* param = adapter->parameters()->Get(0);
+ auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param);
+
+ ASSERT_EQ(name, "valid_param");
+ ASSERT_TRUE(ort_value.IsTensor());
+
+ const auto& tensor = ort_value.Get();
+ ASSERT_EQ(tensor.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
+
+ auto dims = tensor.Shape().GetDims();
+ ASSERT_EQ(dims.size(), 2u);
+ ASSERT_EQ(dims[0], 8);
+ ASSERT_EQ(dims[1], 4);
+
+ auto result_span = tensor.DataAsSpan();
+ ASSERT_EQ(result_span.size(), 32u);
+ for (size_t i = 0; i < result_span.size(); ++i) {
+ ASSERT_EQ(static_cast(i), result_span[i]);
+ }
+}
+
+#ifndef ORT_NO_EXCEPTIONS
+
+namespace {
+// Helper that wraps a single Parameter offset into a finished Adapter flatbuffer
+// and returns a pointer to the deserialized Parameter.
+// The FlatBufferBuilder must outlive the returned pointer.
+const adapters::Parameter* BuildAdapterAndGetParam(flatbuffers::FlatBufferBuilder& fbb,
+ flatbuffers::Offset param_offset) {
+ auto params_offset = fbb.CreateVector(¶m_offset, 1);
+ auto adapter_offset = adapters::CreateAdapter(
+ fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
+ adapters::FinishAdapterBuffer(fbb, adapter_offset);
+
+ const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer());
+ return adapter->parameters()->Get(0);
+}
+} // namespace
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_RawDataSizeMismatch) {
+ // Craft a flatbuffer Parameter where raw_data has fewer bytes than
+ // shape (8 x 4) * sizeof(float) = 128 bytes.
+ // We supply only 64 bytes (half the expected amount) so the validation
+ // inside CreateOrtValueOverLoraParameter must throw.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ auto name_offset = fbb.CreateString("bad_param");
+ std::vector dims = {8, 4};
+ auto dims_offset = fbb.CreateVector(dims);
+
+ // 8 * 4 floats = 32 elements = 128 bytes expected.
+ // Provide only 64 bytes (16 floats worth) to trigger the mismatch.
+ std::vector short_data(64, 0);
+ fbb.ForceVectorAlignment(short_data.size(), sizeof(uint8_t), 8);
+ auto data_offset = fbb.CreateVector(short_data);
+
+ auto param_offset = adapters::CreateParameter(
+ fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset);
+
+ // Wrap the single parameter inside an Adapter so the buffer is valid flatbuffers.
+ auto params_offset = fbb.CreateVector(¶m_offset, 1);
+ auto adapter_offset = adapters::CreateAdapter(
+ fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
+ adapters::FinishAdapterBuffer(fbb, adapter_offset);
+
+ auto* buf = fbb.GetBufferPointer();
+
+ // Retrieve the Parameter from the Adapter
+ const auto* adapter = adapters::GetAdapter(buf);
+ ASSERT_NE(adapter, nullptr);
+ ASSERT_NE(adapter->parameters(), nullptr);
+ ASSERT_EQ(adapter->parameters()->size(), 1u);
+
+ const auto* param = adapter->parameters()->Get(0);
+ ASSERT_NE(param, nullptr);
+
+ // The raw_data is 64 bytes but shape says 8x4 floats = 128 bytes.
+ // CreateOrtValueOverLoraParameter must throw.
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ExcessRawData) {
+ // Craft a flatbuffer Parameter where raw_data has MORE bytes than expected.
+ // Shape (2, 2) with float => 4 elements => 16 bytes expected, but we supply 32.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ auto name_offset = fbb.CreateString("excess_param");
+ std::vector dims = {2, 2};
+ auto dims_offset = fbb.CreateVector(dims);
+
+ // 2 * 2 floats = 4 elements = 16 bytes expected. Supply 32.
+ std::vector excess_data(32, 0);
+ fbb.ForceVectorAlignment(excess_data.size(), sizeof(uint8_t), 8);
+ auto data_offset = fbb.CreateVector(excess_data);
+
+ auto param_offset = adapters::CreateParameter(
+ fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset);
+
+ auto params_offset = fbb.CreateVector(¶m_offset, 1);
+ auto adapter_offset = adapters::CreateAdapter(
+ fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
+ adapters::FinishAdapterBuffer(fbb, adapter_offset);
+
+ const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer());
+ ASSERT_NE(adapter, nullptr);
+
+ const auto* param = adapter->parameters()->Get(0);
+ ASSERT_NE(param, nullptr);
+
+ // Excess data should also trigger the mismatch throw.
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingName) {
+ // Parameter with null name should throw gracefully.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ std::vector dims = {2, 2};
+ std::vector raw_data(16, 0); // 2*2 floats = 16 bytes
+
+ // name is nullptr, all other fields are valid
+ auto param_offset = adapters::CreateParameterDirect(
+ fbb, /*name=*/nullptr, &dims, adapters::TensorDataType::FLOAT, &raw_data);
+
+ const auto* param = BuildAdapterAndGetParam(fbb, param_offset);
+ ASSERT_NE(param, nullptr);
+ ASSERT_EQ(param->name(), nullptr);
+
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingDims) {
+ // Parameter with null dims should throw gracefully.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ std::vector raw_data(16, 0);
+
+ // dims is nullptr
+ auto param_offset = adapters::CreateParameterDirect(
+ fbb, "no_dims_param", /*dims=*/nullptr, adapters::TensorDataType::FLOAT, &raw_data);
+
+ const auto* param = BuildAdapterAndGetParam(fbb, param_offset);
+ ASSERT_NE(param, nullptr);
+ ASSERT_EQ(param->dims(), nullptr);
+
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_EmptyDims) {
+ // Parameter with an empty dims vector should throw gracefully.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ std::vector empty_dims;
+ std::vector raw_data(16, 0);
+
+ auto param_offset = adapters::CreateParameterDirect(
+ fbb, "empty_dims_param", &empty_dims, adapters::TensorDataType::FLOAT, &raw_data);
+
+ const auto* param = BuildAdapterAndGetParam(fbb, param_offset);
+ ASSERT_NE(param, nullptr);
+ ASSERT_EQ(param->dims()->size(), 0u);
+
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingRawData) {
+ // Parameter with null raw_data should throw gracefully.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ std::vector dims = {2, 2};
+
+ // raw_data is nullptr
+ auto param_offset = adapters::CreateParameterDirect(
+ fbb, "no_data_param", &dims, adapters::TensorDataType::FLOAT, /*raw_data=*/nullptr);
+
+ const auto* param = BuildAdapterAndGetParam(fbb, param_offset);
+ ASSERT_NE(param, nullptr);
+ ASSERT_EQ(param->raw_data(), nullptr);
+
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_UndefinedDataType) {
+ // Parameter with UNDEFINED data_type should throw gracefully.
+ flatbuffers::FlatBufferBuilder fbb;
+
+ std::vector dims = {2, 2};
+ std::vector raw_data(16, 0);
+
+ // data_type defaults to UNDEFINED when not set
+ auto param_offset = adapters::CreateParameterDirect(
+ fbb, "undef_type_param", &dims, adapters::TensorDataType::UNDEFINED, &raw_data);
+
+ const auto* param = BuildAdapterAndGetParam(fbb, param_offset);
+ ASSERT_NE(param, nullptr);
+ ASSERT_EQ(param->data_type(), adapters::TensorDataType::UNDEFINED);
+
+ ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
+}
+
+#endif // ORT_NO_EXCEPTIONS
+
#ifdef USE_CUDA
TEST(LoraAdapterTest, VerifyDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
diff --git a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc
new file mode 100644
index 0000000000000..8aa4c88052742
--- /dev/null
+++ b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc
@@ -0,0 +1,595 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Unit tests for the DQMatMulNBitsFusion graph transformer.
+// Tests Pattern 1: DQ(3D,axis=2)->Reshape->Transpose([1,0])->[Cast]->MatMul/Gemm -> MatMulNBits
+// Tests Pattern 2: DQ(2D,axis=0)->MatMul/Gemm -> MatMulNBits
+
+#include "core/common/span_utils.h"
+#include "core/framework/int4.h"
+#include "core/graph/node_attr_utils.h"
+#include "core/optimizer/dq_matmulnbits_fusion.h"
+
+#include "test/test_environment.h"
+#include "test/unittest_util/framework_test_utils.h"
+#include "test/unittest_util/graph_transform_test_builder.h"
+#include "test/optimizer/graph_transform_test_fixture.h"
+#include "test/util/include/asserts.h"
+
+#include "gtest/gtest.h"
+
+#if !defined(DISABLE_CONTRIB_OPS)
+
+namespace onnxruntime {
+namespace test {
+
+static std::vector MakePackedUint4(const std::vector& values) {
+ const size_t num_pairs = UInt4x2::CalcNumInt4Pairs(values.size());
+ std::vector packed(num_pairs);
+ for (size_t i = 0; i < values.size(); i += 2) {
+ uint8_t lo = values[i] & 0x0F;
+ uint8_t hi = (i + 1 < values.size()) ? (values[i + 1] & 0x0F) : 0;
+ packed[i / 2] = UInt4x2(lo, hi);
+ }
+ return packed;
+}
+
+static void BuildPattern1Graph(ModelTestBuilder& builder,
+ int64_t M, int64_t N, int64_t K,
+ int64_t block_size,
+ bool with_zp,
+ bool with_cast,
+ bool use_gemm,
+ const std::vector* weight_values = nullptr,
+ const std::vector* scale_values = nullptr,
+ const std::vector* zp_values = nullptr) {
+ const int64_t num_blocks = K / block_size;
+
+ auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f);
+ auto* output = builder.MakeOutput();
+
+ const int64_t weight_elems = N * num_blocks * block_size;
+ std::vector w_vals;
+ if (weight_values) {
+ w_vals = *weight_values;
+ } else {
+ w_vals.resize(static_cast(weight_elems));
+ for (size_t i = 0; i < w_vals.size(); ++i) {
+ w_vals[i] = static_cast(i % 16);
+ }
+ }
+ auto w_packed = MakePackedUint4(w_vals);
+ auto* weight_arg = builder.MakeInitializer(
+ {N, num_blocks, block_size}, w_packed);
+
+ std::vector s_vals;
+ if (scale_values) {
+ s_vals = *scale_values;
+ } else {
+ s_vals.resize(static_cast(N * num_blocks));
+ for (size_t i = 0; i < s_vals.size(); ++i) {
+ s_vals[i] = 0.1f + 0.01f * static_cast(i % 10);
+ }
+ }
+ auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals);
+
+ NodeAttributes dq_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(2)), dq_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs);
+
+ auto* dq_output = builder.MakeIntermediate();
+ if (with_zp) {
+ std::vector z_vals;
+ if (zp_values) {
+ z_vals = *zp_values;
+ } else {
+ z_vals.resize(static_cast(N * num_blocks));
+ for (size_t i = 0; i < z_vals.size(); ++i) {
+ z_vals[i] = 8;
+ }
+ }
+ auto zp_packed = MakePackedUint4(z_vals);
+ auto* zp_arg = builder.MakeInitializer({N, num_blocks, 1}, zp_packed);
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs);
+ } else {
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs);
+ }
+
+ auto* reshape_shape = builder.MakeInitializer({2}, {N, K});
+ auto* reshape_output = builder.MakeIntermediate();
+ builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output});
+
+ NodeAttributes tp_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs);
+ auto* transpose_output = builder.MakeIntermediate();
+ builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs);
+
+ NodeArg* matmul_b = transpose_output;
+
+ if (with_cast) {
+ auto* cast_output = builder.MakeIntermediate();
+ NodeAttributes cast_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("to", static_cast(1)), cast_attrs);
+ builder.AddNode("Cast", {transpose_output}, {cast_output}, "", &cast_attrs);
+ matmul_b = cast_output;
+ }
+
+ if (use_gemm) {
+ builder.AddNode("Gemm", {input_a, matmul_b}, {output});
+ } else {
+ builder.AddNode("MatMul", {input_a, matmul_b}, {output});
+ }
+}
+
+static void BuildPattern1GemmBiasGraph(ModelTestBuilder& builder,
+ int64_t M, int64_t N, int64_t K,
+ int64_t block_size,
+ bool with_zp) {
+ const int64_t num_blocks = K / block_size;
+
+ auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f);
+ auto* output = builder.MakeOutput();
+
+ const int64_t weight_elems = N * num_blocks * block_size;
+ std::vector w_vals(static_cast(weight_elems));
+ for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16);
+ auto w_packed = MakePackedUint4(w_vals);
+ auto* weight_arg = builder.MakeInitializer({N, num_blocks, block_size}, w_packed);
+
+ std::vector s_vals(static_cast(N * num_blocks));
+ for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f;
+ auto* scale_arg = builder.MakeInitializer({N, num_blocks, 1}, s_vals);
+
+ NodeAttributes dq_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(2)), dq_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs);
+ auto* dq_output = builder.MakeIntermediate();
+
+ if (with_zp) {
+ std::vector z_vals(static_cast(N * num_blocks), 8);
+ auto zp_packed = MakePackedUint4(z_vals);
+ auto* zp_arg = builder.MakeInitializer({N, num_blocks, 1}, zp_packed);
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs);
+ } else {
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs);
+ }
+
+ auto* reshape_shape = builder.MakeInitializer({2}, {N, K});
+ auto* reshape_output = builder.MakeIntermediate();
+ builder.AddNode("Reshape", {dq_output, reshape_shape}, {reshape_output});
+
+ NodeAttributes tp_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("perm", std::vector{1, 0}), tp_attrs);
+ auto* transpose_output = builder.MakeIntermediate();
+ builder.AddNode("Transpose", {reshape_output}, {transpose_output}, "", &tp_attrs);
+
+ auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f));
+ builder.AddNode("Gemm", {input_a, transpose_output, bias_arg}, {output});
+}
+
+static void BuildPattern2Graph(ModelTestBuilder& builder,
+ int64_t M, int64_t N, int64_t K,
+ int64_t block_size,
+ bool with_zp,
+ bool use_gemm) {
+ const int64_t k_blocks = K / block_size;
+
+ auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f);
+ auto* output = builder.MakeOutput();
+
+ std::vector w_vals(static_cast(K * N));
+ for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16);
+ auto w_packed = MakePackedUint4(w_vals);
+ auto* weight_arg = builder.MakeInitializer({K, N}, w_packed);
+
+ std::vector s_vals(static_cast(k_blocks * N));
+ for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f + 0.01f * static_cast(i % 10);
+ auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals);
+
+ NodeAttributes dq_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs);
+ auto* dq_output = builder.MakeIntermediate();
+
+ if (with_zp) {
+ std::vector z_vals(static_cast(k_blocks * N), 8);
+ auto zp_packed = MakePackedUint4(z_vals);
+ auto* zp_arg = builder.MakeInitializer({k_blocks, N}, zp_packed);
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs);
+ } else {
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs);
+ }
+
+ if (use_gemm) {
+ builder.AddNode("Gemm", {input_a, dq_output}, {output});
+ } else {
+ builder.AddNode("MatMul", {input_a, dq_output}, {output});
+ }
+}
+
+static void BuildPattern2GemmBiasGraph(ModelTestBuilder& builder,
+ int64_t M, int64_t N, int64_t K,
+ int64_t block_size,
+ bool with_zp) {
+ const int64_t k_blocks = K / block_size;
+
+ auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f);
+ auto* output = builder.MakeOutput();
+
+ std::vector w_vals(static_cast(K * N));
+ for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16);
+ auto w_packed = MakePackedUint4(w_vals);
+ auto* weight_arg = builder.MakeInitializer({K, N}, w_packed);
+
+ std::vector s_vals(static_cast(k_blocks * N));
+ for (size_t i = 0; i < s_vals.size(); ++i) s_vals[i] = 0.1f;
+ auto* scale_arg = builder.MakeInitializer({k_blocks, N}, s_vals);
+
+ NodeAttributes dq_attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs);
+ auto* dq_output = builder.MakeIntermediate();
+
+ if (with_zp) {
+ std::vector z_vals(static_cast(k_blocks * N), 8);
+ auto zp_packed = MakePackedUint4(z_vals);
+ auto* zp_arg = builder.MakeInitializer({k_blocks, N}, zp_packed);
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs);
+ } else {
+ builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &dq_attrs);
+ }
+
+ auto* bias_arg = builder.MakeInitializer({N}, std::vector(static_cast(N), 0.5f));
+ builder.AddNode("Gemm", {input_a, dq_output, bias_arg}, {output});
+}
+
+static void BuildPattern1WrongAxis(ModelTestBuilder& builder,
+ int64_t M, int64_t N, int64_t K,
+ int64_t block_size) {
+ const int64_t num_blocks = K / block_size;
+ auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f);
+ auto* output = builder.MakeOutput();
+
+ std::vector w_vals(static_cast(N * num_blocks * block_size));
+ for (size_t i = 0; i < w_vals.size(); ++i) w_vals[i] = static_cast(i % 16);
+ auto w_packed = MakePackedUint4(w_vals);
+ auto* weight_arg = builder.MakeInitializer