Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix for wrong reqs set after switching from training to inference #16553

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 48 additions & 28 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ DMLC_REGISTER_PARAMETER(CachedOpConfig);

constexpr uint32_t kEidNotExist = std::numeric_limits<uint32_t>::max();

const char CachedOp::FULL[] = "full";
const char CachedOp::FORWARD[] = "forward";
const char CachedOp::BACKWARD[] = "backward";
const char CachedOp::REF_COUNT[] = "ref_count";
const char CachedOp::MEM_PLAN[] = "mem_plan";
const char CachedOp::STORAGE_PLAN[] = "storage_plan";

namespace {

std::string AddPrefix(const std::string& prefix,
const std::string& s) {
return prefix + "_" + s;
}

} // namespace

struct CachedOp::GraphInfo {
nnvm::Graph fwd_graph;
nnvm::Graph full_graph;
Expand Down Expand Up @@ -136,7 +152,7 @@ CachedOp::CachedOp(
for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
}

fwd_graph_.attrs["forward_ref_count"] =
fwd_graph_.attrs[AddPrefix(FORWARD, REF_COUNT)] =
std::make_shared<dmlc::any>(std::move(ref_count));

inlining_ = !config_.static_alloc &&
Expand Down Expand Up @@ -201,9 +217,9 @@ CachedOp::CachedOp(
}
}

auto full_ref_count = fwd_graph_.GetAttr<std::vector<uint32_t> >("forward_ref_count");
auto full_ref_count = fwd_graph_.GetAttr<std::vector<uint32_t> >(AddPrefix(FORWARD, REF_COUNT));
for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += ref_count[i];
fwd_graph_.attrs["full_ref_count"] =
fwd_graph_.attrs[AddPrefix(FULL, REF_COUNT)] =
std::make_shared<dmlc::any>(std::move(full_ref_count));

size_t num_forward_inputs = num_inputs();
Expand Down Expand Up @@ -336,14 +352,15 @@ bool CachedOp::SetForwardGraph(

// When dynmaic shape exists, it is not feasible to plan memory ahead of time
if (contain_dynamic_shape) {
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
g.attrs.erase(AddPrefix(FORWARD, MEM_PLAN));
g.attrs.erase(AddPrefix(FULL, MEM_PLAN));
return false;
}
const std::string& prefix = recording ? FULL : FORWARD;
if (!match) {
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
} else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
g.attrs.erase(AddPrefix(FORWARD, MEM_PLAN));
g.attrs.erase(AddPrefix(FULL, MEM_PLAN));
} else if (g.attrs.count(AddPrefix(prefix, MEM_PLAN))) {
return true;
}

Expand All @@ -363,9 +380,9 @@ bool CachedOp::SetForwardGraph(
}

auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >(
recording ? "full_ref_count" : "forward_ref_count"));
g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >(AddPrefix(prefix, REF_COUNT)),
AddPrefix(prefix, STORAGE_PLAN));
g.attrs[AddPrefix(prefix, MEM_PLAN)] =
std::make_shared<dmlc::any>(std::move(mem_plan));

