Skip to content

Commit

Permalink
Allow input reordering duing Gluon / CachedOp graph transformations (a…
Browse files Browse the repository at this point in the history
…pache#17949)

* Initial commit of input reordering in Gluon

* Add test for Gluon input reorder

* Fix backward in CachedOp for input reordering

* Fix test_input_reorder for backward pass

* Fix merge error in NaiveCachedOp

* Include correct header for std::iota

Co-authored-by: Vladimir Cherepanov <[email protected]>
  • Loading branch information
2 people authored and AntiZpvoh committed Jul 6, 2020
1 parent f74290b commit cae5138
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 40 deletions.
58 changes: 31 additions & 27 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,9 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
auto& state = state_ptr.get_state<CachedOpState>();

nnvm::Graph& g = state.info.fwd_graph;
ShapeVector shape_inputs;
shape_inputs.reserve(inputs.size());
for (auto input : inputs) {
shape_inputs.emplace_back(input->shape());
ShapeVector shape_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
shape_inputs[i] = inputs[state.info.input_map[i]]->shape();
}
// We leverage the shape inference pass to detect whether dynamic shape exists.
// If so, the pass will fail with `contain_dynamic_shape = true`,
Expand All @@ -176,16 +175,13 @@ bool CachedOp::SetForwardGraph(
CHECK_EQ(inputs.size(), num_inputs());
nnvm::Graph& g = info->fwd_graph;

ShapeVector shape_inputs;
DTypeVector dtype_inputs;
StorageTypeVector storage_type_inputs;
shape_inputs.reserve(inputs.size());
dtype_inputs.reserve(inputs.size());
storage_type_inputs.reserve(inputs.size());
for (auto input : inputs) {
shape_inputs.emplace_back(input->shape());
dtype_inputs.emplace_back(input->dtype());
storage_type_inputs.emplace_back(input->storage_type());
ShapeVector shape_inputs(inputs.size());
DTypeVector dtype_inputs(inputs.size());
StorageTypeVector storage_type_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
shape_inputs[i] = inputs[info->input_map[i]]->shape();
dtype_inputs[i] = inputs[info->input_map[i]]->dtype();
storage_type_inputs[i] = inputs[info->input_map[i]]->storage_type();
}

bool match = true;
Expand Down Expand Up @@ -321,9 +317,10 @@ bool CachedOp::SetBackwardGraph(
if (info->bwd_input_eid[i] == kEidNotExist) {
continue;
}
shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
size_t oi = BwdOriginalInput(info->input_map, i);
shapes[info->bwd_input_eid[i]] = inputs[oi]->shape();
dtypes[info->bwd_input_eid[i]] = inputs[oi]->dtype();
stypes[info->bwd_input_eid[i]] = inputs[oi]->storage_type();
}

std::pair<uint32_t, uint32_t> node_range, entry_range;
Expand Down Expand Up @@ -649,22 +646,22 @@ OpStatePtr CachedOp::StaticForward(
if (config_.static_shape) {
for (auto i : config_.param_indices) {
auto nid = idx.input_nodes()[i];
if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[state.info.input_map[i]])) {
match = false;
auto ptr = &state.buff[idx.entry_id(nid, 0)];
CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr);
*arrays[idx.entry_id(nid, 0)] = *inputs[i];
*arrays[idx.entry_id(nid, 0)] = *inputs[state.info.input_map[i]];
state.dynamic_entries[idx.entry_id(nid, 0)] = false;
}
}
for (auto i : config_.data_indices) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
arrays[eid] = inputs[i];
arrays[eid] = inputs[state.info.input_map[i]];
}
} else {
for (size_t i = 0; i < num_inputs(); ++i) {
auto nid = idx.input_nodes()[i];
arrays[idx.entry_id(nid, 0)] = inputs[i];
arrays[idx.entry_id(nid, 0)] = inputs[state.info.input_map[i]];
}
}

