Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
}

// Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits.
// Used by both DQMatMulToMatMulNBitsAction and DQCastMatMulToMatMulNBitsAction.
// Used by DQMatMulToMatMulNBitsAction.
struct TransposedQuantizedTensors {
Tensor weight;
Tensor scale;
Expand Down Expand Up @@ -486,149 +486,9 @@
return Status::OK();
}

DQCastMatMulToMatMulNBitsAction::DQCastMatMulToMatMulNBitsAction(
int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
: accuracy_level_{accuracy_level},
intra_op_thread_pool_{intra_op_thread_pool} {
ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");
}

Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
// Selected nodes layout (from DQCastMatMulToMatMulNBitsSelector):
// Input(0) = DQ node
// Input(1) = Cast on input B (between DQ and MatMul)
// Target() = MatMul node
auto* dq_node = selected_nodes.Input(0);
auto* cast_b_node = selected_nodes.Input(1);
auto& matmul_node = selected_nodes.Target();

// --- Transpose DQ weights/scales/zp via shared helper ---
TransposedQuantizedTensors transposed;
ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits(
graph, *dq_node, "fused_DQ_Cast_MatMul", intra_op_thread_pool_, transposed));

// MatMulNBits operates in the DQ scale dtype.
// Always insert Cast on input A (to DQ dtype) and Cast on output (DQ dtype to MatMul output dtype).
// ORT's redundant cast elimination optimizer will clean up unnecessary casts later.

// Determine DQ output element type (e.g., fp16)
int32_t dq_output_dtype = cast_b_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
// Determine MatMul output element type (e.g., fp32)
int32_t matmul_output_dtype = matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

const auto& dq_attrs = dq_node->GetAttributes();
const auto* weight_arg = dq_node->InputDefs()[0];
auto K = weight_arg->Shape()->dim(0).dim_value();
auto N = weight_arg->Shape()->dim(1).dim_value();
auto block_size = dq_attrs.at("block_size").i();
int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type();
auto bits = DQWeightBits(dt_weight);

// --- Create fp16 NodeArg for MatMulNBits input A ---
NodeArg* matmul_input_a = matmul_node.MutableInputDefs()[0];
ONNX_NAMESPACE::TypeProto input_a_fp16_type;
input_a_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype);
if (matmul_input_a->Shape()) {
*input_a_fp16_type.mutable_tensor_type()->mutable_shape() =
matmul_input_a->TypeAsProto()->tensor_type().shape();
}
auto cast_a_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_input_a_cast");
NodeArg* input_a_arg = &graph.GetOrCreateNodeArg(cast_a_out_name, &input_a_fp16_type);

// --- Create fp16 NodeArg for MatMulNBits output ---
ONNX_NAMESPACE::TypeProto output_fp16_type;
output_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype);
if (matmul_node.OutputDefs()[0]->Shape()) {
*output_fp16_type.mutable_tensor_type()->mutable_shape() =
matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().shape();
}
auto mnb_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_matmulnbits_out");
NodeArg* mnb_output_arg = &graph.GetOrCreateNodeArg(mnb_out_name, &output_fp16_type);

// --- Create MatMulNBits node ---
NodeAttributes attrs;
utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs);
utils::SetNodeAttribute(utils::MakeAttribute("N", N), attrs);
utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs);
utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs);
utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), attrs);

auto& new_node = graph.AddNode(
graph.GenerateNodeName(matmul_node.Name() + "_MatMulNBits"),
"MatMulNBits",
"Fused DQ+Cast+MatMul to MatMulNBits",
{input_a_arg},
{mnb_output_arg},
&attrs,
kMSDomain);

const auto& target_provider = matmul_node.GetExecutionProviderType();
new_node.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider);

// Add transposed weight, scale, zp to inputs
auto& input_defs = new_node.MutableInputDefs();
input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight)));
new_node.MutableInputArgsCount().push_back(1);

input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale)));
new_node.MutableInputArgsCount().push_back(1);

