Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix names for post quantize fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed Oct 19, 2021
1 parent f0b5603 commit f149a85
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/operator/subgraph/dnnl/dnnl_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ inline static bool DNNLBatchDotStorageType(const nnvm::NodeAttrs& attrs,
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}

NNVM_REGISTER_OP(_sg_dnnl_batch_dot)
.describe(R"code(_sg_dnnl_batch_dot)code" ADD_FILELINE)
NNVM_REGISTER_OP(_sg_onednn_batch_dot)
.describe(R"code(_sg_onednn_batch_dot)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLDotParam>(attrs.parsed);
// two normal inputs + min/max for quantized version
Expand Down Expand Up @@ -158,7 +158,7 @@ NNVM_REGISTER_OP(_sg_dnnl_batch_dot)
.set_attr<FQuantizedOp>("FQuantizedOp",
[](const NodeAttrs& attrs) {
nnvm::ObjectPtr node = nnvm::Node::Create();
node->attrs.op = Op::Get("_sg_dnnl_batch_dot");
node->attrs.op = Op::Get("_sg_onednn_batch_dot");
node->attrs.name = "quantized_" + attrs.name;
node->attrs.dict = attrs.dict;
node->attrs.dict["quantized"] = "True";
Expand Down
2 changes: 1 addition & 1 deletion src/operator/subgraph/dnnl/dnnl_batch_dot_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class SgDNNLBatchDotProperty : public SubgraphProperty {
});

n->attrs.name = node_name.str();
n->attrs.op = Op::Get("_sg_dnnl_batch_dot");
n->attrs.op = Op::Get("_sg_onednn_batch_dot");
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(sym));
n->attrs.dict["transpose_a"] = std::to_string(param.transpose_a);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class SgDNNLMatmulPostQuantizeSelector : public SubgraphSelector {

bool Select(const nnvm::Node& n) override {
if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") ||
n.op() == Op::Get("_sg_onednn_selfatt_valatt"))) {
n.op() == Op::Get("_sg_onednn_selfatt_valatt") ||
n.op() == Op::Get("_sg_onednn_batch_dot"))) {
status = disable_all ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand Down Expand Up @@ -152,7 +153,8 @@ class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty {
if (node->is_variable())
return;
if (node->op() == Op::Get("_sg_onednn_selfatt_qk") ||
node->op() == Op::Get("_sg_onednn_selfatt_valatt")) {
node->op() == Op::Get("_sg_onednn_selfatt_valatt") ||
node->op() == Op::Get("_sg_onednn_batch_dot")) {
interleaved_node = node;
} else if (node->op() == Op::Get("_contrib_requantize")) {
requantize_node = node;
Expand Down

0 comments on commit f149a85

Please sign in to comment.