Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix for wrong reqs set after switching from training to inference #16553

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ bool CachedOp::SetForwardGraph(

auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >(
recording ? "full_ref_count" : "forward_ref_count"));
recording ? "full_ref_count" : "forward_ref_count"),
recording ? "full" : "forward");
g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
std::make_shared<dmlc::any>(std::move(mem_plan));

Expand Down Expand Up @@ -492,6 +493,7 @@ bool CachedOp::SetBackwardGraph(

auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
"backward",
{num_forward_nodes, idx.num_nodes()},
{num_forward_entries, idx.num_node_entries()},
detect_inplace_addto);
Expand Down Expand Up @@ -526,9 +528,10 @@ void CachedOp::StaticAllocMemory(
const auto& default_ctx = state.context;
nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
const auto& idx = g.indexed_graph();
const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
const std::string mem_plan_type = keep_fwd ? "backward" : (recording ? "full" : "forward");
const auto& storage_inplace_attr = "storage_inplace_index_" + mem_plan_type;
const auto& vstorage_inplace = g.GetAttr<std::vector<int> >(storage_inplace_attr);
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(mem_plan_type + "_mem_plan");
std::vector<int> addto_entry;
if (g.attrs.count("addto_entry")) {
addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
Expand Down
2 changes: 2 additions & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ inline MemoryPlanVector PlanMemory(
nnvm::Graph* p_g,
nnvm::StorageVector&& storage,
const std::vector<uint32_t>& ref_count,
const std::string& mem_plan_type,
const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
bool detect_inplace_addto = false) {
Expand All @@ -831,6 +832,7 @@ inline MemoryPlanVector PlanMemory(
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
g.attrs["storage_inplace_index_" + mem_plan_type] = std::make_shared<any>(storage_inplace);
const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
uint32_t entry_start = entry_range.first;
uint32_t entry_end =
Expand Down