if (transposed.zero_point_proto) {
input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point)));
new_node.MutableInputArgsCount().push_back(1);
}

// --- Insert Cast on input A: matmul_input_dtype -> dq_output_dtype ---
{
NodeAttributes cast_attrs;
utils::SetNodeAttribute(
utils::MakeAttribute("to", static_cast<int64_t>(dq_output_dtype)),
cast_attrs);
auto& cast_node = graph.AddNode(
graph.GenerateNodeName(matmul_node.Name() + "_Cast_input_a"),
"Cast", "",
{matmul_input_a},
{input_a_arg},
&cast_attrs,
kOnnxDomain);
cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType());
}

// --- Insert Cast on output: dq_output_dtype -> matmul_output_dtype ---
{
NodeAttributes cast_attrs;
utils::SetNodeAttribute(
utils::MakeAttribute("to", static_cast<int64_t>(matmul_output_dtype)),
cast_attrs);
auto& cast_node = graph.AddNode(
graph.GenerateNodeName(matmul_node.Name() + "_Cast_output"),
"Cast", "",
{mnb_output_arg},
{const_cast<NodeArg*>(matmul_node.OutputDefs()[0])},
&cast_attrs,
kOnnxDomain);
cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType());
}

// --- Remove original nodes ---
auto remove_node = [&graph](Node* node) {
if (node) {
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
};

remove_node(&matmul_node);
remove_node(cast_b_node);
remove_node(dq_node);

return Status::OK();
}

static std::vector<NodeAndMoveInfo> GetGemmMoveInfo(bool does_q_node_exist) {
NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0};

Check warning on line 490 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:490:25: "NTO" is a misspelling of "NOT"

Check warning on line 490 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:490:2: "NTO" is a misspelling of "NOT"
NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1};

Check warning on line 491 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:491:25: "NTO" is a misspelling of "NOT"

Check warning on line 491 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:491:2: "NTO" is a misspelling of "NOT"
NTO::NodeLocation dq_bias{NTO::NodeType::kInput, 2};
NTO::NodeLocation target{NTO::NodeType::kTarget, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,6 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew {
concurrency::ThreadPool* intra_op_thread_pool_;
};

// Used together with DQCastMatMulToMatMulNBitsSelector.
// Handles DQ -> Cast(fp16->fp32) -> MatMul fusion to MatMulNBits,
// including optional Cast on input A and output type alignment.
struct DQCastMatMulToMatMulNBitsAction : public Action {
DQCastMatMulToMatMulNBitsAction(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool);

Status Run(Graph&, const NodesToOptimize& selected_nodes) const override;

private:
int64_t accuracy_level_;
concurrency::ThreadPool* intra_op_thread_pool_;
};

struct GemmReplaceWithQuant : public Action {
GemmReplaceWithQuant();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"

#include "core/mlas/inc/mlas.h"

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h"
Expand Down Expand Up @@ -306,7 +307,12 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
intra_op_thread_pool);

#if !defined(ORT_MINIMAL_BUILD)
std::vector<const char*> providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider};
// Include "" (empty string) to match nodes not yet assigned to an EP.
// For FP16 models on CPU EP, FP16 MatMul nodes are not claimed during partitioning
// (no FP16 MatMul kernel on CPU), leaving their EP unassigned. The DQ->MatMul fusion
// should still apply; the action assigns kCpuExecutionProvider to the resulting
// MatMulNBits node (which has both float and float16 CPU kernels).
std::vector<const char*> providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::DQMatMulToMatMulNBitsSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
Expand All @@ -316,25 +322,6 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
#else
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
#endif

// DQ -> Cast(fp16->fp32) -> MatMul pattern.
// Handles FP16 models where Cast nodes are inserted between DQ and MatMul.
const std::string cast_action_name{"DQCastMatMulToMatMulNBits"};

std::unique_ptr<Action> cast_action =
std::make_unique<QDQ::DQCastMatMulToMatMulNBitsAction>(qdq_matmulnbits_accuracy_level,
intra_op_thread_pool);

