Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,17 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod
return false;
}

// Reject fusion if the weight or scale initializer is shared (consumed by more than one node).
// When two DQ nodes reference the same initializer (tied embedding pattern, issue #28306),
// the first fusion would consume the initializer, leaving the second DQ unable to find it —
// causing a crash in TransposeDQWeightsForMatMulNBits with "Missing required scale".
const auto* weight_arg = weight_dq->InputDefs()[0];
const auto* scale_arg = weight_dq->InputDefs()[1];
if (graph.GetConsumerNodes(weight_arg->Name()).size() > 1 ||
graph.GetConsumerNodes(scale_arg->Name()).size() > 1) {
return false;
}

if (is_gemm) {
// If there's a second DQ node (for bias), it must feed input 2
if (dq_nodes.size() == 2) {
Expand Down
80 changes: 80 additions & 0 deletions onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,86 @@ TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch_Cuda)
RunDQMatMulNotConverted_TypeShapeMismatch<Int4x2, true>({2, 12, 37}, {2, 37, 12}, 0, 16, 0, DefaultCudaExecutionProvider());
}

// Two MatMul nodes sharing the SAME weight and scale initializers (tied embedding pattern).
// Regression test for issue #28306: the second DQ->MatMul fusion used to crash with
// "Missing required scale" because the first fusion consumed the shared initializer.
// Both DQ nodes should be rejected from fusion when weight or scale is shared.
template <typename T, bool use_zp>
typename std::enable_if<std::is_same_v<T, Int4x2> || std::is_same_v<T, UInt4x2>, void>::type
RunDQMatMulNotConverted_SharedWeight(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& weight_shape,
const int64_t axis,
const int64_t block_size,
int64_t accuracy_level) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput(input_shape, -100.0f, 100.0f);
auto* input2_arg = builder.MakeInput(input_shape, -100.0f, 100.0f);
auto* output_arg = builder.MakeOutput();

NodeAttributes attrs;
utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs);
utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs);

auto scale_shape = std::vector<int64_t>{weight_shape};
scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size;

// Both DQ nodes share the SAME weight and scale initializers (tied embedding).
auto* shared_weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0));
auto* shared_scales_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f);

auto* dq1_output = builder.MakeIntermediate();
auto* dq2_output = builder.MakeIntermediate();

if constexpr (use_zp) {
auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0));
builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg, zp_arg}, {dq1_output}, "", &attrs);
builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg, zp_arg}, {dq2_output}, "", &attrs);
} else {
builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg}, {dq1_output}, "", &attrs);
builder.AddNode("DequantizeLinear", {shared_weight_arg, shared_scales_arg}, {dq2_output}, "", &attrs);
}

builder.AddNode("MatMul", {input1_arg, dq1_output}, {output_arg});
// Second MatMul is sunk to avoid unused-output; use a separate output node.
auto* output2_arg = builder.MakeOutput();
builder.AddNode("MatMul", {input2_arg, dq2_output}, {output2_arg});
};

auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(false);
// Fusion must NOT happen: shared initializers prevent safe fusion.
EXPECT_EQ(op_to_count["MatMul"], 2);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2);
EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 0);
};

std::function<void(SessionOptions&)> add_session_options_fn{};
if (accuracy_level >= 0) {
add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) {
std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
std::to_string(accuracy_level).c_str());
};
}

TransformerTester(build_test_case,
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
21 /*opset_version*/,
1e-5 /*per_sample_tolerance*/,
1e-5 /*relative_per_sample_tolerance*/,
nullptr,
add_session_options_fn);
}

TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_SharedWeight) {
RunDQMatMulNotConverted_SharedWeight<UInt4x2, false>({12, 12}, {12, 37}, 0, 16, 0);
RunDQMatMulNotConverted_SharedWeight<Int4x2, false>({12, 12}, {12, 37}, 0, 16, 0);
RunDQMatMulNotConverted_SharedWeight<UInt4x2, true>({12, 12}, {12, 37}, 0, 16, 0);
RunDQMatMulNotConverted_SharedWeight<Int4x2, true>({12, 12}, {12, 37}, 0, 16, 0);
}

// Input1
// | DQ
// \ /
Expand Down
Loading