Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed Oct 20, 2021
1 parent 488a78a commit 880c751
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
20 changes: 14 additions & 6 deletions src/operator/nn/dnnl/dnnl_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,24 @@ void DNNLBatchDotFwd::Execute(const OpContext& ctx,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto lhs = dnnl::memory(
fwd_pd->src_desc(), engine, reinterpret_cast<void*>(inputs[DotIn::lhs].data().dptr_));
auto rhs = dnnl::memory(
fwd_pd->weights_desc(), engine, reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_));
auto lhs = inputs[DotIn::lhs];
auto rhs = inputs[DotIn::rhs];
// Created primitive descriptor assumes that both inputs are in default format
if (lhs.IsDNNLData())
lhs = lhs.Reorder2Default();
if (rhs.IsDNNLData())
rhs = rhs.Reorder2Default();

auto lhs_mem =
dnnl::memory(fwd_pd->src_desc(), engine, reinterpret_cast<void*>(lhs.data().dptr_));
auto rhs_mem =
dnnl::memory(fwd_pd->weights_desc(), engine, reinterpret_cast<void*>(rhs.data().dptr_));
dnnl_output_t out_mem = CreateDNNLMem(
outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]);

dnnl_args_map_t args = {
{DNNL_ARG_SRC, lhs},
{DNNL_ARG_WEIGHTS, rhs},
{DNNL_ARG_SRC, lhs_mem},
{DNNL_ARG_WEIGHTS, rhs_mem},
{DNNL_ARG_DST, *out_mem.second},
};

Expand Down
8 changes: 3 additions & 5 deletions src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,19 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN)

MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLConvProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);

MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", Context::CPU());

MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLConvProperty).set_attr("quantize", true);

MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCProperty).set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty)
.set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerQKProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerValAttProperty);

MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty)
.set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class SgDNNLTransformerQKSelector : public SubgraphSelector {
}

std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) override {
if (status_ == kFail) {
if (status_ != kSuccess) {
return std::vector<nnvm::Node*>(0);
} else {
std::vector<nnvm::Node*> ret;
Expand Down
4 changes: 1 addition & 3 deletions tests/python/dnnl/subgraphs/test_matmul_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err
from mxnet.util import use_np
import math
from subgraph_common import check_fusion, check_neg_fusion, DATA_SHAPE

@use_np
@pytest.mark.parametrize('batch_size', [1, 32])
Expand Down Expand Up @@ -68,14 +67,13 @@ def forward(self, x, mask):
net.hybridize()
ref_out = net(in_data, mask)

fused_net.optimize_for(in_data, mask, backend="DNNL")
fused_net.optimize_for(in_data, mask, backend="ONEDNN")
out = fused_net(in_data, mask)
mx.nd.waitall()

for i in range(len(out)):
assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy())


calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1)
qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
exclude_layers=None,
Expand Down

0 comments on commit 880c751

Please sign in to comment.