Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Mar 11, 2019
1 parent a4b9802 commit 8756ea3
Showing 1 changed file with 5 additions and 193 deletions.
198 changes: 5 additions & 193 deletions src/operator/control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ class WhileLoopState: public LoopState {
n_iterations(0U),
cond_op(LoopState::MakeSharedOp(cond)),
oi_map(params.func_var_locs.ndim(), -1) {
CHECK_EQ(params.num_out_data, 0) << "unsupported";
const nnvm::Tuple<dim_t> &func_input_locs = params.func_input_locs;
const nnvm::Tuple<dim_t> &func_var_locs = params.func_var_locs;
const nnvm::Tuple<dim_t> &cond_input_locs = params.cond_input_locs;
Expand Down Expand Up @@ -602,14 +603,13 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
}
// func_outputs[num_out_data: ] are new_loop_vars, need to allocate new memory
for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype());
func_outputs[i] = NDArray(outputs[i].ctx(), outputs[i].dtype());
}
state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
// func_inputs on the next step:
// the output (new_loop_vars) will become the new inputs (loop_vars)
for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
size_t j = params.func_var_locs[i - params.num_out_data];
CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape());
func_inputs[j] = func_outputs[i];
int k = state.oi_map[i - params.num_out_data];
if (k != -1) {
Expand All @@ -627,6 +627,9 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
// therefore, we copy func_inputs[:] to outputs[num_out_data: ]
for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
size_t j = params.func_var_locs[i - params.num_out_data];
if (outputs[i].shape().ndim() == 0) {
const_cast<NDArray &>(outputs[i]).Init(func_inputs[j].shape());
}
mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
}
}
Expand Down Expand Up @@ -726,108 +729,6 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
state.Cleanup();
}

static bool WhileLoopShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using mxnet::ShapeVector;
const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
static const std::function<bool(const mxnet::TShape &)> is_udf = is_shape_udf;
// sanity checks
CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args);
CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
CHECK_EQ(attrs.subgraphs.size(), 2U);
CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
// infer shape for cond and func
auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> subg,
ShapeVector *_subg_out,
const nnvm::Tuple<dim_t> &input_locs,
int num_out_data,
bool fill_out_shape) {
// create subg_in
ShapeVector subg_in;
ShapeVector &subg_out = *_subg_out;
extract_by_loc(*in_shape, input_locs, &subg_in);
// create an indexed graph
nnvm::Graph g;
g.outputs = subg->outputs;
const auto& idx = g.indexed_graph();
// get input nodes
const auto &input_nids = idx.input_nodes();
// sanity checks
CHECK_EQ(input_nids.size(), subg_in.size());
CHECK_EQ(g.outputs.size(), subg_out.size());
CHECK_EQ(idx.input_nodes().size(), subg_in.size());
CHECK_EQ(idx.outputs().size(), subg_out.size());
// create empty shapes for inference
ShapeVector shapes(idx.num_node_entries());
// copy subg_in into shapes
for (size_t i = 0; i < subg_in.size(); ++i) {
auto eid = idx.entry_id(input_nids[i], 0);
shapes[eid] = subg_in[i];
}
// copy subg_out into shapes
// note that ndim of out_data is not increased
// because subg is only one step
for (size_t i = 0; i < subg_out.size(); ++i) {
auto eid = idx.entry_id(g.outputs[i]);
shapes[eid] = subg_out[i];
}
// copy done, call InferShape
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
g = exec::InferShape(std::move(g));
// now `shapes' won't be used anymore, use new_shapes instead
const auto& new_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
// copy subg_in back to in_shape
for (size_t i = 0; i < subg_in.size(); ++i) {
auto eid = idx.entry_id(input_nids[i], 0);
auto g_out_shape = new_shapes[eid];
if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
// when the shape is not fully inferred
continue;
}
SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
}
if (!fill_out_shape) {
return true;
}
// copy subg_out back to out_shape
// for results in [0, num_out_data), ndim should increase by 1
for (int i = 0; i < num_out_data; ++i) {
auto eid = idx.entry_id(g.outputs[i]);
auto g_out_shape = new_shapes[eid];
if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
// when the shape is not fully inferred
continue;
}
auto out = mxnet::TShape(g_out_shape.ndim() + 1);
out[0] = params.max_iterations;
for (size_t i = 1; i < out.ndim(); i++)
out[i] = g_out_shape[i - 1];
SHAPE_ASSIGN_CHECK(*out_shape, i, out);
}
// for results in [num_out_data, ...), ndim does not change
for (size_t i = num_out_data; i < g.outputs.size(); ++i) {
auto eid = idx.entry_id(g.outputs[i]);
auto g_out_shape = new_shapes[eid];
if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
// when the shape is not fully inferred
continue;
}
SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
}
return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
};
mxnet::ShapeVector cond_out_shape{mxnet::TShape(1U)}; // this means: [(1, )]
mxnet::ShapeVector func_out_shape(params.num_outputs);
CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false);
CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \
params.func_input_locs, params.num_out_data, true);
CHECK(params.sync_in_out(in_shape, out_shape, is_udf));
return succ_0 && succ_1;
}

