Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[master] Add aliases for subgraph operators to be compatible with old models #20679

Merged
merged 5 commits into from
Oct 31, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@
'zeros_like',
'_sg_onednn_conv',
'_sg_onednn_fully_connected',
'_sg_mkldnn_conv',
'_sg_mkldnn_fully_connected',
'broadcast_mul',
'Convolution_v1',
'IdentityAttachKLSparseReg',
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@
'_sg_onednn_fully_connected',
'_sg_onednn_selfatt_qk',
'_sg_onednn_selfatt_valatt',
'_sg_mkldnn_conv',
'_sg_mkldnn_fully_connected',
'_sg_mkldnn_selfatt_qk',
'_sg_mkldnn_selfatt_valatt',
])

# Functions that have to be cast to FP32 only for
Expand Down
3 changes: 2 additions & 1 deletion src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ inline QuantizeType NeedQuantize(ObjectPtr node,
need = false;
if (need) {
if ((quantize_granularity == "channel-wise") &&
(node->op() == Op::Get("_sg_onednn_fully_connected"))) {
(node->op() == Op::Get("_sg_onednn_fully_connected") ||
bartekkuncer marked this conversation as resolved.
Show resolved Hide resolved
node->op() == Op::Get("_sg_mkldnn_fully_connected"))) {
quantized_node->attrs.dict["channel_wise_quantize"] = "True";
}
quantized_node_map->insert(std::make_pair(node, quantized_node));
Expand Down
1 change: 1 addition & 0 deletions src/operator/subgraph/dnnl/dnnl_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ bool SgDNNLAvoidConvQuantizeInput(const NodeAttrs& attrs,
}

NNVM_REGISTER_OP(_sg_onednn_conv)
.add_alias("_sg_mkldnn_conv")
.describe(R"code(_sg_onednn_conv)code" ADD_FILELINE)
.set_num_inputs(SgDNNLConvNumInputs)
.set_num_outputs([](const NodeAttrs& attrs) {
Expand Down
1 change: 1 addition & 0 deletions src/operator/subgraph/dnnl/dnnl_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ static bool SgDNNLAvoidFCQuantizeInput(const NodeAttrs& attrs,
}

NNVM_REGISTER_OP(_sg_onednn_fully_connected)
.add_alias("_sg_mkldnn_fully_connected")
.describe(R"code(_sg_onednn_fully_connected)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelector {

bool Select(const nnvm::Node& n) override {
if (n.op() && support_requantize_fusion_op_name.count(n.op()->name)) {
if (n.op()->name == "_sg_onednn_conv") {
if (n.op()->name == "_sg_onednn_conv" || n.op()->name == "_sg_mkldnn_conv") {
auto const& param = nnvm::get<DNNLConvFusionParam>(n.attrs.parsed);
if (param.full_conv_param.dnnl_param.quantized) {
status = kStart;
Expand Down
2 changes: 2 additions & 0 deletions src/operator/subgraph/dnnl/dnnl_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ nnvm::ObjectPtr SgDNNLSelfAttQKQuantizedOp(const NodeAttrs& attrs) {
}

NNVM_REGISTER_OP(_sg_onednn_selfatt_qk)
.add_alias("_sg_mkldnn_selfatt_qk")
.describe(R"code(_sg_onednn_selfatt_qk)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
Expand Down Expand Up @@ -700,6 +701,7 @@ void DNNLSelfAttValAttOp::Forward(const OpContext& ctx,
}

NNVM_REGISTER_OP(_sg_onednn_selfatt_valatt)
.add_alias("_sg_mkldnn_selfatt_valatt")
.describe(R"code(_sg_onednn_selfatt_valatt)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLSelfAttParam>(attrs.parsed);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class SgDNNLTransformerPostQuantizeSelector : public SubgraphSelector {

bool Select(const nnvm::Node& n) override {
if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") ||
n.op() == Op::Get("_sg_onednn_selfatt_valatt"))) {
n.op() == Op::Get("_sg_onednn_selfatt_valatt") ||
bartekkuncer marked this conversation as resolved.
Show resolved Hide resolved
n.op() == Op::Get("_sg_mkldnn_selfatt_qk") ||
n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) {
status = disable_all ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand Down Expand Up @@ -152,7 +154,9 @@ class SgDNNLTransformerPostQuantizeProperty : public SubgraphProperty {
if (node->is_variable())
return;
if (node->op() == Op::Get("_sg_onednn_selfatt_qk") ||
node->op() == Op::Get("_sg_onednn_selfatt_valatt")) {
node->op() == Op::Get("_sg_onednn_selfatt_valatt") ||
node->op() == Op::Get("_sg_mkldnn_selfatt_qk") ||
node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) {
interleaved_node = node;
} else if (node->op() == Op::Get("_contrib_requantize")) {
requantize_node = node;
Expand Down