Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix for wrong reqs set after switching from training to inference (#1…
Browse files Browse the repository at this point in the history
…6553)

* Debugging reqs

* Move literal strings to const static members

* Fix lint
  • Loading branch information
ptrendx authored and apeforest committed Nov 6, 2019
1 parent 579b9dd commit c4580ae
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
70 changes: 45 additions & 25 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ DMLC_REGISTER_PARAMETER(CachedOpConfig);

constexpr uint32_t kEidNotExist = std::numeric_limits<uint32_t>::max();

const char CachedOp::FULL[] = "full";
const char CachedOp::FORWARD[] = "forward";
const char CachedOp::BACKWARD[] = "backward";
const char CachedOp::REF_COUNT[] = "ref_count";
const char CachedOp::MEM_PLAN[] = "mem_plan";
const char CachedOp::STORAGE_PLAN[] = "storage_plan";

namespace {

std::string AddPrefix(const std::string& prefix,
const std::string& s) {
return prefix + "_" + s;
}

} // namespace

struct CachedOp::GraphInfo {
nnvm::Graph fwd_graph;
nnvm::Graph grad_graph;
Expand Down Expand Up @@ -427,14 +443,15 @@ bool CachedOp::SetForwardGraph(

// When dynmaic shape exists, it is not feasible to plan memory ahead of time
if (contain_dynamic_shape) {
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
g.attrs.erase(AddPrefix(FORWARD, MEM_PLAN));
g.attrs.erase(AddPrefix(FULL, MEM_PLAN));
return false;
}
const std::string& prefix = recording ? FULL : FORWARD;
if (!match) {
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
} else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
g.attrs.erase(AddPrefix(FORWARD, MEM_PLAN));
g.attrs.erase(AddPrefix(FULL, MEM_PLAN));
} else if (g.attrs.count(AddPrefix(prefix, MEM_PLAN))) {
return true;
}

Expand All @@ -454,9 +471,9 @@ bool CachedOp::SetForwardGraph(
}

auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >(
recording ? "full_ref_count" : "forward_ref_count"));
g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >(AddPrefix(prefix, REF_COUNT)),
AddPrefix(prefix, STORAGE_PLAN));
g.attrs[AddPrefix(prefix, MEM_PLAN)] =
std::make_shared<dmlc::any>(std::move(mem_plan));

