Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix mkldnn subgraph with float64
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin committed Aug 12, 2019
1 parent 57927a9 commit 9147b38
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 62 deletions.
17 changes: 12 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_

#if MXNET_USE_MKLDNN == 1
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include <algorithm>
#include <memory>
#include <vector>
#include "mkldnn.hpp"
#include "mxnet/graph_attr_types.h"
#include "mxnet/ndarray.h"
#include "mxnet/resource.h"
#include "mxnet/op_attr_types.h"
#include "mxnet/resource.h"
using namespace mkldnn;
namespace mxnet {

Expand Down Expand Up @@ -132,6 +133,11 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRNN(const NDArray &input) {
int ndim = input.shape().ndim();
return (input.dtype() == mshadow::kFloat32) && (ndim == 3);
}

static inline bool SupportMKLDNNQuantize(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 ||
dtype == mshadow::kUint8;
Expand Down Expand Up @@ -569,7 +575,8 @@ class MKLDNNMemory {
}
};

void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
template <typename Compute, typename AttrState>
void FallBackCompute(Compute fn, const AttrState &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand Down
23 changes: 21 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p
return mkldnn::memory::primitive_desc(data_md, pd.get_engine());
}

void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
template <typename Compute, typename AttrState>
void FallBackCompute(Compute fn, const AttrState &attrs_states,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -461,7 +462,7 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
out_blobs[i] = output.data();
}

fn(attrs, ctx, in_blobs, req, out_blobs);
fn(attrs_states, ctx, in_blobs, req, out_blobs);
for (size_t i = 0; i < out_blobs.size(); i++) {
if (req[i] == kAddTo && outputs[i].IsMKLDNNData())
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
Expand Down Expand Up @@ -518,6 +519,24 @@ static bool SimilarArray(const mxnet::NDArray &arr1, const mxnet::NDArray &arr2,
return success.load();
}

template void FallBackCompute(void (*)(nnvm::NodeAttrs const &, OpContext const &,
std::vector<TBlob, std::allocator<TBlob> > const &,
std::vector<OpReqType, std::allocator<OpReqType> > const &,
std::vector<TBlob, std::allocator<TBlob> > const &),
nnvm::NodeAttrs const &, OpContext const &,
std::vector<NDArray, std::allocator<NDArray> > const &,
std::vector<OpReqType, std::allocator<OpReqType> > const &,
std::vector<NDArray, std::allocator<NDArray> > const &);

template void FallBackCompute(void (*)(OpStatePtr const &, OpContext const &,
std::vector<TBlob, std::allocator<TBlob> > const &,
std::vector<OpReqType, std::allocator<OpReqType> > const &,
std::vector<TBlob, std::allocator<TBlob> > const &),
OpStatePtr const &, OpContext const &,
std::vector<NDArray, std::allocator<NDArray> > const &,
std::vector<OpReqType, std::allocator<OpReqType> > const &,
std::vector<NDArray, std::allocator<NDArray> > const &);

void OpCheck::Init(const std::vector<mxnet::NDArray> &inputs_,
const std::vector<mxnet::NDArray> &outputs_) {
auto ctx = inputs_[0].ctx();
Expand Down
16 changes: 15 additions & 1 deletion src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,20 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
});
});
}

static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRNN(inputs[0])) {
RNNStatefulComputeCPU(state_ptr, ctx, inputs, req, outputs);
return;
}
int use_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", 0);
FallBackCompute(RNNStatefulCompute<cpu>, state_ptr, ctx, inputs, req, outputs);
dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", use_mkldnn_rnn);
}
#endif

NNVM_REGISTER_OP(RNN)
Expand Down Expand Up @@ -717,7 +731,7 @@ The definition of GRU here is slightly different from paper but compatible with
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeCPU)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
.set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx)
Expand Down
33 changes: 28 additions & 5 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <mxnet/op_attr_types.h>
#include <unordered_set>
#include <stack>
#include <queue>
Expand Down Expand Up @@ -105,6 +104,28 @@ void ResetNodeLabels(const nnvm::Graph& g,
subgraph_nodes->clear();
}

/*
* \brief Prepare NodeAttr for node. NodeAttr will be used in SubgraphSelectorV2.
*/
static const std::shared_ptr<NodeAttr> PrepareNodeAttr(const nnvm::Graph& g,
const BiDirectedNode& node) {
const auto& indexed_graph = g.indexed_graph();
if (g.HasAttr("dtype") && g.HasAttr("shape") && g.HasAttr("dispatch_mode")) {
const auto& vdtype = g.GetAttr<nnvm::DTypeVector>("dtype");
const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& dispatch_modes = g.GetAttr<mxnet::DispatchModeVector>("dispatch_mode");
auto ret = std::make_shared<NodeAttr>();
ret->dispatch_mode = dispatch_modes[indexed_graph.node_id(node.node)];
for (const auto& e : node.node->inputs) {
ret->ishape.emplace_back(vshape[indexed_graph.entry_id(e)]);
ret->itype.emplace_back(vdtype[indexed_graph.entry_id(e)]);
}
return ret;
} else {
return nullptr;
}
}

