From 93bccbd86cd2c05f3636f7f90ce9f0766d33af46 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 19 Apr 2019 10:49:41 +0800 Subject: [PATCH 1/4] Fix cached_op --- src/executor/attach_op_execs_pass.cc | 5 +++-- src/executor/exec_pass.h | 9 ++++++++- src/imperative/cached_op.cc | 6 ++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index b04d132ee9f6..8508500010f0 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor { ExecType exec_type_; }; -void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { +void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) { using nnvm::DTypeVector; using mxnet::ShapeVector; using nnvm::FMutateInputs; @@ -302,6 +302,7 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { OpStatePtr state = fcreate_op_state[op]( inode.source->attrs, vctx[i], ishape, itype); + if (p_state) p_state->at(i) = state; FStatefulComputeEx fcompute_ex = common::GetFCompute( op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx @@ -359,7 +360,7 @@ Graph AttachOpExecs(Graph g) { const auto& idx = g.indexed_graph(); OpExecVector ret(idx.num_nodes()); for (size_t i = 0; i < idx.num_nodes(); ++i) { - CreateOpExecs(g, &ret, i); + CreateOpExecs(g, &ret, nullptr, i); } g.attrs["op_execs"] = std::make_shared(ret); return g; diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index dd4132301346..7e5130f4921c 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -98,6 +98,12 @@ class OpExecutor { */ using OpExecVector = std::vector >; +/*! + * \brief per node vector of operator states. + * \note stored under attribute "op_states" + */ +using OpStateVector = std::vector; + /*! * \brief per node context vector * \node stored under "context" @@ -115,9 +121,10 @@ using DevMaskVector = std::vector; * * \param g input graph * \param p_ret OpExecVector for input and output + * \param p_state OpStateVector if it has. * \param i the id of the node */ -void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i); +void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i); /*! * \brief Attach OpExecutor to the graph attributes. * diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index c9215c5c8827..0d7ebb5ad71b 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); - if (erase_result) { + if (erase_result && contain_dynamic_shape) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); } @@ -603,7 +603,7 @@ void CachedOp::StaticInitExec( } } else { for (size_t i = start_nid; i < end_nid; ++i) { - exec::CreateOpExecs(g, &state.execs, i); + exec::CreateOpExecs(g, &state.execs, &state.op_states, i); } exec::AttachOpResources(g, state.execs, start_nid, end_nid); @@ -705,8 +705,6 @@ void CachedOp::StaticRunOps( arg_shapes.emplace_back(ndinput->shape()); arg_dtypes.emplace_back(ndinput->dtype()); } - state.op_states[i] = createop[node.source->op()]( - node.source->attrs, default_ctx, arg_shapes, arg_dtypes); Imperative::Get()->InvokeOp( default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state.op_states[i]); From 08b082c2a9515dd596fa70f5319e8f28199a220e Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 24 Apr 2019 21:27:05 +0800 Subject: [PATCH 2/4] try to fix ci --- src/imperative/cached_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 0d7ebb5ad71b..92452c03479a 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); - if (erase_result && contain_dynamic_shape) { + if (erase_result) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); } @@ -908,7 +908,7 @@ OpStatePtr CachedOp::Forward( OpStatePtr op_state; try { - if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) { + if (config_.is_dynamic && CheckDynamicShapeExists(default_ctx, inputs, true)) { config_.is_dynamic = true; config_.static_alloc = false; op_state = DynamicForward(default_ctx, inputs, outputs, true); From 2849633b6bcef43da1f7a093486d40883bc94ca6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 25 Apr 2019 10:01:42 +0800 Subject: [PATCH 3/4] Fix CI --- src/executor/attach_op_execs_pass.cc | 5 ++++- src/imperative/cached_op.cc | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 8508500010f0..8f47bc29db13 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -302,7 +302,10 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, OpStatePtr state = fcreate_op_state[op]( inode.source->attrs, vctx[i], ishape, itype); - if (p_state) p_state->at(i) = state; + if (p_state) { + CHECK_GT(p_state->size(), i); + p_state->at(i) = state; + } FStatefulComputeEx fcompute_ex = common::GetFCompute( op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 92452c03479a..7a5ed21432d3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); - if (erase_result) { + if (contain_dynamic_shape && erase_result) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); } @@ -705,6 +705,10 @@ void CachedOp::StaticRunOps( arg_shapes.emplace_back(ndinput->shape()); arg_dtypes.emplace_back(ndinput->dtype()); } + if (!state.op_states[i]) { + state.op_states[i] = + createop[node.source->op()](node.source->attrs, default_ctx, arg_shapes, arg_dtypes); + } Imperative::Get()->InvokeOp( default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state.op_states[i]); @@ -908,7 +912,7 @@ OpStatePtr CachedOp::Forward( OpStatePtr op_state; try { - if (config_.is_dynamic && CheckDynamicShapeExists(default_ctx, inputs, true)) { + if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) { config_.is_dynamic = true; config_.static_alloc = false; op_state = DynamicForward(default_ctx, inputs, outputs, true); From 2e38a0aed9c26b37d39669dd7ce4e722db92b9e0 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 26 Apr 2019 09:53:10 +0800 Subject: [PATCH 4/4] Fix ci --- src/imperative/imperative_utils.h | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 9d4e4bd15a37..5c9706834b2d 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -595,23 +595,21 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, *contain_unknown = false; } nnvm::Graph& g = *p_g; - if (use_inputs) { - if (g.attrs.count("shape_inputs") && - g.GetAttr("shape_inputs") == shapes) return true; - } else if (g.attrs.count("shape")) { + if (g.attrs.count("shape")) { const auto& prev_shapes = g.GetAttr("shape"); - CHECK_EQ(prev_shapes.size(), shapes.size()); - bool match = true; - for (size_t i = 0; i < shapes.size(); ++i) { - if (i == entry_range.first) { - i = entry_range.second; - if (i >= shapes.size()) break; + if (prev_shapes.size() == shapes.size()) { + bool match = true; + for (size_t i = 0; i < shapes.size(); ++i) { + if (i == entry_range.first) { + i = entry_range.second; + if (i >= shapes.size()) break; + } + if (shapes[i] == prev_shapes[i]) continue; + match = false; + break; } - if (shapes[i] == prev_shapes[i]) continue; - match = false; - break; + if (match) return true; } - if (match) return true; } g.attrs.erase("shape"); g.attrs.erase("shape_inputs");