Expand Down Expand Up @@ -714,6 +711,7 @@ OpStatePtr CachedOp::DynamicForward(
std::lock_guard<std::mutex> lock(state.mutex);
SetForwardGraph(default_ctx, &state.info, recording, inputs);
runtime.info.fwd_graph = state.info.fwd_graph;
runtime.info.input_map = state.info.input_map;
}
nnvm::Graph& g = runtime.info.fwd_graph;
const auto& idx = g.indexed_graph();
Expand All @@ -736,7 +734,7 @@ OpStatePtr CachedOp::DynamicForward(
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
CollectInputOutputNDRefs(g, inputs, outputs, &arrays);
CollectInputOutputNDRefs(g, inputs, runtime.info.input_map, outputs, &arrays);

if (!use_naive_run) {
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(graph_type, MEM_PLAN));
Expand Down Expand Up @@ -853,6 +851,7 @@ void CachedOp::DynamicBackward(
auto& state = state_ptr.get_state<CachedOpState>();
std::lock_guard<std::mutex> lock(state.mutex);
state.info.fwd_graph = runtime.info.fwd_graph;
state.info.input_map = runtime.info.input_map;
SetBackwardGraph(&state.info, reqs, inputs);
runtime.info.full_graph = state.info.full_graph;
runtime.info.bwd_input_eid = state.info.bwd_input_eid;
Expand All @@ -875,7 +874,7 @@ void CachedOp::DynamicBackward(
if (runtime.info.bwd_input_eid[i] == kEidNotExist) {
continue;
}
arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
arrays[runtime.info.bwd_input_eid[i]] = inputs[BwdOriginalInput(runtime.info.input_map, i)];
}
for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
if (reqs[i] == kNullOp) continue;
Expand Down Expand Up @@ -952,10 +951,8 @@ void CachedOp::StaticBackward(
auto& arrays = state.arrays_with_in_out;
for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
auto eid = state.info.bwd_input_eid[i];
if (eid == kEidNotExist) {
continue;
}
if (state.dynamic_entries[eid]) arrays[eid] = inputs[i];
if (eid == kEidNotExist || !state.dynamic_entries[eid]) continue;
arrays[eid] = inputs[BwdOriginalInput(state.info.input_map, i)];
}

if (config_.static_shape) {
Expand Down Expand Up @@ -1293,6 +1290,13 @@ void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
}
}

size_t CachedOp::BwdOriginalInput(const std::vector<size_t>& input_map, size_t new_i) {
CHECK_GE(input_map.size(), bwd_in_dep_.size());
if (new_i >= bwd_ograd_dep_.size() && new_i < bwd_ograd_dep_.size() + bwd_in_dep_.size())
return bwd_ograd_dep_.size() + input_map[new_i - bwd_ograd_dep_.size()];
return new_i;
}

NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
Expand Down
34 changes: 23 additions & 11 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <mxnet/imperative.h>
#include <vector>
#include <numeric>
#include <atomic>
#include <utility>
#include <string>
Expand Down Expand Up @@ -133,16 +134,18 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {

void CollectInputOutputNDRefs(const nnvm::Graph& g,
const std::vector<NDArray*>& inputs,
const std::vector<size_t>& input_map,
const std::vector<NDArray*>& outputs,
std::vector<NDArray*>* arrays) DMLC_ATTRIBUTE_UNUSED;
void CollectInputOutputNDRefs(const nnvm::Graph& g,
const std::vector<NDArray*>& inputs,
const std::vector<size_t>& input_map,
const std::vector<NDArray*>& outputs,
std::vector<NDArray*>* arrays) {
const auto& idx = g.indexed_graph();
size_t num_inputs = idx.input_nodes().size();
for (size_t i = 0; i < num_inputs; ++i) {
(*arrays)[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i];
(*arrays)[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[input_map[i]];
}
for (size_t i = 0; i < idx.outputs().size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
Expand Down Expand Up @@ -322,8 +325,11 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) {
std::make_shared<dmlc::any>(std::move(full_ref_count));
}

void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Graph * grad_graph,
const Context& context, size_t num_forward_outputs, const bool inlining) {
void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* grad_graph,
std::vector<size_t>* input_map, const Context& context,
size_t num_forward_outputs, const bool inlining) {
input_map->resize(full_graph->indexed_graph().input_nodes().size());
std::iota(input_map->begin(), input_map->end(), 0);
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
if (context.dev_mask() == kGPU &&
!inlining &&
Expand All @@ -336,7 +342,7 @@ void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Grap
*full_graph = exec::FusePointwiseForward(std::move(*full_graph));
full_graph->attrs["num_forward_outputs"] = std::make_shared<nnvm::any>(num_forward_outputs);
*full_graph = exec::FusePointwiseBackward(std::move(*full_graph));
// Check the topological order of inputs
// Fill in input_map - mapping from the new to the original input indices.
const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes();
const auto &new_inputs = full_graph->indexed_graph().input_nodes();
if (original_inputs.size() != new_inputs.size()) {
Expand All @@ -345,13 +351,17 @@ void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Grap
<< "This is most probably a bug. Disabling fusion for this run.";
*full_graph = unoptimized_graph;
} else {
std::unordered_map<std::string, size_t> original_input_map;
for (size_t i = 0; i < original_inputs.size(); ++i) {
auto r = original_input_map.insert(std::make_pair(
unoptimized_graph.indexed_graph()[original_inputs[i]].source->attrs.name, i));
CHECK(r.second);
}
for (size_t i = 0; i < new_inputs.size(); ++i) {
if (unoptimized_graph.indexed_graph()[original_inputs[i]].source->attrs.name !=
full_graph->indexed_graph()[new_inputs[i]].source->attrs.name) {
LOG(WARNING) << "Disabling fusion due to altered topological order of inputs.";
*full_graph = unoptimized_graph;
break;
}
auto it = original_input_map.find(
full_graph->indexed_graph()[new_inputs[i]].source->attrs.name);
CHECK(it != original_input_map.end());
(*input_map)[i] = it->second;
}
}
} else {
Expand Down Expand Up @@ -524,6 +534,7 @@ class CachedOp {
nnvm::Graph fwd_graph;
nnvm::Graph grad_graph;
nnvm::Graph full_graph;
std::vector<size_t> input_map; // the original index of an input
std::vector<nnvm::NodeEntry> ograd_entries;
std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output;
std::vector<OpReqType> bwd_output_reqs;
Expand All @@ -540,7 +551,7 @@ class CachedOp {
&info.full_graph, &info.ograd_entries,
&info.fwd_input_to_grad_output);

OptimizeGraph(&info.full_graph, &info.fwd_graph, &info.grad_graph,
OptimizeGraph(&info.full_graph, &info.fwd_graph, &info.grad_graph, &info.input_map,
context_, fwd_graph_.outputs.size(), inlining_);

size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
Expand Down Expand Up @@ -638,6 +649,7 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
size_t BwdOriginalInput(const std::vector<size_t>& input_map, size_t new_i);

CachedOpConfig config_;
nnvm::Graph fwd_graph_;
Expand Down
4 changes: 3 additions & 1 deletion src/imperative/cached_op_threadsafe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx,

const MemoryPlanVector& mem_plan = g.GetAttr<MemoryPlanVector>("forward_mem_plan");
// Collect input output pointers to ndarray into the arrays data structure
CollectInputOutputNDRefs(g, inputs, outputs, &arrays);
std::vector<size_t> input_map(inputs.size());
std::iota(input_map.begin(), input_map.end(), 0);
CollectInputOutputNDRefs(g, inputs, input_map, outputs, &arrays);
// The SetForwardGraph call in DynamicForward runs the memory planning phase
// and allocates storage for intermediate and final outputs of the graph
// We need to still create NDArrays (pointer data structure), based on this
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/naive_cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ OpStatePtr NaiveCachedOp::Forward(
std::lock_guard<std::mutex> lock(state.mutex);
SetForwardGraph(default_ctx, &state.info, recording, inputs);
runtime.info.fwd_graph = state.info.fwd_graph;
runtime.info.input_map = state.info.input_map;
}
nnvm::Graph& g = runtime.info.fwd_graph;
const auto& idx = g.indexed_graph();
Expand All @@ -84,7 +85,7 @@ OpStatePtr NaiveCachedOp::Forward(
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
CollectInputOutputNDRefs(g, inputs, outputs, &arrays);
CollectInputOutputNDRefs(g, inputs, runtime.info.input_map, outputs, &arrays);

mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
imperative::NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
Expand Down
35 changes: 35 additions & 0 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import random
import mxnet as mx
import numpy as np
from mxnet import autograd, gluon
from mxnet.test_utils import *

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
Expand Down Expand Up @@ -311,3 +312,37 @@ def hybrid_forward(self, F, x):
out = foo(mx.nd.ones((10,10), ctx=mx.gpu()))
assert np.all(out.asnumpy() == np.ones((10,10)))
assert out.shape == (10,10,1)

@with_seed()
def test_input_reorder():
class Block(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Block, self).__init__(**kwargs)

def hybrid_forward(self, F, x, y, z):
s = x * 2
s2 = s + z
s = F.broadcast_add(s, y * y)
return F.dot(s, s2)

for static_alloc in (False, True):
arg_shapes = [(10, 10), (10, 1), (10, 10)]
arg_data = [mx.random.uniform(shape=s) for s in arg_shapes]

arrays = {}
for use_fusion in ('0', '1'):
os.environ['MXNET_USE_FUSION'] = use_fusion
arrays[use_fusion] = {}
n = Block()
n.hybridize(static_alloc=static_alloc)
args = [arg.copyto(mx.gpu()) for arg in arg_data]
for arg in args:
arg.attach_grad()
with autograd.record():
r = n(*args)
arrays[use_fusion]['result'] = r
r.backward()
for i, arg in enumerate(args):
arrays[use_fusion][i] = arg.grad
for key in ['result'] + list(range(len(arg_data))):
assert_allclose(arrays['0'][key].asnumpy(), arrays['1'][key].asnumpy())

0 comments on commit cae5138

Please sign in to comment.