Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,23 @@ void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex);
}

bool MatMulNBitsNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
return true;
std::cout << "MatMulNBitsNodeGroupSelector::Check is not implemented yet." << std::endl;
// we should check that the first and third inputs hav DQ nodes, and that the output has a Q node
if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, 2)) {
return false;
}
// MatMulNBits has 2 DQ inputs and 1 Q output
if (dq_nodes.size() != 2 || q_nodes.size() != 1) {
return false;
}

return true;
}

bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,21 @@ class DQMatMulNodeGroupSelector : public NodeGroupSelector {
const std::vector<const Node*>& q_nodes) const override;
};

// MatMulNBits node group selector
class MatMulNBitsNodeGroupSelector : public NodeGroupSelector {
public:
explicit MatMulNBitsNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}

private:
bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;

bool allow_16bit_;
bool allow_4bit_;
};

// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmNodeGroupSelector : public NodeGroupSelector {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetReciprocalOpVersionsMap() {
static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() {
return {{"MatMul", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetMatMulNBitsOpVersionsMap() {
return {{"MatMulNBits", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetGemmOpVersionsMap() {
return {{"Gemm", {}}};
}
Expand Down Expand Up @@ -245,6 +248,13 @@ void RegisterGemmSelector(Selectors& qdq_selectors) {
std::move(selector));
}

void RegisterMatMulNbitsSelector(Selectors& qdq_selectors) {
/* register selector for MatMulNBits op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<MatMulNBitsNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetMatMulNBitsOpVersionsMap(),
std::move(selector));
}

void RegisterInstanceAndLayerNormalizationSelector(Selectors& qdq_selectors) {
/* register selector for InstanceNormalization op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<InstanceAndLayerNormalizationNodeGroupSelector>();
Expand Down Expand Up @@ -313,6 +323,7 @@ void SelectorManager::CreateSelectors() {
RegisterEinsumSelector(qdq_selectors_);
RegisterReciprocalSelector(qdq_selectors_);
RegisterMatMulSelector(qdq_selectors_);
RegisterMatMulNbitsSelector(qdq_selectors_);
RegisterGemmSelector(qdq_selectors_);
RegisterInstanceAndLayerNormalizationSelector(qdq_selectors_);
RegisterBatchNormalizationSelector(qdq_selectors_);
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateMatMulOpBuilder("MatMul", *this);
}

{
CreateMatMulNBitsOpBuilder("MatMulNBits", *this);
}

{
CreateMeanOpBuilder("Mean", *this);
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistratio

void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateMatMulNBitsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper,
Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type;
output_qnn_dtypes.push_back(qnn_data_type);
}
if (IsCpuBackend(qnn_model_wrapper.GetQnnBackendType())) {
if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) {
// Currently QnnIr has no constraints on datatypes
return Status::OK();
} else if (IsCpuBackend(qnn_model_wrapper.GetQnnBackendType())) {
return CheckCpuDataTypes(input_qnn_dtypes, output_qnn_dtypes);
} else if (IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) {
return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class BaseOpBuilder : public IOpBuilder {
{"QuantizeLinear", QNN_OP_QUANTIZE},

{"MatMul", QNN_OP_MAT_MUL},
{"MatMulNBits", "MatMulNBits"},

{"Elu", QNN_OP_ELU},
{"Relu", QNN_OP_RELU},
Expand Down
Loading