diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 0c4c1e60208f..1f115cd64ad5 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -23,6 +23,7 @@ #include "../executor/exec_pass.h" #include "../profiler/profiler.h" #include "../operator/operator_common.h" +#include "../operator/subgraph/common.h" namespace mxnet { @@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward( return op_state; } - void CachedOp::DynamicBackward( const bool retain_graph, const OpStatePtr& op_state, @@ -1067,34 +1067,153 @@ void CachedOp::Backward( Engine::Get()->set_bulk_size(prev_bulk_size); } -bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - using namespace imperative; - nnvm::Graph g(fwd_graph_); - const auto& idx = g.indexed_graph(); - const auto &outputs = idx.outputs(); +/* + * This is the operator state of CachedOp when CachedOp is used in the symbol + * executor. This is different from the OpState returned by CachedOp::Forward. + * The main reason why we need this OpState is that CachedOp and the symbol executor + * maintain OpState differently. The symbol executor generates OpState in advance + * while CachedOp generates OpState after Forward is called. We need this data + * structure to keep the OpState generated by CachedOp::Forward and pass it to + * Backward. + */ +struct CachedOpActualState { + std::shared_ptr op; + OpStatePtr forward_state; - // Prepare stypes and contexts based on inputs - StorageTypeVector storage_type_inputs; - storage_type_inputs.reserve(in_attrs->size()); - for (size_t i = 0; i < in_attrs->size(); ++i) { - storage_type_inputs.emplace_back(in_attrs->at(i)); + explicit CachedOpActualState(std::shared_ptr op) { + this->op = op; } - exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); +}; - // Forward graph storage type inference - CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true); - // Retrieve result and set outputs - const auto& inferred_stypes = g.GetAttr("storage_type"); - for (size_t i = 0; i < out_attrs->size(); i++) { - const auto eid = idx.entry_id(outputs[i]); - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]); +/* + * This is the forward computation when CachedOp is used as an operator in + * a symbol executor. + */ +void CachedOpForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs(in_bufs.size()); + std::vector out_ptrs(out_bufs.size()); + for (size_t i = 0; i < in_ptrs.size(); i++) + in_ptrs[i] = &in_bufs[i]; + for (size_t i = 0; i < out_ptrs.size(); i++) + out_ptrs[i] = &out_bufs[i]; + + // Set is_recording correct for the imperative executor. + bool orig_is_record; + if (ctx.need_grad) + orig_is_record = Imperative::Get()->set_is_recording(true); + else + orig_is_record = Imperative::Get()->is_recording(); + // Set is_training correct for the imperative executor. + bool orig_is_train; + if (ctx.is_train) + orig_is_train = Imperative::Get()->set_is_training(true); + else + orig_is_train = Imperative::Get()->is_training(); + s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs); + Imperative::Get()->set_is_training(orig_is_train); + Imperative::Get()->set_is_recording(orig_is_record); + // The arrays in out_ptrs may be changed by CachedOp. + // If it is, we need to copy data back. + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(outputs[i])) + CopyFromTo(out_bufs[i], outputs[i]); +} + +/* + * This is the backward computation when CachedOp is used as an operator in + * a symbol executor. + */ +void CachedOpBackward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace nnvm; + using namespace imperative; + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs; + std::vector out_ptrs; + CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); + in_ptrs.reserve(s.op->num_backward_inputs()); + out_ptrs.reserve(s.op->num_inputs()); + + const std::vector &save_inputs = s.op->save_inputs(); + const std::vector &save_outputs = s.op->save_outputs(); + size_t bwd_in_dep = s.op->num_inputs(); + size_t bwd_out_dep = s.op->num_outputs(); + CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); + size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; + + // Find inputs, outputs and ograds + auto ograds_begin = in_bufs.begin(); + auto ograds_end = in_bufs.begin() + bwd_ograd_dep; + auto in_begin = ograds_end; + auto in_end = in_begin + bwd_in_dep; + auto out_begin = in_end; + auto out_end = in_bufs.end(); + + for (auto it = ograds_begin; it != ograds_end; it++) + in_ptrs.push_back(&(*it)); + + CHECK_EQ(save_inputs.size(), in_end - in_begin); + CHECK_EQ(s.op->num_outputs(), out_end - out_begin); + for (auto it = in_begin; it != in_end; it++) { + auto i = it - in_begin; + if (save_inputs[i]) + in_ptrs.push_back(&(*it)); } - DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); - return true; + for (auto it = out_begin; it != out_end; it++) { + auto i = it - out_begin; + if (save_outputs[i]) + in_ptrs.push_back(&(*it)); + } + CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs()); + for (size_t i = 0; i < out_bufs.size(); i++) + out_ptrs.push_back(&out_bufs[i]); + CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs()); + // Set is_training correct for the imperative executor. + bool orig_is_train; + if (ctx.is_train) + orig_is_train = Imperative::Get()->set_is_training(true); + else + orig_is_train = Imperative::Get()->is_training(); + // TODO(zhengda) CachedOp supports recording computation when running + // the backward path. This is necessary if we want to support the second-order + // differentiation. However, MXNet operator doesn't have an interface to + // pass a flag to determine whether to record computation inside an operator. + // Let's use false here for now and design a solution when the second-order + // differentiation is supported. + s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); + Imperative::Get()->set_is_training(orig_is_train); + + // Clean up what we recorded. + s.forward_state.reset(); + + // The arrays in out_ptrs may be changed by CachedOp. + // If it is, we need to copy data back. + // For example, when the inputs and outputs share the same NDArrays, + // the outputs will be replaced by inputs. + // https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385 + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(outputs[i])) + CopyFromTo(out_bufs[i], outputs[i]); +} + +OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, + Context ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return OpStatePtr::Create(op); } bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, @@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, return true; } +void CachedOpParamParser(nnvm::NodeAttrs* attrs) { + CachedOpConfig param; + try { + param.Init(attrs->dict); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + if (!param.subgraph.empty()) { + nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph); + CHECK(!g.outputs.empty()); + nnvm::Symbol sym; + sym.outputs = g.outputs; + std::vector > flags; + for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++) + flags.emplace_back(it->first, it->second); + attrs->parsed = CachedOpPtr(new CachedOp(sym, flags)); + } +} NNVM_REGISTER_OP(_CachedOp) .set_num_inputs([](const NodeAttrs& attrs) { @@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->num_outputs(); }) -.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); - }) +.set_attr_parser(CachedOpParamParser) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { const CachedOpPtr& op = nnvm::get(n->attrs.parsed); return op->Gradient(n, ograds); - }); + }) +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardInputNames(); + }) +.set_attr("FListOutputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardOutputNames(); + }) +.set_attr("FCreateOpState", CreateCachedOpState) +.set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); + }) +.set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); + }) +.set_attr("FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), + dev_mask, dispatch_mode, + in_stypes, out_stypes); + }) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); + }) +.set_attr("FExecType", op::DefaultSubgraphOpExecType) +.add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) .set_num_inputs([](const NodeAttrs& attrs){ @@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); }) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FExecType", op::DefaultSubgraphOpExecType) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 4f4dfdcc14dd..59a793ee1b65 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter { bool static_shape; nnvm::Tuple data_indices; nnvm::Tuple param_indices; + std::string subgraph; DMLC_DECLARE_PARAMETER(CachedOpConfig) { DMLC_DECLARE_FIELD(static_alloc) .set_default(false) @@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter { DMLC_DECLARE_FIELD(param_indices) .set_default(nnvm::Tuple()) .describe("Position of parameters."); + DMLC_DECLARE_FIELD(subgraph) + .set_default(std::string("")) + .describe("JSON string of a subgraph."); } }; @@ -80,6 +84,10 @@ class CachedOp { uint32_t num_backward_inputs() const { return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); } + uint32_t num_backward_outputs() const { + auto &idx = fwd_graph_.indexed_graph(); + return idx.input_nodes().size() - idx.mutable_input_nodes().size(); + } std::vector& save_inputs() { return save_inputs_; } @@ -102,13 +110,6 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); - // forward storage type inference - bool ForwardStorageType( - const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs); // backward storage type inference bool BackwardStorageType( const nnvm::NodeAttrs& attrs, @@ -116,6 +117,19 @@ class CachedOp { DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs); + std::vector ListForwardInputNames() const { + nnvm::Symbol sym = GetForwardSym(); + return sym.ListInputNames(nnvm::Symbol::kAll); + } + std::vector ListForwardOutputNames() const { + nnvm::Symbol sym = GetForwardSym(); + return sym.ListOutputNames(); + } + nnvm::Symbol GetForwardSym() const { + nnvm::Symbol sym; + sym.outputs = fwd_graph_.outputs; + return sym; + } private: struct GraphInfo; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 29112939a22f..6a4c3d027075 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ { \ - if (!shape_assign(&(shape_array)[index], TShape(shape))) { \ + if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \ std::ostringstream os; \ os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\ << " inferred shape=" << shape; \ @@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define TYPE_ASSIGN_CHECK(type_array, index, type) \ { \ - if (!type_assign(&(type_array)[index], type)) { \ + if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \ std::ostringstream os; \ os << "Type inconsistent, Provided = " \ - << type_string((type_array)[index]) << ',' \ - << " inferred type = " << type_string(type); \ + << ::mxnet::op::type_string((type_array)[index]) << ',' \ + << " inferred type = " << ::mxnet::op::type_string(type); \ throw ::mxnet::op::InferTypeError(os.str(), index); \ } \ } @@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { #define UNIFORM_TYPE_CHECK(type, expected, arg) \ { \ CHECK_EQ(type, expected) << "This layer requires uniform type. " \ - << "Expected '" << type_string(expected) \ - << "' v.s. given '" << type_string(type) \ + << "Expected '" << ::mxnet::op::type_string(expected) \ + << "' v.s. given '" << ::mxnet::op::type_string(type) \ << "' at '" << arg << "'"; \ } diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h index 22058d556e07..4e1cd66b8b68 100644 --- a/src/operator/subgraph/common.h +++ b/src/operator/subgraph/common.h @@ -49,11 +49,10 @@ inline std::vector DefaultSubgraphOpListOutputs(const nnvm::NodeAtt return sym.ListOutputNames(); } -inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shapes, - std::vector *out_shapes) { +inline bool DefaultSubgraphOpShapeHelper(const nnvm::Symbol& subgraph_sym, + std::vector *in_shapes, + std::vector *out_shapes) { using namespace exec; - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -94,10 +93,15 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, return g.GetAttr("shape_num_unknown_nodes") == 0; } -inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, - std::vector *in_types, - std::vector *out_types) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; +inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + return DefaultSubgraphOpShapeHelper(*attrs.subgraphs[0], in_shapes, out_shapes); +} + +inline bool DefaultSubgraphOpTypeHelper(const nnvm::Symbol& subgraph_sym, + std::vector *in_types, + std::vector *out_types) { nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -137,12 +141,17 @@ inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, return g.GetAttr("dtype_num_unknown_nodes") == 0; } -inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; +inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + return DefaultSubgraphOpTypeHelper(*attrs.subgraphs[0], in_types, out_types); +} + +inline bool DefaultSubgraphOpStorageTypeHelper(const nnvm::Symbol& subgraph_sym, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -190,12 +199,21 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, return g.GetAttr("storage_type_num_unknown_nodes") == 0; } +inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + return DefaultSubgraphOpStorageTypeHelper(*attrs.subgraphs[0], dev_mask, dispatch_mode, + in_stypes, out_stypes); +} + inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) { return ExecType::kSubgraphExec; } -inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; +inline std::vector DefaultSubgraphOpMutableInputsHelper( + const nnvm::Symbol& subgraph_sym) { const std::vector input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll); const std::vector immutable_input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs); @@ -217,8 +235,12 @@ inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr return ret; } -inline std::vector DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; +inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) { + return DefaultSubgraphOpMutableInputsHelper(*attrs.subgraphs[0]); +} + +inline std::vector DefaultSubgraphOpResourceRequestHelper( + const nnvm::Symbol& subgraph_sym) { static auto& fresource = Op::GetAttr("FResourceRequest"); std::set resource_types; DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) { @@ -231,6 +253,10 @@ inline std::vector DefaultSubgraphOpResourceRequest(const nnvm: return std::vector(resource_types.begin(), resource_types.end()); } +inline std::vector DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) { + return DefaultSubgraphOpResourceRequestHelper(*attrs.subgraphs[0]); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc deleted file mode 100644 index d5fb7ee2db61..000000000000 --- a/src/operator/subgraph/default_subgraph_op.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ - -#include -#include "./common.h" -#include "../../imperative/imperative_utils.h" -#include "../../imperative/cached_op.h" - -namespace mxnet { -namespace op { - -#define DEBUG_SUBGRAPH 0 - -class DefaultSubgraphOperator { - public: - explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) { - subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}, - {"static_shape", "true"}})); - } - - void Forward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - void Backward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - LOG(FATAL) << "Not implemented"; - } - - private: - nnvm::Symbol subgraph_sym_; - CachedOpPtr subgraph_exec_; -}; - -void DefaultSubgraphOperator::Forward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - std::vector tmp_inputs = inputs; - std::vector input_ptrs; - input_ptrs.reserve(inputs.size()); - for (auto& nd : tmp_inputs) { - input_ptrs.push_back(&nd); - } - std::vector tmp_outputs = outputs; - std::vector output_ptrs; - for (auto& nd : tmp_outputs) { - output_ptrs.push_back(&nd); - } -#if DEBUG_SUBGRAPH - for (size_t i = 0; i < inputs.size(); ++i) { - LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version(); - } - for (size_t i = 0; i < outputs.size(); ++i) { - LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version(); - } -#endif - subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs); -} - -OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs, - Context ctx, - const std::vector& in_shapes, - const std::vector& in_types) { - return OpStatePtr::Create(*attrs.subgraphs[0]); -} - -void DefaultSubgraphOpForward(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - DefaultSubgraphOperator& op = state_ptr.get_state(); - op.Forward(ctx, inputs, req, outputs); -} - -NNVM_REGISTER_OP(_default_subgraph_op) -.describe(R"code(_default_subgraph_op)code" ADD_FILELINE) -.set_num_inputs(DefaultSubgraphOpNumInputs) -.set_num_outputs(DefaultSubgraphOpNumOutputs) -.set_attr("FListInputNames", DefaultSubgraphOpListInputs) -.set_attr("FListOutputNames", DefaultSubgraphOpListOutputs) -.set_attr("FCreateOpState", CreateDefaultSubgraphOpState) -.set_attr("FInferShape", DefaultSubgraphOpShape) -.set_attr("FInferType", DefaultSubgraphOpType) -.set_attr("FInferStorageType", DefaultSubgraphOpStorageType) -.set_attr("FStatefulComputeEx", DefaultSubgraphOpForward) -.set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs) -.set_attr("key_var_num_args", "num_args") -.set_attr("FExecType", DefaultSubgraphOpExecType) -.add_argument("data", "NDArray-or-Symbol[]", "input data list"); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu deleted file mode 100644 index 008826b21d71..000000000000 --- a/src/operator/subgraph/default_subgraph_op.cu +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2018 by Contributors - * \file default_subgraph_op.cu - * \brief GPU Implementation of subgraph operations - */ - -#include -#include "./common.h" -#include "../../imperative/imperative_utils.h" -#include "../../imperative/cached_op.h" - -namespace mxnet { -namespace op { - -void DefaultSubgraphOpForward(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - -NNVM_REGISTER_OP(_default_subgraph_op) -.set_attr("FStatefulComputeEx", DefaultSubgraphOpForward); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc index c8d3e9ffd438..0152344f4d43 100644 --- a/src/operator/subgraph/default_subgraph_property.cc +++ b/src/operator/subgraph/default_subgraph_property.cc @@ -21,6 +21,7 @@ #include #include "./common.h" #include "./subgraph_property.h" +#include "../../imperative/cached_op.h" namespace mxnet { namespace op { @@ -51,7 +52,7 @@ class ContainOpSelector: public SubgraphSelector { /* * This subgraph property finds a subgraph whose nodes have only operators - * within a set. The operators in the subgraph will be executed by _default_subgraph_op. + * within a set. The operators in the subgraph will be executed by _CachedOp. */ class DefaultSubgraphProperty: public SubgraphProperty { public: @@ -59,9 +60,13 @@ class DefaultSubgraphProperty: public SubgraphProperty { virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const { nnvm::NodePtr n = nnvm::Node::Create(); - n->attrs.op = Op::Get("_default_subgraph_op"); - n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id); + n->attrs.op = Op::Get("_CachedOp"); + n->attrs.name = "_CachedOp" + std::to_string(subgraph_id); n->attrs.subgraphs.push_back(std::make_shared(sym)); + + std::vector > flags{{"static_alloc", "true"}}; + n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags)); + return n; } virtual SubgraphSelectorPtr CreateSubgraphSelector() const { diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py new file mode 100644 index 000000000000..b5577d4d0ff5 --- /dev/null +++ b/tests/python/unittest/test_subgraph.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import print_function +import numpy as np +import mxnet as mx +import copy +import math +import ctypes +import random +import itertools +from numpy.testing import assert_allclose, assert_array_equal +from mxnet.test_utils import * +from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint +from common import setup_module, with_seed, teardown +import unittest +from mxnet.gluon.model_zoo.vision import get_model + +def make_subgraph(subg, *args): + js = subg.tojson() + return mx.sym._internal._CachedOp(*args, subgraph=js) + +@with_seed() +def test_make_subgraph(): + def make_subgraph1(stype): + a = mx.symbol.Variable(name='a', stype=stype) + b = mx.symbol.Variable(name='b', stype=stype) + c = a * b + d = c * 2 + + a1 = mx.symbol.Variable(name='a', stype=stype) + b1 = mx.symbol.Variable(name='b', stype=stype) + y = make_subgraph(c, a1, b1) + y = y * 2 + + s = (10, 10) + a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s), + ctx=default_context()).tostype(stype) + b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s), + ctx=default_context()).tostype(stype) + return (d, y, {'a': a_arr, 'b': b_arr}, {}) + + def create_weights(shapes, names): + nd_dict = {} + sym_dict = {} + assert len(shapes) == len(names) + for i in range(len(shapes)): + sym_dict[names[i]] = mx.symbol.Variable(names[i]) + nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_context()) + return (nd_dict, sym_dict) + + def make_subgraph_weight(orig, shape, stype): + arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape) + weight_shapes = arg_shapes[1:] + weight_names = orig.list_arguments()[1:] + weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names) + aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states()) + + input_dict = copy.deepcopy(weight_sym_dict) + input_dict.update(aux_sym_dict) + input_dict['data'] = mx.symbol.Variable('data', stype=stype) + input_list = [] + for name in orig.list_inputs(): + assert name in input_dict.keys() + input_list.append(input_dict[name]) + subg = make_subgraph(orig, *input_list) + + arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_context()).tostype(stype) + arg_dict = weight_dict + arg_dict['data'] = arr + return (orig, subg, arg_dict, aux_dict) + + def make_subgraph2(stype, out_mean_var): + data = mx.symbol.Variable('data', stype=stype) + orig = mx.symbol.BatchNorm(data, fix_gamma=False, + output_mean_var=out_mean_var, name="batchnorm") + s = (10, 10) + return make_subgraph_weight(orig, s, stype) + + def make_subgraph3(stype): + data = mx.symbol.Variable('data', stype=stype) + conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True) + bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False) + conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True) + bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False) + orig = bn1 + bn2 + s = (1, 3, 32, 32) + return make_subgraph_weight(orig, s, stype) + + def make_subgraph4(stype): + model = get_model('resnet18_v1') + model.hybridize() + model.initialize() + s = (1, 3, 32, 32) + data = mx.nd.random.normal(shape=s) + out = model(data) + model.export('resnet18') + orig = mx.sym.load('resnet18-symbol.json') + return make_subgraph_weight(orig, s, stype) + + make_subgraphs = [make_subgraph1, + lambda stype: make_subgraph2(stype, False), + lambda stype: make_subgraph2(stype, True), + make_subgraph3, make_subgraph4] + stypes = ['default', 'row_sparse'] + for make_subg in make_subgraphs: + for stype in stypes: + orig, subg, inputs, aux_states = make_subg(stype) + all_inputs = copy.deepcopy(inputs) + all_inputs.update(aux_states) + args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} + e1 = orig.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, + aux_states=all_inputs) + args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} + e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, + aux_states=all_inputs) + e1.forward(is_train=True) + e2.forward(is_train=True) + for i in range(len(e1.outputs)): + assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(), + rtol=0.001, atol=0.0001) + + out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_context()) + for out in e1.outputs] + e1.backward(out_grads) + e2.backward(out_grads) + for i in range(len(e1.grad_arrays)): + assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(), + rtol=0.001, atol=0.0001) + + +if __name__ == '__main__': + import nose + nose.runmodule()