diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 58ec4e65b846..8dd0a4deaac3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -285,12 +285,20 @@ bool CachedOp::SetForwardGraph( } bool match = true; - match &= CheckAndInferShape(&g, std::move(shape_inputs), true); + bool contain_dynamic_shape = false; + match &= CheckAndInferShape(&g, std::move(shape_inputs), true, + {0, 0}, {0, 0}, &contain_dynamic_shape); match &= CheckAndInferType(&g, std::move(dtype_inputs), true); exec::DevMaskVector dev_mask(g.indexed_graph().num_nodes(), inputs[0]->ctx().dev_mask()); match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(storage_type_inputs), true); + // When dynmaic shape exists, it is not feasible to plan memory ahead of time + if (contain_dynamic_shape) { + g.attrs.erase("forward_mem_plan"); + g.attrs.erase("full_mem_plan"); + return false; + } if (!match) { g.attrs.erase("forward_mem_plan"); g.attrs.erase("full_mem_plan"); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 59a793ee1b65..3b173c8654a4 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -35,6 +35,7 @@ struct CachedOpConfig : public dmlc::Parameter { uint32_t backward_bulk_size; bool static_alloc; bool static_shape; + bool is_dynamic; nnvm::Tuple data_indices; nnvm::Tuple param_indices; std::string subgraph; @@ -66,6 +67,9 @@ struct CachedOpConfig : public dmlc::Parameter { DMLC_DECLARE_FIELD(subgraph) .set_default(std::string("")) .describe("JSON string of a subgraph."); + DMLC_DECLARE_FIELD(is_dynamic) + .set_default(false) + .describe("Whether the graph contains dynamic shape operators."); } }; diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 6446c37627a9..7113cb2063c1 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -179,7 +179,7 @@ inline void SetShapeType(const Context& ctx, for (size_t i = 0; i < outputs.size(); ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); - if (outputs[i]->is_none()) { + if (outputs[i]->is_none() || outputs[i]->shape().ndim() == 0) { if (is_dynamic_shape_existing) { // once there is dynamic shape somewhere, we could not pre-determine the shape. *outputs[i] = NDArray(ctx, out_types[i]); @@ -566,8 +566,12 @@ inline void PushOperator(const OpStatePtr& state, inline bool CheckAndInferShape(nnvm::Graph* p_g, nnvm::ShapeVector&& shapes, bool use_inputs, std::pair node_range = {0, 0}, - std::pair entry_range = {0, 0}) { + std::pair entry_range = {0, 0}, + bool *contain_unknown = nullptr) { using namespace nnvm; + if (contain_unknown != nullptr) { + *contain_unknown = false; + } nnvm::Graph& g = *p_g; if (use_inputs) { if (g.attrs.count("shape_inputs") && @@ -601,8 +605,11 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, nnvm::ShapeVector&& shapes, g.attrs["shape"] = std::make_shared(std::move(shapes)); g = exec::InferShape(std::move(g)); } - CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0U); - + if (contain_unknown == nullptr) { + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0U); + } else { + *contain_unknown = g.GetAttr("shape_num_unknown_nodes") != 0U; + } return false; } diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 2dcafb6b9494..7fd66bc321bf 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -98,6 +98,8 @@ which stands for the rows in x where the corresonding element in index is non-ze .set_attr("FComputeEx", BooleanMaskForward) .set_attr("FInferStorageType", BooleanMaskStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"data", "index"};}) .add_argument("data", "NDArray-or-Symbol", "Data") .add_argument("index", "NDArray-or-Symbol", "Mask") .add_arguments(BooleanMaskParam::__FIELDS__()); diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index 7a99aedb8602..4b8f63abd4ce 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -222,8 +222,14 @@ void LoopState::Forward(int iter_no, // If an input and an output share the array, the output array will be changed // by CachedOp. We need to copy data to the real output. for (size_t i = 0; i < out_bufs.size(); i++) - if (!out_bufs[i].IsSame(coutputs[i])) + if (!out_bufs[i].IsSame(coutputs[i])) { + // The line below checks whether dynamic shape exists. + // If so, re-initialize the shape. + if (coutputs[i].shape().ndim() == 0) { + const_cast(coutputs[i]).Init(out_bufs[i].shape()); + } CopyFromTo(out_bufs[i], coutputs[i]); + } if (is_recording) { all_inputs.push_back(cinputs); all_outputs.push_back(coutputs); diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index ebf727f0f5a8..c316fca91d95 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -161,7 +161,8 @@ class LoopState { // only static_alloc supports nested call of CachedOp. std::vector > kwargs = { {"inline_limit", "0"}, - {"static_alloc", "1"} + {"static_alloc", "1"}, + {"is_dynamic", "1"} }; return std::make_shared(sym, kwargs); }