From 191fdc690edddb201e2d3cbaf1236a12ea073507 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Thu, 13 Jun 2019 11:08:15 +0800 Subject: [PATCH] Improve static cached_op optimization (#15187) * Fix cached op Change-Id: If90c6f0997548ffd5daa67cc18bab7405f24213b * Fix UT * trigger --- src/imperative/cached_op.cc | 2 +- src/imperative/imperative_utils.h | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 07c7871c6045..b49cad4eb77b 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -290,7 +290,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); - if (contain_dynamic_shape && erase_result) { + if (!config_.static_shape && erase_result) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 5cb805c5abcb..4e63e4d2b3be 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -595,7 +595,10 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, *contain_unknown = false; } nnvm::Graph& g = *p_g; - if (g.attrs.count("shape")) { + if (use_inputs) { + if (g.attrs.count("shape_inputs") && g.GetAttr("shape_inputs") == shapes) + return true; + } else if (g.attrs.count("shape")) { const auto& prev_shapes = g.GetAttr("shape"); if (prev_shapes.size() == shapes.size()) { bool match = true;