static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
Expand Down Expand Up @@ -1033,93 +934,6 @@ static void CondGradComputeExCPU(const OpStatePtr& state_ptr,
loop_state.Cleanup();
}

static bool CondShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using mxnet::ShapeVector;
const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
static const std::function<bool(const mxnet::TShape &)> is_udf = is_shape_udf;
// sanity checks
CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args);
CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
CHECK_EQ(attrs.subgraphs.size(), 3U);
CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size());
// infer shape for cond, then and else
auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> subg,
ShapeVector *_subg_out,
const nnvm::Tuple<dim_t> &input_locs,
bool fill_out_shape) {
// create subg_in
mxnet::ShapeVector subg_in;
mxnet::ShapeVector &subg_out = *_subg_out;
extract_by_loc(*in_shape, input_locs, &subg_in);
// create an indexed graph
nnvm::Graph g;
g.outputs = subg->outputs;
const auto& idx = g.indexed_graph();
// get input nodes
const auto &input_nids = idx.input_nodes();
// sanity checks
CHECK_EQ(input_nids.size(), subg_in.size());
CHECK_EQ(g.outputs.size(), subg_out.size());
CHECK_EQ(idx.input_nodes().size(), subg_in.size());
CHECK_EQ(idx.outputs().size(), subg_out.size());
// create empty shapes for inference
mxnet::ShapeVector shapes(idx.num_node_entries());
// copy subg_in into shapes
for (size_t i = 0; i < subg_in.size(); ++i) {
auto eid = idx.entry_id(input_nids[i], 0);
shapes[eid] = subg_in[i];
}
// copy subg_out into shapes
for (size_t i = 0; i < subg_out.size(); ++i) {
auto eid = idx.entry_id(g.outputs[i]);
shapes[eid] = subg_out[i];
}
// copy done, call InferShape
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
g = exec::InferShape(std::move(g));
// now `shapes' won't be used anymore, use new_shapes instead
const auto& new_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
// copy subg_in back to in_shape
for (size_t i = 0; i < subg_in.size(); ++i) {
auto eid = idx.entry_id(input_nids[i], 0);
auto g_out_shape = new_shapes[eid];
if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
// when the shape is not fully inferred
continue;
}
SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
}
if (!fill_out_shape) {
return true;
}
// copy subg_out back to out_shape
for (size_t i = 0; i < g.outputs.size(); ++i) {
auto eid = idx.entry_id(g.outputs[i]);
auto g_out_shape = new_shapes[eid];
if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
// when the shape is not fully inferred
continue;
}
SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
}
return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
};
ShapeVector cond_out_shape{mxnet::TShape(1U)}; // this means: [(1, )]
ShapeVector then_out_shape(params.num_outputs);
ShapeVector else_out_shape(params.num_outputs);
bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \
params.cond_input_locs, false);
bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \
params.then_input_locs, true);
bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
params.else_input_locs, true);
sync_out_out(&then_out_shape, &else_out_shape, is_udf);
return succ_0 && succ_1 && succ_2;
}

static bool CondType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
Expand Down Expand Up @@ -1342,7 +1156,6 @@ NNVM_REGISTER_OP(_while_loop)
})
.set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
.set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
.set_attr<mxnet::FInferShape>("FInferShape", WhileLoopShape)
.set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
Expand Down Expand Up @@ -1405,7 +1218,6 @@ NNVM_REGISTER_OP(_cond)
})
.set_attr<nnvm::FGradient>("FGradient", CondGradient)
.set_attr<FCreateOpState>("FCreateOpState", CreateCondState)
.set_attr<mxnet::FInferShape>("FInferShape", CondShape)
.set_attr<nnvm::FInferType>("FInferType", CondType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondComputeExCPU)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
Expand Down

0 comments on commit 8756ea3

Please sign in to comment.