return false;
Expand Down Expand Up @@ -432,7 +449,7 @@ bool CachedOp::SetBackwardGraph(
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();

if (!g.attrs.count("backward_ref_count")) {
if (!g.attrs.count(AddPrefix(BACKWARD, REF_COUNT))) {
std::vector<uint32_t> ref_count(idx.num_node_entries(), 0);
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
Expand All @@ -443,7 +460,7 @@ bool CachedOp::SetBackwardGraph(
}
}
for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
g.attrs["backward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count));
g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared<dmlc::any>(std::move(ref_count));
}

auto shapes = info->fwd_graph.GetAttr<mxnet::ShapeVector>("shape");
Expand Down Expand Up @@ -476,8 +493,8 @@ bool CachedOp::SetBackwardGraph(
false, node_range, entry_range);

if (!match) {
g.attrs.erase("backward_mem_plan");
} else if (g.attrs.count("backward_mem_plan")) {
g.attrs.erase(AddPrefix(BACKWARD, MEM_PLAN));
} else if (g.attrs.count(AddPrefix(BACKWARD, MEM_PLAN))) {
return true;
}

Expand All @@ -491,11 +508,13 @@ bool CachedOp::SetBackwardGraph(
for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;

auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
&g, std::move(storage),
g.GetAttr<std::vector<uint32_t> >(AddPrefix(BACKWARD, REF_COUNT)),
AddPrefix(BACKWARD, STORAGE_PLAN),
{num_forward_nodes, idx.num_nodes()},
{num_forward_entries, idx.num_node_entries()},
detect_inplace_addto);
g.attrs["backward_mem_plan"] = std::make_shared<dmlc::any>(std::move(mem_plan));
g.attrs[AddPrefix(BACKWARD, MEM_PLAN)] = std::make_shared<dmlc::any>(std::move(mem_plan));

return false;
}
Expand Down Expand Up @@ -526,9 +545,10 @@ void CachedOp::StaticAllocMemory(
const auto& default_ctx = state.context;
nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
const auto& idx = g.indexed_graph();
const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
const std::string& graph_type = keep_fwd ? BACKWARD : (recording ? FULL : FORWARD);
const auto& storage_plan_attr = AddPrefix(graph_type, STORAGE_PLAN);
const auto& storage_plan = g.GetAttr<std::vector<int> >(storage_plan_attr);
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(AddPrefix(graph_type, MEM_PLAN));
std::vector<int> addto_entry;
if (g.attrs.count("addto_entry")) {
addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
Expand Down Expand Up @@ -558,9 +578,9 @@ void CachedOp::StaticAllocMemory(
for (size_t i = start_eid; i < end_eid; ++i) {
if (addto_entry.size() && addto_entry[i]) {
state.array_reqs[i] = kAddTo;
} else if (vstorage_inplace[i] >= 0) {
} else if (storage_plan[i] >= 0) {
state.array_reqs[i] = kWriteInplace;
} else if (vstorage_inplace[i] == -2) {
} else if (storage_plan[i] == -2) {
// -2 indicate that the entry is never referenced.
state.array_reqs[i] = kNullOp;
} else {
Expand Down Expand Up @@ -862,17 +882,17 @@ OpStatePtr CachedOp::DynamicForward(
}

// Allocate NDArrays
std::vector<uint32_t> ref_count = g.GetAttr<std::vector<uint32_t> >(
recording ? "full_ref_count" : "forward_ref_count");
const std::string& graph_type = recording ? FULL : FORWARD;
std::vector<uint32_t> ref_count =
g.GetAttr<std::vector<uint32_t> >(AddPrefix(graph_type, REF_COUNT));

std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
if (!use_naive_run) {
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
recording ? "full_mem_plan" : "forward_mem_plan");
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(graph_type, MEM_PLAN));
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -1011,7 +1031,7 @@ void CachedOp::DynamicBackward(
}

// Allocate NDArrays
auto ref_count = g.GetAttr<std::vector<uint32_t> >("backward_ref_count");
auto ref_count = g.GetAttr<std::vector<uint32_t> >(AddPrefix(BACKWARD, REF_COUNT));
if (retain_graph) {
for (size_t i = 0; i < num_forward_entries; ++i) ++ref_count[i];
}
Expand All @@ -1027,7 +1047,7 @@ void CachedOp::DynamicBackward(
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}

const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(BACKWARD, MEM_PLAN));
AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);

Expand Down
7 changes: 7 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ class CachedOp {
void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all = false);

static const char FULL[];
static const char FORWARD[];
static const char BACKWARD[];
static const char REF_COUNT[];
static const char MEM_PLAN[];
static const char STORAGE_PLAN[];

private:
struct GraphInfo;
struct DynamicRuntime;
Expand Down
2 changes: 2 additions & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ inline MemoryPlanVector PlanMemory(
nnvm::Graph* p_g,
nnvm::StorageVector&& storage,
const std::vector<uint32_t>& ref_count,
const std::string& storage_plan,
const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
bool detect_inplace_addto = false) {
Expand All @@ -831,6 +832,7 @@ inline MemoryPlanVector PlanMemory(
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
g.attrs[storage_plan] = std::make_shared<any>(storage_inplace);
const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
uint32_t entry_start = entry_range.first;
uint32_t entry_end =
Expand Down