diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index e4c4ac20fc02..ce98cf491150 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -200,12 +200,12 @@ The following environments can be used to profile the application without changi * MXNET_PROFILER_AUTOSTART - Values: 0(false) or 1(true) ```(default=0)``` - - Set to 1, MXNet starts the profiler automatically. The profiling result is stored into profile.json in the working directory. + - Set to 1, MXNet starts the profiler automatically. The profiling result is stored into profile.json in the working directory. * MXNET_PROFILER_MODE - Values: 0(false) or 1(true) ```(default=0)``` - - If set to '0', profiler records the events of the symbolic operators. - - If set to '1', profiler records the events of all operators. + - If set to '0', profiler records the events of the symbolic operators. + - If set to '1', profiler records the events of all operators. ## Interface between Python and the C API @@ -241,14 +241,14 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. * MXNET_CUDA_ALLOW_TENSOR_CORE - 0(false) or 1(true) ```(default=1)``` - - If set to '0', disallows Tensor Core use in CUDA ops. - - If set to '1', allows Tensor Core use in CUDA ops. + - If set to '0', disallows Tensor Core use in CUDA ops. + - If set to '1', allows Tensor Core use in CUDA ops. - This variable can only be set once in a session. * MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION - 0(false) or 1(true) ```(default=0)``` - - If set to '0', disallows implicit type conversions to Float16 to use Tensor Cores - - If set to '1', allows CUDA ops like RNN and Convolution to use TensorCores even with Float32 input data by using implicit type casting to Float16. Only has an effect if `MXNET_CUDA_ALLOW_TENSOR_CORE` is `1`. + - If set to '0', disallows implicit type conversions to Float16 to use Tensor Cores + - If set to '1', allows CUDA ops like RNN and Convolution to use TensorCores even with Float32 input data by using implicit type casting to Float16. Only has an effect if `MXNET_CUDA_ALLOW_TENSOR_CORE` is `1`. * MXNET_CUDA_LIB_CHECKING - 0(false) or 1(true) ```(default=1)``` @@ -328,6 +328,17 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. with float32. - Model accuracies do not necessarily improve with this environment variable turned on. +* MXNET_USE_FUSION + - Values: 0(false) or 1(true) ```(default=1)``` + - If this variable is set, MXNet will try fusing some of the operations (pointwise operations only for now). + - It works in Symbolic execution as well as in Gluon models hybridized with ```static_alloc=True``` option. + - Only applies to MXNet that has been compiled with CUDA (```pip install mxnet-cuXX``` or built from source with ```USE_CUDA=1```) and running on GPU. + +* MXNET_FUSION_VERBOSE + - Values: 0(false) or 1(true) ```(default=0)``` + - Only applies to MXNet that has been compiled with CUDA and when ```MXNET_USE_FUSION``` option is enabled. + - If this variable is set, MXNet will print the code for fused operators that it generated. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/src/common/exec_utils.cc b/src/common/exec_utils.cc new file mode 100644 index 000000000000..6782abd8b21f --- /dev/null +++ b/src/common/exec_utils.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file exec_utils.cc + * \brief Implementation of executor util functions. + */ + +#include "exec_utils.h" +#include +#include +#include + +namespace mxnet { +namespace common { + +void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables) { + using nnvm::Node; + using nnvm::NodePtr; + using nnvm::NodeEntry; + std::unordered_map old_new; + // use DFSVisit to copy all the nodes + DFSVisit(src.outputs, [&old_new, copy_variables](const NodePtr& node) { + NodePtr np; + if (copy_variables || !node->is_variable()) { + np = Node::Create(); + np->attrs = node->attrs; + } else { + np = node; + } + old_new[node.get()] = std::move(np); + }); + // connect nodes of new graph + for (const auto &kv : old_new) { + for (const NodeEntry& e : kv.first->inputs) { + Node *ptr = e.node.get(); + kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); + } + for (const NodePtr& p : kv.first->control_deps) { + kv.second->control_deps.emplace_back(old_new[p.get()]); + } + } + // set the head + for (const NodeEntry &e : src.outputs) { + (*dst).outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); + } +} + +bool CheckForInputNameDuplicates(const nnvm::IndexedGraph &idx) { + std::unordered_set names; + for (const auto& nid : idx.input_nodes()) { + const std::string &name = idx[nid].source->attrs.name; + if (names.count(name)) { + LOG(WARNING) << "Variable name " << name << " is used more than once!"; + return false; + } + names.insert(name); + } + return true; +} + +} // namespace common +} // namespace mxnet diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index d8b7a33bf22b..3bd2ef3597a9 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -621,6 +621,25 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, return g; } +/*! + * \brief Copy the graph, optionally leaving original Variable nodes. + * + * \param dst destination graph + * \param src source graph being copied + * \param copy_variable whether to copy or reuse Variable nodes from the + * source graph + */ +void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables); + +/*! + * \brief Check whether graph contains any duplicated names in its inputs. + * + * \param idx Indexed graph being checked + * + * \return true if there are no duplicates, false otherwise + */ +bool CheckForInputNameDuplicates(const nnvm::IndexedGraph &idx); + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_EXEC_UTILS_H_ diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index f544d6ba3392..25a326171510 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -34,10 +34,34 @@ #include #include #include +#include +#include namespace mxnet { namespace exec { +template +using FAccessSubgraphAttr = std::function, + std::vector> + (const NodeAttrs& attrs)>; + +using FAccessSubgraphShape = FAccessSubgraphAttr; +using FAccessSubgraphType = FAccessSubgraphAttr; +using FAccessSubgraphStorageType = FAccessSubgraphAttr; + +template +using FProvideSubgraphAttr = std::function &nodes, + const std::vector> &in_attrs, + const std::vector> &out_attrs)>; +using FProvideSubgraphShape = FProvideSubgraphAttr; +using FProvideSubgraphType = FProvideSubgraphAttr; +using FProvideSubgraphStorageType = FProvideSubgraphAttr; + +using TIsFusion = bool; +using TIsFusionHelper = bool; + /*! \brief reuse graph definition */ using nnvm::Graph; @@ -170,6 +194,24 @@ void AttachOpResources(const Graph& g, */ Graph DetectInplaceAddTo(Graph g); +/*! + * \brief Fuse pointwise operations in the forward pass. + * + * \param g input graph (needs to be entire graph, not just forward part) + * + * \return graph with fused pointwise operations in the forward pass + */ +Graph FusePointwiseForward(Graph&& g); + +/*! + * \brief Fuse pointwise operations in the backward pass. + * + * \param g input graph (needs to be entire graph, not just forward part) + * + * \return graph with fused pointwise operations in the backward pass + */ +Graph FusePointwiseBackward(Graph&& g); + /*! * \brief Infer shapes in the graph given the information. * \param graph The input graph. diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index d92253266f35..24fe4d995342 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include "./exec_pass.h" @@ -337,6 +338,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, if (!need_grad_) return g; for (size_t i = 0; i < g.outputs.size(); ++i) { NodeEntry ngrad(nnvm::Node::Create(), 0, 0); + ngrad.node->attrs.name = "_head_grad_" + std::to_string(i); head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i])); head_grad_map_[ngrad.node.get()] = i; } @@ -377,6 +379,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, for (const auto &e : g_grad.outputs) { g.outputs.push_back(e); } + return g; } @@ -796,6 +799,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::NodeEntryMap& feed_dict) { nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types); + // The following code of shape and dtype inferences and argument // initialization is for simple_bind only. Regular bind operation // should do this differently. @@ -976,6 +980,7 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping, this); return exec; } + /*! * \brief This function is triggered by both simple_bind * and bind flows. @@ -993,6 +998,41 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // setup gradient nnvm::Graph g = InitFullGraph(symbol, grad_req_types); +#if MXNET_USE_CUDA && !defined(_WIN32) + if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) { + nnvm::Graph unoptimized_graph; + common::CopyGraph(&unoptimized_graph, g, false); + + if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) { + g.attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs_); + g = FusePointwiseForward(std::move(g)); + g.attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs_); + g = FusePointwiseBackward(std::move(g)); + // Check the topological order of inputs + const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes(); + const auto &new_inputs = g.indexed_graph().input_nodes(); + if (original_inputs.size() != new_inputs.size()) { + LOG(WARNING) + << "Number of inputs after fusion does not match original number of inputs. " + << "This is most probably a bug. Disabling fusion for this run."; + g = unoptimized_graph; + } else { + for (size_t i = 0; i < new_inputs.size(); ++i) { + if (unoptimized_graph.indexed_graph()[original_inputs[i]].source->attrs.name != + g.indexed_graph()[new_inputs[i]].source->attrs.name) { + LOG(WARNING) << "Disabling fusion due to altered topological order of inputs."; + g = unoptimized_graph; + break; + } + } + } + } else { + LOG(WARNING) + << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; + } + } +#endif // MXNET_USE_CUDA + // create "device" and "context" attrs for the graph g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, @@ -1946,7 +1986,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes, verbose); - exec->Init(symbol, default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes, + exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes, tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads, &tmp_aux_states, shared_buffer, shared_exec); @@ -1985,7 +2025,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, } if (!init) { // init without subgraph - exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + exec->Init(symbol.Copy(), default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names, in_args, arg_grads, aux_states, shared_buffer, shared_exec); } @@ -2017,8 +2057,8 @@ Executor *Executor::Bind(nnvm::Symbol symbol, verbose); } } - exec->Init(symbol, default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, tmp_grad_req_type, - tmp_aux_states, reinterpret_cast(shared_exec)); + exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, + tmp_grad_req_type, tmp_aux_states, reinterpret_cast(shared_exec)); return exec; } } // namespace mxnet diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index d72325392604..80e4084c478e 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -63,6 +63,156 @@ bool ApplyOpInferAttr(const nnvm::Graph& g, return true; } +template +inline void GetAttrFromForwardNode(const uint32_t nid, + const nnvm::IndexedGraph &idx, + std::vector* rshape_ptr, + IsNone fis_none) { + std::vector& rshape = *rshape_ptr; + const nnvm::IndexedGraph::Node& inode = idx[nid]; + // gradient function, used to get node correspondence. + static auto& fgrad = + Op::GetAttr("FGradient"); + nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; + const nnvm::IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; + // use gradient function to find out the correspondence. + std::vector ograd(fwd_ptr->num_outputs()); + for (size_t i = 0; i < ograd.size(); ++i) { + ograd[i].index = static_cast(i); + } + // input gradient list + const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); + const nnvm::Node* igrad_node = nullptr; + // Input gradient assignement + for (size_t i = 0; i < igrad.size(); ++i) { + if (igrad[i].node->op() == inode.source->op()) { + uint32_t eid = idx.entry_id(nid, igrad[i].index); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { + // Need to skip empty forward shape, because it may not be + // available now and it is possible to infer the forward + // shape in one of the next a few passes + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } + if (igrad_node == nullptr) { + igrad_node = igrad[i].node.get(); + } else { + CHECK(igrad_node == igrad[i].node.get()); + } + } + } + // out grad entries + CHECK(igrad_node != nullptr) + << "Cannot find matching backward op for " << inode.source->attrs.name; + for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { + const nnvm::NodeEntry& e = igrad_node->inputs[i]; + if (e.node == nullptr) { + uint32_t eid = idx.entry_id(inode.inputs[i]); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; + } + } + } +} + +template +void GetAttrFromFusedNode(uint32_t nid, + const nnvm::IndexedGraph& idx, + std::vector* rshape_ptr, + IsNone fis_none, + const std::string& infer_fusion_name) { + std::vector& rshape = *rshape_ptr; + const auto& inode = idx[nid]; + // gradient function, used to get node correspondence. + static auto& fgrad = + Op::GetAttr("FGradient"); + nnvm::NodePtr fused_fwd_ptr = inode.source->control_deps[0]; + static auto& finfer_fused_shape = + Op::GetAttr(infer_fusion_name); + auto finfer = finfer_fused_shape.get(fused_fwd_ptr->op(), nullptr); + CHECK(finfer != nullptr) << "Operator " << fused_fwd_ptr->attrs.name << + " is marked as Fusion but does not allow accessing attributes"; + const auto& inferred_attrs = finfer(fused_fwd_ptr->attrs); + const auto& fwd_ptr = std::get<0>(inferred_attrs); + const auto& input_attrs = std::get<1>(inferred_attrs); + const auto& output_attrs = std::get<2>(inferred_attrs); + + // use gradient function to find out the correspondence. + std::vector ograd(fwd_ptr->num_outputs()); + for (size_t i = 0; i < ograd.size(); ++i) { + ograd[i].index = static_cast(i); + } + // input gradient list + const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); + const nnvm::Node* igrad_node = nullptr; + // Set the attributes of output gradients + // using attributes of forward node inputs + for (size_t i = 0; i < igrad.size(); ++i) { + if (igrad[i].node->op() == inode.source->op()) { + uint32_t eid = idx.entry_id(nid, igrad[i].index); + if (fis_none(rshape[eid])) { + rshape[eid] = input_attrs[i]; + } else if (!fis_none(input_attrs[i])) { + // Need to skip empty forward shape, because it may not be + // available now and it is possible to infer the forward + // shape in one of the next a few passes + CHECK_EQ(rshape[eid], input_attrs[i]) + << "Backward shape inconsistent with the forward shape"; + } + if (igrad_node == nullptr) { + igrad_node = igrad[i].node.get(); + } else { + CHECK(igrad_node == igrad[i].node.get()); + } + } + } + + // Set the attributes of input gradients + // using attributes of forward node outputs + CHECK(igrad_node != nullptr) + << "Cannot find matching backward op for " << inode.source->attrs.name; + for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { + const nnvm::NodeEntry& e = igrad_node->inputs[i]; + if (e.node == nullptr) { + uint32_t eid = idx.entry_id(inode.inputs[i]); + if (fis_none(rshape[eid])) { + rshape[eid] = output_attrs[e.index]; + } + } + } +} + +template +void ProvideAttrToFusion(const uint32_t nid, + const nnvm::IndexedGraph& idx, + const std::vector& rshape, + const std::string& provide_fusion_name) { + const auto& inode = idx[nid]; + std::vector> in_attrs; + std::vector> out_attrs; + for (const auto& dep_node : inode.source->control_deps) { + in_attrs.push_back({}); + out_attrs.push_back({}); + auto ¤t_in_attrs = in_attrs.back(); + auto ¤t_out_attrs = out_attrs.back(); + uint32_t dep_node_id = idx.node_id(dep_node.get()); + for (const auto& e : idx[dep_node_id].inputs) { + current_in_attrs.push_back(rshape[idx.entry_id(e)]); + } + for (size_t i = 0; i < dep_node->num_outputs(); ++i) { + current_out_attrs.push_back(rshape[idx.entry_id(dep_node_id, i)]); + } + } + auto provide = + Op::GetAttr(provide_fusion_name).get(inode.source->op(), nullptr); + CHECK(provide != nullptr) << + "Encountered Fusion operator that does not implement providing subgraph attr " << + provide_fusion_name << "."; + provide(inode.source->attrs, inode.source->control_deps, in_attrs, out_attrs); +} + /*!\brief * This is a duplicate of the InferAttr function in nnvm with minor modification * to support inferring storage type whose function signature is different from @@ -73,6 +223,7 @@ bool ApplyOpInferAttr(const nnvm::Graph& g, * \param ret graph used for attribute inference * \param emmpty_val empty value of the attribute * \param infer_name name of the function used for attribute inference + * \param infer_fusion_name name of the function used for accessing attributes in fused nodes * \param input_name name of the attribute in the graph used to store the * input data for attribute inference * \param attr_key_name name of the attribute used for inference for variable nodes @@ -90,10 +241,13 @@ bool ApplyOpInferAttr(const nnvm::Graph& g, * \param default_mode_val default value of the dispatch mode attribute on the node. Used * for storage type inference */ -template +template nnvm::Graph InferAttr(nnvm::Graph &&ret, const AttrType empty_val, const char* infer_name, + const char* infer_fusion_name, + const char* provide_fusion_name, const char* input_name, const char* attr_key_name, const char* attr_name, @@ -114,9 +268,6 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, Op::GetAttr(infer_name); static auto& is_backward = Op::GetAttr("TIsBackward"); - // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); // reshape shape vector AttrVector rshape; // dispatch mode vector @@ -209,53 +360,19 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); } } else if (is_backward.get(inode.source->op(), false) && - inode.control_deps.size() && bwd_identity_assign) { + inode.source->control_deps.size() && bwd_identity_assign) { CHECK(dispatch_mode_name == nullptr) << "Backward inference for node attributes is not available"; - CHECK_GE(inode.control_deps.size(), 1U) + CHECK_GE(inode.source->control_deps.size(), 1U) << "BackwardOp need to have control_deps to its forward op"; - const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; - // use gradient function to find out the correspondence. - std::vector ograd(fwd_ptr->num_outputs()); - for (size_t i = 0; i < ograd.size(); ++i) { - ograd[i].index = static_cast(i); - } - // input gradient list - auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); - const nnvm::Node* igrad_node = nullptr; - // Input gradient assignement - for (size_t i = 0; i < igrad.size(); ++i) { - if (igrad[i].node->op() == inode.source->op()) { - uint32_t eid = idx.entry_id(nid, igrad[i].index); - if (fis_none(rshape[eid])) { - rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; - } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { - // Need to skip empty forward shape, because it may not be - // available now and it is possible to infer the forward - // shape in one of the next a few passes - CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) - << "Backward shape inconsistent with the forward shape"; - } - if (igrad_node == nullptr) { - igrad_node = igrad[i].node.get(); - } else { - CHECK(igrad_node == igrad[i].node.get()); - } - } - } - // out grad entries - CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; - for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { - const nnvm::NodeEntry& e = igrad_node->inputs[i]; - if (e.node == nullptr) { - uint32_t eid = idx.entry_id(inode.inputs[i]); - if (fis_none(rshape[eid])) { - rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; - } - } + + static auto& is_fusion_helper = Op::GetAttr("TIsFusionHelper"); + if (!is_fusion_helper.get(fwd_ptr->op(), false)) { + GetAttrFromForwardNode(nid, idx, &rshape, fis_none); + } else { + GetAttrFromFusedNode(nid, idx, &rshape, fis_none, infer_fusion_name); } } else { DispatchMode* dispatch_mode = nullptr; @@ -280,6 +397,10 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, if (finfer != nullptr) { // Call inference function of the operator. try { + static auto& is_fusion = Op::GetAttr("TIsFusion"); + if (is_fusion.get(inode.source->op(), false)) { + ProvideAttrToFusion(nid, idx, rshape, provide_fusion_name); + } forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, nid, &ishape, &oshape, dispatch_mode); } catch (const std::exception& e) { @@ -394,9 +515,6 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, Op::GetAttr(infer_name); static auto& is_backward = Op::GetAttr("TIsBackward"); - // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); // reshape shape vector AttrVector rshape; // dispatch mode vector @@ -500,53 +618,20 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); } } else if (is_backward.get(inode.source->op(), false) && - inode.control_deps.size() && bwd_identity_assign) { + inode.source->control_deps.size() && bwd_identity_assign) { CHECK(dispatch_mode_name == nullptr) << "Backward inference for node attributes is not available"; - CHECK_GE(inode.control_deps.size(), 1U) + CHECK_GE(inode.source->control_deps.size(), 1U) << "BackwardOp need to have control_deps to its forward op"; - const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; - // use gradient function to find out the correspondence. - std::vector ograd(fwd_ptr->num_outputs()); - for (size_t i = 0; i < ograd.size(); ++i) { - ograd[i].index = static_cast(i); - } - // input gradient list - auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); - const nnvm::Node* igrad_node = nullptr; - // Input gradient assignement - for (size_t i = 0; i < igrad.size(); ++i) { - if (igrad[i].node->op() == inode.source->op()) { - uint32_t eid = idx.entry_id(nid, igrad[i].index); - if (fis_none(rshape[eid])) { - rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; - } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { - // Need to skip empty forward shape, because it may not be - // available now and it is possible to infer the forward - // shape in one of the next a few passes - CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) - << "Backward shape inconsistent with the forward shape"; - } - if (igrad_node == nullptr) { - igrad_node = igrad[i].node.get(); - } else { - CHECK(igrad_node == igrad[i].node.get()); - } - } - } - // out grad entries - CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; - for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { - const nnvm::NodeEntry& e = igrad_node->inputs[i]; - if (e.node == nullptr) { - uint32_t eid = idx.entry_id(inode.inputs[i]); - if (fis_none(rshape[eid])) { - rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; - } - } + + static auto& is_fusion_helper = Op::GetAttr("TIsFusionHelper"); + if (!is_fusion_helper.get(fwd_ptr->op(), false)) { + GetAttrFromForwardNode(nid, idx, &rshape, fis_none); + } else { + GetAttrFromFusedNode(nid, idx, &rshape, fis_none, + "FAccessSubgraphShape"); } } else { DispatchMode* dispatch_mode = nullptr; @@ -581,6 +666,11 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (finfer != nullptr) { // Call inference function of the operator. try { + static auto& is_fusion = Op::GetAttr("TIsFusion"); + if (is_fusion.get(inode.source->op(), false)) { + ProvideAttrToFusion(nid, idx, rshape, + "FProvideSubgraphShape"); + } forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, nid, &ishape, &oshape, dispatch_mode); } catch (const std::exception& e) { @@ -686,10 +776,11 @@ nnvm::Graph InferType(nnvm::Graph&& graph, if (dtype_attr_key.length() != 0) { graph.attrs["dtype_attr_key"] = std::make_shared(dtype_attr_key); } - return InferAttr( + return InferAttr( std::move(graph), -1, - "FInferType", "dtype_inputs", "dtype_attr_key", - "dtype", "dtype_num_unknown_nodes", + "FInferType", "FAccessSubgraphType", "FProvideSubgraphType", + "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, common::SameType, true, nullptr); } @@ -719,10 +810,12 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph, } // for storage type, the backward attr is not necessarily the same as it's correspondence - nnvm::Graph ret = InferAttr( + nnvm::Graph ret = InferAttr( std::move(graph), -1, - "FInferStorageType", "storage_type_inputs", "storage_type_attr_key", - "storage_type", "storage_type_num_unknown_nodes", + "FInferStorageType", "FAccessSubgraphStorageType", "FProvideSubgraphStorageType", + "storage_type_inputs", "storage_type_attr_key", "storage_type", + "storage_type_num_unknown_nodes", [](const int t) { return t == -1; }, common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable); diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc new file mode 100644 index 000000000000..c6e2405cb2a4 --- /dev/null +++ b/src/executor/pointwise_fusion_pass.cc @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file pointwise_fusion_pass.cc + * \brief Pass applying pointwise fusion. + * \author Clement Fuji Tsang + */ + +#include +#include +#include +#include +#include +#include +#include +#include "./simple_partition_pass.h" +#include "../operator/fusion/fused_op-inl.h" +#include "../operator/fusion/fused_op.h" +#include "../operator/operator_common.h" + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace exec { +namespace { + bool IsFusionCompatible(nnvm::Node* n) { + using namespace mxnet::fusion; + if (n->op() == nullptr) + return false; + std::string op_name = n->op()->name; + if (ops_desc.count(op_name)) + return true; + if (slice_ops.count(op_name)) + return false; + if (std::find(variable_io_ops.begin(), + variable_io_ops.end(), + op_name) != + variable_io_ops.end()) + return true; + return false; + } + + bool IsInputsOnlyCompatible(nnvm::Node* n) { + using namespace mxnet::fusion; + if (n->op() == nullptr) + return false; + std::string op_name = n->op()->name; + if (slice_ops.count(op_name)) { + if (op_name == "slice") { + // slice with non-default step attribute is not supported + // currently + if (n->attrs.dict.count("step") && + !(n->attrs.dict.at("step") == "()" || + n->attrs.dict.at("step") == "[]")) { + return false; + } + } + return true; + } + return false; + } + + nnvm::NodePtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) { + nnvm::Symbol subgraph_sym; + auto node = nnvm::Node::Create(); + subgraph_sym.outputs = subgraph.outputs; + node->attrs.subgraphs.emplace_back(std::make_shared(subgraph_sym)); + std::ostringstream name_oss; + // the name of the new node will be the concatenation of all the node names in the subgraph + DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) { + if (n->op() != nullptr) + name_oss << n->op()->name << "_"; + }); + auto subgraph_name = name_oss.str(); + subgraph_name.pop_back(); + node->attrs.name = subgraph_name; + node->attrs.dict["num_inputs"] = std::to_string(inputs_size); + node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); + node->attrs.op = Op::Get("_FusedOp"); + node->op()->attr_parser(&(node->attrs)); + return node; + } +} // namespace + +/*! + * \brief Replace a set of nodes by a subgraph node. + * This function is used specifically in pointwise fusion. + */ +template +Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector& subgraph_sets, + FCreateNode create_subgraph_node) { + for (auto subgraph_set : subgraph_sets) { + // Create MXNet subgraph + Graph subgraph; + const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set); + subgraph.outputs.resize(sub_outputs_in_main.size()); + for (auto p : sub_outputs_in_main) { + subgraph.outputs[p.second] = p.first; + } + // To generate a subgraph an input has to be replaced by data node (no op) + // and it has to be agnostic to the node from which it's an output + // (For example, even if two inputs are two different outputs from the same node, + // they need to be replaced by two completely separate data nodes) + auto inputs = GetSubgraphInputs(subgraph, subgraph_set); + auto subgraph_node = create_subgraph_node(subgraph, inputs.size()); + subgraph_node->inputs = inputs; + // replug inputs of node out of subgraph to be output of the subgraph node + // if it was a node in the subgraph + DFSVisit(g.outputs, + [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::NodePtr node) { + if (!subgraph_set.count(node.get())) { + for (auto &e : node->inputs) { + auto it = sub_outputs_in_main.find(e); + if (it != sub_outputs_in_main.end()) { + e.node = subgraph_node; + e.index = it->second; + } + } + } + }); + // replug outputs of the graph to be output of the subgraph node + // if it was a node in the subgraph + for (auto &e : g.outputs) { + auto it = sub_outputs_in_main.find(e); + if (it != sub_outputs_in_main.end()) { + e.node = subgraph_node; + e.index = it->second; + } + } + // move control dependencies between nodes of the subgraph and out of the subgraph + // to a dependencies between the subgraph node and the nodes out of the subgraph + DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::NodePtr& node) { + if (subgraph_set.count(node.get())) { + auto it = node->control_deps.begin(); + static auto& is_fusion = Op::GetAttr("TIsFusionHelper"); + std::vector new_control_deps; + while (it != node->control_deps.end()) { + if (subgraph_set.count(it->get())) { + new_control_deps.push_back(*it); + } else { + if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) { + uint32_t node_id = subgraph_node->control_deps.size(); + subgraph_node->control_deps.push_back(*it); + auto helper_node = op::MakeNode("_FusedOpOutHelper", + subgraph_node->attrs.name + "_" + + node->attrs.name + "_outhelper", + nullptr, + nullptr, + nullptr); + helper_node->attrs.parsed = + FusedOpHelperParamPtr(new FusedOpHelperParam( + nnvm::get(subgraph_node->attrs.parsed), + node_id)); + new_control_deps.push_back(helper_node); + } else { + new_control_deps.push_back(*it); + } + } + ++it; + } + node->control_deps = new_control_deps; + } + }); + + const auto& index = subgraph.indexed_graph(); + DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::NodePtr& node) { + for (auto &e : node->control_deps) { + if (subgraph_set.count(e.get())) { + uint32_t node_id = index.node_id(e.get()); + auto helper_node = op::MakeNode("_FusedOpHelper", + subgraph_node->attrs.name + "_" + + node->attrs.name + "_helper", + nullptr, + nullptr, + nullptr); + helper_node->attrs.parsed = + FusedOpHelperParamPtr(new FusedOpHelperParam( + nnvm::get(subgraph_node->attrs.parsed), + node_id)); + e = helper_node; + } + } + }); + } + Graph new_graph; + new_graph.outputs = g.outputs; + return new_graph; +} + +/* \brief Add nodes as inputs to the subgraph. This is used for operations + * which are only compatible when they are the first nodes in the + * subgraph. + */ +template +void AddInputsOnlyCompatible(const Graph &g, + std::vector >* subsets, + IsCompatible is_compatible) { + std::unordered_map node2setidx; + size_t subgraphs_fullsize = 0; + for (auto& s : *subsets) { + subgraphs_fullsize += s.size(); + } + node2setidx.reserve(subgraphs_fullsize); + for (size_t i = 0; i < subsets->size(); ++i) { + for (auto& n : (*subsets)[i]) { + node2setidx.insert({n, i}); + } + } + std::vector > to_add(subsets->size()); + DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::NodePtr& n) { + const auto& it = node2setidx.find(n.get()); + if (it != node2setidx.end()) { + for (auto& e : n->inputs) { + if (is_compatible(e.node.get())) + to_add[it->second].push_back(e.node.get()); + } + } + }); + + // Avoid duplicating the node that is input of two subsets + std::unordered_set added; + for (size_t i = 0; i < subsets->size(); ++i) { + std::vector heads; + for (auto n : subsets->at(i)) { + for (auto e : n->inputs) { + if (!subsets->at(i).count(e.node.get())) + heads.push_back(e); + } + } + for (size_t j = 0; j < to_add[i].size(); ++j) { + if (!added.count(to_add[i][j])) { + bool make_cycle = false; + const auto& node = to_add[i][j]; + std::vector _heads; + std::copy_if(heads.begin(), heads.end(), std::back_inserter(_heads), + [&node](const nnvm::NodeEntry& n) { + return n.node.get() != node; + }); + DFSVisit(_heads, [&make_cycle, &node](const nnvm::NodePtr& n) { + if (n.get() == node) + make_cycle = true; + }); + if (!make_cycle) { + (*subsets)[i].insert(to_add[i][j]); + added.insert(to_add[i][j]); + } + } + } + } +} + +Graph FusePointwiseForward(Graph &&g) { + Graph ret; + g.indexed_graph(); + const auto& num_forward_outputs = g.GetAttr("num_forward_outputs"); + Graph fg; + fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(), + g.outputs.begin() + num_forward_outputs); + auto subsets = GetCompatibleSubsets(fg, IsFusionCompatible); + AddInputsOnlyCompatible(fg, &subsets, IsInputsOnlyCompatible); + g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode); + ret.outputs = g.outputs; + return ret; +} + +Graph FusePointwiseBackward(Graph &&g) { + Graph ret; + g.indexed_graph(); + const auto& num_forward_outputs = g.GetAttr("num_forward_outputs"); + Graph fg; + fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(), + g.outputs.begin() + num_forward_outputs); + std::unordered_set exclusion_set; + DFSVisit(fg.outputs, [&exclusion_set](const nnvm::NodePtr& n) { + exclusion_set.insert(n.get()); + }); + auto subsets = GetCompatibleSubsets(g, [&exclusion_set](nnvm::Node* n) { + if (exclusion_set.count(n)) + return false; + return IsFusionCompatible(n); + }); + g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode); + ret.outputs = g.outputs; + return ret; +} + +} // namespace exec +} // namespace mxnet + +#endif // MXNET_USE_CUDA diff --git a/src/executor/simple_partition_pass.h b/src/executor/simple_partition_pass.h new file mode 100644 index 000000000000..5b26a4523c13 --- /dev/null +++ b/src/executor/simple_partition_pass.h @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file simple_partition_pass.h + * \brief Simple pass for partitioning a graph. + * \author Clement Fuji Tsang + */ +#ifndef MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_ +#define MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "exec_pass.h" + +namespace mxnet { +namespace exec { + + +/*! + * \brief Custom graph class, which contains bi-directional nodes + * required for traversing in both directions (from outputs to inputs + * and vice versa). It is a non-owning layer on top of NNVM graph, since + * NNVM graph enables traversing only in 1 direction (from outputs to inputs). + */ +class BidirectionalGraph { + public: + struct Node { + nnvm::Node* nnvmptr; + std::vector inputs; + std::vector outputs; + }; + + explicit BidirectionalGraph(const Graph &g) { + auto& idx = g.indexed_graph(); + auto num_nodes = idx.num_nodes(); + nodes.reserve(num_nodes); + nnvm2nid.reserve(num_nodes); + outputs.reserve(idx.outputs().size()); + // Create all the nodes in a new graph from + // nodes in the NNVM graph and store them + // in nodes array + DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) { + Node new_node; + new_node.nnvmptr = n.get(); + nnvm2nid[n.get()] = static_cast(nodes.size()); + nodes.emplace_back(std::move(new_node)); + }); + // Create all connections between nodes in + // the graph (both directions) + for (const auto& it : nnvm2nid) { + nnvm::Node* nnvmnode = it.first; + uint32_t nid = it.second; + for (auto& n : nnvmnode->inputs) { + uint32_t input_nid = nnvm2nid[n.node.get()]; + nodes[input_nid].outputs.emplace_back(&nodes[nid]); + nodes[nid].inputs.emplace_back(&nodes[input_nid]); + } + } + // Create output connections from the graph + for (auto& e : g.outputs) { + uint32_t nid = nnvm2nid[e.node.get()]; + outputs.emplace_back(&nodes[nid]); + } + } + + /* \brief Get all subsets of nodes, where: + * - graph constructed from nodes in each subset is a connected graph + * - every node fulfills a predicate is_compatible + * - if nodes u and v are part of a subset, then for each path between + * u and v in the original directed graph, all nodes on those paths + * are also part of the subset + * \param is_compatible A function taking nnvm::Node* and returning bool + * which identifies which nodes should be included in + * subsets. + */ + template + std::vector> get_subsets(FCompatible is_compatible) { + std::vector> subgraphs; + std::unordered_set incomp_set; + std::unordered_set all_set(nodes.size()); + std::vector separation_sets; + // Check each node for compatibility + // and, if it is incompatible, mark nodes + // on each side of it as not possible to be + // in the same subset + for (Node& node : nodes) { + if (!is_compatible(node.nnvmptr)) { + incomp_set.insert(&node); + std::unordered_set in_graph; + std::unordered_set out_graph; + std::vector dummy_head; + dummy_head.emplace_back(&node); + DFS(dummy_head, false, [&out_graph, &is_compatible](Node* node) { + if (is_compatible(node->nnvmptr)) + out_graph.insert(node); + }); + DFS(dummy_head, true, [&in_graph, is_compatible](Node* node) { + if (is_compatible(node->nnvmptr)) + in_graph.insert(node); + }); + if (!(in_graph.empty() || out_graph.empty())) + separation_sets.push_back(std::make_pair(in_graph, out_graph)); + } + all_set.emplace(&node); + } + IncompMap incomp_map; + std::unordered_set comp_set; + comp_set.insert(all_set.begin(), all_set.end()); + for (Node* n : incomp_set) { + comp_set.erase(n); + } + // For each node construct the map of nodes that cannot be in + // the same subset + for (Node* n : comp_set) { + for (PairSet p : separation_sets) { + if (p.first.count(n)) { + incomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + incomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + for (Node* incomp_n : incomp_set) { + incomp_map[n].erase(incomp_n); + } + } + std::unordered_set unused_set; + unused_set.reserve(comp_set.size()); + + for (auto& n : comp_set) { + unused_set.insert(n); + } + std::unordered_set visited; + std::deque stack(outputs.begin(), outputs.end()); + // Create subsets + while (!stack.empty()) { + Node* vertex = stack.front(); + stack.pop_front(); + if (!visited.count(vertex)) { + visited.insert(vertex); + if (unused_set.count(vertex)) { + subgraphs.emplace_back(naive_grow_subgraph(vertex, &unused_set, &incomp_map)); + } + for (Node* input : vertex->inputs) { + stack.emplace_back(input); + } + } + } + return subgraphs; + } + + private: + using PairSet = std::pair, std::unordered_set>; + using PairVec = std::pair, std::vector>; + using IncompMap = std::unordered_map>; + + /* \brief Traverse the graph using DFS in either direction. + * \param heads Starting nodes for the DFS algorithm. + * \param reverse If true, DFS will traverse the graph from + * outputs to inputs. Otherwise, it will + * traverse the graph from inputs to outputs. + * \param fvisit Function to call on each visisted node. + */ + template + void DFS(const std::vector& heads, bool reverse, FVisit fvisit) { + std::unordered_set visited; + std::vector vec(heads.begin(), heads.end()); + visited.reserve(heads.size()); + while (!vec.empty()) { + Node* vertex = vec.back(); + vec.pop_back(); + if (visited.count(vertex) == 0) { + visited.insert(vertex); + fvisit(vertex); + std::vector nexts = reverse ? vertex->inputs : vertex->outputs; + for (Node* node : nexts) { + if (visited.count(node) == 0) { + vec.emplace_back(node); + } + } + } + } + } + + /* \brief Get the connected subgraph that contains the head node, + * only previously unused nodes, according to the rules + * from incompatibility map. + * \param head Node which needs to be part of the returned subgraph. + * \param unused_set Only nodes from this set will be considered when + * adding to the growing subgraph. + * \param incomp_map Map containing data on which nodes are incompatible + * to be in the same subgraph. + */ + std::unordered_set naive_grow_subgraph(Node* head, + std::unordered_set* unused_set, + IncompMap* incomp_map) { + std::unordered_set subgraph; + std::unordered_set incomp_set; + std::deque stack; + stack.emplace_back(head); + while (!stack.empty()) { + Node* vertex = stack.back(); + stack.pop_back(); + if (unused_set->count(vertex) && !incomp_set.count(vertex)) { + unused_set->erase(vertex); + subgraph.insert(vertex); + incomp_set.insert((*incomp_map)[vertex].begin(), (*incomp_map)[vertex].end()); + // Traverse the grpah in both directions + for (Node* input : vertex->inputs) { + if (unused_set->count(input) && !incomp_set.count(input)) { + stack.emplace_back(input); + } + } + for (Node* output : vertex->outputs) { + if (unused_set->count(output) && !incomp_set.count(output)) { + stack.emplace_back(output); + } + } + } + } + return subgraph; + } + + friend class Graph; + + std::vector nodes; + std::unordered_map nnvm2nid; + std::vector outputs; +}; // class BidirectionalGraph + +using NodeEntrySet = std::unordered_set; +using NodeRawPtrSet = std::unordered_set; + +/*! + * \brief Get the output nodes of the subgraph in the main graph. + * \return a map between the node in the main graph and the output index of the subgraph node +*/ +nnvm::NodeEntryMap GetSubgraphOutputs(Graph g, NodeRawPtrSet subgraph_set) { + nnvm::NodeEntryMap outputs; + uint32_t count = 0; + for (auto& e : g.outputs) { + if (subgraph_set.count(e.node.get()) && !outputs.count(e)) { + outputs.insert({e, count++}); + } + } + DFSVisit(g.outputs, [&subgraph_set, &outputs, &count](const nnvm::NodePtr &node){ + if (!subgraph_set.count(node.get())) { + for (auto& e : node->inputs) { + if (subgraph_set.count(e.node.get()) && !outputs.count(e)) { + outputs.insert({e, count++}); + } + } + } + }); + return outputs; +} + +/*! + * \brief Create new input nodes of the subgraph and plug them. + * \return the inputs of the subgraph node in the main graph +*/ +std::vector GetSubgraphInputs(Graph g, NodeRawPtrSet subgraph_set) { + std::vector inputs; + nnvm::NodeEntryMap entry_map; + DFSVisit(g.outputs, [&subgraph_set, &inputs, &entry_map](const nnvm::NodePtr &node){ + if (subgraph_set.count(node.get())) { + for (auto &e : node->inputs) { + if (!subgraph_set.count(e.node.get())) { + if (entry_map.count(e)) { + e = entry_map[e]; + } else { + auto new_node = nnvm::Node::Create(); + new_node->attrs.name = "input_" + std::to_string(inputs.size()); + entry_map.insert({e, nnvm::NodeEntry{new_node, 0, 0}}); + inputs.push_back(e); + e.node = new_node; + e.index = 0; + } + } + } + } + }); + // Fix ordering of w.r.t to topology + Graph _g; + _g.outputs = g.outputs; + const auto &idx = _g.indexed_graph(); + std::sort(inputs.begin(), inputs.end(), + [&idx, &entry_map](const nnvm::NodeEntry lhs, const nnvm::NodeEntry rhs) { + return idx.entry_id(entry_map.at(lhs)) < idx.entry_id(entry_map.at(rhs)); + }); + return inputs; +} + +std::unordered_map GetGraphInputsMap(const Graph& g) { + std::unordered_map outputs; + auto& idx = g.indexed_graph(); + outputs.reserve(idx.num_nodes()); + std::vector input_nodes = idx.input_nodes(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + outputs[input_nodes[i]] = static_cast(i); + } + return outputs; +} + +/*! + * \brief Helper function to display what nodes are in a specific subset. + */ +void dispNodesSet(Graph g, NodeRawPtrSet s) { + DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){ + if (s.count(n.get())) { + std::cout << " Y " << n->attrs.name << std::endl; + } else { + std::cout << " N " << n->attrs.name << std::endl; + } + }); +} + +/*! + * \brief Replace a set of nodes by a subgraph node. + */ +template +Graph ReplaceSubgraphs(Graph&& g, const std::vector& subgraph_sets, + FCreateNode create_subgraph_node) { + for (auto subgraph_set : subgraph_sets) { + // Create MXNet subgraph + Graph subgraph; + const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set); + subgraph.outputs.resize(sub_outputs_in_main.size()); + for (auto p : sub_outputs_in_main) { + subgraph.outputs[p.second] = p.first; + } + // To generate a subgraph an input has to be replaced by data node (no op) + // and it has to be agnostic to the node from which it's an output + // (For example, even if two inputs are two different outputs from the same node, + // they need to be replaced by two completely separate data nodes) + auto inputs = GetSubgraphInputs(subgraph, subgraph_set); + auto subgraph_node = create_subgraph_node(subgraph); + subgraph_node->inputs = inputs; + // replug inputs of node out of subgraph to be output of the subgraph node + // if it was a node in the subgraph + DFSVisit(g.outputs, + [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::NodePtr node) { + if (!subgraph_set.count(node.get())) { + for (auto &e : node->inputs) { + auto it = sub_outputs_in_main.find(e); + if (it != sub_outputs_in_main.end()) { + e.node = subgraph_node; + e.index = it->second; + } + } + } + }); + // replug outputs of the graph to be output of the subgraph node + // if it was a node in the subgraph + for (auto &e : g.outputs) { + auto it = sub_outputs_in_main.find(e); + if (it != sub_outputs_in_main.end()) { + e.node = subgraph_node; + e.index = it->second; + } + } + // move control dependencies between nodes of the subgraph and out of the subgraph + // to a dependencies between the subgraph node and the nodes out of the subgraph + DFSVisit(g.outputs, [&subgraph_node, &subgraph_set](const nnvm::NodePtr& node) { + for (auto &e : node->control_deps) { + if (subgraph_set.count(e.get())) + e = subgraph_node; + } + }); + DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::NodePtr& node) { + auto it = node->control_deps.begin(); + while (it != node->control_deps.end()) { + if (subgraph_set.count(it->get())) { + ++it; + } else { + subgraph_node->control_deps.push_back(*it); + it = node->control_deps.erase(it); + } + } + }); + } + Graph new_graph; + new_graph.outputs = g.outputs; + return new_graph; +} + +/* \brief Get all subsets of nodes, where: + * - graph constructed from nodes in each subset is a connected graph + * - every node fulfills a predicate is_compatible + * - if nodes u and v are part of a subset, then for each path between + * u and v in the original directed graph, all nodes on those paths + * are also part of the subset + * \param g NNVM graph + * \param is_compatible A function taking nnvm::Node* and returning bool + * which identifies which nodes should be included in + * subsets. + */ +template +std::vector GetCompatibleSubsets(const Graph& g, FCompatible is_compatible) { + BidirectionalGraph biG = BidirectionalGraph(g); + std::vector> subsets = + biG.get_subsets(is_compatible); + std::vector nnvm_subsets; + nnvm_subsets.reserve(subsets.size()); + for (auto& subset : subsets) { + if (subset.size() > 1) { + NodeRawPtrSet node_set; + node_set.reserve(subset.size()); + for (auto& n : subset) { + node_set.insert(n->nnvmptr); + } + nnvm_subsets.push_back(node_set); + } + } + return nnvm_subsets; +} + +} // namespace exec +} // namespace mxnet +#endif // MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_ diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 6818d757ab79..f64450e5b9f7 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -34,7 +34,10 @@ constexpr uint32_t kEidNotExist = std::numeric_limits::max(); struct CachedOp::GraphInfo { nnvm::Graph fwd_graph; + nnvm::Graph grad_graph; nnvm::Graph full_graph; + std::vector ograd_entries; + std::unordered_map fwd_input_to_grad_output; std::vector bwd_output_reqs; std::vector bwd_input_eid; }; @@ -45,13 +48,167 @@ struct CachedOp::DynamicRuntime { std::vector op_states; }; +void CreateFullGraph(const nnvm::Symbol& sym, + nnvm::Graph* fwd_graph, + nnvm::Graph* grad_graph, + nnvm::Graph* full_graph, + std::vector* ograd_entries, + std::unordered_map* fwd_input_to_grad_output) { + using namespace nnvm; + static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; + static const auto _copy_op = Op::Get("_copy"); + { + NodeEntryMap dedup_out; + for (const NodeEntry& nodeEntry : sym.outputs) { + if (dedup_out.find(nodeEntry) != dedup_out.end()) { + NodePtr copy_node = Node::Create(); + copy_node->attrs.op = _copy_op; + copy_node->attrs.name = + nodeEntry.node->attrs.name + "_copy" + std::to_string(dedup_out[nodeEntry]++); + copy_node->inputs.emplace_back(nodeEntry); + if (_copy_op->attr_parser != nullptr) { + _copy_op->attr_parser(&(copy_node->attrs)); + } + fwd_graph->outputs.emplace_back(std::move(copy_node)); + } else { + dedup_out.emplace(nodeEntry, 0); + fwd_graph->outputs.push_back(nodeEntry); + } + } + } + + // construct backward graph + { + ograd_entries->reserve(fwd_graph->outputs.size()); + for (size_t i = 0; i < fwd_graph->outputs.size(); ++i) { + nnvm::NodePtr np = Node::Create(); + np->attrs.name = "_head_grad_" + std::to_string(i); + ograd_entries->emplace_back(np); + } + + std::vector xs; + const IndexedGraph& indexed_graph = fwd_graph->indexed_graph(); + for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) { + const uint32_t node_id = indexed_graph.input_nodes()[i]; + if (indexed_graph.mutable_input_nodes().count(node_id)) + continue; + (*fwd_input_to_grad_output)[i] = xs.size(); + xs.emplace_back(indexed_graph[node_id].weak_ref.lock()); + } + + CHECK(!xs.empty()) + << "There are no inputs in computation graph that require gradients."; + + *grad_graph = pass::MXGradient( + *fwd_graph, fwd_graph->outputs, xs, *ograd_entries, + exec::AggregateGradient, nullptr, nullptr, + zero_ops, "_copy"); + } + + // construct full graph + { + full_graph->outputs = fwd_graph->outputs; + for (const auto& i : grad_graph->outputs) full_graph->outputs.emplace_back(i); + } +} + +void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { + const auto& idx = fwd_graph->indexed_graph(); + CHECK_GE(idx.input_nodes().size(), 1) << "CachedOp requires at least 1 input"; + + std::vector ref_count(idx.num_node_entries(), 0); + for (const auto& i : idx.input_nodes()) ++ref_count[idx.entry_id(i, 0)]; + for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; + for (size_t i = 0; i < idx.num_nodes(); ++i) { + for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; + } + + fwd_graph->attrs[AddPrefix(CachedOp::FORWARD, CachedOp::REF_COUNT)] = + std::make_shared(std::move(ref_count)); + + size_t num_forward_nodes = idx.num_nodes(); + size_t num_forward_entries = idx.num_node_entries(); + + const auto& full_idx = full_graph.indexed_graph(); + + std::vector temp_ref_count(full_idx.num_node_entries(), 0); + for (size_t i = num_forward_nodes; i < full_idx.num_nodes(); ++i) { + for (const auto& j : full_idx[i].inputs) { + ++temp_ref_count[full_idx.entry_id(j)]; + } + } + + auto full_ref_count = fwd_graph->GetAttr >(AddPrefix(CachedOp::FORWARD, + CachedOp::REF_COUNT)); + for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += temp_ref_count[i]; + fwd_graph->attrs[AddPrefix(CachedOp::FULL, CachedOp::REF_COUNT)] = + std::make_shared(std::move(full_ref_count)); +} + +void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Graph * grad_graph, + const Context& context, size_t num_forward_outputs, const bool inlining) { +#if MXNET_USE_CUDA && !defined(_WIN32) + if (context.dev_mask() == kGPU && + !inlining && + dmlc::GetEnv("MXNET_USE_FUSION", true)) { + nnvm::Graph unoptimized_graph; + common::CopyGraph(&unoptimized_graph, *full_graph, false); + + if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) { + full_graph->attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs); + *full_graph = exec::FusePointwiseForward(std::move(*full_graph)); + full_graph->attrs["num_forward_outputs"] = std::make_shared(num_forward_outputs); + *full_graph = exec::FusePointwiseBackward(std::move(*full_graph)); + // Check the topological order of inputs + const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes(); + const auto &new_inputs = full_graph->indexed_graph().input_nodes(); + if (original_inputs.size() != new_inputs.size()) { + LOG(WARNING) + << "Number of inputs after fusion does not match original number of inputs. " + << "This is most probably a bug. Disabling fusion for this run."; + *full_graph = unoptimized_graph; + } else { + for (size_t i = 0; i < new_inputs.size(); ++i) { + if (unoptimized_graph.indexed_graph()[original_inputs[i]].source->attrs.name != + full_graph->indexed_graph()[new_inputs[i]].source->attrs.name) { + LOG(WARNING) << "Disabling fusion due to altered topological order of inputs."; + *full_graph = unoptimized_graph; + break; + } + } + } + } else { + LOG(WARNING) + << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; + } + } +#endif // MXNET_USE_CUDA + + *fwd_graph = nnvm::Graph(); + fwd_graph->outputs = std::vector(full_graph->outputs.begin(), + full_graph->outputs.begin() + + num_forward_outputs); + *grad_graph = nnvm::Graph(); + grad_graph->outputs = std::vector(full_graph->outputs.begin() + + num_forward_outputs, + full_graph->outputs.end()); + SetRefCounts(fwd_graph, *full_graph); +} + struct CachedOp::CachedOpState { CachedOpState(const Context& context_, const nnvm::Graph& fwd_graph_, - const nnvm::Graph& full_graph_) { + const nnvm::Graph& full_graph_, + const bool inlining_) { context = context_; - info.fwd_graph = fwd_graph_; - info.full_graph = full_graph_; + nnvm::Symbol sym; + sym.outputs = fwd_graph_.outputs; + CreateFullGraph(sym.Copy(), &info.fwd_graph, &info.grad_graph, + &info.full_graph, &info.ograd_entries, + &info.fwd_input_to_grad_output); + + OptimizeGraph(&info.full_graph, &info.fwd_graph, &info.grad_graph, + context_, fwd_graph_.outputs.size(), inlining_); size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); @@ -96,10 +253,6 @@ struct CachedOp::CachedOpState { CachedOp::CachedOp( const nnvm::Symbol& sym, const std::vector >& flags) { - using namespace nnvm; - using namespace imperative; - static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; - static const auto _copy_op = Op::Get("_copy"); config_.Init(flags); this->dynamic_shape_checked_ = false; @@ -107,38 +260,14 @@ CachedOp::CachedOp( CHECK(config_.static_alloc) << "static_alloc must be True when static_shape is True"; } - // construct forward graph + auto grad_graph = nnvm::Graph(); + std::unordered_map fwd_input_to_grad_output; + CreateFullGraph(sym, &fwd_graph_, &grad_graph, &full_graph_, + &ograd_entries_, &fwd_input_to_grad_output); + { - NodeEntryMap dedup_out; - for (const NodeEntry& nodeEntry : sym.outputs) { - if (dedup_out.find(nodeEntry) != dedup_out.end()) { - NodePtr copy_node = Node::Create(); - copy_node->attrs.op = _copy_op; - copy_node->attrs.name = - nodeEntry.node->attrs.name + "_copy" + std::to_string(dedup_out[nodeEntry]++); - copy_node->inputs.emplace_back(nodeEntry); - if (_copy_op->attr_parser != nullptr) { - _copy_op->attr_parser(&(copy_node->attrs)); - } - fwd_graph_.outputs.emplace_back(std::move(copy_node)); - } else { - dedup_out.emplace(nodeEntry, 0); - fwd_graph_.outputs.push_back(nodeEntry); - } - } const auto& idx = fwd_graph_.indexed_graph(); - CHECK_GE(idx.input_nodes().size(), 1) << "CachedOp requires at least 1 input"; - - std::vector ref_count(idx.num_node_entries(), 0); - for (const auto& i : idx.input_nodes()) ++ref_count[idx.entry_id(i, 0)]; - for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; - for (size_t i = 0; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; - } - - fwd_graph_.attrs["forward_ref_count"] = - std::make_shared(std::move(ref_count)); - + bwd_output_reqs_ = std::vector(grad_graph.outputs.size(), kWriteTo); inlining_ = !config_.static_alloc && (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; } @@ -159,53 +288,9 @@ CachedOp::CachedOp( } } - // construct backward graph - { - ograd_entries_.reserve(fwd_graph_.outputs.size()); - for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) - ograd_entries_.emplace_back(Node::Create()); - - std::vector xs; - const IndexedGraph& indexed_graph = fwd_graph_.indexed_graph(); - for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) { - const uint32_t node_id = indexed_graph.input_nodes()[i]; - if (indexed_graph.mutable_input_nodes().count(node_id)) - continue; - fwd_input_to_grad_output_[i] = xs.size(); - xs.emplace_back(indexed_graph[node_id].weak_ref.lock()); - } - - CHECK(!xs.empty()) - << "There are no inputs in computation graph that require gradients."; - - grad_graph_ = pass::MXGradient( - fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_, - exec::AggregateGradient, nullptr, nullptr, - zero_ops, "_copy"); - } - - // construct full graph + // Set the backward dependency vectors { - size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); - size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); - - full_graph_.outputs = fwd_graph_.outputs; - bwd_output_reqs_ = std::vector(grad_graph_.outputs.size(), kWriteTo); - for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i); const auto& idx = full_graph_.indexed_graph(); - - std::vector ref_count(idx.num_node_entries(), 0); - for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) { - ++ref_count[idx.entry_id(j)]; - } - } - - auto full_ref_count = fwd_graph_.GetAttr >("forward_ref_count"); - for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += ref_count[i]; - fwd_graph_.attrs["full_ref_count"] = - std::make_shared(std::move(full_ref_count)); - size_t num_forward_inputs = num_inputs(); size_t num_forward_outputs = num_outputs(); for (uint32_t i = 0; i < ograd_entries_.size(); ++i) { @@ -223,6 +308,8 @@ CachedOp::CachedOp( bwd_out_dep_.push_back(i); } } + + SetRefCounts(&fwd_graph_, full_graph_); } CachedOp::~CachedOp() { @@ -411,10 +498,10 @@ bool CachedOp::SetBackwardGraph( info->bwd_output_reqs = reqs; info->bwd_input_eid.clear(); g = nnvm::Graph(); - g.outputs = fwd_graph_.outputs; - for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { + g.outputs = info->fwd_graph.outputs; + for (size_t i = 0; i < info->grad_graph.outputs.size(); ++i) { if (info->bwd_output_reqs[i] == kNullOp) continue; - g.outputs.emplace_back(grad_graph_.outputs[i]); + g.outputs.emplace_back(info->grad_graph.outputs[i]); } g.attrs["context"] = std::make_shared( std::vector(g.indexed_graph().num_nodes(), default_ctx)); @@ -425,12 +512,12 @@ bool CachedOp::SetBackwardGraph( if (info->bwd_input_eid.size() != inputs.size()) { info->bwd_input_eid.clear(); SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, - ograd_entries_, idx, &info->bwd_input_eid); + info->ograd_entries, idx, &info->bwd_input_eid); CHECK_EQ(inputs.size(), info->bwd_input_eid.size()); } - size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); - size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); + size_t num_forward_nodes = info->fwd_graph.indexed_graph().num_nodes(); + size_t num_forward_entries = info->fwd_graph.indexed_graph().num_node_entries(); if (!g.attrs.count("backward_ref_count")) { std::vector ref_count(idx.num_node_entries(), 0); @@ -509,7 +596,8 @@ OpStatePtr CachedOp::GetCachedOpState( return i; } } - auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_); + auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_, + inlining_); cached_op_states_[ctx].push_back(state_ptr); return state_ptr; @@ -917,8 +1005,10 @@ OpStatePtr CachedOp::Forward( CHECK_EQ(inputs.size(), num_inputs()); Context default_ctx = inputs[0]->ctx(); + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); - const auto& idx = fwd_graph_.indexed_graph(); + const auto& idx = state.info.fwd_graph.indexed_graph(); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx(), default_ctx) << "CachedOp requires all inputs to live on the same context. But " @@ -986,9 +1076,9 @@ void CachedOp::DynamicBackward( auto& buff = runtime.buff; auto& states = runtime.op_states; - size_t num_forward_outputs = fwd_graph_.outputs.size(); - size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); - size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); + size_t num_forward_outputs = runtime.info.fwd_graph.outputs.size(); + size_t num_forward_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); + size_t num_forward_entries = runtime.info.fwd_graph.indexed_graph().num_node_entries(); buff.resize(idx.num_node_entries()); std::vector arrays; arrays.reserve(buff.size()); @@ -1084,9 +1174,9 @@ void CachedOp::StaticBackward( if (config_.static_shape) { for (auto i : config_.param_indices) { - const auto iter = fwd_input_to_grad_output_.find(i); - if (iter == fwd_input_to_grad_output_.end()) continue; - auto entry = grad_graph_.outputs[iter->second]; + const auto iter = state.info.fwd_input_to_grad_output.find(i); + if (iter == state.info.fwd_input_to_grad_output.end()) continue; + auto entry = state.info.grad_graph.outputs[iter->second]; if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); if (!arrays[eid]->IsSame(*outputs[iter->second]) || @@ -1101,9 +1191,9 @@ void CachedOp::StaticBackward( } } for (auto i : config_.data_indices) { - const auto iter = fwd_input_to_grad_output_.find(i); - if (iter == fwd_input_to_grad_output_.end()) continue; - auto entry = grad_graph_.outputs[iter->second]; + const auto iter = state.info.fwd_input_to_grad_output.find(i); + if (iter == state.info.fwd_input_to_grad_output.end()) continue; + auto entry = state.info.grad_graph.outputs[iter->second]; if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); state.array_reqs[eid] = reqs[iter->second]; @@ -1113,8 +1203,8 @@ void CachedOp::StaticBackward( arrays[eid] = outputs[iter->second]; } } else { - for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { - auto entry = grad_graph_.outputs[i]; + for (size_t i = 0; i < state.info.grad_graph.outputs.size(); ++i) { + auto entry = state.info.grad_graph.outputs[i]; if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); state.array_reqs[eid] = reqs[i]; diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index db049d59ed80..0b8bc3528543 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -198,13 +198,11 @@ class CachedOp { CachedOpConfig config_; nnvm::Graph fwd_graph_; - nnvm::Graph grad_graph_; nnvm::Graph full_graph_; bool inlining_; bool dynamic_shape_checked_; std::vector ograd_entries_; std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; - std::unordered_map fwd_input_to_grad_output_; std::vector save_inputs_, save_outputs_; std::vector bwd_output_reqs_; diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index b2ffd1096b2b..5c2a2f9585a1 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -305,7 +305,9 @@ std::vector Imperative::Backward( std::vector ograd_entries; ograd_entries.reserve(ograds.size()); for (size_t i = 0; i < outputs.size(); ++i) { - ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0}); + nnvm::NodePtr np = Node::Create(); + np->attrs.name = "_head_grad_" + std::to_string(i); + ograd_entries.emplace_back(NodeEntry{np, 0, 0}); AGInfo& info = AGInfo::Create(ograd_entries.back().node); info.ctx = outputs[i]->ctx(); if (ograds[i] != nullptr) { diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h new file mode 100644 index 000000000000..3085bfd1dc07 --- /dev/null +++ b/src/operator/fusion/fused_op-inl.h @@ -0,0 +1,999 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ +#define MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ + +#include +#include +#include + +#if MXNET_USE_CUDA + +namespace mxnet { + +namespace fusion { + +const char fp16_support_string[] = R"code( +struct __align__(2) __half { + __host__ __device__ __half() { } + unsigned short __x; +}; +/* Definitions of intrinsics */ +__device__ inline __half __float2half(const float f) { + __half val; + asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f)); + return val; +} +__device__ inline float __half2float(const __half h) { + float val; + asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x)); + return val; +} + +typedef __half half; +)code"; + +const char type_support_string[] = R"code( +using float32 = float; +using float64 = double; +using float16 = half; +using uint8 = unsigned char; +using int8 = char; +using int32 = int; +using int64 = long long; +)code"; + +const std::map>> ops_desc = { + {"elemwise_add" , {{"op::add(%, %)", "_0", "_1"}}}, + {"_plus" , {{"op::add(%, %)", "_0", "_1"}}}, + {"_Plus" , {{"op::add(%, %)", "_0", "_1"}}}, + {"_add" , {{"op::add(%, %)", "_0", "_1"}}}, + {"elemwise_sub" , {{"op::sub(%, %)", "_0", "_1"}}}, + {"_minus" , {{"op::sub(%, %)", "_0", "_1"}}}, + {"_Minus" , {{"op::sub(%, %)", "_0", "_1"}}}, + {"_sub" , {{"op::sub(%, %)", "_0", "_1"}}}, + {"elemwise_mul" , {{"op::mul(%, %)", "_0", "_1"}}}, + {"_mul" , {{"op::mul(%, %)", "_0", "_1"}}}, + {"_Mul" , {{"op::mul(%, %)", "_0", "_1"}}}, + {"elemwise_div" , {{"op::div(%, %)", "_0", "_1"}}}, + {"_div" , {{"op::div(%, %)", "_0", "_1"}}}, + {"_Div" , {{"op::div(%, %)", "_0", "_1"}}}, + {"_Power" , {{"op::power(%, %)", "_0", "_1"}}}, + {"_power" , {{"op::power(%, %)", "_0", "_1"}}}, + {"_Maximum" , {{"op::max(%, %)", "_0", "_1"}}}, + {"_maximum" , {{"op::max(%, %)", "_0", "_1"}}}, + {"_Minimum" , {{"op::min(%, %)", "_0", "_1"}}}, + {"_minimum" , {{"op::min(%, %)", "_0", "_1"}}}, + {"amp_cast" , {{"op::identity(%)", "_0"}}}, + {"_backward_amp_cast" , {{"op::identity(%)", "_0"}}}, + {"relu" , {{"op::relu(%)", "_0"}}}, + {"sigmoid" , {{"op::sigmoid(%)", "_0"}}}, + {"softsign" , {{"op::softsign(%)", "_0"}}}, + {"exp" , {{"op::exp(%)", "_0"}}}, + {"expm1" , {{"op::expm1(%)", "_0"}}}, + {"log" , {{"op::log(%)", "_0"}}}, + {"log10" , {{"op::log10(%)", "_0"}}}, + {"log2" , {{"op::log2(%)", "_0"}}}, + {"log1p" , {{"op::log1p(%)", "_0"}}}, + {"degrees" , {{"op::degrees(%)", "_0"}}}, + {"radians" , {{"op::radians(%)", "_0"}}}, + {"sin" , {{"op::sin(%)", "_0"}}}, + {"cos" , {{"op::cos(%)", "_0"}}}, + {"tan" , {{"op::tan(%)", "_0"}}}, + {"arcsin" , {{"op::arcsin(%)", "_0"}}}, + {"arccos" , {{"op::arccos(%)", "_0"}}}, + {"arctan" , {{"op::arctan(%)", "_0"}}}, + {"sinh" , {{"op::sinh(%)", "_0"}}}, + {"cosh" , {{"op::cosh(%)", "_0"}}}, + {"tanh" , {{"op::tanh(%)", "_0"}}}, + {"arcsinh" , {{"op::arcsinh(%)", "_0"}}}, + {"arccosh" , {{"op::arccosh(%)", "_0"}}}, + {"arctanh" , {{"op::arctanh(%)", "_0"}}}, + {"sqrt" , {{"op::sqrt(%)", "_0"}}}, + {"rsqrt" , {{"op::rsqrt(%)", "_0"}}}, + {"cbrt" , {{"op::cbrt(%)", "_0"}}}, + {"rcbrt" , {{"op::rcbrt(%)", "_0"}}}, + {"square" , {{"op::square(%)", "_0"}}}, + {"squeeze" , {{"op::identity(%)", "_0"}}}, + {"zeros_like" , {{"op::zero(%)", "_0"}}}, + {"ones_like" , {{"op::one(%)", "_0"}}}, + {"flatten" , {{"op::identity(%)", "_0"}}}, + {"Reshape" , {{"op::identity(%)", "_0"}}}, + {"reshape" , {{"op::identity(%)", "_0"}}}, + {"_backward_reshape" , {{"op::identity(%)", "_0"}}}, + {"expand_dims" , {{"op::identity(%)", "_0"}}}, + {"round" , {{"op::round(%)", "_0"}}}, + {"rint" , {{"op::rint(%)", "_0"}}}, + {"fix" , {{"op::fix(%)", "_0"}}}, + {"floor" , {{"op::floor(%)", "_0"}}}, + {"ceil" , {{"op::ceil(%)", "_0"}}}, + {"trunc" , {{"op::trunc(%)", "_0"}}}, + {"sign" , {{"op::sign(%)", "_0"}}}, + {"reciprocal" , {{"op::reciprocal(%)", "_0"}}}, + {"abs" , {{"op::abs(%)", "_0"}}}, + {"gamma" , {{"op::gamma(%)", "_0"}}}, + {"gammaln" , {{"op::gammaln(%)", "_0"}}}, + {"erf" , {{"op::erf(%)", "_0"}}}, + {"erfinv" , {{"op::erfinv(%)", "_0"}}}, + {"_copy" , {{"op::identity(%)", "_0"}}}, + {"_identity_with_attr_like_rhs" , {{"op::identity(%)", "_0"}}}, + {"_plus_scalar" , {{"op::add(%, float(%))", "_0", "scalar"}}}, + {"_PlusScalar" , {{"op::add(%, float(%))", "_0", "scalar"}}}, + {"_minus_scalar" , {{"op::sub(%, float(%))", "_0", "scalar"}}}, + {"_MinusScalar" , {{"op::sub(%, float(%))", "_0", "scalar"}}}, + {"_rminus_scalar" , {{"(-op::sub(%, float(%)))", "_0", "scalar"}}}, + {"_RMinusScalar" , {{"(-op::sub(%, float(%)))", "_0", "scalar"}}}, + {"_mul_scalar" , {{"op::mul(%, float(%))", "_0", "scalar"}}}, + {"_MulScalar" , {{"op::mul(%, float(%))", "_0", "scalar"}}}, + {"_div_scalar" , {{"op::mul(%, 1.0f/float(%))", "_0", "scalar"}}}, + {"_DivScalar" , {{"op::mul(%, 1.0f/float(%))", "_0", "scalar"}}}, + {"_rdiv_scalar" , {{"op::rdiv(%, float(%))", "_0", "scalar"}}}, + {"_power_scalar" , {{"op::power(%, float(%))", "_0", "scalar"}}}, + {"_PowerScalar" , {{"op::power(%, float(%))", "_0", "scalar"}}}, + {"_rpower_scalar" , {{"op::rpow(%, float(%))", "_0", "scalar"}}}, + {"_RPowerScalar" , {{"op::rpow(%, float(%))", "_0", "scalar"}}}, + {"_RDivScalar" , {{"op::rdiv(%, float(%))", "_0", "scalar"}}}, + {"Cast" , {{"op::cast<%>(%)", "dtype", "_0"}}}, + {"cast" , {{"op::cast<%>(%)", "dtype", "_0"}}}, + {"Activation" , {{"op::%(%)", "act_type", "_0"}}}, + {"clip" , {{"op::clip(%, %, %)", "_0", "a_min", "a_max"}}}, + {"_zeros" , {{"op::zero<%>()", "dtype"}}}, + {"_ones" , {{"op::one<%>()", "dtype"}}}, + {"negative" , {{"(-%)", "_0"}}}, + {"_hypot" , {{"op::hypot(%, %)", "_0", "_1"}}}, + {"_hypot_scalar" , {{"op::hypot(%, float(%))", "_0", "scalar"}}}, + {"_backward_relu" , {{"op::backward_relu(%, %)", "_1", "_0"}}}, + {"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_1", "_0"}}}, + {"_backward_expm1" , {{"op::backward_expm1(%, %)", "_1", "_0"}}}, + {"_backward_log" , {{"op::backward_log(%, %)", "_1", "_0"}}}, + {"_backward_log10" , {{"op::backward_log10(%, %)", "_1", "_0"}}}, + {"_backward_log2" , {{"op::backward_log2(%, %)", "_1", "_0"}}}, + {"_backward_log1p" , {{"op::backward_log1p(%, %)", "_1", "_0"}}}, + {"_backward_sin" , {{"op::backward_sin(%, %)", "_1", "_0"}}}, + {"_backward_cos" , {{"op::backward_cos(%, %)", "_1", "_0"}}}, + {"_backward_tan" , {{"op::backward_tan(%, %)", "_1", "_0"}}}, + {"_backward_arcsin" , {{"op::backward_arcsin(%, %)", "_1", "_0"}}}, + {"_backward_arccos" , {{"op::backward_arccos(%, %)", "_1", "_0"}}}, + {"_backward_arctan" , {{"op::backward_arctan(%, %)", "_1", "_0"}}}, + {"_backward_sinh" , {{"op::backward_sinh(%, %)", "_1", "_0"}}}, + {"_backward_cosh" , {{"op::backward_cosh(%, %)", "_1", "_0"}}}, + {"_backward_tanh" , {{"op::backward_tanh(%, %)", "_1", "_0"}}}, + {"_backward_arcsinh" , {{"op::backward_arcsinh(%, %)", "_1", "_0"}}}, + {"_backward_arccosh" , {{"op::backward_arccosh(%, %)", "_1", "_0"}}}, + {"_backward_arctanh" , {{"op::backward_arctanh(%, %)", "_1", "_0"}}}, + {"_backward_sqrt" , {{"op::backward_sqrt(%, %)", "_1", "_0"}}}, + {"_backward_rsqrt" , {{"op::backward_rsqrt(%, %)", "_1", "_0"}}}, + {"_backward_cbrt" , {{"op::backward_cbrt(%, %)", "_1", "_0"}}}, + {"_backward_rcbrt" , {{"op::backward_rcbrt(%, %)", "_1", "_0"}}}, + {"_backward_square" , {{"op::backward_square(%, %)", "_1", "_0"}}}, + {"_backward_div_scalar" , {{"(% * 1.0f/float(%))", "_0", "scalar"}}}, + {"_backward_div_scalar" , {{"(% * 1.0f/float(%))", "_0", "scalar"}}}, + {"_backward_rdiv_scalar" , {{"(-% * float(%) / (% * %))", "_0", + "scalar", "_1", "_1"}}}, + {"_backward_hypot_scalar" , {{"(% * % / op::hypot(%, float(%)))", + "_0", "_1", "_1", "scalar"}}}, + {"_backward_radians" , {{"op::radians(%)", "_0"}}}, + {"_backward_erf" , {{"op::backward_erf(%, %)", "_1", "_0"}}}, + {"_backward_erfinv" , {{"op::backward_erfinv(%, %)", "_1", "_0"}}}, + {"_backward_reciprocal" , {{"op::backward_reciprocal(%, %)", "_1", "_0"}}}, + {"_backward_abs" , {{"(% * op::sign(%))", "_0", "_1"}}}, + {"_backward_degrees" , {{"op::degrees(%)", "_0"}}}, + {"_backward_sign" , {{"op::zero(%)", "_0"}}}, + {"_backward_clip" , {{"op::backward_clip(%, %, %, %)", "_1", "_0", + "a_min", "a_max"}}}, + {"smooth_l1" , {{"op::smooth_l1(%, float(%))", "_0", "scalar"}}}, + {"_backward_smooth_l1" , {{"op::backward_smooth_l1(%, float(%), %)", + "_1", "scalar", "_0"}}}, + // TODO(ptredak): arange + // TODO(ptredak): LeakyRelu + // TODO(ptredak): mod and rmod + {"_backward_sub" , {{"(%)", "_0"}, + {"(-(%))", "_0"}}}, + {"_backward_mul" , {{"(% * %)", "_0", "_2"}, + {"(% * %)", "_0", "_1"}}}, + {"_backward_mul_scalar" , {{"(% * float(%))", "_0", "scalar"}}}, + {"_backward_div" , {{"(% / %)", "_0", "_2"}, + {"(-% * % / (% * %))", "_0", "_1", "_2", "_2"}}}, + {"_backward_power" , {{"(% * % * powf(%, % - 1))", "_0", "_2", "_1", "_2"}, + {"(% * powf(%, %) * logf(%))", "_0", "_1", "_2", "_1"}}}, + {"_backward_power_scalar" , {{"(% * float(%) * powf(%, float(%) - 1))", + "_0", "scalar", "_1", "scalar"}}}, + {"_backward_rpower_scalar" , {{"(% * % * logf(float(%)))", "_0", "_1", "scalar"}}}, + {"_backward_maximum" , {{"((% >= %) ? % : 0)", "_1", "_2", "_0"}, + {"((% >= %) ? 0 : %)", "_1", "_2", "_0"}}}, + {"_backward_minimum" , {{"((% <= %) ? % : 0)", "_1", "_2", "_0"}, + {"((% <= %) ? 0 : %)", "_1", "_2", "_0"}}}, + {"_backward_hypot" , {{"(% * % / op::hypot(%, %))", "_0", "_1", "_1", "_2"}, + {"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}} +}; + +const std::map slice_ops = { + {"slice_axis" , ""}, + {"slice" , ""}, + {"slice_like" , ""}, + {"broadcast_like" , ""}, +}; + +const std::vector variable_io_ops = { + "add_n", + "_backward_Activation", + "amp_multicast", + "_backward_amp_multicast", + "_backward_cast" +}; + +const char function_definitions[] = R"code( + +#define INT_MAX (2147483647) + +namespace op { + +template +struct LoadType { + using Type = DType; +}; + +template <> +struct LoadType { + using Type = float; +}; + +template +inline typename LoadType::Type load(const DType input) { + return input; +} + +template <> +inline float load(const half input) { + return __half2float(input); +} + +template +inline DType1 store(const DType2 input, DType1* ref) { + return input; +} + +template +inline half store(const DType input, half* ref) { + return __float2half(input); +} + +template +struct VectorConfig { + static_assert(size >= 4, "VectorConfig needs to have size of at least 4B"); + using IndexType = float; +}; + +template <> +struct VectorConfig<8> { + using IndexType = double; +}; + +template <> +struct VectorConfig<16> { + using IndexType = double2; +}; + +template <> +struct VectorConfig<32> { + using IndexType = double4; +}; + +template +inline DType add_elem(const DType& x, const DType& y) { + return x + y; +} + +template <> +inline half add_elem(const half& x, const half& y) { + return __float2half(__half2float(x) + __half2float(y)); +} + +template +union VectorType { + typename VectorConfig::IndexType y; + DType x[nvec]; + VectorType () {}; + VectorType (const VectorType& y2) { + y = y2.y; + } + VectorType (const decltype(y) &y2) { + y = y2; + } + inline VectorType& operator+=(const VectorType& rhs) { + #pragma unroll + for (int i = 0; i < nvec; ++i) { + x[i] = add_elem(x[i], rhs.x[i]); + } + return *this; + } +}; + +template +struct Shape { + int x[ndim]; + size_t size; + inline const int& operator [](const int i) const { + return x[i]; + } + inline int& operator [](const int i) { + return x[i]; + } + inline void set(const int def) { + #pragma unroll + for (int i = 0; i < ndim; i++) { + x[i] = def; + } + } +}; + +template <> +struct Shape<0> { + size_t size; +}; + +template +inline VectorType load_index(const DType * input, int i, const Shape &shape) { + if (i < shape.size) { + const auto* vector_input = reinterpret_cast< + const typename VectorConfig::IndexType *>( + input + i); + VectorType ret = {*vector_input}; + return ret; + } else { + VectorType ret({0}); + return ret; + } +} + +template +inline VectorType global_load_index(const DType * input, int i, const Shape &shape) { + if (i < shape.size) { + const auto* vector_input = reinterpret_cast< + const typename VectorConfig::IndexType *>( + input + i); + VectorType ret = {__ldg(vector_input)}; + return ret; + } else { + VectorType ret({0}); + return ret; + } +} + +template +inline VectorType load_slice(const DType * input, const Shape& shape, Shape begin, Shape end, int offset) { + int idx[nvec]; + + Shape ref_strides; + Shape strides; + ref_strides[ndim-1] = 1; + strides[ndim-1] = 1; + #pragma unroll + for (int dim = ndim-1; dim >=0; dim--) { + if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim]; + if (end[dim] < 0) end[dim] = shape[dim] - end[dim]; + if (end[dim] == INT_MAX) end[dim] = shape[dim]; + if (dim > 0) { + ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); + strides[dim-1] = strides[dim] * shape[dim]; + } + } + #pragma unroll + for (int j = 0; j < nvec; j++) { + idx[j] = 0; + int ref_idx = offset + j; + #pragma unroll + for (int dim = 0; dim < ndim; dim++) { + int stride = ref_strides[dim]; + if (shape[dim] > 1) { + idx[j] += (ref_idx / stride + begin[dim]) * strides[dim]; + } + ref_idx = ref_idx % stride; + } + } + VectorType ret; + #pragma unroll + for (int j = 0; j < nvec; j++) { + ret.x[j] = *(input + idx[j]); + } + return ret; +} + +template +inline VectorType fast_load_slice(const DType * input, const Shape& shape, Shape begin, Shape end, int offset) { + int idx = 0; + + Shape ref_strides; + Shape strides; + ref_strides[ndim-1] = 1; + strides[ndim-1] = 1; + #pragma unroll + for (int dim = ndim-1; dim >=0; dim--) { + if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim]; + if (end[dim] < 0) end[dim] = shape[dim] - end[dim]; + if (end[dim] == INT_MAX) end[dim] = shape[dim]; + if (dim > 0) { + ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); + strides[dim-1] = strides[dim] * shape[dim]; + } + } + int ref_idx = offset; + #pragma unroll + for (int dim = 0; dim < ndim; dim++) { + int stride = ref_strides[dim]; + if (shape[dim] > 1) { + idx += (ref_idx / stride + begin[dim]) * strides[dim]; + } + ref_idx = ref_idx % stride; + } + return global_load_index(input, idx, shape); +} + +template +inline void store_index(const VectorType value, int i, + DType * output, const Shape& shape) { + if (i < (shape.size + nvec - 1) / nvec) { + auto vector_output = reinterpret_cast< + typename VectorConfig::IndexType *>(output); + vector_output[i] = value.y; + } +} + +template +inline void store_add_index(const VectorType value, int i, + DType * output, const Shape& shape) { + if (i < (shape.size + nvec - 1) / nvec) { + auto vector_output = reinterpret_cast< + typename VectorConfig::IndexType *>(output); + VectorType ret(vector_output[i]); + ret += value; + vector_output[i] = ret.y; + } +} + +template +inline DType identity(const DType val) { + return val; +} + +template +inline DType add(const DType a, const DType2 b) { + return a + b; +} + +template +inline DType sub(const DType a, const DType2 b) { + return a - b; +} + +template +inline DType mul(const DType a, const DType2 b) { + return a * b; +} + +template +inline DType div(const DType a, const DType2 b) { + return a / b; +} + +template +inline DType rdiv(const DType a, const DType2 b) { + return b / a; +} + +template +inline DType power(const DType a, const DType2 b) { + return powf(a, b); +} + +template +inline DType rpow(const DType a, const DType2 b) { + return powf(b, a); +} + +template +inline DType max(const DType a, const DType2 b) { + return a > b ? a : b; +} + +template +inline DType min(const DType a, const DType2 b) { + return a < b ? a : b; +} + +template +inline DType hypot(const DType a, const DType2 b) { + return hypotf(a, b); +} + +template +inline typename LoadType::Type cast(const DType val) { + return static_cast::Type>(val); +} + +// activations + +template +inline DType relu(const DType val) { + return val > 0 ? val : 0; +} + +template +inline DType sigmoid(const DType val) { + return 1.f/(1 + expf(-val)); +} + +template +inline DType softrelu(const DType val) { + return logf(1 + expf(val)); +} + +template +inline DType softsign(const DType val) { + return val / (1 + fabsf(val)); +} + +// exp and log + +template +inline DType exp(const DType val) { + return expf(val); +} + +template +inline DType expm1(const DType val) { + return expm1f(val); +} + +template +inline DType log(const DType val) { + return logf(val); +} + +template +inline DType log10(const DType val) { + return log10f(val); +} + +template +inline DType log2(const DType val) { + return log2f(val); +} + +template +inline DType log1p(const DType val) { + return log1pf(val); +} + +// trigonometric + +constexpr double pi = 3.14159265358979323846; + +template +inline DType degrees(const DType val) { + return (val / pi) * 180; +} + +template +inline DType radians(const DType val) { + return (val / 180.0) * pi; +} + +template +inline DType sin(const DType val) { + return sinf(val); +} + +template +inline DType cos(const DType val) { + return cosf(val); +} + +template +inline DType tan(const DType val) { + return tanf(val); +} + +template +inline DType arcsin(const DType val) { + return asinf(val); +} + +template +inline DType arccos(const DType val) { + return acosf(val); +} + +template +inline DType arctan(const DType val) { + return atanf(val); +} + +template +inline DType sinh(const DType val) { + return sinhf(val); +} + +template +inline DType cosh(const DType val) { + return coshf(val); +} + +template +inline DType tanh(const DType val) { + return tanhf(val); +} + +template +inline DType arcsinh(const DType val) { + return asinhf(val); +} + +template +inline DType arccosh(const DType val) { + return acoshf(val); +} + +template +inline DType arctanh(const DType val) { + return atanhf(val); +} + +// sqrt + +template +inline DType sqrt(const DType val) { + return sqrtf(val); +} + +template +inline DType rsqrt(const DType val) { + return rsqrtf(val); +} + +template +inline DType cbrt(const DType val) { + return cbrtf(val); +} + +template +inline DType rcbrt(const DType val) { + return rcbrtf(val); +} + +template +inline DType square(const DType val) { + return val * val; +} + +template +inline typename LoadType::Type zero(const DType val) { + return 0; +} + +template +inline typename LoadType::Type zero() { + return 0; +} + +template +inline typename LoadType::Type one(const DType val) { + return 1; +} + +template +inline typename LoadType::Type one() { + return 1; +} + +template +inline DType round(const DType val) { + return roundf(val); +} + +template +inline DType rint(const DType val) { + return rintf(val); +} + +template +inline DType fix(const DType val) { + const auto floor = floorf(val); + const auto ceil = ceilf(val); + return (floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil; +} + +template +inline DType floor(const DType val) { + return floorf(val); +} + +template +inline DType ceil(const DType val) { + return ceilf(val); +} + +template +inline DType trunc(const DType val) { + return truncf(val); +} + +template +inline DType clip(const DType val, const float a_min, const float a_max) { + return max(min(val, a_max), a_min); +} + +template +inline DType sign(const DType val) { + if (val < 0) return -1; + return val > 0 ? 1 : 0; +} + +template +inline DType reciprocal(const DType val) { + return 1.0f / val; +} + +template +inline DType abs(const DType val) { + return fabsf(val); +} + +template +inline DType gamma(const DType val) { + return tgammaf(val); +} + +template +inline DType gammaln(const DType val) { + return lgammaf(val); +} + +template +inline DType erf(const DType val) { + return erff(val); +} + +template +inline DType erfinv(const DType val) { + return erfinvf(val); +} + +template +inline DType1 smooth_l1(const DType1 val, const DType2 scalar) { + const auto bsq = scalar * scalar; + const auto ibsq = 1.0f / bsq; + if (val > ibsq) { + return val - 0.5f * ibsq; + } else if (val < -ibsq) { + return -val - 0.5f * ibsq; + } else { + return 0.5f * val * val * bsq; + } +} + +} // namespace op + +)code"; + +const char backward_function_definitions[] = R"code( + +namespace op { + +template +inline DTypeGrad backward_relu(const DType val, const DTypeGrad grad) { + return val > 0 ? grad : 0; +} + +template +inline DTypeGrad backward_sigmoid(const DType out, const DTypeGrad grad) { + return grad * out * (1 - out); +} + +template +inline DTypeGrad backward_softrelu(const DType val, const DTypeGrad grad) { + return grad * sigmoid(val); +} + +template +inline DTypeGrad backward_softsign(const DType val, const DTypeGrad grad) { + const DType ap1 = 1 + fabsf(val); + return grad / (ap1 * ap1); +} + +template +inline DTypeGrad backward_exp(const DType val, const DTypeGrad grad) { + return grad * expf(val); +} + +template +inline DTypeGrad backward_expm1(const DType val, const DTypeGrad grad) { + return grad * expf(val); +} + +template +inline DTypeGrad backward_log(const DType val, const DTypeGrad grad) { + return grad / val; +} + +template +inline DTypeGrad backward_log10(const DType val, const DTypeGrad grad) { + return grad / (val * logf(10)); +} + +template +inline DTypeGrad backward_log2(const DType val, const DTypeGrad grad) { + return grad / (val * logf(2)); +} + +template +inline DTypeGrad backward_log1p(const DType val, const DTypeGrad grad) { + return grad / (1 + val); +} + +template +inline DTypeGrad backward_sin(const DType val, const DTypeGrad grad) { + return grad * cosf(val); +} + +template +inline DTypeGrad backward_cos(const DType val, const DTypeGrad grad) { + return -grad * sinf(val); +} + +// Uses output from tan +template +inline DTypeGrad backward_tan(const DType out, const DTypeGrad grad) { + return grad * (out * out + 1); +} + +template +inline DTypeGrad backward_arcsin(const DType val, const DTypeGrad grad) { + return grad / sqrtf(1 - val*val); +} + +template +inline DTypeGrad backward_arccos(const DType val, const DTypeGrad grad) { + return -grad / sqrtf(1 - val*val); +} + +template +inline DTypeGrad backward_arctan(const DType val, const DTypeGrad grad) { + return grad / (1 + val*val); +} + +template +inline DTypeGrad backward_sinh(const DType val, const DTypeGrad grad) { + return grad * coshf(val); +} + +template +inline DTypeGrad backward_cosh(const DType val, const DTypeGrad grad) { + return grad * sinhf(val); +} + +// Uses tanh output +template +inline DTypeGrad backward_tanh(const DType out, const DTypeGrad grad) { + return grad * (1 - out * out); +} + +template +inline DTypeGrad backward_arcsinh(const DType val, const DTypeGrad grad) { + return grad / sqrtf(val * val + 1); +} + +template +inline DTypeGrad backward_arccosh(const DType val, const DTypeGrad grad) { + return grad / sqrtf(val * val - 1); +} + +template +inline DTypeGrad backward_arctanh(const DType val, const DTypeGrad grad) { + return grad / (1 - val * val); +} + +template +inline DTypeGrad backward_sqrt(const DType out, const DTypeGrad grad) { + return 0.5 * grad / out; +} + +template +inline DTypeGrad backward_rsqrt(const DType val, const DTypeGrad grad) { + const DType inv = 1 / val; + return -0.5 * grad * sqrtf(inv) * inv; +} + +template +inline DTypeGrad backward_cbrt(const DType out, const DTypeGrad grad) { + return grad / (3.0f * out * out); +} + +template +inline DTypeGrad backward_rcbrt(const DType val, const DTypeGrad grad) { + const DType inv = 1 / val; + return -1.f/3.f * grad * cbrtf(inv) * inv; +} + +template +inline DTypeGrad backward_square(const DType val, const DTypeGrad grad) { + return 2 * val * grad; +} + +template +inline DTypeGrad backward_clip(const DType val, const DTypeGrad grad, const float a_min, const float a_max) { + if (val > a_max || val < a_min) { + return 0; + } else { + return grad; + } +} + +template +inline DTypeGrad backward_reciprocal(const DType val, const DTypeGrad grad) { + return -grad / (val * val); +} + +template +inline DTypeGrad backward_erf(const DType val, const DTypeGrad grad) { + return 2.0f / sqrt(pi) * exp(-(val*val)) * grad; +} + +template +inline DTypeGrad backward_erfinv(const DType val, const DTypeGrad grad) { + return 0.5f * sqrt(pi) * exp(val * val) * grad; +} + +template +inline DTypeGrad backward_smooth_l1(const DType val, const DType2 scalar, const DTypeGrad grad) { + auto bsq = scalar * scalar; + auto ibsq = 1.0f / bsq; + if (val > ibsq) { + return grad; + } else if (val < -ibsq) { + return -grad; + } else { + return bsq * val * grad; + } +} + +} // namespace op + +)code"; + +const char kernel_begin[] = R"code( +const int tid = threadIdx.x + blockIdx.x * blockDim.x; +for (int i = tid; i < N; i+= gridDim.x * blockDim.x) { + int offset = i*nvec; + +)code"; + +const char kernel_end[] = R"code( +} +} +)code"; + +} // namespace fusion + +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ diff --git a/src/operator/fusion/fused_op.cc b/src/operator/fusion/fused_op.cc new file mode 100644 index 000000000000..071215b840a5 --- /dev/null +++ b/src/operator/fusion/fused_op.cc @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "./fused_op.h" +#include "../operator_common.h" +#include "../../executor/exec_pass.h" + +#if MXNET_USE_CUDA + +namespace mxnet { + +DMLC_REGISTER_PARAMETER(FusedOpConfig); + +std::mutex FusedOp::mutex_; + +void FusedOpParamParser(nnvm::NodeAttrs* attrs) { + FusedOpConfig param; + try { + param.Init(attrs->dict); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + attrs->parsed = FusedOpPtr(new FusedOp(attrs, param)); +} + +FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) { + this->inputs_ = std::vector(config.num_inputs); + this->outputs_ = std::vector(config.num_outputs); + this->subgraph_ = nnvm::Graph(); + this->subgraph_.outputs = attrs->subgraphs[0]->outputs; + this->initialized_ = false; + this->cc_major_ = -1; + this->cc_minor_ = -1; +} + +bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + this->subgraph_.attrs.erase("shape"); + this->subgraph_.attrs.erase("shape_inputs"); + std::vector input_shapes(*in_attrs); + this->subgraph_ = mxnet::exec::InferShape(std::move(this->subgraph_), + std::move(input_shapes), + "__shape__"); + + const auto& g = this->subgraph_.indexed_graph(); + const auto& input_nids = g.input_nodes(); + + std::vector out_shapes; + const std::vector shapes = this->subgraph_.GetAttr("shape"); + for (auto& e : g.outputs()) { + out_shapes.push_back(shapes[g.entry_id(e)]); + } + CHECK_EQ(out_shapes.size(), out_attrs->size()); + for (size_t i = 0; i < out_attrs->size(); ++i) { + op::shape_assign(&(out_attrs->at(i)), out_shapes[i]); + } + + // assign to in_attrs + for (size_t i = 0; i < in_attrs->size(); ++i) { + const auto eid = g.entry_id(input_nids[i], 0); + SHAPE_ASSIGN_CHECK(*in_attrs, i, shapes[eid]); + } + + bool inferred = true; + for (const auto& attr : *in_attrs) { + inferred = inferred && !op::shape_is_none(attr); + } + for (const auto& attr : *out_attrs) { + inferred = inferred && !op::shape_is_none(attr); + } + if (inferred) { + std::lock_guard lock(my_mutex_); + intermediate_shapes_.push_back({*in_attrs, *out_attrs, shapes}); + } + return inferred; +} + +bool FusedOp::InferType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + this->subgraph_.attrs.erase("dtype"); + this->subgraph_.attrs.erase("dtype_inputs"); + std::vector input_types(*in_attrs); + this->subgraph_ = mxnet::exec::InferType(std::move(this->subgraph_), + std::move(input_types), + "__dtype__"); + + const auto& g = this->subgraph_.indexed_graph(); + const auto& input_nids = g.input_nodes(); + + std::vector out_types; + const std::vector types = this->subgraph_.GetAttr("dtype"); + for (auto& e : g.outputs()) { + out_types.push_back(types[g.entry_id(e)]); + } + CHECK_EQ(out_types.size(), out_attrs->size()); + for (size_t i = 0; i < out_attrs->size(); ++i) { + op::type_assign(&(out_attrs->at(i)), out_types[i]); + } + + // assign to in_attrs + for (size_t i = 0; i < in_attrs->size(); ++i) { + const auto eid = g.entry_id(input_nids[i], 0); + TYPE_ASSIGN_CHECK(*in_attrs, i, types[eid]); + } + + bool inferred = true; + for (const auto& attr : *in_attrs) { + inferred = inferred && !op::type_is_none(attr); + } + for (const auto& attr : *out_attrs) { + inferred = inferred && !op::type_is_none(attr); + } + if (inferred) { + std::lock_guard lock(my_mutex_); + intermediate_dtypes_.push_back({*in_attrs, *out_attrs, types}); + } + return inferred; +} + +template +std::tuple, + std::vector> + FusedOp::GetAttrs(const std::string& attr_name, + const uint32_t node_id) { + const auto& g = this->subgraph_.indexed_graph(); + const std::vector attrs = this->subgraph_.GetAttr>(attr_name); + const auto& node = g[node_id]; + std::vector inputs, outputs; + for (const auto& e : node.inputs) { + inputs.emplace_back(attrs[g.entry_id(e)]); + } + outputs.resize(node.source->num_outputs()); + for (size_t i = 0; i < g.num_nodes(); ++i) { + if (i == node_id) continue; + const auto& other_node = g[i]; + for (const auto& e : other_node.inputs) { + if (e.node_id == node_id) { + outputs[e.index] = attrs[g.entry_id(e)]; + } + } + } + for (const auto& e : g.outputs()) { + if (e.node_id == node_id) { + outputs[e.index] = attrs[g.entry_id(e)]; + } + } + + return std::make_tuple(node.weak_ref.lock(), + inputs, + outputs); +} + +bool FusedOpInferShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + return op->InferShape(attrs, in_attrs, out_attrs); +} + +bool FusedOpInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + return op->InferType(attrs, in_attrs, out_attrs); +} + +void FusedOpProvideShape(const nnvm::NodeAttrs& attrs, + const std::vector& nodes, + const std::vector> &in_attrs, + const std::vector> &out_attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + op->ProvideShape(nodes, in_attrs, out_attrs); +} + +void FusedOpProvideType(const nnvm::NodeAttrs& attrs, + const std::vector& nodes, + const std::vector> &in_attrs, + const std::vector> &out_attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + op->ProvideType(nodes, in_attrs, out_attrs); +} + +void FusedOpProvideStorageType(const nnvm::NodeAttrs& attrs, + const std::vector& nodes, + const std::vector> &in_attrs, + const std::vector> &out_attrs) {} + +NNVM_REGISTER_OP(_FusedOp) +.set_attr("TIsFusion", true) +.set_num_inputs([](const NodeAttrs& attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_inputs(); + }) +.set_num_outputs([](const NodeAttrs& attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_outputs(); + }) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + const auto num_inputs = op->num_inputs(); + const auto num_outputs = op->num_outputs(); + std::vector > ret; + for (unsigned int i = 0; i < num_inputs; ++i) { + for (unsigned int j = 0; j < num_outputs; ++j) { + ret.emplace_back(i, j); + } + } + return ret; + }) +.set_attr("FProvideSubgraphShape", FusedOpProvideShape) +.set_attr("FProvideSubgraphType", FusedOpProvideType) +.set_attr("FProvideSubgraphStorageType", + FusedOpProvideStorageType) +.set_attr("FInferShape", FusedOpInferShape) +.set_attr("FInferType", FusedOpInferType) +.set_attr_parser(FusedOpParamParser) +.add_argument("data", "NDArray-or-Symbol[]", "Data"); + +std::tuple, + std::vector> +FusedOpHelperShape(const NodeAttrs& attrs) { + const auto& p = nnvm::get(attrs.parsed); + const auto& op = p->op; + const auto& node_id = p->node_id; + return op->GetAttrs("shape", node_id); +} + +std::tuple, + std::vector> +FusedOpHelperType(const NodeAttrs& attrs) { + const auto& p = nnvm::get(attrs.parsed); + const auto& op = p->op; + const auto& node_id = p->node_id; + return op->GetAttrs("dtype", node_id); +} + +NNVM_REGISTER_OP(_FusedOpHelper) +.set_num_inputs(0) +.set_num_outputs(0) +.set_attr("TIsGhost", true) +.set_attr("TIsFusionHelper", true) +.set_attr("FAccessSubgraphShape", FusedOpHelperShape) +.set_attr("FAccessSubgraphType", FusedOpHelperType); + + +std::tuple, + std::vector> +FusedOpOutHelperShape(const NodeAttrs& attrs) { + const auto& p = nnvm::get(attrs.parsed); + const auto& op = p->op; + const auto& node_id = p->node_id; + return op->GetAuxShape(node_id); +} + +std::tuple, + std::vector> +FusedOpOutHelperType(const NodeAttrs& attrs) { + const auto& p = nnvm::get(attrs.parsed); + const auto& op = p->op; + const auto& node_id = p->node_id; + return op->GetAuxType(node_id); +} + +NNVM_REGISTER_OP(_FusedOpOutHelper) +.set_num_inputs(0) +.set_num_outputs(0) +.set_attr("TIsGhost", true) +.set_attr("TIsFusionHelper", true) +.set_attr("FAccessSubgraphShape", FusedOpOutHelperShape) +.set_attr("FAccessSubgraphType", FusedOpOutHelperType); + +} // namespace mxnet + +#endif // MXNET_USE_CUDA diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu new file mode 100644 index 000000000000..f6df38bac247 --- /dev/null +++ b/src/operator/fusion/fused_op.cu @@ -0,0 +1,746 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "./fused_op.h" +#include "./fused_op-inl.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" +#include "../../executor/exec_pass.h" +#include "../../common/cuda_utils.h" + +namespace mxnet { + +namespace { + +inline std::string mshadowTypeToString(int type) { + switch (type) { + case mshadow::kFloat32: + return "float"; + case mshadow::kFloat64: + return "double"; + case mshadow::kFloat16: + return "half"; + case mshadow::kUint8: + return "unsigned char"; + case mshadow::kInt8: + return "char"; + case mshadow::kInt32: + return "int"; + case mshadow::kInt64: + return "long long"; + default: + LOG(FATAL) << "Unknown type enum " << type; + } + return ""; +} + +inline int mshadowTypeToVectorLength(int type) { + switch (type) { + case mshadow::kFloat32: + return 1; + case mshadow::kFloat64: + return 1; + case mshadow::kFloat16: + return 2; + case mshadow::kUint8: + return 4; + case mshadow::kInt8: + return 4; + case mshadow::kInt32: + return 1; + case mshadow::kInt64: + return 1; + default: + LOG(FATAL) << "Unknown type enum " << type; + } + return 0; +} + +inline void replaceString(std::string *input, const std::string old, const std::string repl) { + size_t pos = 0; + while ((pos = input->find(old, pos)) != std::string::npos) { + input->replace(pos, old.size(), repl); + pos += repl.size(); + } +} + +inline std::vector splitStringToVector(const std::string& input, const std::string def) { + size_t pos_start = 0, pos_end; + const std::string& s = input.substr(1, input.length()-2); + std::vector res; + + auto convert_token = [def](std::string token){ + if (token == def) { + return 0; + } + return std::stoi(token); + }; + + while ((pos_end = s.find(",", pos_start)) != std::string::npos) { + std::string token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + 1; + if (token.length() > 0) { + res.push_back(convert_token(token)); + } + } + + if (pos_start < s.length()) { + res.push_back(convert_token(s.substr(pos_start))); + } + return res; +} + +std::string ParseOpDescription(const std::vector& op_desc, + const std::map, std::string>& variables, + const nnvm::IndexedGraph::Node& node) { + const auto* source = node.source; + std::string fmt = op_desc[0]; + for (size_t j = 1; j < op_desc.size(); ++j) { + const std::string& desc = op_desc[j]; + std::string sub; + if (desc[0] == '_') { + // Argument + const int arg_id = std::stoi(desc.substr(1)); + sub = variables.at({node.inputs[arg_id].node_id, node.inputs[arg_id].index}); + } else { + sub = source->attrs.dict.at(desc); + } + size_t pos = fmt.find("%"); + CHECK_NE(pos, std::string::npos); + fmt.replace(pos, 1, sub); + } + return fmt; +} + +void AddShape(const mxnet::TShape& shape, + std::vector>* shapes) { + // We need alignment to 8 bytes for size_t in the Shape struct + // so if ndim is odd, there will be 4B of padding + int ndim = shape.ndim(); + const int offset = ndim % 2 == 0 ? 2 : 3; + shapes->push_back(std::vector(ndim + offset)); + std::vector& tensor_shapes = shapes->back(); + size_t total_size = 1; + for (int i = ndim-1; i >= 0; i--) { + tensor_shapes[i] = shape[i]; + total_size *= shape[i]; + } + size_t * shape_size_ptr = reinterpret_cast(&tensor_shapes[ndim + offset - 2]); + *shape_size_ptr = total_size; +} + +void AddPointerAndShape(const TBlob& data, + std::vector *ptrs, + std::vector>* shapes, + mshadow::Stream * s) { + using namespace mshadow; + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Tensor tensor = data.FlatTo1D(s); + ptrs->push_back(tensor.dptr_); + AddShape(data.shape_, shapes); + }); +} + +} // namespace + +void FusedOp::GenerateCode(int kernel_index, const std::vector &req, + const std::vector &in_dtypes, + const std::vector &out_dtypes, + const std::vector &in_ndims, + const std::vector &out_ndims, + const mxnet::ShapeVector &node_shapes, + const std::vector &node_dtypes, + const int nvec, + const std::string &kernel_name, + std::vector* check_shapes) { + const auto& g = this->subgraph_.indexed_graph(); + std::string code = ""; + int temp_name_counter = 0; + using NodeEntry = nnvm::IndexedGraph::NodeEntry; + std::map, std::string> variables; + std::map load_index; + bool check_shapes_compile = true; + + std::vector outputs(g.num_nodes()); + + for (size_t i = 0; i < g.num_nodes(); ++i) { + const auto& node = g[i]; + if (node.source != nullptr) { + outputs[i] = node.source->num_outputs(); + } else { + outputs[i] = 0; + } + } + + for (size_t i = 0; i < g.num_nodes(); ++i) { + const auto& node = g[i]; + const auto* source = node.source; + if (source != nullptr) { + if (source->is_variable()) { + load_index[i] = 1; + } else { + std::string op_name = source->op()->name; + if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) { + load_index[node.inputs[0].node_id] = 0; + } + } + } + } + for (size_t i = 0; i < g.num_nodes(); ++i) { + const auto& node = g[i]; + const auto* source = node.source; + if (source != nullptr) { + if (source->is_variable()) { + if (load_index[i]) { + const auto& var_name = source->attrs.name; + code += "const auto vec_" + var_name + " = op::load_index(" + + var_name + ", offset, " + var_name + "_shape);\n"; + variables[{i, 0}] = var_name; + } + CHECK_EQ(outputs[i], 1); + } else { + std::string op_name = source->op()->name; + if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) { + int node_id = node.inputs[0].node_id; + const uint32_t input_entry_id = g.entry_id(node.inputs[0]); + const auto& shape = node_shapes[input_entry_id]; + const int ndim = shape.ndim(); + const auto& var_name = g[node_id].source->attrs.name; + const auto vec_name = "vec_" + var_name + "_" + std::to_string(i); + load_index[node_id] = 0; + auto parse_tuple = [](const std::string& input, const std::string def) { + std::string out = input; + replaceString(&out, "(", "{"); + replaceString(&out, ")", "}"); + replaceString(&out, "None", def); + replaceString(&out, " ", ""); + return out; + }; + auto build_tuple = [ndim](int axis, const std::string str, const std::string def) { + std::string tuple = "{"; + for (int i = 0; i < axis; i++) { + tuple = tuple + def + ","; + } + tuple += str; + for (int i = axis + 1; i < ndim; i++) { + tuple = tuple + "," + def; + } + tuple += "}"; + return tuple; + }; + auto check_tuple = [ndim, nvec](const std::string str) { + std::vector tuple = splitStringToVector(str, "INT_MAX"); + if (tuple[ndim-1] % nvec == 0) { + return true; + } + return false; + }; + auto build_string_axis = [ndim](int axis) { + if (axis < 0) { + axis = ndim + axis; + } + return std::to_string(axis); + }; + auto build_string_end = [i, ndim, var_name](std::string* code) { + std::string end_var_name = var_name + "_" + std::to_string(i) + "_end"; + *code += "op::Shape<" + std::to_string(ndim) + "> "+ end_var_name + ";\n"; + *code += end_var_name + ".set(INT_MAX);\n"; + return end_var_name; + }; + std::string begin; + std::string end; + if (op_name == "broadcast_like" || op_name == "slice_like") { + uint32_t like_id = g.entry_id(i, 0); + begin = build_tuple(0, "0", "0"); + std::string extra_var_name = "extra_" + std::to_string(like_id) + "_shape"; + if (std::find(extra_shape_args_.begin(), extra_shape_args_.end(), like_id) == + extra_shape_args_.end()) { + extra_shape_args_.push_back(like_id); + } + if (check_shapes) { + check_shapes->push_back(like_id); + check_shapes->push_back(input_entry_id); + } + end = extra_var_name; + } else { + begin = parse_tuple(source->attrs.dict.at("begin"), "0"); + end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX"); + if (op_name == "slice_axis") { + int axis = std::stoi(source->attrs.dict.at("axis")); + begin = build_tuple(axis, begin, "0"); + end = build_tuple(axis, end, "INT_MAX"); + } + if (check_shapes) { + if (check_tuple(begin) && check_tuple(end)) { + check_shapes->push_back(input_entry_id); + } else { + check_shapes_compile = false; + } + } + } + std::string slice_func = "load_slice"; + if (!check_shapes) { + slice_func = "fast_" + slice_func; + } + code += "const auto " + vec_name + " = op::" + slice_func + "(" + + var_name + ", " + var_name + "_shape," + begin + + "," + end + ", offset);\n"; + CHECK_EQ(outputs[i], 1); + variables[{i, 0}] = vec_name; + continue; + } + } + } + } + + if (!check_shapes_compile) { + check_shapes->clear(); + } + + size_t counter = 0; + for (const auto& entry : g.outputs()) { + std::string var_name = "output" + std::to_string(counter); + code += "op::VectorType vec_" + var_name + ";\n"; + ++counter; + } + + code += "for (int j = 0; j < nvec; j++ ) {\n"; + + + for (size_t i = 0; i < g.num_nodes(); ++i) { + const auto& node = g[i]; + const auto* source = node.source; + if (source != nullptr) { + std::string var_name = "temp" + std::to_string(temp_name_counter++); + if (source->is_variable()) { + if (load_index[i]) { + code += "const auto " + var_name + " = op::load(vec_" + + variables[{i, 0}] + ".x[j]);\n"; + CHECK_EQ(outputs[i], 1); + variables[{i, 0}] = var_name; + } + } else { + std::string op_name = source->op()->name; + if (fusion::ops_desc.find(op_name) != fusion::ops_desc.end()) { + const std::vector>& op_descs = + fusion::ops_desc.at(op_name); + CHECK_EQ(outputs[i], op_descs.size()); + size_t count = 0; + for (const auto& op_desc : op_descs) { + var_name = "temp" + std::to_string(temp_name_counter++); + const std::string& fmt = ParseOpDescription(op_desc, variables, node); + code += "const auto " + var_name + " = " + fmt + ";\n"; + variables[{i, count}] = var_name; + ++count; + } + continue; + } + + if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) { + code += "const auto " + var_name + " = op::load(" + variables[{i, 0}] + ".x[j]);\n"; + variables[{i, 0}] = var_name; + continue; + } + + + // Special cases with variable number + // of inputs/outputs, listed in + // fusion::variable_io_ops + if (op_name == "add_n") { + CHECK_EQ(outputs[i], 1); + const auto& arg = variables[{node.inputs[0].node_id, node.inputs[0].index}]; + code += "auto " + var_name + " = " + arg + ";\n"; + for (size_t inp = 1; inp < node.inputs.size(); ++inp) { + const auto& temp_arg = variables[{node.inputs[inp].node_id, node.inputs[inp].index}]; + code += var_name + " = op::add(" + var_name + ", " + temp_arg + ");\n"; + } + variables[{i, 0}] = var_name; + continue; + } + + if (op_name == "_backward_Activation") { + CHECK_EQ(outputs[i], 1); + std::string act_type = node.source->attrs.dict.at("act_type"); + std::string rhs, lhs; + rhs = variables[{node.inputs[0].node_id, node.inputs[0].index}]; + if (act_type == "relu" || + act_type == "sigmoid" || + act_type == "tanh") { + lhs = variables[{node.inputs[1].node_id, node.inputs[1].index}]; + } else { + lhs = variables[{node.inputs[2].node_id, node.inputs[2].index}]; + } + code += "const auto " + var_name + " = op::backward_" + act_type + + "(" + lhs + ", " + rhs + ");\n"; + + variables[{i, 0}] = var_name; + continue; + } + + if (op_name == "amp_multicast" || op_name == "_backward_amp_multicast") { + CHECK_EQ(outputs[i], node.inputs.size()); + for (size_t counter = 0; counter < outputs[i]; ++counter) { + const auto& input = node.inputs[counter]; + var_name = "temp" + std::to_string(temp_name_counter++); + const auto& arg = variables[{input.node_id, input.index}]; + code += "const auto " + var_name + " = " + arg + ";\n"; + variables[{i, counter}] = var_name; + } + continue; + } + + if (op_name == "_backward_cast") { + CHECK_EQ(outputs[i], 1); + const int output_type = node_dtypes[g.entry_id(i, 0)]; + const auto& arg = variables[{node.inputs[0].node_id, node.inputs[0].index}]; + code += "const auto " + var_name + " = op::cast<" + mshadowTypeToString(output_type) + + ">(" + arg + ");\n"; + variables[{i, 0}] = var_name; + continue; + } + + LOG(FATAL) << "Unrecognized op " + op_name; + } + } else { + LOG(FATAL) << "Encountered node with NULL source."; + } + } + + counter = 0; + for (const auto& entry : g.outputs()) { + const std::string& var = variables[{entry.node_id, entry.index}]; + const auto var_name = "output" + std::to_string(counter); + code += "vec_" + var_name + ".x[j] = op::store("+ var +", " + var_name + ");\n"; + ++counter; + } + + code += "}\n"; + + counter = 0; + + for (const auto& entry : g.outputs()) { + const std::string& var = variables[{entry.node_id, entry.index}]; + if (req[counter] == kWriteTo || req[counter] == kWriteInplace) { + const auto var_name = "output" + std::to_string(counter); + code += "op::store_index(vec_" + var_name + ", i, " + var_name + ", " + + var_name + "_shape);\n"; + } else if (req[counter] == kAddTo) { + const auto var_name = "output" + std::to_string(counter); + code += "op::store_add_index(vec_" + var_name + ", i, " + var_name + ", " + + var_name + "_shape);\n"; + } else if (req[counter] == kNullOp) { + // NULL req, do not do anything + } else { + LOG(FATAL) << "Encountered unexpected req."; + } + ++counter; + } + + this->code_[kernel_index] = code; + + // Add boilerplate and type information + if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) { + LOG(INFO) << code_[kernel_index]; + } + std::string kernel_params = ""; + std::string tensor_params = ""; + nnvm::Symbol sym; + sym.outputs = this->subgraph_.outputs; + const std::vector input_names = sym.ListInputNames(nnvm::Symbol::kAll); + size_t num_params = in_dtypes.size() + out_dtypes.size(); + size_t i = 0; + std::string aux_code = "static const int nvec = " + std::to_string(nvec) + ";\n"; + + for (const auto &shape_id : extra_shape_args_) { + std::string shape_name = "extra_" + std::to_string(shape_id) + "_shape"; + int ndim = node_shapes[shape_id].ndim(); + kernel_params += " const op::Shape<" + std::to_string(ndim) + "> " + shape_name; + kernel_params += ", "; + } + for (const auto &type : in_dtypes) { + std::string type_name = mshadowTypeToString(type); + std::string dtype_var = "DType_" + input_names[i]; + std::string dim_var = "ndim_" + input_names[i]; + std::string dim_val = std::to_string(in_ndims[i]); + aux_code = "using " + dtype_var + " = " + type_name + ";\n" + aux_code; + aux_code = "static const int " + dim_var + " = " + dim_val + ";\n" + aux_code; + tensor_params += dtype_var + "* " +input_names[i]; + kernel_params += " const op::Shape<" + dim_val + "> " + input_names[i]+"_shape"; + ++i; + if (i < num_params) { + tensor_params += ", "; + } + kernel_params += ", "; + } + for (const auto &type : out_dtypes) { + std::string type_name = mshadowTypeToString(type); + std::string out_name = "output" + std::to_string(i - in_dtypes.size()); + std::string dtype_var = "DType_" + out_name; + std::string dim_var = "ndim_" + out_name; + std::string dim_val = std::to_string(out_ndims[i - in_dtypes.size()]); + aux_code = "static const int " + dim_var + " = " + dim_val + ";\n" + aux_code; + aux_code = "using " + dtype_var + " = " + type_name + ";\n" + aux_code; + tensor_params += dtype_var + "* " + out_name; + kernel_params += " const op::Shape<" + dim_val + "> " + out_name+"_shape"; + ++i; + if (i < num_params) { + tensor_params += ", "; + } + kernel_params += ", "; + } + kernel_params += tensor_params; + + code_[kernel_index] = std::string(fusion::fp16_support_string) + "\n" + + fusion::type_support_string + "\n" + + fusion::function_definitions + "\n" + + fusion::backward_function_definitions + "\n" + + aux_code + "\n" + + "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" + + "__global__ void FusedKernel_" + kernel_name + + "(size_t N, " + kernel_params + ") {\n" + + fusion::kernel_begin + "\n" + + code_[kernel_index] + "\n" + + fusion::kernel_end; +} + +void FusedOp::CompileCode(int kernel_index, const std::string &kernel_name) { + // Guard NVRTC calls + std::lock_guard lock_nvrtc(mutex_); + nvrtcProgram program; + NVRTC_CALL( + nvrtcCreateProgram(&program, // prog + &code_[kernel_index][0], // buffer + (kernel_name + "_kernel.cu").c_str(), // name + 0, // num headers + NULL, // headers + NULL)); // include names + std::string gpu_arch = "--gpu-architecture=compute_" + + std::to_string(this->cc_major_) + + std::to_string(this->cc_minor_); + + const char *opts[] = {gpu_arch.c_str(), + "--std=c++11", + "-default-device"}; + const std::string kernel_name_demangled = "FusedKernel_" + kernel_name; + NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); + + nvrtcResult compileResult = nvrtcCompileProgram(program, // prog + 3, // num options + opts); // options + // Obtain compilation log from the program. + size_t log_size; + NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size)); + std::string log(log_size, '\0'); + NVRTC_CALL(nvrtcGetProgramLog(program, &log[0])); + CHECK_EQ(compileResult, NVRTC_SUCCESS) + << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n" << log; + // Obtain PTX from the program. + size_t ptx_size; + NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size)); + ptx_[kernel_index].reserve(ptx_size); + NVRTC_CALL(nvrtcGetPTX(program, &ptx_[kernel_index][0])); + const char *name; + NVRTC_CALL(nvrtcGetLoweredName(program, + kernel_name_demangled.c_str(), + &name)); + kernel_name_[kernel_index] = name; + // Destroy the program. + NVRTC_CALL(nvrtcDestroyProgram(&program)); + int device; + CUdevice cu_device; + CUcontext context; + CUmodule module; + CUDA_CALL(cudaGetDevice(&device)); + CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, device)); + CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device)); + CUDA_DRIVER_CALL(cuModuleLoadData(&module, &ptx_[kernel_index][0])); + CUDA_DRIVER_CALL(cuModuleGetFunction(&kernel_[kernel_index], + module, + kernel_name_[kernel_index].c_str())); +} + +bool FusedOp::CheckComputeCapability(const OpContext &ctx) { + const int dev_id = ctx.run_ctx.ctx.dev_id; + const int cc_major = ComputeCapabilityMajor(dev_id); + const int cc_minor = ComputeCapabilityMinor(dev_id); + + const bool ret = cc_major == this->cc_major_ && cc_minor == this->cc_minor_; + this->cc_major_ = cc_major; + this->cc_minor_ = cc_minor; + return ret; +} + +void FusedOp::CheckShapesAndTypes(const std::vector &inputs, + const std::vector &outputs, + std::vector *in_dtypes, + std::vector *in_ndims, + std::vector *out_dtypes, + std::vector *out_ndims, + int *nvec) { + std::vector in_shapes; + std::vector out_shapes; + CHECK_EQ(inputs.size(), inputs_.size()); + CHECK_EQ(outputs.size(), outputs_.size()); + + for (size_t counter = 0; counter < inputs.size(); ++counter) { + const auto& blob = inputs[counter]; + in_dtypes->push_back(blob.type_flag_); + in_ndims->push_back(blob.ndim()); + in_shapes.push_back(blob.shape_); + initialized_ = initialized_ && blob.type_flag_ == inputs_[counter].dtype; + inputs_[counter].dtype = blob.type_flag_; + *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_)); + } + + for (size_t counter = 0; counter < outputs.size(); ++counter) { + const auto& blob = outputs[counter]; + out_dtypes->push_back(blob.type_flag_); + out_ndims->push_back(blob.ndim()); + out_shapes.push_back(blob.shape_); + initialized_ = initialized_ && blob.type_flag_ == outputs_[counter].dtype; + outputs_[counter].dtype = blob.type_flag_; + *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_)); + } + + for (auto it = intermediate_shapes_.begin(); + it != intermediate_shapes_.end(); + ++it) { + if (it->input_attr == in_shapes && it->output_attr == out_shapes) { + intermediate_shapes_.erase(intermediate_shapes_.begin(), it); + break; + } + } + for (auto it = intermediate_dtypes_.begin(); + it != intermediate_dtypes_.end(); + ++it) { + if (it->input_attr == *in_dtypes && it->output_attr == *out_dtypes) { + intermediate_dtypes_.erase(intermediate_dtypes_.begin(), it); + break; + } + } +} + +template <> +void FusedOp::Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + std::lock_guard lock(my_mutex_); + CHECK_GE(outputs.size(), 1) << "There needs to be at least 1 output."; + + std::vector in_dtypes; + std::vector in_ndims; + std::vector out_dtypes; + std::vector out_ndims; + int nvec = 1; + + CheckShapesAndTypes(inputs, outputs, &in_dtypes, &in_ndims, + &out_dtypes, &out_ndims, &nvec); + + const auto& node_shapes = intermediate_shapes_[0].internal_attr; + const auto& node_dtypes = intermediate_dtypes_[0].internal_attr; + + // Check and save compute capability of the current GPU + if (!CheckComputeCapability(ctx)) initialized_ = false; + + initialized_ = initialized_ && (req == saved_reqs_); + saved_reqs_ = req; + + if (!initialized_) { + this->GenerateCode(0, req, in_dtypes, out_dtypes, in_ndims, out_ndims, + node_shapes, node_dtypes, nvec, attrs.name, &check_shape_args_); + this->CompileCode(0, attrs.name); + if (check_shape_args_.size() > 0) { + this->GenerateCode(1, req, in_dtypes, out_dtypes, in_ndims, out_ndims, + node_shapes, node_dtypes, nvec, attrs.name, NULL); + this->CompileCode(1, attrs.name); + } + initialized_ = true; + } + Stream* s = ctx.get_stream(); + auto stream = Stream::GetStream(s); + std::vector args; + size_t N = 0; + for (const auto& output : outputs) { + N = std::max(N, output.shape_.Size()); + } + N = (N + nvec - 1)/nvec; + args.push_back(&N); + + unsigned int num_blocks = (N + FusedOp::NTHREADS - 1) / FusedOp::NTHREADS; + + std::vector ptrs; + std::vector> shapes; + + for (const auto &shape_id : extra_shape_args_) { + AddShape(node_shapes[shape_id], &shapes); + } + for (const auto &data : inputs) { + AddPointerAndShape(data, &ptrs, &shapes, s); + } + for (const auto &data : outputs) { + AddPointerAndShape(data, &ptrs, &shapes, s); + } + + for (auto &tensor_shapes : shapes) { + args.push_back(tensor_shapes.data()); + } + for (auto &ptr : ptrs) { + args.push_back(reinterpret_cast(&ptr)); + } + int kernel_index = 0; + if (check_shape_args_.size() > 0) { + kernel_index = 1; + for (const auto &shape_id : check_shape_args_) { + const auto& shape = node_shapes[shape_id]; + if (shape[shape.ndim()-1] % nvec != 0) { + kernel_index = 0; + } + } + } + CUDA_DRIVER_CALL( + cuLaunchKernel(kernel_[kernel_index], + num_blocks, 1, 1, // grid dim + FusedOp::NTHREADS, 1, 1, // block dim + 0, stream, // shared mem and stream + &(args[0]), 0)); // arguments +} + +void FusedOpForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const FusedOpPtr& op = nnvm::get(attrs.parsed); + op->Forward(attrs, ctx, inputs, req, outputs); +} + +NNVM_REGISTER_OP(_FusedOp) +.set_attr("FCompute", FusedOpForwardGPU); + +} // namespace mxnet diff --git a/src/operator/fusion/fused_op.h b/src/operator/fusion/fused_op.h new file mode 100644 index 000000000000..035e5432fca4 --- /dev/null +++ b/src/operator/fusion/fused_op.h @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_FUSION_FUSED_OP_H_ +#define MXNET_OPERATOR_FUSION_FUSED_OP_H_ + + +#include +#include +#include +#include +#include +#include +#include + +#if MXNET_USE_CUDA + + +namespace mxnet { + +struct FusedOpConfig : public dmlc::Parameter { + int num_inputs; + int num_outputs; + DMLC_DECLARE_PARAMETER(FusedOpConfig) { + DMLC_DECLARE_FIELD(num_inputs) + .describe("Number of inputs."); + DMLC_DECLARE_FIELD(num_outputs) + .describe("Number of outputs."); + } +}; + +struct FusedOpEntry { + FusedOpEntry() : dtype(-1) {} + int dtype; +}; + +class FusedOp { + public: + static const int NTHREADS = 512; + + explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config); + ~FusedOp() {} + uint32_t num_inputs() const { + return inputs_.size(); + } + uint32_t num_outputs() const { + return outputs_.size(); + } + + template + void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool InferShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs); + + bool InferType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs); + + template + std::tuple, + std::vector> + GetAttrs(const std::string& attr_name, + const uint32_t node_id); + + void ProvideShape(const std::vector& nodes, + const std::vector> &in_attrs, + const std::vector> &out_attrs) { + aux_nodes_ = nodes; + aux_in_shapes_ = in_attrs; + aux_out_shapes_ = out_attrs; + } + + void ProvideType(const std::vector& nodes, + const std::vector> &in_attrs, + const std::vector> &out_attrs) { + aux_nodes_ = nodes; + aux_in_types_ = in_attrs; + aux_out_types_ = out_attrs; + } + + std::tuple, + std::vector> + GetAuxShape(const int node_id) const { + return std::make_tuple(aux_nodes_[node_id], + aux_in_shapes_[node_id], + aux_out_shapes_[node_id]); + } + + std::tuple, + std::vector> + GetAuxType(const int node_id) const { + return std::make_tuple(aux_nodes_[node_id], + aux_in_types_[node_id], + aux_out_types_[node_id]); + } + + private: + void GenerateCode(int kernel_index, + const std::vector &req, + const std::vector &in_dtypes, + const std::vector &out_dtypes, + const std::vector &in_ndims, + const std::vector &out_ndims, + const mxnet::ShapeVector &node_shapes, + const std::vector &node_dtypes, + const int nvec, + const std::string& kernel_name, + std::vector *check_shapes); + void CompileCode(int kernel_index, + const std::string &kernel_name); + bool CheckComputeCapability(const OpContext &ctx); + void CheckShapesAndTypes(const std::vector &inputs, + const std::vector &outputs, + std::vector *in_dtypes, + std::vector *in_ndims, + std::vector *out_dtypes, + std::vector *out_ndims, + int *nvec); + + std::vector inputs_; + std::vector outputs_; + + std::string code_[2]; + nnvm::Graph subgraph_; + + template + struct IntermediateAttr { + std::vector input_attr; + std::vector output_attr; + std::vector internal_attr; + }; + + // Shapes and types inside the subgraph + // copied here, because a subsequent call + // to InferShape/InferType can overwrite the + // original information stored in subgraph_ + // attributes while the previous iterations + // still need them. + std::vector > intermediate_shapes_; + std::vector > intermediate_dtypes_; + + std::vector aux_nodes_; + std::vector> aux_in_shapes_; + std::vector> aux_out_shapes_; + std::vector> aux_in_types_; + std::vector> aux_out_types_; + std::vector saved_reqs_; + std::vector extra_shape_args_; + std::vector check_shape_args_; + + std::string ptx_[2]; + std::string kernel_name_[2]; + CUfunction kernel_[2]; + bool initialized_; + int cc_major_; + int cc_minor_; + + static std::mutex mutex_; + std::mutex my_mutex_; +}; + +using FusedOpPtr = std::shared_ptr; + +struct FusedOpHelperParam { + FusedOpPtr op; + uint32_t node_id; + + FusedOpHelperParam(FusedOpPtr op, uint32_t node_id) : + op(op), + node_id(node_id) {} +}; + +using FusedOpHelperParamPtr = std::shared_ptr; + +} // namespace mxnet + +#endif // MXNET_USE_CUDA +#endif // MXNET_OPERATOR_FUSION_FUSED_OP_H_ diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 92655c146193..3d8d5ee6efbf 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -169,7 +169,7 @@ struct softrelu : public mxnet_op::tunable { MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a)); -MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(erfinv::Map(a)))); +MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(a))); MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 117cfa96518a..acb81db6db11 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -905,7 +905,7 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", UnaryOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_erfinv"}); +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_erfinv"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_erfinv) .set_attr("FCompute", diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index 6e54ddd7e52a..d9d727786613 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -106,8 +106,8 @@ class GPUPooledStorageManager final : public StorageManager { } size_t RoundAllocSize(size_t size) { - // Round up small allocs to the page_size_ to consolidate the pool lookups - size = std::max(size, page_size_); + // Round up small allocs to multiple of page_size_ to consolidate the pool lookups + size = RoundToMultiple(size, page_size_); // To ensure proper freeing under some driver variants, make sure // large allocs entirely occupy their slabs, which cannot then be // locked by smaller permanent allocations sharing the slab. diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py new file mode 100644 index 000000000000..6adf935fb29c --- /dev/null +++ b/tests/python/gpu/test_fusion.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import random +import mxnet as mx +import numpy as np +from mxnet.test_utils import * + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../unittest')) +from common import with_seed + +def check_fused_symbol(sym, **kwargs): + inputs = sym.list_inputs() + shapes = {inp : kwargs[inp].shape for inp in inputs} + # Double identity so that there is always something to fuse + test_sym = mx.sym.Group([mx.sym.identity(mx.sym.identity(s)) for s in sym]) + rtol = {'float16' : 1e-2, + 'float32' : 1.5e-6, + 'float64' : 1.5e-6, + } + atol = {'float16' : 1e-3, + 'float32' : 1e-7, + 'float64' : 1e-7, + } + for dtype in ['float16', 'float32', 'float64']: + data = {inp : kwargs[inp].astype(dtype) for inp in inputs} + for grad_req in ['write', 'add']: + type_dict = {inp : dtype for inp in inputs} + os.environ["MXNET_USE_FUSION"] = "0" + orig_exec = test_sym.simple_bind(ctx=mx.gpu(0), grad_req=grad_req, type_dict=type_dict, **shapes) + os.environ["MXNET_USE_FUSION"] = "1" + fused_exec = test_sym.simple_bind(ctx=mx.gpu(0), grad_req=grad_req, type_dict=type_dict, **shapes) + fwd_orig = orig_exec.forward(is_train=True, **data) + out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig] + orig_exec.backward(out_grads=out_grads) + fwd_fused = fused_exec.forward(is_train=True, **data) + fused_exec.backward(out_grads=out_grads) + for orig, fused in zip(fwd_orig, fwd_fused): + np.testing.assert_allclose(orig.asnumpy(), fused.asnumpy(), rtol=rtol[dtype], atol=atol[dtype]) + for orig, fused in zip(orig_exec.grad_arrays, fused_exec.grad_arrays): + if orig is None and fused is None: + continue + assert orig is not None + assert fused is not None + np.testing.assert_allclose(orig.asnumpy(), fused.asnumpy(), rtol=rtol[dtype], atol=atol[dtype]) + +def check_unary_ops(): + unary_ops = [ + 'relu', + 'sigmoid', + 'softsign', + 'exp', + 'expm1', + 'log', + 'log10', + 'log2', + 'log1p', + 'degrees', + 'radians', + 'sin', + 'cos', + 'tan', + 'arcsin', + 'arccos', + 'arctan', + 'sinh', + 'cosh', + 'tanh', + 'arcsinh', + 'arctanh', + 'sqrt', + 'rsqrt', + 'cbrt', + 'rcbrt', + 'square', + 'squeeze', + 'zeros_like', + 'ones_like', + 'flatten', + 'round', + 'rint', + 'fix', + 'floor', + 'ceil', + 'trunc', + 'sign', + 'reciprocal', + 'abs', + 'gamma', + 'gammaln', + 'erf', + 'negative', + ] + + def announce_check(op_name): + print("Checking fusion of " + op_name) + + arr = mx.random.uniform(shape=rand_shape_2d()) + a = mx.sym.Variable('a') + for op_name in unary_ops: + announce_check(op_name) + op = getattr(mx.sym, op_name) + sym = op(a) + check_fused_symbol(sym, a=arr) + + # unary ops requiring special treatment + + # arccosh needs input to be >= 1 + arr2 = arr + 1 + announce_check('arccosh') + check_fused_symbol(mx.sym.arccosh(a), a=arr2) + + # erfinv needs -1 < input < 1, but we avoid the limits of this range where the slope nears +inf. + arr2 = (arr - 0.5) * 1.99 + announce_check('erfinv') + check_fused_symbol(mx.sym.erfinv(a), a=arr2) + + # Activation requires act_type attribute + for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']: + announce_check("Activation(act_type='{}')".format(act_type)) + check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr) + + # Cast requires dtype + for dtype in ['float16', 'float32', 'float64', 'int32']: + announce_check("Cast(dtype='{}')".format(dtype)) + check_fused_symbol(mx.sym.Cast(a, dtype=dtype), a=arr) + + # reshape requires shape + announce_check('reshape') + check_fused_symbol(mx.sym.reshape(a, shape=(-1,)), a=arr) + + # expand_dims requires axis + announce_check('expand_dims') + check_fused_symbol(mx.sym.expand_dims(a, axis=1), a=arr) + + # clip requires a_min, a_max + announce_check('clip') + check_fused_symbol(mx.sym.clip(a, a_min=0.3, a_max=0.7), a=arr) + + # smooth_l1 requires a scalar + announce_check('smooth_l1') + check_fused_symbol(mx.sym.smooth_l1(a, scalar=0.3), a=arr) + +def check_binary_ops(): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + shape = rand_shape_2d() + arr1 = mx.random.uniform(shape=shape) + arr2 = mx.random.uniform(shape=shape) + + check_fused_symbol(a+b, a=arr1, b=arr2) + check_fused_symbol(a+3, a=arr1) + check_fused_symbol(a-b, a=arr1, b=arr2) + check_fused_symbol(a-3, a=arr1) + check_fused_symbol(3-a, a=arr1) + check_fused_symbol(a*b, a=arr1, b=arr2) + check_fused_symbol(a*3, a=arr1) + check_fused_symbol(a/b, a=arr1, b=arr2) + check_fused_symbol(a/3, a=arr1) + check_fused_symbol(3/a, a=arr1) + check_fused_symbol(a**b, a=arr1, b=arr2) + check_fused_symbol(a**3, a=arr1) + check_fused_symbol(mx.sym.pow(3,a), a=arr1) + check_fused_symbol(mx.sym.maximum(a,b), a=arr1, b=arr2) + check_fused_symbol(mx.sym.minimum(a,b), a=arr1, b=arr2) + check_fused_symbol(mx.sym.hypot(a,b), a=arr1, b=arr2) + check_fused_symbol(mx.sym.hypot(a,3), a=arr1) + +def check_other_ops(): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = mx.sym.Variable('c') + shape = rand_shape_2d() + shape = (5,) + shape + arr1 = mx.random.uniform(shape=shape) + arr2 = mx.random.uniform(shape=shape) + arr3 = mx.random.uniform(shape=shape) + + check_fused_symbol(mx.sym.add_n(a,b,c), a=arr1, b=arr2, c=arr3) + + check_fused_symbol(mx.sym.slice_axis(a, axis=0, begin=1, end=4), a=arr1) + + begin = (random.randint(0, shape[0]-1), + random.randint(0, shape[1]-1), + random.randint(0, shape[2]-1)) + end = (random.randint(begin[0]+1, shape[0]), + random.randint(begin[1]+1, shape[1]), + random.randint(begin[2]+1, shape[2])) + check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1) + + arr1 = mx.random.uniform(shape=(2,3,4,5)) + arr2 = mx.random.uniform(shape=(1,2,3)) + check_fused_symbol(mx.sym.slice_like(a,b, axes=[-2, 0]), a=arr1, b=arr2) + + arr1 = mx.random.uniform(shape=(1,1,2,3)) + arr2 = mx.random.uniform(shape=(2,2,2,3)) + check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2) + +@with_seed() +def test_fusion(): + check_unary_ops() + check_binary_ops() + check_other_ops() + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 380ce762a9f7..1363e5851b0c 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -2987,6 +2987,47 @@ def forward(self, x): shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1) block(mx.nd.ones(shape)) +@with_seed() +def test_reqs_switching_training_inference(): + class Foo(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Foo, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + y = 2 * x + return F.sqrt(x) + F.sqrt(y) + + f = Foo() + f.hybridize(static_alloc=True) + x = mx.nd.ones(shape=(10,10)) + x.attach_grad() + x2 = mx.nd.ones(shape=x.shape) * 2 + x2.attach_grad() + + # Call first in training mode + with mx.autograd.record(): + y = f(x) + y.backward() + + grad1 = x.grad.asnumpy() + + # Compute the gradient with some other input + with mx.autograd.record(): + y = f(x2) + y.backward() + + # Call inference mode + y = f(x) + + # Call training mode again + with mx.autograd.record(): + y = f(x) + y.backward() + + grad2 = x.grad.asnumpy() + + mx.test_utils.assert_almost_equal(grad1, grad2) + if __name__ == '__main__': import nose nose.runmodule()