#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> cast_selector =
std::make_unique<QDQ::DQCastMatMulToMatMulNBitsSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(cast_action_name,
{{"MatMul", {}}},
std::move(cast_selector),
std::move(cast_action));
#else
qdq_selector_action_registry.RegisterAction(cast_action_name, std::move(cast_action));
#endif
}

void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
Expand Down Expand Up @@ -416,7 +403,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer(
apply_context,
// this transformer is compatible with CPU, DML, ACL and CUDA EP.
// There is further EP control on the rule level.
{kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} {
// Also accept nodes with empty EP (unassigned) so that individual selectors
// that include "" in their compatible providers can match unassigned nodes.
{kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider, ""}} {
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -651,75 +651,6 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod
return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]);
}

std::optional<NodesToOptimizeIndices>
DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
// Check EP compatibility
const std::string_view node_ep = node.GetExecutionProviderType();
if (!compatible_providers_.empty() &&
std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) {
return std::nullopt;
}

const auto& graph = graph_viewer.GetGraph();

// node must be MatMul
if (node.OpType() != "MatMul") {
return std::nullopt;
}

if (node.InputDefs().size() < 2) {
return std::nullopt;
}

// Check input B: must be Cast(fp16->fp32)
const Node* cast_b = graph_viewer.GetProducerNode(node.InputDefs()[1]->Name());
if (!cast_b || cast_b->OpType() != "Cast") {
return std::nullopt;
}

const auto& cast_b_attrs = cast_b->GetAttributes();
auto to_iter = cast_b_attrs.find("to");
if (to_iter == cast_b_attrs.end() ||
to_iter->second.i() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) {
return std::nullopt;
}

// Cast B input must be fp16
if (!cast_b->InputDefs()[0]->TypeAsProto() ||
cast_b->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) {
return std::nullopt;
}

// Cast B must have exactly 1 output edge (to MatMul) and not be a graph output
if (!optimizer_utils::CheckOutputEdges(graph, *cast_b, 1)) {
return std::nullopt;
}

// Cast B's input must come from a DQ node
const Node* dq_node = graph_viewer.GetProducerNode(cast_b->InputDefs()[0]->Name());
if (!dq_node || dq_node->OpType() != QDQ::DQOpName) {
return std::nullopt;
}

// DQ must have exactly 1 output edge (to Cast B) and not be a graph output
if (!optimizer_utils::CheckOutputEdges(graph, *dq_node, 1)) {
return std::nullopt;
}

if (!ValidateBlockwiseDQForMatMulNBits(graph, *dq_node)) {
return std::nullopt;
}

// Build selection
NodesToOptimizeIndicesBuilder builder;
builder.input_nodes.push_back(dq_node->Index());
builder.input_nodes.push_back(cast_b->Index());
builder.target_node = node.Index();

return builder.Build();
}

bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,27 +461,6 @@ class DQMatMulToMatMulNBitsSelector : public BaseSelector {
: BaseSelector(std::make_unique<DQMatMulNodeGroupSelector>(), compatible_providers) {}
};

// Convert "DQ -> Cast(fp16->fp32) -> MatMul" to "MatMulNBits".
// Handles Cast(fp16->fp32) between DQ and MatMul on input B, and optionally on input A.
// Selection layout:
// input_nodes[0] = DQ node
// input_nodes[1] = Cast on input B (between DQ and MatMul)
// target_node = MatMul
// output_nodes = {}
class DQCastMatMulToMatMulNBitsSelector : public NodeSelector {
public:
explicit DQCastMatMulToMatMulNBitsSelector(gsl::span<const char*> compatible_providers = {})
: compatible_providers_(compatible_providers.begin(), compatible_providers.end()) {}

DQCastMatMulToMatMulNBitsSelector(DQCastMatMulToMatMulNBitsSelector&& rhs) noexcept
: compatible_providers_(std::move(rhs.compatible_providers_)) {}

std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override;

private:
std::vector<std::string> compatible_providers_;
};

// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmSelector : public BaseSelector {
Expand Down
Loading
Loading