/*!
* \brief This function traverses the nodes in a computation graph from a starting
* node following the input edges and output edges, and marks all nodes that
Expand Down Expand Up @@ -153,7 +174,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector
CHECK_LT(nid, simple_nodes.size());
const bool select_input =
(snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
subgraph_selector->SelectInput(*cur_node, *snode);
subgraph_selector->SelectInput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
if (select_input) {
// e.node is a subgraph node
snode->label = label;
Expand All @@ -170,7 +191,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector
CHECK_LT(nid, simple_nodes.size());
const bool select_output =
(snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
subgraph_selector->SelectOutput(*cur_node, *snode);
subgraph_selector->SelectOutput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
if (select_output) {
// it->first is a subgraph node
snode->label = label;
Expand Down Expand Up @@ -325,14 +346,16 @@ void SelectSubgraphNodes(nnvm::Graph* g, SubgraphSelectorV2Ptr subgraph_selector
std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors,
const BiDirectedNode* node, const size_t snid, size_t* subgraph_id) {
const auto& indexed_graph = g->indexed_graph();

auto node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
};
if (simple_nodes[snid]->label == -1 && subgraph_selector->Select(*node)) {
if ((simple_nodes[snid]->label == -1) &&
subgraph_selector->Select(*node, PrepareNodeAttr(*g, *node))) {
// pre-select nodes that can be grouped in a subgraph
std::vector<BiDirectedNode*> preselected_nodes;
PreSelectSubgraphNodes(*g, subgraph_selector, *subgraph_id, snid, simple_nodes,
&preselected_nodes);
&preselected_nodes);

// filter out unqualified pre-selected nodes
std::vector<BiDirectedNode*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
Expand Down
10 changes: 5 additions & 5 deletions src/operator/subgraph/mkldnn/mkldnn_conv_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include "../../nn/mkldnn/mkldnn_ops-inl.h"
#include "../../tensor/matrix_op-inl.h"
#include "../common.h"
#include "../subgraph_property.h"
#include "mkldnn_subgraph_base-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -61,15 +61,15 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
disable_conv_sum_(dis_conv_sum),
quantize_(quantize) {}

bool Select(const nnvm::Node &n) override {
bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
if (n.op() && n.op()->name == "Convolution") {
const auto &param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
if (param.kernel.ndim() == 2) {
if (param.kernel.ndim() == 2 && SupportMKLDNNAttr(node_attr)) {
status_ = disable_all_ ? kSuccess : kStart;
matched_list_.clear();
matched_list_.push_back(&n);
return true;
}
}
}
return false;
}
Expand Down Expand Up @@ -161,7 +161,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
CHECK_GE(matched_list_.size(), 1);
auto new_selector = SgMKLDNNConvSelector(disable_all_, disable_conv_bn_, disable_conv_act_,
disable_conv_sum_, quantize_);
new_selector.Select(*matched_list_[0]);
new_selector.Select(*matched_list_[0], nullptr);
*this = new_selector;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "../../nn/fully_connected-inl.h"
#include "../../quantization/requantize-inl.h"
#include "../common.h"
#include "../subgraph_property.h"
#include "mkldnn_subgraph_base-inl.h"

namespace mxnet {
namespace op {
Expand Down
8 changes: 4 additions & 4 deletions src/operator/subgraph/mkldnn/mkldnn_fc_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <string>
#include <vector>
#include "../common.h"
#include "../subgraph_property.h"
#include "mkldnn_subgraph_base-inl.h"

namespace mxnet {
namespace op {
Expand All @@ -53,8 +53,8 @@ class SgMKLDNNFCSelector : public SubgraphSelector {
public:
explicit SgMKLDNNFCSelector(const bool dis_fc_relu) : disable_fc_relu(dis_fc_relu) {}

bool Select(const nnvm::Node &n) override {
if (n.op() == Op::Get("FullyConnected")) {
bool Select(const nnvm::Node &n, const std::shared_ptr<NodeAttr>& node_attr) override {
if (n.op() == Op::Get("FullyConnected") && SupportMKLDNNAttr(node_attr)) {
status = disable_fc_relu ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand Down Expand Up @@ -119,7 +119,7 @@ class SgMKLDNNFCSelector : public SubgraphSelector {
void Reset() override {
CHECK_GE(matched_list.size(), 1);
auto new_selector = SgMKLDNNFCSelector(disable_fc_relu);
new_selector.Select(*matched_list[0]);
new_selector.Select(*matched_list[0], nullptr);
*this = new_selector;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <string>
#include <vector>
#include "../common.h"
#include "../subgraph_property.h"
#include "mkldnn_subgraph_base-inl.h"

namespace mxnet {
namespace op {
Expand Down
8 changes: 4 additions & 4 deletions src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_POST_QUANTIZE_PROPERTY_H_
#if MXNET_USE_MKLDNN == 1

#include <set>
#include <string>
#include <vector>
#include <set>
#include "../common.h"
#include "../subgraph_property.h"
#include "../../nn/mkldnn/mkldnn_convolution-inl.h"
#include "mkldnn_conv-inl.h"
#include "../../quantization/requantize-inl.h"
#include "../common.h"
#include "mkldnn_conv-inl.h"
#include "mkldnn_subgraph_base-inl.h"

namespace mxnet {
namespace op {
Expand Down
42 changes: 42 additions & 0 deletions src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_
#if MXNET_USE_MKLDNN == 1

#include "../subgraph_property.h"

namespace mxnet {
namespace op {

static inline bool SupportMKLDNNAttr(const std::shared_ptr<NodeAttr>& node_attr) {
if (node_attr) {
int ndim = node_attr->ishape[0].ndim();
return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) &&
(node_attr->itype[0] == mshadow::kFloat32) && (ndim == 1 || ndim == 2 || ndim == 4);
} else {
return true;
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_
Loading

0 comments on commit 9147b38

Please sign in to comment.