diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 07c7871c6045..b49cad4eb77b 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -290,7 +290,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); - if (contain_dynamic_shape && erase_result) { + if (!config_.static_shape && erase_result) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 5cb805c5abcb..4e63e4d2b3be 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -595,7 +595,10 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, *contain_unknown = false; } nnvm::Graph& g = *p_g; - if (g.attrs.count("shape")) { + if (use_inputs) { + if (g.attrs.count("shape_inputs") && g.GetAttr("shape_inputs") == shapes) + return true; + } else if (g.attrs.count("shape")) { const auto& prev_shapes = g.GetAttr("shape"); if (prev_shapes.size() == shapes.size()) { bool match = true;