Skip to content

Commit

Permalink
Improve static cached_op optimization (apache#15187)
Browse files Browse the repository at this point in the history
* Fix cached op

Change-Id: If90c6f0997548ffd5daa67cc18bab7405f24213b

* Fix UT

* trigger
  • Loading branch information
xinyu-intel authored and haohuw committed Jun 23, 2019
1 parent f9f37ee commit 1adf79c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
if (contain_dynamic_shape && erase_result) {
if (!config_.static_shape && erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
Expand Down
5 changes: 4 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,10 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes,
*contain_unknown = false;
}
nnvm::Graph& g = *p_g;
if (g.attrs.count("shape")) {
if (use_inputs) {
if (g.attrs.count("shape_inputs") && g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes)
return true;
} else if (g.attrs.count("shape")) {
const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
if (prev_shapes.size() == shapes.size()) {
bool match = true;
Expand Down

0 comments on commit 1adf79c

Please sign in to comment.