From 88f59cd9ecac6eeed4d1f4a5086dc5741f7b2bf9 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 8 Aug 2019 13:05:19 +0800 Subject: [PATCH 1/7] Fix mkldnn subgraph with float64 --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 17 ++- src/operator/nn/mkldnn/mkldnn_base.cc | 23 +++- src/operator/rnn.cc | 16 ++- src/operator/subgraph/build_subgraph.cc | 33 +++++- .../subgraph/mkldnn/mkldnn_conv_property.h | 10 +- .../mkldnn/mkldnn_fc_post_quantize_property.h | 2 +- .../subgraph/mkldnn/mkldnn_fc_property.h | 8 +- ...kldnn_post_quantize_align_scale_property.h | 2 +- .../mkldnn/mkldnn_post_quantize_property.h | 8 +- .../mkldnn/mkldnn_subgraph_base-inl.h | 42 +++++++ src/operator/subgraph/subgraph_property.h | 112 ++++++++++++------ tests/python/mkl/test_subgraph.py | 21 ++++ tests/python/unittest/test_operator.py | 32 +++++ 13 files changed, 264 insertions(+), 62 deletions(-) create mode 100644 src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 5db9d6e5defc..34f4a0b0b062 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -47,17 +47,18 @@ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ #if MXNET_USE_MKLDNN == 1 +#include #include +#include #include #include -#include #include -#include -#include +#include #include "mkldnn.hpp" +#include "mxnet/graph_attr_types.h" #include "mxnet/ndarray.h" -#include "mxnet/resource.h" #include "mxnet/op_attr_types.h" +#include "mxnet/resource.h" using namespace mkldnn; namespace mxnet { @@ -132,6 +133,11 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) { return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); } +static inline bool SupportMKLDNNRNN(const NDArray &input) { + int ndim = input.shape().ndim(); + return (input.dtype() == mshadow::kFloat32) && (ndim == 3); +} + static inline bool SupportMKLDNNQuantize(int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == mshadow::kUint8; @@ -569,7 +575,8 @@ class MKLDNNMemory { } }; -void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, +template +void FallBackCompute(Compute fn, const AttrState &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index a13337b122c3..862947eb726a 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -420,7 +420,8 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p return mkldnn::memory::primitive_desc(data_md, pd.get_engine()); } -void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, +template +void FallBackCompute(Compute fn, const AttrState &attrs_states, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -461,7 +462,7 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, out_blobs[i] = output.data(); } - fn(attrs, ctx, in_blobs, req, out_blobs); + fn(attrs_states, ctx, in_blobs, req, out_blobs); for (size_t i = 0; i < out_blobs.size(); i++) { if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false); @@ -518,6 +519,24 @@ static bool SimilarArray(const mxnet::NDArray &arr1, const mxnet::NDArray &arr2, return success.load(); } +template void FallBackCompute(void (*)(nnvm::NodeAttrs const &, OpContext const &, + std::vector > const &, + std::vector > const &, + std::vector > const &), + nnvm::NodeAttrs const &, OpContext const &, + std::vector > const &, + std::vector > const &, + std::vector > const &); + +template void FallBackCompute(void (*)(OpStatePtr const &, OpContext const &, + std::vector > const &, + std::vector > const &, + std::vector > const &), + OpStatePtr const &, OpContext const &, + std::vector > const &, + std::vector > const &, + std::vector > const &); + void OpCheck::Init(const std::vector &inputs_, const std::vector &outputs_) { auto ctx = inputs_[0].ctx(); diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 86fb1c7d1ec6..5eb274abfa18 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -631,6 +631,20 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, }); }); } + +static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportMKLDNNRNN(inputs[0])) { + RNNStatefulComputeCPU(state_ptr, ctx, inputs, req, outputs); + return; + } + int use_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1); + dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", 0); + FallBackCompute(RNNStatefulCompute, state_ptr, ctx, inputs, req, outputs); + dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", use_mkldnn_rnn); +} #endif NNVM_REGISTER_OP(RNN) @@ -717,7 +731,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FStatefulCompute", RNNStatefulCompute) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) -.set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) +.set_attr("FStatefulComputeEx", RNNStatefulComputeExCPU) #endif .set_attr("FGradient", RNNGrad{"_backward_RNN"}) .set_attr("FResourceRequestEx", RNNResourceEx) diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 0420e2248077..8c6fb23ae3cf 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -24,7 +24,6 @@ */ #include #include -#include #include #include #include @@ -105,6 +104,28 @@ void ResetNodeLabels(const nnvm::Graph& g, subgraph_nodes->clear(); } +/* + * \brief Prepare NodeAttr for node. NodeAttr will be used in SubgraphSelectorV2. + */ +static const std::shared_ptr PrepareNodeAttr(const nnvm::Graph& g, + const BiDirectedNode& node) { + const auto& indexed_graph = g.indexed_graph(); + if (g.HasAttr("dtype") && g.HasAttr("shape") && g.HasAttr("dispatch_mode")) { + const auto& vdtype = g.GetAttr("dtype"); + const auto& vshape = g.GetAttr("shape"); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + auto ret = std::make_shared(); + ret->dispatch_mode = dispatch_modes[indexed_graph.node_id(node.node)]; + for (const auto& e : node.node->inputs) { + ret->ishape.emplace_back(vshape[indexed_graph.entry_id(e)]); + ret->itype.emplace_back(vdtype[indexed_graph.entry_id(e)]); + } + return ret; + } else { + return nullptr; + } +} + /*! * \brief This function traverses the nodes in a computation graph from a starting * node following the input edges and output edges, and marks all nodes that @@ -153,7 +174,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector CHECK_LT(nid, simple_nodes.size()); const bool select_input = (snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) && - subgraph_selector->SelectInput(*cur_node, *snode); + subgraph_selector->SelectInput(*cur_node, *snode, PrepareNodeAttr(g, *snode)); if (select_input) { // e.node is a subgraph node snode->label = label; @@ -170,7 +191,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector CHECK_LT(nid, simple_nodes.size()); const bool select_output = (snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) && - subgraph_selector->SelectOutput(*cur_node, *snode); + subgraph_selector->SelectOutput(*cur_node, *snode, PrepareNodeAttr(g, *snode)); if (select_output) { // it->first is a subgraph node snode->label = label; @@ -325,14 +346,16 @@ void SelectSubgraphNodes(nnvm::Graph* g, SubgraphSelectorV2Ptr subgraph_selector std::vector* subgraph_selectors, const BiDirectedNode* node, const size_t snid, size_t* subgraph_id) { const auto& indexed_graph = g->indexed_graph(); + auto node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) { return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node); }; - if (simple_nodes[snid]->label == -1 && subgraph_selector->Select(*node)) { + if ((simple_nodes[snid]->label == -1) && + subgraph_selector->Select(*node, PrepareNodeAttr(*g, *node))) { // pre-select nodes that can be grouped in a subgraph std::vector preselected_nodes; PreSelectSubgraphNodes(*g, subgraph_selector, *subgraph_id, snid, simple_nodes, - &preselected_nodes); + &preselected_nodes); // filter out unqualified pre-selected nodes std::vector filtered_nodes = subgraph_selector->Filter(preselected_nodes); diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index 42ea9ea67fcf..40b3f7c1d010 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -28,7 +28,7 @@ #include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../../tensor/matrix_op-inl.h" #include "../common.h" -#include "../subgraph_property.h" +#include "mkldnn_subgraph_base-inl.h" namespace mxnet { namespace op { @@ -61,15 +61,15 @@ class SgMKLDNNConvSelector : public SubgraphSelector { disable_conv_sum_(dis_conv_sum), quantize_(quantize) {} - bool Select(const nnvm::Node &n) override { + bool Select(const nnvm::Node& n, const std::shared_ptr& node_attr) override { if (n.op() && n.op()->name == "Convolution") { const auto ¶m = nnvm::get(n.attrs.parsed); - if (param.kernel.ndim() == 2) { + if (param.kernel.ndim() == 2 && SupportMKLDNNAttr(node_attr)) { status_ = disable_all_ ? kSuccess : kStart; matched_list_.clear(); matched_list_.push_back(&n); return true; - } + } } return false; } @@ -161,7 +161,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { CHECK_GE(matched_list_.size(), 1); auto new_selector = SgMKLDNNConvSelector(disable_all_, disable_conv_bn_, disable_conv_act_, disable_conv_sum_, quantize_); - new_selector.Select(*matched_list_[0]); + new_selector.Select(*matched_list_[0], nullptr); *this = new_selector; } }; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h index 8b5c08802986..f4f252bc92e9 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h @@ -33,7 +33,7 @@ #include "../../nn/fully_connected-inl.h" #include "../../quantization/requantize-inl.h" #include "../common.h" -#include "../subgraph_property.h" +#include "mkldnn_subgraph_base-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index 17d30d272ad1..6dcd114d9ec4 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -32,7 +32,7 @@ #include #include "../common.h" #include "../../tensor/matrix_op-inl.h" -#include "../subgraph_property.h" +#include "mkldnn_subgraph_base-inl.h" #include "mkldnn_fc-inl.h" namespace mxnet { @@ -58,8 +58,8 @@ class SgMKLDNNFCSelector : public SubgraphSelector { disable_fc_eltwise_(dis_fc_eltwise), quantized_(quantized) {} - bool Select(const nnvm::Node &n) override { - if (n.op() == Op::Get("FullyConnected")) { + bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { + if (n.op() == Op::Get("FullyConnected") && SupportMKLDNNAttr(node_attr)) { status_ = disable_fc_eltwise_ ? kSuccess : kStart; matched_list_.clear(); matched_list_.push_back(&n); @@ -150,7 +150,7 @@ class SgMKLDNNFCSelector : public SubgraphSelector { void Reset() override { CHECK_GE(matched_list_.size(), 1); auto new_selector = SgMKLDNNFCSelector(disable_fc_eltwise_, quantized_); - new_selector.Select(*matched_list_[0]); + new_selector.Select(*matched_list_[0], nullptr); *this = new_selector; } }; diff --git a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h index 5c5037e7a116..c05c2a8e4a6a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h @@ -24,7 +24,7 @@ #include #include #include "../common.h" -#include "../subgraph_property.h" +#include "mkldnn_subgraph_base-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h index e78b8d1bfa42..38b08968d8a5 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h @@ -20,14 +20,14 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 +#include #include #include -#include -#include "../common.h" -#include "../subgraph_property.h" #include "../../nn/mkldnn/mkldnn_convolution-inl.h" -#include "mkldnn_conv-inl.h" #include "../../quantization/requantize-inl.h" +#include "../common.h" +#include "mkldnn_conv-inl.h" +#include "mkldnn_subgraph_base-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h new file mode 100644 index 000000000000..4c8a7ab285b3 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_ +#if MXNET_USE_MKLDNN == 1 + +#include "../subgraph_property.h" + +namespace mxnet { +namespace op { + +static inline bool SupportMKLDNNAttr(const std::shared_ptr& node_attr) { + if (node_attr) { + int ndim = node_attr->ishape[0].ndim(); + return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) && + (node_attr->itype[0] == mshadow::kFloat32) && (ndim == 1 || ndim == 2 || ndim == 4); + } else { + return true; + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_ diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index b8b125fdb5de..3d3dd692834f 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -20,12 +20,14 @@ #ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_ #define MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_ -#include #include #include +#include +#include +#include +#include #include #include -#include namespace mxnet { namespace op { @@ -53,6 +55,12 @@ struct BiDirectedNode { std::unordered_map> outputs; }; // struct BiDirectedNode +struct NodeAttr { + DispatchMode dispatch_mode; + ShapeVector ishape; + std::vector itype; +}; + /* * This provides criteria for the graph partitioning algorithm to select * nodes to subgraphs. @@ -80,21 +88,41 @@ class SubgraphSelector { /*! * \brief Determines if to search for other nodes to form a subgraph from the seed_node. */ - virtual bool Select(const nnvm::Node &seed_node) = 0; + virtual bool Select(const nnvm::Node& seed_node) { + LOG(FATAL) << "No Select is implemented."; + return false; + } + virtual bool Select(const nnvm::Node& seed_node, const std::shared_ptr& node_attr) { + return Select(seed_node); + } /*! * \brief Determines if to select input_node when traverse to the cur_node. * \param cur_node the node for determining whether its input_node should be selected * \param input_node the input node of the cur_node * \return true if input_node is selected */ - virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) = 0; + virtual bool SelectInput(const nnvm::Node& cur_node, const nnvm::Node& input_node) { + LOG(FATAL) << "No SelectInput is implemented."; + return false; + } + virtual bool SelectInput(const nnvm::Node& cur_node, const nnvm::Node& input_node, + const std::shared_ptr& input_node_attr) { + return SelectInput(cur_node, input_node); + } /*! * \brief Determines if to select output_node when traverse to the cur_node. * \param cur_node the node for determining whether its output_node should be selected * \param output_node the output node of the cur_node * \return true if output_node is selected */ - virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) = 0; + virtual bool SelectOutput(const nnvm::Node& cur_node, const nnvm::Node& output_node) { + LOG(FATAL) << "No SelectOutput is implemented."; + return false; + } + virtual bool SelectOutput(const nnvm::Node& cur_node, const nnvm::Node& output_node, + const std::shared_ptr& output_node_attr) { + return SelectOutput(cur_node, output_node); + } /*! * \brief Post processes pre-selected subgraph nodes. Return a list of nodes that * users want to keep in subgraph(s). @@ -119,31 +147,48 @@ class SubgraphSelectorV2 { /*! * \brief Determines if to search for other nodes to form a subgraph from the seed_node. */ - virtual bool Select(const BiDirectedNode& seed_node) = 0; + virtual bool Select(const BiDirectedNode& seed_node) { + LOG(FATAL) << "No Select is implemented."; + return false; + } + virtual bool Select(const BiDirectedNode& seed_node, const std::shared_ptr& node_attr) { + return Select(seed_node); + }; /*! * \brief Determines if to select input_node when traverse to the cur_node. * \param cur_node the node for determining whether its input_node should be selected * \param input_node the input node of the cur_node * \return true if input_node is selected */ - virtual bool SelectInput(const BiDirectedNode& cur_node, - const BiDirectedNode& input_node) = 0; + virtual bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& input_node) { + LOG(FATAL) << "No SelectInput is implemented."; + return false; + } + virtual bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& input_node, + const std::shared_ptr& input_node_attr) { + return SelectInput(cur_node, input_node); + } /*! * \brief Determines if to select output_node when traverse to the cur_node. * \param cur_node the node for determining whether its output_node should be selected * \param output_node the output node of the cur_node * \return true if output_node is selected */ - virtual bool SelectOutput(const BiDirectedNode& cur_node, - const BiDirectedNode& output_node) = 0; + virtual bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) { + LOG(FATAL) << "No SelectOutput is implemented."; + return false; + } + virtual bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node, + const std::shared_ptr& output_node_attr) { + return SelectOutput(cur_node, output_node); + } /*! * \brief Post processes pre-selected subgraph nodes. Return a list of nodes that * users want to keep in subgraph(s). * \param candidates re-selected subgraph nodes to filt * \return a list of nodes to keep */ - virtual std::vector Filter( - const std::vector& candidates) { + virtual std::vector Filter(const std::vector& candidates) { return candidates; } @@ -162,22 +207,22 @@ class SubgraphSelectorV2Bridge : public SubgraphSelectorV2 { virtual ~SubgraphSelectorV2Bridge() {} - bool Select(const BiDirectedNode& seed_node) override { - return ss_ptr_->Select(*seed_node.node); + bool Select(const BiDirectedNode& seed_node, + const std::shared_ptr& node_attr) override { + return ss_ptr_->Select(*seed_node.node, node_attr); } - bool SelectInput(const BiDirectedNode& cur_node, - const BiDirectedNode& input_node) override { - return ss_ptr_->SelectInput(*cur_node.node, *input_node.node); + bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& input_node, + const std::shared_ptr& node_attr) override { + return ss_ptr_->SelectInput(*cur_node.node, *input_node.node, node_attr); } - bool SelectOutput(const BiDirectedNode& cur_node, - const BiDirectedNode& output_node) override { - return ss_ptr_->SelectOutput(*cur_node.node, *output_node.node); + bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node, + const std::shared_ptr& node_attr) override { + return ss_ptr_->SelectOutput(*cur_node.node, *output_node.node, node_attr); } - std::vector Filter( - const std::vector& candidates) override { + std::vector Filter(const std::vector& candidates) override { std::unordered_map node_2_snode_map; std::vector n_candidates; for (auto i : candidates) { @@ -276,7 +321,7 @@ class SubgraphProperty { * \param subgraph_id subgraph id */ virtual void AdjustSubgraphNode(const std::vector& subgraph_nodes, - const SubgraphSelectorV2Ptr &subgraph_selector, + const SubgraphSelectorV2Ptr& subgraph_selector, const int subgraph_id = 0) const { CHECK_EQ(GetPropertyType(), kAdjust); LOG(FATAL) << "Not implement AdjustSubgraphNode() for this subgraph property."; @@ -309,7 +354,7 @@ class SubgraphProperty { /*! * \brief Set an attr with name in the attr map. */ - template + template SubgraphProperty& SetAttr(const std::string& name, const T& value) { attrs_[name] = std::make_shared(value); return *this; @@ -317,7 +362,7 @@ class SubgraphProperty { /*! * \brief Get the attr with the name. */ - template + template const T& GetAttr(const std::string& name) const { auto it = attrs_.find(name); CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty"; @@ -355,7 +400,7 @@ class SubgraphPropertyEntry { public: explicit SubgraphPropertyEntry(std::shared_ptr entry) : entry_(entry) {} - template + template SubgraphPropertyEntry set_attr(const std::string& name, const T value) const { if (entry_) entry_->SetAttr(name, value); return *this; @@ -371,7 +416,7 @@ class SubgraphBackend { /*! * \brief Set an attr with name in the attr map. */ - template + template SubgraphBackend& SetAttr(const std::string& name, const T& value) { attrs_[name] = std::make_shared(value); return *this; @@ -379,7 +424,7 @@ class SubgraphBackend { /*! * \brief Get the attr with the name. */ - template + template const T& GetAttr(const std::string& name) const { auto it = attrs_.find(name); CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty"; @@ -427,7 +472,7 @@ class SubgraphBackendEntry { public: explicit SubgraphBackendEntry(SubgraphBackendPtr entry) : entry_(entry) {} - template + template SubgraphBackendEntry set_attr(const std::string& name, const T value) const { entry_->SetAttr(name, value); return *this; @@ -448,8 +493,8 @@ class SubgraphBackendRegistry { SubgraphBackendPtr& GetSubgraphBackend(const std::string& name) { auto it = backend_map_.find(name); - CHECK(it != backend_map_.end()) << "SubgraphProperty " << name - << " is not found in SubgraphBackendRegistry"; + CHECK(it != backend_map_.end()) + << "SubgraphProperty " << name << " is not found in SubgraphBackendRegistry"; return it->second; } @@ -481,7 +526,7 @@ class SubgraphBackendRegistry { // This set is only used for the testing purpose. // key: property name, value: op name set typedef dmlc::ThreadLocalStore>> - SubgraphPropertyOpNameSet; + SubgraphPropertyOpNameSet; #define DECLARE_PROPERTY_EX(NAME, SubgraphPropertyType, X) \ static const DMLC_ATTRIBUTE_UNUSED auto __make_##SubgraphPropertyType##_##Name##_##X##__ @@ -492,8 +537,7 @@ typedef dmlc::ThreadLocalStore__REGISTER_PROPERTY__(#Name, &SubgraphPropertyType::Create) -#define DECLARE_BACKEND(Name) \ - static const DMLC_ATTRIBUTE_UNUSED auto __make_##Name##__ +#define DECLARE_BACKEND(Name) static const DMLC_ATTRIBUTE_UNUSED auto __make_##Name##__ #define MXNET_REGISTER_SUBGRAPH_BACKEND(Name) \ DECLARE_BACKEND(Name) = SubgraphBackendRegistry::Get()->__REGISTER_BACKEND__(#Name) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index eba77f2c8710..bfdf6fac7dc2 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -845,6 +845,27 @@ def test_neg_fc_relu(): syms, attrs, excluded_attrs = neg_fc_relu(no_bias, dshape, flatten) check_neg_fusion(syms, attrs, excluded_attrs, dshape, name='fc') +def test_float64_fallback(): + sym = mx.sym.FullyConnected( + mx.sym.Variable('in'), + mx.sym.Variable('w'), + mx.sym.Variable('b'), + num_hidden=2 + ) + + dtype = 'float64' + ex = sym.bind(mx.cpu(), + { + 'in': mx.nd.array([[2, 3, 4]], dtype=dtype), + 'w': mx.nd.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + 'b': mx.nd.array([7, 8], dtype=dtype) + }, + args_grad=None, + grad_req='write' + ) + ex.forward() + ex.outputs[0].wait_to_read() + if __name__ == "__main__": import nose nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ceee51a3e503..9dc5f0c21c38 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -271,6 +271,38 @@ def test_rnnrelu_dropout(): out = exe.forward(is_train=True) out[0].wait_to_read() +def test_RNN_float64(): + sym = mx.sym.RNN( + mx.sym.Variable('in'), + mx.sym.Variable('par'), + mx.sym.Variable('s'), + state_size = (2), + num_layers = 1, + mode = 'rnn_tanh' + ) + + dtype = 'float64' + explicit_grad = { + 'in': mx.nd.ones([2, 1, 2], dtype=dtype), + 'par': mx.nd.ones([12], dtype=dtype), + 's': mx.nd.ones([1, 1, 2], dtype=dtype) + } + + args_grad = explicit_grad + grad_req = 'write' + + ex = sym.bind(mx.cpu(), + { + 'in': mx.nd.ones([2, 1, 2], dtype=dtype), + 'par': mx.nd.ones([12], dtype=dtype), + 's': mx.nd.ones([1, 1, 2], dtype=dtype) + }, + args_grad = args_grad, + grad_req = grad_req + ) + ex.forward() + ex.outputs[0].wait_to_read() + def np_softmax(x, axis=-1, temperature=1.0): x = x - np.max(x, axis=axis, keepdims=True) x = np.exp(x/temperature) From bc4d3b5b925d07374754c13f844c67daaf13edfa Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 12 Aug 2019 16:45:50 +0800 Subject: [PATCH 2/7] Fix ci Change-Id: I0bb4e8d7a0a534aa661601887cc633cb9c4fcadf --- src/operator/subgraph/subgraph_property.h | 2 +- src/operator/tensor/matrix_op-inl.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index 3d3dd692834f..0e386eb0ec2a 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -153,7 +153,7 @@ class SubgraphSelectorV2 { } virtual bool Select(const BiDirectedNode& seed_node, const std::shared_ptr& node_attr) { return Select(seed_node); - }; + } /*! * \brief Determines if to select input_node when traverse to the cur_node. * \param cur_node the node for determining whether its input_node should be selected diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 611dd7287206..4319a3b1a93e 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2635,7 +2635,7 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, mxnet::TShape ishape = in_attrs->at(split_enum::kData); if (!mxnet::ndim_is_known(dshape)) return false; if (param.axis >= 0) { - CHECK_LT(static_cast(param.axis), dshape.ndim()); + CHECK_LT(param.axis, dshape.ndim()); } else { CHECK_LT(param.axis + dshape.ndim(), dshape.ndim()); } From 5ea26ed3065358a69468d2e61d1d897e1f57c2c6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 13 Aug 2019 09:55:47 +0800 Subject: [PATCH 3/7] Fix test Change-Id: I96c529abe7adb6def90a22f03b3432263ef12fda --- tests/python/unittest/test_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 9dc5f0c21c38..d6dd3374e670 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -291,7 +291,7 @@ def test_RNN_float64(): args_grad = explicit_grad grad_req = 'write' - ex = sym.bind(mx.cpu(), + ex = sym.bind(default_context(), { 'in': mx.nd.ones([2, 1, 2], dtype=dtype), 'par': mx.nd.ones([12], dtype=dtype), From 6e912e62d6a1bd7f9ca63401e76dd305191c552b Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 29 Aug 2019 10:02:19 +0800 Subject: [PATCH 4/7] Update dmlc-core Change-Id: I472fb7bbffc16ed8c36494ab49838b08c59b2f12 --- .gitmodules | 2 +- 3rdparty/dmlc-core | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 90ef157f0eec..205d01a3827d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "3rdparty/dmlc-core"] path = 3rdparty/dmlc-core - url = https://github.com/dmlc/dmlc-core.git + url = https://github.com/ZhennanQin/dmlc-core.git [submodule "3rdparty/ps-lite"] path = 3rdparty/ps-lite url = https://github.com/dmlc/ps-lite diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index f1ff6cc117f4..a1dbe9d0aa49 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit f1ff6cc117f4e95169a9f62be549c8fe3e15c20f +Subproject commit a1dbe9d0aa49c28ed5679771fe9472515c0e4168 From 48a6ef3e2f9d1f6d0c62fde5a9d73b5e9ae4fa8f Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 16 Sep 2019 14:06:39 +0800 Subject: [PATCH 5/7] pin to offical dmlc Change-Id: I5a27dc83b892bf8fcb34bb089449d1d3b6e9beed --- .gitmodules | 2 +- 3rdparty/dmlc-core | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 205d01a3827d..90ef157f0eec 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "3rdparty/dmlc-core"] path = 3rdparty/dmlc-core - url = https://github.com/ZhennanQin/dmlc-core.git + url = https://github.com/dmlc/dmlc-core.git [submodule "3rdparty/ps-lite"] path = 3rdparty/ps-lite url = https://github.com/dmlc/ps-lite diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index a1dbe9d0aa49..9088d2ee02cd 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit a1dbe9d0aa49c28ed5679771fe9472515c0e4168 +Subproject commit 9088d2ee02cdfe393fd0569af4aefebe94f8b105 From fc97b9df1ebfe089392345be6f4469ecf7e3a5aa Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 18 Sep 2019 10:11:38 +0800 Subject: [PATCH 6/7] Fix GPU CI Change-Id: I285947e01bdb0651c2c7830ed4eb76931a09b754 --- tests/python/unittest/test_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2d13d072d8d8..57b672a9aa70 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -330,7 +330,7 @@ def test_RNN_float64(): args_grad = explicit_grad grad_req = 'write' - ex = sym.bind(default_context(), + ex = sym.bind(mx.cpu(), { 'in': mx.nd.ones([2, 1, 2], dtype=dtype), 'par': mx.nd.ones([12], dtype=dtype), From 6d3a6b3ccead0d33f960ae8cf12efacf36cfa7a2 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 19 Sep 2019 07:55:42 +0800 Subject: [PATCH 7/7] Fix GPU CI Change-Id: I6f23b51d6bda44f6ae18766ebe390118740bb9c7 --- tests/python/unittest/test_operator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 57b672a9aa70..77b8080422d9 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -311,6 +311,8 @@ def test_rnnrelu_dropout(): out[0].wait_to_read() def test_RNN_float64(): + if default_context().device_type == 'gpu': + return sym = mx.sym.RNN( mx.sym.Variable('in'), mx.sym.Variable('par'), @@ -330,7 +332,7 @@ def test_RNN_float64(): args_grad = explicit_grad grad_req = 'write' - ex = sym.bind(mx.cpu(), + ex = sym.bind(default_context(), { 'in': mx.nd.ones([2, 1, 2], dtype=dtype), 'par': mx.nd.ones([12], dtype=dtype),