From f149a85fc2126c4378e2342818c8c396d8b9b719 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Tue, 19 Oct 2021 18:55:49 +0200 Subject: [PATCH] Fix names for post quantize fuse --- src/operator/subgraph/dnnl/dnnl_batch_dot.cc | 6 +++--- src/operator/subgraph/dnnl/dnnl_batch_dot_property.h | 2 +- .../subgraph/dnnl/dnnl_matmul_post_quantize_property.h | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc index 009d7e90f861..d48f5171c6cf 100644 --- a/src/operator/subgraph/dnnl/dnnl_batch_dot.cc +++ b/src/operator/subgraph/dnnl/dnnl_batch_dot.cc @@ -114,8 +114,8 @@ inline static bool DNNLBatchDotStorageType(const nnvm::NodeAttrs& attrs, return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); } -NNVM_REGISTER_OP(_sg_dnnl_batch_dot) - .describe(R"code(_sg_dnnl_batch_dot)code" ADD_FILELINE) +NNVM_REGISTER_OP(_sg_onednn_batch_dot) + .describe(R"code(_sg_onednn_batch_dot)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { auto const& param = nnvm::get(attrs.parsed); // two normal inputs + min/max for quantized version @@ -158,7 +158,7 @@ NNVM_REGISTER_OP(_sg_dnnl_batch_dot) .set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); - node->attrs.op = Op::Get("_sg_dnnl_batch_dot"); + node->attrs.op = Op::Get("_sg_onednn_batch_dot"); node->attrs.name = "quantized_" + attrs.name; node->attrs.dict = attrs.dict; node->attrs.dict["quantized"] = "True"; diff --git a/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h b/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h index 4975bdd2afb0..3e42f4ad798d 100644 --- a/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h +++ b/src/operator/subgraph/dnnl/dnnl_batch_dot_property.h @@ -79,7 +79,7 @@ class SgDNNLBatchDotProperty : public SubgraphProperty { }); n->attrs.name = node_name.str(); - n->attrs.op = Op::Get("_sg_dnnl_batch_dot"); + n->attrs.op = Op::Get("_sg_onednn_batch_dot"); CHECK(n->attrs.op); n->attrs.subgraphs.emplace_back(std::make_shared(sym)); n->attrs.dict["transpose_a"] = std::to_string(param.transpose_a); diff --git a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h index d681c4d3421a..6fbd97fd1f56 100644 --- a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h +++ b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h @@ -53,7 +53,8 @@ class SgDNNLMatmulPostQuantizeSelector : public SubgraphSelector { bool Select(const nnvm::Node& n) override { if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") || - n.op() == Op::Get("_sg_onednn_selfatt_valatt"))) { + n.op() == Op::Get("_sg_onednn_selfatt_valatt") || + n.op() == Op::Get("_sg_onednn_batch_dot"))) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -152,7 +153,8 @@ class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty { if (node->is_variable()) return; if (node->op() == Op::Get("_sg_onednn_selfatt_qk") || - node->op() == Op::Get("_sg_onednn_selfatt_valatt")) { + node->op() == Op::Get("_sg_onednn_selfatt_valatt") || + node->op() == Op::Get("_sg_onednn_batch_dot")) { interleaved_node = node; } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node;