From f04686a586220d90630c8aecd4a36364d0f2f660 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 5 Jul 2018 00:54:35 +0000 Subject: [PATCH 01/16] extend _CachedOp a regular operator. --- src/imperative/cached_op.cc | 342 ++++++++++++++++++++++++- src/imperative/cached_op.h | 31 +++ src/operator/operator_common.h | 12 +- tests/python/unittest/test_subgraph.py | 149 +++++++++++ 4 files changed, 518 insertions(+), 16 deletions(-) create mode 100644 tests/python/unittest/test_subgraph.py diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 0c4c1e60208f..defe9df22c7d 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -874,7 +874,6 @@ OpStatePtr CachedOp::Forward( return op_state; } - void CachedOp::DynamicBackward( const bool retain_graph, const OpStatePtr& op_state, @@ -1067,6 +1066,130 @@ void CachedOp::Backward( Engine::Get()->set_bulk_size(prev_bulk_size); } +struct CachedOpActualState { + std::shared_ptr op; + OpStatePtr forward_state; + + explicit CachedOpActualState(std::shared_ptr op) { + this->op = op; + } +}; + +void CachedOpForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs(in_bufs.size()); + std::vector out_ptrs(out_bufs.size()); + for (size_t i = 0; i < in_ptrs.size(); i++) + in_ptrs[i] = &in_bufs[i]; + for (size_t i = 0; i < out_ptrs.size(); i++) + out_ptrs[i] = &out_bufs[i]; + + // Set is_recording correct for the imperative executor. + bool orig_is_record; + if (ctx.need_grad) + orig_is_record = Imperative::Get()->set_is_recording(true); + else + orig_is_record = Imperative::Get()->is_recording(); + // Set is_training correct for the imperative executor. + bool orig_is_train; + if (ctx.is_train) + orig_is_train = Imperative::Get()->set_is_training(true); + else + orig_is_train = Imperative::Get()->is_training(); + s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs); + Imperative::Get()->set_is_training(orig_is_train); + Imperative::Get()->set_is_recording(orig_is_record); + // The arrays in out_ptrs may be changed by CachedOp. + // If it is, we need to copy data back. + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(outputs[i])) + CopyFromTo(out_bufs[i], outputs[i]); +} + +void CachedOpBackward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace nnvm; + using namespace imperative; + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs; + std::vector out_ptrs; + CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); + in_ptrs.reserve(s.op->num_backward_inputs()); + out_ptrs.reserve(s.op->num_inputs()); + + const std::vector &save_inputs = s.op->save_inputs(); + const std::vector &save_outputs = s.op->save_outputs(); + size_t bwd_in_dep = s.op->num_inputs(); + size_t bwd_out_dep = s.op->num_outputs(); + CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); + size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; + + // Find inputs, outputs and ograds + auto ograds_begin = in_bufs.begin(); + auto ograds_end = in_bufs.begin() + bwd_ograd_dep; + auto in_begin = ograds_end; + auto in_end = in_begin + bwd_in_dep; + auto out_begin = in_end; + auto out_end = in_bufs.end(); + + for (auto it = ograds_begin; it != ograds_end; it++) + in_ptrs.push_back(&(*it)); + + CHECK_EQ(save_inputs.size(), in_end - in_begin); + CHECK_EQ(s.op->num_outputs(), out_end - out_begin); + for (auto it = in_begin; it != in_end; it++) { + auto i = it - in_begin; + if (save_inputs[i]) + in_ptrs.push_back(&(*it)); + } + for (auto it = out_begin; it != out_end; it++) { + auto i = it - out_begin; + if (save_outputs[i]) + in_ptrs.push_back(&(*it)); + } + CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs()); + for (size_t i = 0; i < out_bufs.size(); i++) + out_ptrs.push_back(&out_bufs[i]); + CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs()); + // Set is_training correct for the imperative executor. + bool orig_is_train; + if (ctx.is_train) + orig_is_train = Imperative::Get()->set_is_training(true); + else + orig_is_train = Imperative::Get()->is_training(); + // TODO(zhengda) is it right to use false here? + s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); + Imperative::Get()->set_is_training(orig_is_train); + + // Clean up what we recorded. + s.forward_state.reset(); + + // The arrays in out_ptrs may be changed by CachedOp. + // If it is, we need to copy data back. + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(outputs[i])) + CopyFromTo(out_bufs[i], outputs[i]); +} + +OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, + Context ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return OpStatePtr::Create(op); +} + bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -1143,6 +1266,155 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, return true; } +bool CachedOp::ForwardInferShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + using namespace exec; + nnvm::Graph g(fwd_graph_); + const auto& idx_g = g.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size()); + CHECK_EQ(idx_g.outputs().size(), out_shapes->size()); + + // TODO(zhengda) we can cache the shape vector. + // Put the input and output shapes to the shape vector. + nnvm::ShapeVector shapes(idx_g.num_node_entries()); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_shapes->size()); + for (size_t i = 0; i < in_shapes->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + shapes[eid] = in_shapes->at(i); + } + CHECK_EQ(g.outputs.size(), out_shapes->size()); + for (size_t i = 0; i < out_shapes->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + shapes[eid] = out_shapes->at(i); + } + + // Infer shape of the graph. + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + + // Copy the inferred shape back to the input shapes and the output shapes. + shapes = g.GetAttr("shape"); + // assign to in_shapes + for (size_t i = 0; i < in_shapes->size(); ++i) { + const auto eid = idx_g.entry_id(input_nids[i], 0); + SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]); + } + // assign to out_shapes + for (size_t i = 0; i < g.outputs.size(); ++i) { + const auto eid = idx_g.entry_id(g.outputs[i]); + SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]); + } + // Check if we have inferred the shapes correctly. + return g.GetAttr("shape_num_unknown_nodes") == 0; +} + +bool CachedOp::ForwardInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + nnvm::Graph g(fwd_graph_); + const auto& idx_g = g.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), in_types->size()); + CHECK_EQ(idx_g.outputs().size(), out_types->size()); + + // TODO(zhengda) we can cache the shape vector. + // Put the input and output data types to the dtype vector. + nnvm::DTypeVector types(idx_g.num_node_entries(), -1); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_types->size()); + for (size_t i = 0; i < in_types->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + types[eid] = in_types->at(i); + } + CHECK_EQ(g.outputs.size(), out_types->size()); + for (size_t i = 0; i < out_types->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + types[eid] = out_types->at(i); + } + + // Infer data type of the graph. + g.attrs["dtype"] = std::make_shared(std::move(types)); + g = exec::InferType(std::move(g)); + + types = g.GetAttr("dtype"); + // assign to in_types + for (size_t i = 0; i < in_types->size(); ++i) { + const auto eid = idx_g.entry_id(input_nids[i], 0); + TYPE_ASSIGN_CHECK(*in_types, i, types[eid]); + } + // assign to out_types + for (size_t i = 0; i < g.outputs.size(); ++i) { + const auto eid = idx_g.entry_id(g.outputs[i]); + TYPE_ASSIGN_CHECK(*out_types, i, types[eid]); + } + // Check if we have inferred the dtypes correctly. + return g.GetAttr("dtype_num_unknown_nodes") == 0; +} + +std::vector CachedOp::MutableInputs() const { + nnvm::Symbol sym = GetForwardSym(); + const std::vector input_names = sym.ListInputNames(nnvm::Symbol::kAll); + const std::vector immutable_input_names = + sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + const std::vector mutable_input_names = + sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size()); + std::vector ret; + size_t i1 = 0, i2 = 0; + for (size_t i = 0; i < input_names.size(); ++i) { + if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) { + ++i1; + } else { + CHECK(i2 < mutable_input_names.size()); + CHECK_EQ(input_names[i], mutable_input_names[i2]); + ++i2; + ret.push_back(i); + } + } + return ret; +} + +std::vector CachedOp::GetResourceRequest() const { + nnvm::Symbol sym = GetForwardSym(); + static auto& fresource = Op::GetAttr("FResourceRequest"); + std::set resource_types; + DFSVisit(sym.outputs, [&](const nnvm::NodePtr& node) { + if (!node->is_variable() && fresource.count(node->op())) { + for (ResourceRequest& r : fresource[node->op()](node->attrs)){ + resource_types.insert(r.type); + } + } + }); + return std::vector(resource_types.begin(), resource_types.end()); +} + +void CachedOpParamParser(nnvm::NodeAttrs* attrs) { + CachedOpConfig param; + try { + param.Init(attrs->dict); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + if (!param.subgraph.empty()) { + nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph); + CHECK(!g.outputs.empty()); + nnvm::Symbol sym; + sym.outputs = g.outputs; + std::vector > flags; + for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++) + flags.emplace_back(it->first, it->second); + attrs->parsed = CachedOpPtr(new CachedOp(sym, flags)); + } +} NNVM_REGISTER_OP(_CachedOp) .set_num_inputs([](const NodeAttrs& attrs) { @@ -1153,19 +1425,63 @@ NNVM_REGISTER_OP(_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->num_outputs(); }) -.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); - }) +.set_attr_parser(CachedOpParamParser) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { const CachedOpPtr& op = nnvm::get(n->attrs.parsed); return op->Gradient(n, ograds); - }); + }) +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardInputNames(); + }) +.set_attr("FListOutputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardOutputNames(); + }) +.set_attr("FCreateOpState", CreateCachedOpState) +.set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardInferShape(attrs, in_shapes, out_shapes); + }) +.set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardInferType(attrs, in_types, out_types); + }) +.set_attr("FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes); + }) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->MutableInputs(); + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->GetResourceRequest(); + }) +.set_attr("FExecType", + [](const nnvm::NodeAttrs& attrs) { + return ExecType::kSubgraphExec; + }) +.add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) .set_num_inputs([](const NodeAttrs& attrs){ @@ -1184,6 +1500,12 @@ NNVM_REGISTER_OP(_backward_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); }) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FExecType", + [](const nnvm::NodeAttrs& attrs) { + return ExecType::kSubgraphExec; + }) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 4f4dfdcc14dd..138e0a38a017 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter { bool static_shape; nnvm::Tuple data_indices; nnvm::Tuple param_indices; + std::string subgraph; DMLC_DECLARE_PARAMETER(CachedOpConfig) { DMLC_DECLARE_FIELD(static_alloc) .set_default(false) @@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter { DMLC_DECLARE_FIELD(param_indices) .set_default(nnvm::Tuple()) .describe("Position of parameters."); + DMLC_DECLARE_FIELD(subgraph) + .set_default(std::string("")) + .describe("JSON string of a subgraph."); } }; @@ -80,6 +84,10 @@ class CachedOp { uint32_t num_backward_inputs() const { return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); } + uint32_t num_backward_outputs() const { + auto &idx = fwd_graph_.indexed_graph(); + return idx.input_nodes().size() - idx.mutable_input_nodes().size(); + } std::vector& save_inputs() { return save_inputs_; } @@ -116,6 +124,24 @@ class CachedOp { DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs); + bool ForwardInferShape( + const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes); + bool ForwardInferType( + const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types); + std::vector ListForwardInputNames() const { + nnvm::Symbol sym = GetForwardSym(); + return sym.ListInputNames(nnvm::Symbol::kAll); + } + std::vector ListForwardOutputNames() const { + nnvm::Symbol sym = GetForwardSym(); + return sym.ListOutputNames(); + } + std::vector MutableInputs() const; + std::vector GetResourceRequest() const; private: struct GraphInfo; @@ -167,6 +193,11 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); + nnvm::Symbol GetForwardSym() const { + nnvm::Symbol sym; + sym.outputs = fwd_graph_.outputs; + return sym; + } CachedOpConfig config_; nnvm::Graph fwd_graph_; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 29112939a22f..6a4c3d027075 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ { \ - if (!shape_assign(&(shape_array)[index], TShape(shape))) { \ + if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \ std::ostringstream os; \ os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\ << " inferred shape=" << shape; \ @@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define TYPE_ASSIGN_CHECK(type_array, index, type) \ { \ - if (!type_assign(&(type_array)[index], type)) { \ + if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \ std::ostringstream os; \ os << "Type inconsistent, Provided = " \ - << type_string((type_array)[index]) << ',' \ - << " inferred type = " << type_string(type); \ + << ::mxnet::op::type_string((type_array)[index]) << ',' \ + << " inferred type = " << ::mxnet::op::type_string(type); \ throw ::mxnet::op::InferTypeError(os.str(), index); \ } \ } @@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { #define UNIFORM_TYPE_CHECK(type, expected, arg) \ { \ CHECK_EQ(type, expected) << "This layer requires uniform type. " \ - << "Expected '" << type_string(expected) \ - << "' v.s. given '" << type_string(type) \ + << "Expected '" << ::mxnet::op::type_string(expected) \ + << "' v.s. given '" << ::mxnet::op::type_string(type) \ << "' at '" << arg << "'"; \ } diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py new file mode 100644 index 000000000000..338d3ae781f4 --- /dev/null +++ b/tests/python/unittest/test_subgraph.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import print_function +import numpy as np +import mxnet as mx +import copy +import math +import ctypes +import random +import itertools +from numpy.testing import assert_allclose, assert_array_equal +from mxnet.test_utils import * +from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint +from common import setup_module, with_seed, teardown +import unittest +from mxnet.gluon.model_zoo.vision import get_model + +def make_subgraph(subg, *args): + js = subg.tojson() + return mx.sym._internal._CachedOp(*args, subgraph=js) + +@with_seed() +def test_make_subgraph(): + def make_subgraph1(stype): + a = mx.symbol.Variable(name='a', stype=stype) + b = mx.symbol.Variable(name='b', stype=stype) + c = a * b + d = c * 2 + + a1 = mx.symbol.Variable(name='a', stype=stype) + b1 = mx.symbol.Variable(name='b', stype=stype) + y = make_subgraph(c, a1, b1) + y = y * 2 + + s = (10, 10) + a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s), + ctx=default_context()).tostype(stype) + b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s), + ctx=default_context()).tostype(stype) + return (d, y, {'a': a_arr, 'b': b_arr}, {}) + + def create_weights(shapes, names): + nd_dict = {} + sym_dict = {} + assert len(shapes) == len(names) + for i in range(len(shapes)): + sym_dict[names[i]] = mx.symbol.Variable(names[i]) + nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_context()) + return (nd_dict, sym_dict) + + def make_subgraph_weight(orig, shape, stype): + arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape) + weight_shapes = arg_shapes[1:] + weight_names = orig.list_arguments()[1:] + weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names) + aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states()) + + input_dict = copy.deepcopy(weight_sym_dict) + input_dict.update(aux_sym_dict) + input_dict['data'] = mx.symbol.Variable('data', stype=stype) + input_list = [] + for name in orig.list_inputs(): + assert name in input_dict.keys() + input_list.append(input_dict[name]) + subg = make_subgraph(orig, *input_list) + + arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_context()).tostype(stype) + arg_dict = weight_dict + arg_dict['data'] = arr + return (orig, subg, arg_dict, aux_dict) + + def make_subgraph2(stype, out_mean_var): + data = mx.symbol.Variable('data', stype=stype) + orig = mx.symbol.BatchNorm(data, fix_gamma=False, + output_mean_var=out_mean_var, name="batchnorm") + s = (10, 10) + return make_subgraph_weight(orig, s, stype) + + def make_subgraph3(stype): + data = mx.symbol.Variable('data', stype=stype) + conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True) + bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False) + conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True) + bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False) + orig = bn1 + bn2 + s = (1, 3, 32, 32) + return make_subgraph_weight(orig, s, stype) + + def make_subgraph4(stype): + model = get_model('resnet18_v1') + model.hybridize() + model.initialize() + s = (1, 3, 32, 32) + data = mx.nd.random.normal(shape=s) + out = model(data) + model.export('resnet18') + orig = mx.sym.load('resnet18-symbol.json') + return make_subgraph_weight(orig, s, stype) + + make_subgraphs = [make_subgraph1, + lambda stype: make_subgraph2(stype, False), + lambda stype: make_subgraph2(stype, True), + make_subgraph3] + stypes = ['default', 'row_sparse'] + for make_subg in make_subgraphs: + for stype in stypes: + orig, subg, inputs, aux_states = make_subg(stype) + all_inputs = copy.deepcopy(inputs) + all_inputs.update(aux_states) + args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} + e1 = orig.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, + aux_states=all_inputs) + args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} + e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, + aux_states=all_inputs) + e1.forward() + e2.forward() + for i in range(len(e1.outputs)): + assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(), + rtol=0.001, atol=0.0001) + + out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_context()) + for out in e1.outputs] + e1.backward(out_grads) + e2.backward(out_grads) + for i in range(len(e1.grad_arrays)): + assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(), + rtol=0.001, atol=0.0001) + + +if __name__ == '__main__': + import nose + nose.runmodule() From b7f1066ef41bc90a64189723f106a7a1bc415662 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 31 Aug 2018 11:28:22 -0700 Subject: [PATCH 02/16] use default subgraph infer. --- src/imperative/cached_op.cc | 90 +--------------------------------- src/imperative/cached_op.h | 8 --- src/operator/subgraph/common.h | 30 +++++++++--- 3 files changed, 26 insertions(+), 102 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index defe9df22c7d..4202d2ceca36 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1266,92 +1266,6 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, return true; } -bool CachedOp::ForwardInferShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shapes, - std::vector *out_shapes) { - using namespace exec; - nnvm::Graph g(fwd_graph_); - const auto& idx_g = g.indexed_graph(); - CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size()); - CHECK_EQ(idx_g.outputs().size(), out_shapes->size()); - - // TODO(zhengda) we can cache the shape vector. - // Put the input and output shapes to the shape vector. - nnvm::ShapeVector shapes(idx_g.num_node_entries()); - const auto &input_nids = idx_g.input_nodes(); - CHECK_EQ(input_nids.size(), in_shapes->size()); - for (size_t i = 0; i < in_shapes->size(); i++) { - auto eid = idx_g.entry_id(input_nids[i], 0); - shapes[eid] = in_shapes->at(i); - } - CHECK_EQ(g.outputs.size(), out_shapes->size()); - for (size_t i = 0; i < out_shapes->size(); i++) { - auto eid = idx_g.entry_id(g.outputs[i]); - shapes[eid] = out_shapes->at(i); - } - - // Infer shape of the graph. - g.attrs["shape"] = std::make_shared(std::move(shapes)); - g = exec::InferShape(std::move(g)); - - // Copy the inferred shape back to the input shapes and the output shapes. - shapes = g.GetAttr("shape"); - // assign to in_shapes - for (size_t i = 0; i < in_shapes->size(); ++i) { - const auto eid = idx_g.entry_id(input_nids[i], 0); - SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]); - } - // assign to out_shapes - for (size_t i = 0; i < g.outputs.size(); ++i) { - const auto eid = idx_g.entry_id(g.outputs[i]); - SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]); - } - // Check if we have inferred the shapes correctly. - return g.GetAttr("shape_num_unknown_nodes") == 0; -} - -bool CachedOp::ForwardInferType(const nnvm::NodeAttrs& attrs, - std::vector *in_types, - std::vector *out_types) { - nnvm::Graph g(fwd_graph_); - const auto& idx_g = g.indexed_graph(); - CHECK_EQ(idx_g.input_nodes().size(), in_types->size()); - CHECK_EQ(idx_g.outputs().size(), out_types->size()); - - // TODO(zhengda) we can cache the shape vector. - // Put the input and output data types to the dtype vector. - nnvm::DTypeVector types(idx_g.num_node_entries(), -1); - const auto &input_nids = idx_g.input_nodes(); - CHECK_EQ(input_nids.size(), in_types->size()); - for (size_t i = 0; i < in_types->size(); i++) { - auto eid = idx_g.entry_id(input_nids[i], 0); - types[eid] = in_types->at(i); - } - CHECK_EQ(g.outputs.size(), out_types->size()); - for (size_t i = 0; i < out_types->size(); i++) { - auto eid = idx_g.entry_id(g.outputs[i]); - types[eid] = out_types->at(i); - } - - // Infer data type of the graph. - g.attrs["dtype"] = std::make_shared(std::move(types)); - g = exec::InferType(std::move(g)); - - types = g.GetAttr("dtype"); - // assign to in_types - for (size_t i = 0; i < in_types->size(); ++i) { - const auto eid = idx_g.entry_id(input_nids[i], 0); - TYPE_ASSIGN_CHECK(*in_types, i, types[eid]); - } - // assign to out_types - for (size_t i = 0; i < g.outputs.size(); ++i) { - const auto eid = idx_g.entry_id(g.outputs[i]); - TYPE_ASSIGN_CHECK(*out_types, i, types[eid]); - } - // Check if we have inferred the dtypes correctly. - return g.GetAttr("dtype_num_unknown_nodes") == 0; -} - std::vector CachedOp::MutableInputs() const { nnvm::Symbol sym = GetForwardSym(); const std::vector input_names = sym.ListInputNames(nnvm::Symbol::kAll); @@ -1447,14 +1361,14 @@ NNVM_REGISTER_OP(_CachedOp) std::vector *in_shapes, std::vector *out_shapes) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ForwardInferShape(attrs, in_shapes, out_shapes); + return DefaultSubgraphOpShape(op.GetForwardGraph(), in_shapes, out_shapes); }) .set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, std::vector *in_types, std::vector *out_types) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ForwardInferType(attrs, in_types, out_types); + return DefaultSubgraphOpType(op.GetForwardGraph(), in_types, out_types); }) .set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 138e0a38a017..600060c94f50 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -124,14 +124,6 @@ class CachedOp { DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs); - bool ForwardInferShape( - const nnvm::NodeAttrs& attrs, - std::vector *in_shapes, - std::vector *out_shapes); - bool ForwardInferType( - const nnvm::NodeAttrs& attrs, - std::vector *in_types, - std::vector *out_types); std::vector ListForwardInputNames() const { nnvm::Symbol sym = GetForwardSym(); return sym.ListInputNames(nnvm::Symbol::kAll); diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h index 22058d556e07..fdd6a1e0791b 100644 --- a/src/operator/subgraph/common.h +++ b/src/operator/subgraph/common.h @@ -49,11 +49,10 @@ inline std::vector DefaultSubgraphOpListOutputs(const nnvm::NodeAtt return sym.ListOutputNames(); } -inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, +inline bool DefaultSubgraphOpShape(const nnvm::Symbol& subgraph_sym, std::vector *in_shapes, std::vector *out_shapes) { using namespace exec; - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -94,10 +93,15 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, return g.GetAttr("shape_num_unknown_nodes") == 0; } -inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, +inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + return DefaultSubgraphOpShape(*attrs.subgraphs[0], in_shapes, out_shapes); +} + +inline bool DefaultSubgraphOpType(const nnvm::Symbol& subgraph_sym, std::vector *in_types, std::vector *out_types) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -137,12 +141,17 @@ inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, return g.GetAttr("dtype_num_unknown_nodes") == 0; } -inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, +inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + return DefaultSubgraphOpType(*attrs.subgraphs[0], in_types, out_types); +} + +inline bool DefaultSubgraphOpStorageType(const nnvm::Symbol& subgraph_sym, const int dev_mask, DispatchMode* dispatch_mode, std::vector* in_stypes, std::vector* out_stypes) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -190,6 +199,15 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, return g.GetAttr("storage_type_num_unknown_nodes") == 0; } +inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + return DefaultSubgraphOpStorageType(*attrs.subgraphs[0], dev_mask, dispatch_mode, + in_stypes, out_stypes); +} + inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) { return ExecType::kSubgraphExec; } From 33895d1f05771411ec08b511f1bfacf630c7a65c Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 31 Aug 2018 11:47:15 -0700 Subject: [PATCH 03/16] fix test. --- tests/python/unittest/test_subgraph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py index 338d3ae781f4..b5577d4d0ff5 100644 --- a/tests/python/unittest/test_subgraph.py +++ b/tests/python/unittest/test_subgraph.py @@ -116,7 +116,7 @@ def make_subgraph4(stype): make_subgraphs = [make_subgraph1, lambda stype: make_subgraph2(stype, False), lambda stype: make_subgraph2(stype, True), - make_subgraph3] + make_subgraph3, make_subgraph4] stypes = ['default', 'row_sparse'] for make_subg in make_subgraphs: for stype in stypes: @@ -129,8 +129,8 @@ def make_subgraph4(stype): args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, aux_states=all_inputs) - e1.forward() - e2.forward() + e1.forward(is_train=True) + e2.forward(is_train=True) for i in range(len(e1.outputs)): assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(), rtol=0.001, atol=0.0001) From 9b0771c04c4ca08e8f1aca0da34963bb87acd198 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 31 Aug 2018 20:14:02 +0000 Subject: [PATCH 04/16] fix compilation error. --- src/imperative/cached_op.cc | 5 +++-- src/imperative/cached_op.h | 10 +++++----- src/operator/subgraph/common.h | 30 +++++++++++++++--------------- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 4202d2ceca36..3159662fe79b 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -23,6 +23,7 @@ #include "../executor/exec_pass.h" #include "../profiler/profiler.h" #include "../operator/operator_common.h" +#include "../operator/subgraph/common.h" namespace mxnet { @@ -1361,14 +1362,14 @@ NNVM_REGISTER_OP(_CachedOp) std::vector *in_shapes, std::vector *out_shapes) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return DefaultSubgraphOpShape(op.GetForwardGraph(), in_shapes, out_shapes); + return op::DefaultSubgraphOpShape1(op->GetForwardSym(), in_shapes, out_shapes); }) .set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, std::vector *in_types, std::vector *out_types) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return DefaultSubgraphOpType(op.GetForwardGraph(), in_types, out_types); + return op::DefaultSubgraphOpType1(op->GetForwardSym(), in_types, out_types); }) .set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 600060c94f50..d93d3d6e1bff 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -134,6 +134,11 @@ class CachedOp { } std::vector MutableInputs() const; std::vector GetResourceRequest() const; + nnvm::Symbol GetForwardSym() const { + nnvm::Symbol sym; + sym.outputs = fwd_graph_.outputs; + return sym; + } private: struct GraphInfo; @@ -185,11 +190,6 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); - nnvm::Symbol GetForwardSym() const { - nnvm::Symbol sym; - sym.outputs = fwd_graph_.outputs; - return sym; - } CachedOpConfig config_; nnvm::Graph fwd_graph_; diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h index fdd6a1e0791b..b13f0eae7d68 100644 --- a/src/operator/subgraph/common.h +++ b/src/operator/subgraph/common.h @@ -49,9 +49,9 @@ inline std::vector DefaultSubgraphOpListOutputs(const nnvm::NodeAtt return sym.ListOutputNames(); } -inline bool DefaultSubgraphOpShape(const nnvm::Symbol& subgraph_sym, - std::vector *in_shapes, - std::vector *out_shapes) { +inline bool DefaultSubgraphOpShape1(const nnvm::Symbol& subgraph_sym, + std::vector *in_shapes, + std::vector *out_shapes) { using namespace exec; nnvm::Graph g; g.outputs = subgraph_sym.outputs; @@ -96,12 +96,12 @@ inline bool DefaultSubgraphOpShape(const nnvm::Symbol& subgraph_sym, inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_shapes, std::vector *out_shapes) { - return DefaultSubgraphOpShape(*attrs.subgraphs[0], in_shapes, out_shapes); + return DefaultSubgraphOpShape1(*attrs.subgraphs[0], in_shapes, out_shapes); } -inline bool DefaultSubgraphOpType(const nnvm::Symbol& subgraph_sym, - std::vector *in_types, - std::vector *out_types) { +inline bool DefaultSubgraphOpType1(const nnvm::Symbol& subgraph_sym, + std::vector *in_types, + std::vector *out_types) { nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -144,14 +144,14 @@ inline bool DefaultSubgraphOpType(const nnvm::Symbol& subgraph_sym, inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, std::vector *in_types, std::vector *out_types) { - return DefaultSubgraphOpType(*attrs.subgraphs[0], in_types, out_types); + return DefaultSubgraphOpType1(*attrs.subgraphs[0], in_types, out_types); } -inline bool DefaultSubgraphOpStorageType(const nnvm::Symbol& subgraph_sym, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { +inline bool DefaultSubgraphOpStorageType1(const nnvm::Symbol& subgraph_sym, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -204,8 +204,8 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector* in_stypes, std::vector* out_stypes) { - return DefaultSubgraphOpStorageType(*attrs.subgraphs[0], dev_mask, dispatch_mode, - in_stypes, out_stypes); + return DefaultSubgraphOpStorageType1(*attrs.subgraphs[0], dev_mask, dispatch_mode, + in_stypes, out_stypes); } inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) { From f253adeb9c537d1212c54457399203c2e566a507 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 31 Aug 2018 20:20:56 +0000 Subject: [PATCH 05/16] use default subgraph stuff. --- src/imperative/cached_op.cc | 41 ++-------------------------------- src/imperative/cached_op.h | 2 -- src/operator/subgraph/common.h | 15 +++++++++---- 3 files changed, 13 insertions(+), 45 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 3159662fe79b..b42aa2eeced0 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1267,43 +1267,6 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, return true; } -std::vector CachedOp::MutableInputs() const { - nnvm::Symbol sym = GetForwardSym(); - const std::vector input_names = sym.ListInputNames(nnvm::Symbol::kAll); - const std::vector immutable_input_names = - sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs); - const std::vector mutable_input_names = - sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates); - CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size()); - std::vector ret; - size_t i1 = 0, i2 = 0; - for (size_t i = 0; i < input_names.size(); ++i) { - if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) { - ++i1; - } else { - CHECK(i2 < mutable_input_names.size()); - CHECK_EQ(input_names[i], mutable_input_names[i2]); - ++i2; - ret.push_back(i); - } - } - return ret; -} - -std::vector CachedOp::GetResourceRequest() const { - nnvm::Symbol sym = GetForwardSym(); - static auto& fresource = Op::GetAttr("FResourceRequest"); - std::set resource_types; - DFSVisit(sym.outputs, [&](const nnvm::NodePtr& node) { - if (!node->is_variable() && fresource.count(node->op())) { - for (ResourceRequest& r : fresource[node->op()](node->attrs)){ - resource_types.insert(r.type); - } - } - }); - return std::vector(resource_types.begin(), resource_types.end()); -} - void CachedOpParamParser(nnvm::NodeAttrs* attrs) { CachedOpConfig param; try { @@ -1385,12 +1348,12 @@ NNVM_REGISTER_OP(_CachedOp) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->MutableInputs(); + return DefaultSubgraphOpMutableInputs1(op->GetForwardSym()); }) .set_attr("FResourceRequest", [](const nnvm::NodeAttrs& attrs) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->GetResourceRequest(); + return DefaultSubgraphOpResourceRequest1(op->GetForwardSym()); }) .set_attr("FExecType", [](const nnvm::NodeAttrs& attrs) { diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index d93d3d6e1bff..8a0795818fe0 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -132,8 +132,6 @@ class CachedOp { nnvm::Symbol sym = GetForwardSym(); return sym.ListOutputNames(); } - std::vector MutableInputs() const; - std::vector GetResourceRequest() const; nnvm::Symbol GetForwardSym() const { nnvm::Symbol sym; sym.outputs = fwd_graph_.outputs; diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h index b13f0eae7d68..259a6bb88023 100644 --- a/src/operator/subgraph/common.h +++ b/src/operator/subgraph/common.h @@ -212,8 +212,7 @@ inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) { return ExecType::kSubgraphExec; } -inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; +inline std::vector DefaultSubgraphOpMutableInputs1(const nnvm::Symbol& subgraph_sym) { const std::vector input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll); const std::vector immutable_input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs); @@ -235,8 +234,12 @@ inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr return ret; } -inline std::vector DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) { - const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0]; +inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) { + return DefaultSubgraphOpMutableInputs1(*attrs.subgraphs[0]); +} + +inline std::vector DefaultSubgraphOpResourceRequest1( + const nnvm::Symbol& subgraph_sym) { static auto& fresource = Op::GetAttr("FResourceRequest"); std::set resource_types; DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) { @@ -249,6 +252,10 @@ inline std::vector DefaultSubgraphOpResourceRequest(const nnvm: return std::vector(resource_types.begin(), resource_types.end()); } +inline std::vector DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) { + return DefaultSubgraphOpResourceRequest1(*attrs.subgraphs[0]); +} + } // namespace op } // namespace mxnet From daf38010d4e1fcb998ac4d92b67db8b264a6f21c Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 31 Aug 2018 20:27:01 +0000 Subject: [PATCH 06/16] add comments. --- src/imperative/cached_op.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index b42aa2eeced0..797afa5ae2ec 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1067,6 +1067,10 @@ void CachedOp::Backward( Engine::Get()->set_bulk_size(prev_bulk_size); } +/* + * This is the operator state of CachedOp when CachedOp is used in the symbol + * executor. This is different from the OpState returned by CachedOp::Forward. + */ struct CachedOpActualState { std::shared_ptr op; OpStatePtr forward_state; @@ -1076,6 +1080,10 @@ struct CachedOpActualState { } }; +/* + * This is the forward computation when CachedOp is used as an operator in + * a symbol executor. + */ void CachedOpForward(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, @@ -1113,6 +1121,10 @@ void CachedOpForward(const OpStatePtr& state_ptr, CopyFromTo(out_bufs[i], outputs[i]); } +/* + * This is the backward computation when CachedOp is used as an operator in + * a symbol executor. + */ void CachedOpBackward(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, From 15751a4f2f41bc649249923b5cfa01ae5b6feb65 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Tue, 4 Sep 2018 18:30:49 +0000 Subject: [PATCH 07/16] fix. --- src/imperative/cached_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 797afa5ae2ec..8a29edf078ec 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1360,12 +1360,12 @@ NNVM_REGISTER_OP(_CachedOp) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return DefaultSubgraphOpMutableInputs1(op->GetForwardSym()); + return op::DefaultSubgraphOpMutableInputs1(op->GetForwardSym()); }) .set_attr("FResourceRequest", [](const nnvm::NodeAttrs& attrs) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return DefaultSubgraphOpResourceRequest1(op->GetForwardSym()); + return op::DefaultSubgraphOpResourceRequest1(op->GetForwardSym()); }) .set_attr("FExecType", [](const nnvm::NodeAttrs& attrs) { From 7c175c56ceade9cbe74d084e992f42e82d4ed9b2 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 10 Sep 2018 18:20:57 +0000 Subject: [PATCH 08/16] use a more general InferStorage. --- src/imperative/cached_op.cc | 34 +++------------------------------- src/imperative/cached_op.h | 7 ------- 2 files changed, 3 insertions(+), 38 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 8a29edf078ec..ca6be8693e81 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1203,36 +1203,6 @@ OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, return OpStatePtr::Create(op); } -bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - using namespace imperative; - nnvm::Graph g(fwd_graph_); - const auto& idx = g.indexed_graph(); - const auto &outputs = idx.outputs(); - - // Prepare stypes and contexts based on inputs - StorageTypeVector storage_type_inputs; - storage_type_inputs.reserve(in_attrs->size()); - for (size_t i = 0; i < in_attrs->size(); ++i) { - storage_type_inputs.emplace_back(in_attrs->at(i)); - } - exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); - - // Forward graph storage type inference - CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true); - // Retrieve result and set outputs - const auto& inferred_stypes = g.GetAttr("storage_type"); - for (size_t i = 0; i < out_attrs->size(); i++) { - const auto eid = idx.entry_id(outputs[i]); - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]); - } - DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); - return true; -} - bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -1353,7 +1323,9 @@ NNVM_REGISTER_OP(_CachedOp) std::vector* in_stypes, std::vector* out_stypes) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes); + return op::DefaultSubgraphOpStorageType1(op->GetForwardSym(), + dev_mask, dispatch_mode, + in_stypes, out_stypes); }) .set_attr("FStatefulComputeEx", CachedOpForward) .set_attr("FStatefulComputeEx", CachedOpForward) diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 8a0795818fe0..59a793ee1b65 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -110,13 +110,6 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); - // forward storage type inference - bool ForwardStorageType( - const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs); // backward storage type inference bool BackwardStorageType( const nnvm::NodeAttrs& attrs, From 8fa5ac8363706c014c74fc49671fc6a03787ba98 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 10 Sep 2018 18:21:25 +0000 Subject: [PATCH 09/16] use cachedOp as default subgraph operator. --- src/operator/subgraph/default_subgraph_property.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc index c8d3e9ffd438..9109a97e192c 100644 --- a/src/operator/subgraph/default_subgraph_property.cc +++ b/src/operator/subgraph/default_subgraph_property.cc @@ -21,6 +21,7 @@ #include #include "./common.h" #include "./subgraph_property.h" +#include "../../imperative/cached_op.h" namespace mxnet { namespace op { @@ -59,9 +60,13 @@ class DefaultSubgraphProperty: public SubgraphProperty { virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const { nnvm::NodePtr n = nnvm::Node::Create(); - n->attrs.op = Op::Get("_default_subgraph_op"); - n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id); + n->attrs.op = Op::Get("_CachedOp"); + n->attrs.name = "_CachedOp" + std::to_string(subgraph_id); n->attrs.subgraphs.push_back(std::make_shared(sym)); + + std::vector > flags{{"static_alloc", "true"}}; + n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags)); + return n; } virtual SubgraphSelectorPtr CreateSubgraphSelector() const { From ded2cc2a1967158350d1db67bb3cd2bb1c81ab4e Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 10 Sep 2018 18:25:05 +0000 Subject: [PATCH 10/16] remove default subgraph op. --- src/operator/subgraph/default_subgraph_op.cc | 112 ------------------ src/operator/subgraph/default_subgraph_op.cu | 44 ------- .../subgraph/default_subgraph_property.cc | 2 +- 3 files changed, 1 insertion(+), 157 deletions(-) delete mode 100644 src/operator/subgraph/default_subgraph_op.cc delete mode 100644 src/operator/subgraph/default_subgraph_op.cu diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc deleted file mode 100644 index d5fb7ee2db61..000000000000 --- a/src/operator/subgraph/default_subgraph_op.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ - -#include -#include "./common.h" -#include "../../imperative/imperative_utils.h" -#include "../../imperative/cached_op.h" - -namespace mxnet { -namespace op { - -#define DEBUG_SUBGRAPH 0 - -class DefaultSubgraphOperator { - public: - explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) { - subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}, - {"static_shape", "true"}})); - } - - void Forward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - void Backward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - LOG(FATAL) << "Not implemented"; - } - - private: - nnvm::Symbol subgraph_sym_; - CachedOpPtr subgraph_exec_; -}; - -void DefaultSubgraphOperator::Forward(const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - std::vector tmp_inputs = inputs; - std::vector input_ptrs; - input_ptrs.reserve(inputs.size()); - for (auto& nd : tmp_inputs) { - input_ptrs.push_back(&nd); - } - std::vector tmp_outputs = outputs; - std::vector output_ptrs; - for (auto& nd : tmp_outputs) { - output_ptrs.push_back(&nd); - } -#if DEBUG_SUBGRAPH - for (size_t i = 0; i < inputs.size(); ++i) { - LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version(); - } - for (size_t i = 0; i < outputs.size(); ++i) { - LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version(); - } -#endif - subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs); -} - -OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs, - Context ctx, - const std::vector& in_shapes, - const std::vector& in_types) { - return OpStatePtr::Create(*attrs.subgraphs[0]); -} - -void DefaultSubgraphOpForward(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - DefaultSubgraphOperator& op = state_ptr.get_state(); - op.Forward(ctx, inputs, req, outputs); -} - -NNVM_REGISTER_OP(_default_subgraph_op) -.describe(R"code(_default_subgraph_op)code" ADD_FILELINE) -.set_num_inputs(DefaultSubgraphOpNumInputs) -.set_num_outputs(DefaultSubgraphOpNumOutputs) -.set_attr("FListInputNames", DefaultSubgraphOpListInputs) -.set_attr("FListOutputNames", DefaultSubgraphOpListOutputs) -.set_attr("FCreateOpState", CreateDefaultSubgraphOpState) -.set_attr("FInferShape", DefaultSubgraphOpShape) -.set_attr("FInferType", DefaultSubgraphOpType) -.set_attr("FInferStorageType", DefaultSubgraphOpStorageType) -.set_attr("FStatefulComputeEx", DefaultSubgraphOpForward) -.set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs) -.set_attr("key_var_num_args", "num_args") -.set_attr("FExecType", DefaultSubgraphOpExecType) -.add_argument("data", "NDArray-or-Symbol[]", "input data list"); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu deleted file mode 100644 index 008826b21d71..000000000000 --- a/src/operator/subgraph/default_subgraph_op.cu +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2018 by Contributors - * \file default_subgraph_op.cu - * \brief GPU Implementation of subgraph operations - */ - -#include -#include "./common.h" -#include "../../imperative/imperative_utils.h" -#include "../../imperative/cached_op.h" - -namespace mxnet { -namespace op { - -void DefaultSubgraphOpForward(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - -NNVM_REGISTER_OP(_default_subgraph_op) -.set_attr("FStatefulComputeEx", DefaultSubgraphOpForward); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc index 9109a97e192c..0152344f4d43 100644 --- a/src/operator/subgraph/default_subgraph_property.cc +++ b/src/operator/subgraph/default_subgraph_property.cc @@ -52,7 +52,7 @@ class ContainOpSelector: public SubgraphSelector { /* * This subgraph property finds a subgraph whose nodes have only operators - * within a set. The operators in the subgraph will be executed by _default_subgraph_op. + * within a set. The operators in the subgraph will be executed by _CachedOp. */ class DefaultSubgraphProperty: public SubgraphProperty { public: From 7e722a9b9e6975a6643350bc09d8b803115239f7 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 10 Sep 2018 18:29:24 +0000 Subject: [PATCH 11/16] fix. --- src/imperative/cached_op.cc | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index ca6be8693e81..a7218e9bca64 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1339,10 +1339,7 @@ NNVM_REGISTER_OP(_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op::DefaultSubgraphOpResourceRequest1(op->GetForwardSym()); }) -.set_attr("FExecType", - [](const nnvm::NodeAttrs& attrs) { - return ExecType::kSubgraphExec; - }) +.set_attr("FExecType", DefaultSubgraphOpExecType) .add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) @@ -1364,10 +1361,7 @@ NNVM_REGISTER_OP(_backward_CachedOp) }) .set_attr("FStatefulComputeEx", CachedOpBackward) .set_attr("FStatefulComputeEx", CachedOpBackward) -.set_attr("FExecType", - [](const nnvm::NodeAttrs& attrs) { - return ExecType::kSubgraphExec; - }) +.set_attr("FExecType", DefaultSubgraphOpExecType) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true); From f7b63b40efe52bee8bc4a054f875e612c389dd2a Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 10 Sep 2018 19:05:13 +0000 Subject: [PATCH 12/16] fix. --- src/imperative/cached_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index a7218e9bca64..bda9e750caed 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1339,7 +1339,7 @@ NNVM_REGISTER_OP(_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op::DefaultSubgraphOpResourceRequest1(op->GetForwardSym()); }) -.set_attr("FExecType", DefaultSubgraphOpExecType) +.set_attr("FExecType", op::DefaultSubgraphOpExecType) .add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) @@ -1361,7 +1361,7 @@ NNVM_REGISTER_OP(_backward_CachedOp) }) .set_attr("FStatefulComputeEx", CachedOpBackward) .set_attr("FStatefulComputeEx", CachedOpBackward) -.set_attr("FExecType", DefaultSubgraphOpExecType) +.set_attr("FExecType", op::DefaultSubgraphOpExecType) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true); From 58d485127767ac78aa6baccf1b0d501e8cfc9186 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 14 Sep 2018 11:17:32 -0700 Subject: [PATCH 13/16] rename. --- src/imperative/cached_op.cc | 14 ++++++------ src/operator/subgraph/common.h | 39 +++++++++++++++++----------------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index bda9e750caed..d9bf927b7692 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1307,14 +1307,14 @@ NNVM_REGISTER_OP(_CachedOp) std::vector *in_shapes, std::vector *out_shapes) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpShape1(op->GetForwardSym(), in_shapes, out_shapes); + return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); }) .set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, std::vector *in_types, std::vector *out_types) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpType1(op->GetForwardSym(), in_types, out_types); + return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); }) .set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, @@ -1323,21 +1323,21 @@ NNVM_REGISTER_OP(_CachedOp) std::vector* in_stypes, std::vector* out_stypes) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpStorageType1(op->GetForwardSym(), - dev_mask, dispatch_mode, - in_stypes, out_stypes); + return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), + dev_mask, dispatch_mode, + in_stypes, out_stypes); }) .set_attr("FStatefulComputeEx", CachedOpForward) .set_attr("FStatefulComputeEx", CachedOpForward) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpMutableInputs1(op->GetForwardSym()); + return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); }) .set_attr("FResourceRequest", [](const nnvm::NodeAttrs& attrs) { const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpResourceRequest1(op->GetForwardSym()); + return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); }) .set_attr("FExecType", op::DefaultSubgraphOpExecType) .add_argument("data", "NDArray-or-Symbol[]", "input data list"); diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h index 259a6bb88023..4e1cd66b8b68 100644 --- a/src/operator/subgraph/common.h +++ b/src/operator/subgraph/common.h @@ -49,9 +49,9 @@ inline std::vector DefaultSubgraphOpListOutputs(const nnvm::NodeAtt return sym.ListOutputNames(); } -inline bool DefaultSubgraphOpShape1(const nnvm::Symbol& subgraph_sym, - std::vector *in_shapes, - std::vector *out_shapes) { +inline bool DefaultSubgraphOpShapeHelper(const nnvm::Symbol& subgraph_sym, + std::vector *in_shapes, + std::vector *out_shapes) { using namespace exec; nnvm::Graph g; g.outputs = subgraph_sym.outputs; @@ -96,12 +96,12 @@ inline bool DefaultSubgraphOpShape1(const nnvm::Symbol& subgraph_sym, inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_shapes, std::vector *out_shapes) { - return DefaultSubgraphOpShape1(*attrs.subgraphs[0], in_shapes, out_shapes); + return DefaultSubgraphOpShapeHelper(*attrs.subgraphs[0], in_shapes, out_shapes); } -inline bool DefaultSubgraphOpType1(const nnvm::Symbol& subgraph_sym, - std::vector *in_types, - std::vector *out_types) { +inline bool DefaultSubgraphOpTypeHelper(const nnvm::Symbol& subgraph_sym, + std::vector *in_types, + std::vector *out_types) { nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -144,14 +144,14 @@ inline bool DefaultSubgraphOpType1(const nnvm::Symbol& subgraph_sym, inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs, std::vector *in_types, std::vector *out_types) { - return DefaultSubgraphOpType1(*attrs.subgraphs[0], in_types, out_types); + return DefaultSubgraphOpTypeHelper(*attrs.subgraphs[0], in_types, out_types); } -inline bool DefaultSubgraphOpStorageType1(const nnvm::Symbol& subgraph_sym, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { +inline bool DefaultSubgraphOpStorageTypeHelper(const nnvm::Symbol& subgraph_sym, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { nnvm::Graph g; g.outputs = subgraph_sym.outputs; const auto& idx_g = g.indexed_graph(); @@ -204,15 +204,16 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector* in_stypes, std::vector* out_stypes) { - return DefaultSubgraphOpStorageType1(*attrs.subgraphs[0], dev_mask, dispatch_mode, - in_stypes, out_stypes); + return DefaultSubgraphOpStorageTypeHelper(*attrs.subgraphs[0], dev_mask, dispatch_mode, + in_stypes, out_stypes); } inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) { return ExecType::kSubgraphExec; } -inline std::vector DefaultSubgraphOpMutableInputs1(const nnvm::Symbol& subgraph_sym) { +inline std::vector DefaultSubgraphOpMutableInputsHelper( + const nnvm::Symbol& subgraph_sym) { const std::vector input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll); const std::vector immutable_input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs); @@ -235,10 +236,10 @@ inline std::vector DefaultSubgraphOpMutableInputs1(const nnvm::Symbol& } inline std::vector DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) { - return DefaultSubgraphOpMutableInputs1(*attrs.subgraphs[0]); + return DefaultSubgraphOpMutableInputsHelper(*attrs.subgraphs[0]); } -inline std::vector DefaultSubgraphOpResourceRequest1( +inline std::vector DefaultSubgraphOpResourceRequestHelper( const nnvm::Symbol& subgraph_sym) { static auto& fresource = Op::GetAttr("FResourceRequest"); std::set resource_types; @@ -253,7 +254,7 @@ inline std::vector DefaultSubgraphOpResourceRequest1( } inline std::vector DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) { - return DefaultSubgraphOpResourceRequest1(*attrs.subgraphs[0]); + return DefaultSubgraphOpResourceRequestHelper(*attrs.subgraphs[0]); } } // namespace op From 9f1aa26367713c2f58129738cb598f14287bc0d2 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 14 Sep 2018 11:28:54 -0700 Subject: [PATCH 14/16] add comment. --- src/imperative/cached_op.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index d9bf927b7692..4c1b84cedf1e 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1070,6 +1070,11 @@ void CachedOp::Backward( /* * This is the operator state of CachedOp when CachedOp is used in the symbol * executor. This is different from the OpState returned by CachedOp::Forward. + * The main reason why we need this OpState is that CachedOp and the symbol executor + * maintain OpState differently. The symbol executor generates OpState in advance + * while CachedOp generates OpState after Forward is called. We need this data + * structure to keep the OpState generated by CachedOp::Forward and pass it to + * Backward. */ struct CachedOpActualState { std::shared_ptr op; From d4c95d3f0e5b1971f1212cb99da9f6de195c1b09 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 14 Sep 2018 14:44:03 -0700 Subject: [PATCH 15/16] retrigger From fc001dc2063c712302e3b820af3122212c7ec672 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Sat, 22 Sep 2018 16:54:19 -0700 Subject: [PATCH 16/16] add comments. --- src/imperative/cached_op.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 4c1b84cedf1e..1f115cd64ad5 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1186,7 +1186,12 @@ void CachedOpBackward(const OpStatePtr& state_ptr, orig_is_train = Imperative::Get()->set_is_training(true); else orig_is_train = Imperative::Get()->is_training(); - // TODO(zhengda) is it right to use false here? + // TODO(zhengda) CachedOp supports recording computation when running + // the backward path. This is necessary if we want to support the second-order + // differentiation. However, MXNet operator doesn't have an interface to + // pass a flag to determine whether to record computation inside an operator. + // Let's use false here for now and design a solution when the second-order + // differentiation is supported. s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); Imperative::Get()->set_is_training(orig_is_train); @@ -1195,6 +1200,9 @@ void CachedOpBackward(const OpStatePtr& state_ptr, // The arrays in out_ptrs may be changed by CachedOp. // If it is, we need to copy data back. + // For example, when the inputs and outputs share the same NDArrays, + // the outputs will be replaced by inputs. + // https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385 for (size_t i = 0; i < out_bufs.size(); i++) if (!out_bufs[i].IsSame(outputs[i])) CopyFromTo(out_bufs[i], outputs[i]);