diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index ef9e1b0cad490..12fbedd068866 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -763,6 +763,17 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod return false; } + // Reject fusion if the weight or scale initializer is shared (consumed by more than one node). + // When two DQ nodes reference the same initializer (tied embedding pattern, issue #28306), + // the first fusion would consume the initializer, leaving the second DQ unable to find it — + // causing a crash in TransposeDQWeightsForMatMulNBits with "Missing required scale". + const auto* weight_arg = weight_dq->InputDefs()[0]; + const auto* scale_arg = weight_dq->InputDefs()[1]; + if (graph.GetConsumerNodes(weight_arg->Name()).size() > 1 || + graph.GetConsumerNodes(scale_arg->Name()).size() > 1) { + return false; + } + if (is_gemm) { // If there's a second DQ node (for bias), it must feed input 2 if (dq_nodes.size() == 2) { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index 03005e3a07386..aa65a5e80d3f6 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -319,6 +319,86 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch_Cuda) RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider()); } +// Two MatMul nodes sharing the SAME weight and scale initializers (tied embedding pattern). +// Regression test for issue #28306: the second DQ->MatMul fusion used to crash with +// "Missing required scale" because the first fusion consumed the shared initializer. +// Both DQ nodes should be rejected from fusion when weight or scale is shared. +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_SharedWeight(const std::vector& input_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input_shape, -100.0f, 100.0f); + auto* input2_arg = builder.MakeInput(input_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + + // Both DQ nodes share the SAME weight and scale initializers (tied embedding). + auto* shared_weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* shared_scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg, zp_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg, zp_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input1_arg, dq1_output}, {output_arg}); + // Second MatMul is sunk to avoid unused-output; use a separate output node. + auto* output2_arg = builder.MakeOutput(); + builder.AddNode("MatMul", {input2_arg, dq2_output}, {output2_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + // Fusion must NOT happen: shared initializers prevent safe fusion. + EXPECT_EQ(op_to_count["MatMul"], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_SharedWeight) { + RunDQMatMulNotConverted_SharedWeight({12, 12}, {12, 37}, 0, 16, 0); + RunDQMatMulNotConverted_SharedWeight({12, 12}, {12, 37}, 0, 16, 0); + RunDQMatMulNotConverted_SharedWeight({12, 12}, {12, 37}, 0, 16, 0); + RunDQMatMulNotConverted_SharedWeight({12, 12}, {12, 37}, 0, 16, 0); +} + // Input1 // | DQ // \ /