Skip to content

Commit

Permalink
[MXNET-1315] Add checks for dynamic-shaped operators in CachedOp (apa…
Browse files Browse the repository at this point in the history
…che#14018)

* Initial commit

* Try this

* Boy next door!

* Add comments per discussion with Da

* Try this

* boy try this

* change the boss of the gym
  • Loading branch information
junrushao authored and vdantu committed Mar 31, 2019
1 parent 3dc29b5 commit 85a083a
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 7 deletions.
10 changes: 9 additions & 1 deletion src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,20 @@ bool CachedOp::SetForwardGraph(
}

bool match = true;
match &= CheckAndInferShape(&g, std::move(shape_inputs), true);
bool contain_dynamic_shape = false;
match &= CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0}, &contain_dynamic_shape);
match &= CheckAndInferType(&g, std::move(dtype_inputs), true);
exec::DevMaskVector dev_mask(g.indexed_graph().num_nodes(), inputs[0]->ctx().dev_mask());
match &= CheckAndInferStorageType(&g, std::move(dev_mask),
std::move(storage_type_inputs), true);

// When dynmaic shape exists, it is not feasible to plan memory ahead of time
if (contain_dynamic_shape) {
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
return false;
}
if (!match) {
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
Expand Down
4 changes: 4 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
uint32_t backward_bulk_size;
bool static_alloc;
bool static_shape;
bool is_dynamic;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
std::string subgraph;
Expand Down Expand Up @@ -66,6 +67,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
DMLC_DECLARE_FIELD(subgraph)
.set_default(std::string(""))
.describe("JSON string of a subgraph.");
DMLC_DECLARE_FIELD(is_dynamic)
.set_default(false)
.describe("Whether the graph contains dynamic shape operators.");
}
};

Expand Down
15 changes: 11 additions & 4 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ inline void SetShapeType(const Context& ctx,

for (size_t i = 0; i < outputs.size(); ++i) {
NDArrayStorageType storage_type = static_cast<NDArrayStorageType>(out_storage_types[i]);
if (outputs[i]->is_none()) {
if (outputs[i]->is_none() || outputs[i]->shape().ndim() == 0) {
if (is_dynamic_shape_existing) {
// once there is dynamic shape somewhere, we could not pre-determine the shape.
*outputs[i] = NDArray(ctx, out_types[i]);
Expand Down Expand Up @@ -566,8 +566,12 @@ inline void PushOperator(const OpStatePtr& state,
inline bool CheckAndInferShape(nnvm::Graph* p_g, nnvm::ShapeVector&& shapes,
bool use_inputs,
std::pair<uint32_t, uint32_t> node_range = {0, 0},
std::pair<uint32_t, uint32_t> entry_range = {0, 0}) {
std::pair<uint32_t, uint32_t> entry_range = {0, 0},
bool *contain_unknown = nullptr) {
using namespace nnvm;
if (contain_unknown != nullptr) {
*contain_unknown = false;
}
nnvm::Graph& g = *p_g;
if (use_inputs) {
if (g.attrs.count("shape_inputs") &&
Expand Down Expand Up @@ -601,8 +605,11 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, nnvm::ShapeVector&& shapes,
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
g = exec::InferShape(std::move(g));
}
CHECK_EQ(g.GetAttr<size_t>("shape_num_unknown_nodes"), 0U);

if (contain_unknown == nullptr) {
CHECK_EQ(g.GetAttr<size_t>("shape_num_unknown_nodes"), 0U);
} else {
*contain_unknown = g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U;
}
return false;
}

Expand Down
2 changes: 2 additions & 0 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ which stands for the rows in x where the corresonding element in index is non-ze
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"})
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "index"};})
.add_argument("data", "NDArray-or-Symbol", "Data")
.add_argument("index", "NDArray-or-Symbol", "Mask")
.add_arguments(BooleanMaskParam::__FIELDS__());
Expand Down
8 changes: 7 additions & 1 deletion src/operator/subgraph_op_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,14 @@ void LoopState::Forward(int iter_no,
// If an input and an output share the array, the output array will be changed
// by CachedOp. We need to copy data to the real output.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(coutputs[i]))
if (!out_bufs[i].IsSame(coutputs[i])) {
// The line below checks whether dynamic shape exists.
// If so, re-initialize the shape.
if (coutputs[i].shape().ndim() == 0) {
const_cast<NDArray &>(coutputs[i]).Init(out_bufs[i].shape());
}
CopyFromTo(out_bufs[i], coutputs[i]);
}
if (is_recording) {
all_inputs.push_back(cinputs);
all_outputs.push_back(coutputs);
Expand Down
3 changes: 2 additions & 1 deletion src/operator/subgraph_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class LoopState {
// only static_alloc supports nested call of CachedOp.
std::vector<std::pair<std::string, std::string> > kwargs = {
{"inline_limit", "0"},
{"static_alloc", "1"}
{"static_alloc", "1"},
{"is_dynamic", "1"}
};
return std::make_shared<CachedOp>(sym, kwargs);
}
Expand Down

0 comments on commit 85a083a

Please sign in to comment.