return false;
Expand Down Expand Up @@ -523,7 +540,7 @@ bool CachedOp::SetBackwardGraph(
size_t num_forward_nodes = info->fwd_graph.indexed_graph().num_nodes();
size_t num_forward_entries = info->fwd_graph.indexed_graph().num_node_entries();

if (!g.attrs.count("backward_ref_count")) {
if (!g.attrs.count(AddPrefix(BACKWARD, REF_COUNT))) {
std::vector<uint32_t> ref_count(idx.num_node_entries(), 0);
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
Expand All @@ -534,7 +551,7 @@ bool CachedOp::SetBackwardGraph(
}
}
for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
g.attrs["backward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count));
g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared<dmlc::any>(std::move(ref_count));
}

auto shapes = info->fwd_graph.GetAttr<mxnet::ShapeVector>("shape");
Expand Down Expand Up @@ -567,8 +584,8 @@ bool CachedOp::SetBackwardGraph(
false, node_range, entry_range);

if (!match) {
g.attrs.erase("backward_mem_plan");
} else if (g.attrs.count("backward_mem_plan")) {
g.attrs.erase(AddPrefix(BACKWARD, MEM_PLAN));
} else if (g.attrs.count(AddPrefix(BACKWARD, MEM_PLAN))) {
return true;
}

Expand All @@ -582,11 +599,13 @@ bool CachedOp::SetBackwardGraph(
for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;

auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
&g, std::move(storage),
g.GetAttr<std::vector<uint32_t> >(AddPrefix(BACKWARD, REF_COUNT)),
AddPrefix(BACKWARD, STORAGE_PLAN),
{num_forward_nodes, idx.num_nodes()},
{num_forward_entries, idx.num_node_entries()},
detect_inplace_addto);
g.attrs["backward_mem_plan"] = std::make_shared<dmlc::any>(std::move(mem_plan));
g.attrs[AddPrefix(BACKWARD, MEM_PLAN)] = std::make_shared<dmlc::any>(std::move(mem_plan));

return false;
}
Expand Down Expand Up @@ -618,9 +637,10 @@ void CachedOp::StaticAllocMemory(
const auto& default_ctx = state.context;
nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
const auto& idx = g.indexed_graph();
const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
const std::string& graph_type = keep_fwd ? BACKWARD : (recording ? FULL : FORWARD);
const auto& storage_plan_attr = AddPrefix(graph_type, STORAGE_PLAN);
const auto& storage_plan = g.GetAttr<std::vector<int> >(storage_plan_attr);
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(AddPrefix(graph_type, MEM_PLAN));
std::vector<int> addto_entry;
if (g.attrs.count("addto_entry")) {
addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
Expand Down Expand Up @@ -650,9 +670,9 @@ void CachedOp::StaticAllocMemory(
for (size_t i = start_eid; i < end_eid; ++i) {
if (addto_entry.size() && addto_entry[i]) {
state.array_reqs[i] = kAddTo;
} else if (vstorage_inplace[i] >= 0) {
} else if (storage_plan[i] >= 0) {
state.array_reqs[i] = kWriteInplace;
} else if (vstorage_inplace[i] == -2) {
} else if (storage_plan[i] == -2) {
// -2 indicate that the entry is never referenced.
state.array_reqs[i] = kNullOp;
} else {
Expand Down Expand Up @@ -954,17 +974,17 @@ OpStatePtr CachedOp::DynamicForward(
}

// Allocate NDArrays
std::vector<uint32_t> ref_count = g.GetAttr<std::vector<uint32_t> >(
recording ? "full_ref_count" : "forward_ref_count");
const std::string& graph_type = recording ? FULL : FORWARD;
std::vector<uint32_t> ref_count =
g.GetAttr<std::vector<uint32_t> >(AddPrefix(graph_type, REF_COUNT));

std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
if (!use_naive_run) {
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
recording ? "full_mem_plan" : "forward_mem_plan");
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(graph_type, MEM_PLAN));
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -1105,7 +1125,7 @@ void CachedOp::DynamicBackward(
}

// Allocate NDArrays
auto ref_count = g.GetAttr<std::vector<uint32_t> >("backward_ref_count");
auto ref_count = g.GetAttr<std::vector<uint32_t> >(AddPrefix(BACKWARD, REF_COUNT));
if (retain_graph) {
for (size_t i = 0; i < num_forward_entries; ++i) ++ref_count[i];
}
Expand All @@ -1121,7 +1141,7 @@ void CachedOp::DynamicBackward(
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}

const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(BACKWARD, MEM_PLAN));
AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);

Expand Down
7 changes: 7 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ class CachedOp {
void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all = false);

static const char FULL[];
static const char FORWARD[];
static const char BACKWARD[];
static const char REF_COUNT[];
static const char MEM_PLAN[];
static const char STORAGE_PLAN[];

private:
struct GraphInfo;
struct DynamicRuntime;
Expand Down
2 changes: 2 additions & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ inline MemoryPlanVector PlanMemory(
nnvm::Graph* p_g,
nnvm::StorageVector&& storage,
const std::vector<uint32_t>& ref_count,
const std::string& storage_plan,
const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
bool detect_inplace_addto = false) {
Expand All @@ -831,6 +832,7 @@ inline MemoryPlanVector PlanMemory(
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
g.attrs[storage_plan] = std::make_shared<any>(storage_inplace);
const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
uint32_t entry_start = entry_range.first;
uint32_t entry_end =
Expand Down

0 comments on commit c4580ae

Please sign in to comment.