From 8756ea3bbbb0433b8dbcf7e00f50bc8c179ce38a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 11 Mar 2019 14:20:23 -0700 Subject: [PATCH] Initial commit --- src/operator/control_flow.cc | 198 +---------------------------------- 1 file changed, 5 insertions(+), 193 deletions(-) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index ac6fea7c143b..039424f800f0 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -538,6 +538,7 @@ class WhileLoopState: public LoopState { n_iterations(0U), cond_op(LoopState::MakeSharedOp(cond)), oi_map(params.func_var_locs.ndim(), -1) { + CHECK_EQ(params.num_out_data, 0) << "unsupported"; const nnvm::Tuple &func_input_locs = params.func_input_locs; const nnvm::Tuple &func_var_locs = params.func_var_locs; const nnvm::Tuple &cond_input_locs = params.cond_input_locs; @@ -602,14 +603,13 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, } // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new memory for (size_t i = params.num_out_data; i < outputs.size(); ++i) { - func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + func_outputs[i] = NDArray(outputs[i].ctx(), outputs[i].dtype()); } state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad); // func_inputs on the next step: // the output (new_loop_vars) will become the new inputs (loop_vars) for (size_t i = params.num_out_data; i < outputs.size(); ++i) { size_t j = params.func_var_locs[i - params.num_out_data]; - CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape()); func_inputs[j] = func_outputs[i]; int k = state.oi_map[i - params.num_out_data]; if (k != -1) { @@ -627,6 +627,9 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, // therefore, we copy func_inputs[:] to outputs[num_out_data: ] for (size_t i = params.num_out_data; i < outputs.size(); ++i) { size_t j = params.func_var_locs[i - params.num_out_data]; + if (outputs[i].shape().ndim() == 0) { + const_cast(outputs[i]).Init(func_inputs[j].shape()); + } mxnet::CopyFromTo(func_inputs[j], &outputs[i]); } } @@ -726,108 +729,6 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, state.Cleanup(); } -static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { - using mxnet::ShapeVector; - const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = is_shape_udf; - // sanity checks - CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); - CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); - CHECK_EQ(attrs.subgraphs.size(), 2U); - CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); - // infer shape for cond and func - auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, - ShapeVector *_subg_out, - const nnvm::Tuple &input_locs, - int num_out_data, - bool fill_out_shape) { - // create subg_in - ShapeVector subg_in; - ShapeVector &subg_out = *_subg_out; - extract_by_loc(*in_shape, input_locs, &subg_in); - // create an indexed graph - nnvm::Graph g; - g.outputs = subg->outputs; - const auto& idx = g.indexed_graph(); - // get input nodes - const auto &input_nids = idx.input_nodes(); - // sanity checks - CHECK_EQ(input_nids.size(), subg_in.size()); - CHECK_EQ(g.outputs.size(), subg_out.size()); - CHECK_EQ(idx.input_nodes().size(), subg_in.size()); - CHECK_EQ(idx.outputs().size(), subg_out.size()); - // create empty shapes for inference - ShapeVector shapes(idx.num_node_entries()); - // copy subg_in into shapes - for (size_t i = 0; i < subg_in.size(); ++i) { - auto eid = idx.entry_id(input_nids[i], 0); - shapes[eid] = subg_in[i]; - } - // copy subg_out into shapes - // note that ndim of out_data is not increased - // because subg is only one step - for (size_t i = 0; i < subg_out.size(); ++i) { - auto eid = idx.entry_id(g.outputs[i]); - shapes[eid] = subg_out[i]; - } - // copy done, call InferShape - g.attrs["shape"] = std::make_shared(std::move(shapes)); - g = exec::InferShape(std::move(g)); - // now `shapes' won't be used anymore, use new_shapes instead - const auto& new_shapes = g.GetAttr("shape"); - // copy subg_in back to in_shape - for (size_t i = 0; i < subg_in.size(); ++i) { - auto eid = idx.entry_id(input_nids[i], 0); - auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { - // when the shape is not fully inferred - continue; - } - SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); - } - if (!fill_out_shape) { - return true; - } - // copy subg_out back to out_shape - // for results in [0, num_out_data), ndim should increase by 1 - for (int i = 0; i < num_out_data; ++i) { - auto eid = idx.entry_id(g.outputs[i]); - auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { - // when the shape is not fully inferred - continue; - } - auto out = mxnet::TShape(g_out_shape.ndim() + 1); - out[0] = params.max_iterations; - for (size_t i = 1; i < out.ndim(); i++) - out[i] = g_out_shape[i - 1]; - SHAPE_ASSIGN_CHECK(*out_shape, i, out); - } - // for results in [num_out_data, ...), ndim does not change - for (size_t i = num_out_data; i < g.outputs.size(); ++i) { - auto eid = idx.entry_id(g.outputs[i]); - auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { - // when the shape is not fully inferred - continue; - } - SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); - } - return g.GetAttr("shape_num_unknown_nodes") == 0; - }; - mxnet::ShapeVector cond_out_shape{mxnet::TShape(1U)}; // this means: [(1, )] - mxnet::ShapeVector func_out_shape(params.num_outputs); - CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); - bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); - CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); - bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \ - params.func_input_locs, params.num_out_data, true); - CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); - return succ_0 && succ_1; -} - static bool WhileLoopType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { const WhileLoopParam& params = nnvm::get(attrs.parsed); @@ -1033,93 +934,6 @@ static void CondGradComputeExCPU(const OpStatePtr& state_ptr, loop_state.Cleanup(); } -static bool CondShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { - using mxnet::ShapeVector; - const CondParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = is_shape_udf; - // sanity checks - CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args); - CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); - CHECK_EQ(attrs.subgraphs.size(), 3U); - CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); - CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); - // infer shape for cond, then and else - auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, - ShapeVector *_subg_out, - const nnvm::Tuple &input_locs, - bool fill_out_shape) { - // create subg_in - mxnet::ShapeVector subg_in; - mxnet::ShapeVector &subg_out = *_subg_out; - extract_by_loc(*in_shape, input_locs, &subg_in); - // create an indexed graph - nnvm::Graph g; - g.outputs = subg->outputs; - const auto& idx = g.indexed_graph(); - // get input nodes - const auto &input_nids = idx.input_nodes(); - // sanity checks - CHECK_EQ(input_nids.size(), subg_in.size()); - CHECK_EQ(g.outputs.size(), subg_out.size()); - CHECK_EQ(idx.input_nodes().size(), subg_in.size()); - CHECK_EQ(idx.outputs().size(), subg_out.size()); - // create empty shapes for inference - mxnet::ShapeVector shapes(idx.num_node_entries()); - // copy subg_in into shapes - for (size_t i = 0; i < subg_in.size(); ++i) { - auto eid = idx.entry_id(input_nids[i], 0); - shapes[eid] = subg_in[i]; - } - // copy subg_out into shapes - for (size_t i = 0; i < subg_out.size(); ++i) { - auto eid = idx.entry_id(g.outputs[i]); - shapes[eid] = subg_out[i]; - } - // copy done, call InferShape - g.attrs["shape"] = std::make_shared(std::move(shapes)); - g = exec::InferShape(std::move(g)); - // now `shapes' won't be used anymore, use new_shapes instead - const auto& new_shapes = g.GetAttr("shape"); - // copy subg_in back to in_shape - for (size_t i = 0; i < subg_in.size(); ++i) { - auto eid = idx.entry_id(input_nids[i], 0); - auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { - // when the shape is not fully inferred - continue; - } - SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); - } - if (!fill_out_shape) { - return true; - } - // copy subg_out back to out_shape - for (size_t i = 0; i < g.outputs.size(); ++i) { - auto eid = idx.entry_id(g.outputs[i]); - auto g_out_shape = new_shapes[eid]; - if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { - // when the shape is not fully inferred - continue; - } - SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); - } - return g.GetAttr("shape_num_unknown_nodes") == 0; - }; - ShapeVector cond_out_shape{mxnet::TShape(1U)}; // this means: [(1, )] - ShapeVector then_out_shape(params.num_outputs); - ShapeVector else_out_shape(params.num_outputs); - bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \ - params.cond_input_locs, false); - bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \ - params.then_input_locs, true); - bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \ - params.else_input_locs, true); - sync_out_out(&then_out_shape, &else_out_shape, is_udf); - return succ_0 && succ_1 && succ_2; -} - static bool CondType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { @@ -1342,7 +1156,6 @@ NNVM_REGISTER_OP(_while_loop) }) .set_attr("FGradient", WhileLoopGradient) .set_attr("FCreateOpState", CreateWhileLoopState) -.set_attr("FInferShape", WhileLoopShape) .set_attr("FInferType", WhileLoopType) .set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) .set_attr("FExecType", [](const NodeAttrs& attrs) { @@ -1405,7 +1218,6 @@ NNVM_REGISTER_OP(_cond) }) .set_attr("FGradient", CondGradient) .set_attr("FCreateOpState", CreateCondState) -.set_attr("FInferShape", CondShape) .set_attr("FInferType", CondType) .set_attr("FStatefulComputeEx", CondComputeExCPU) .set_attr("FExecType", [](const NodeAttrs& attrs) {