Skip to content

Commit

Permalink
[MXNET-1352] Allow dynamic shape in while_loop and if conditionals (a…
Browse files Browse the repository at this point in the history
…pache#14393)

* Initial commit

* Rebase

* WIP for fixing rebase issues

* WIP for fixing rebase issues

* fix wip

* wip fix

* wip fix

* wip fix

* wip fix

* wip fix

* wip fix

* should be good to go

* wip remove debug info

* wip remove debug info

* linter

* linter

* Retrigger

* Address comments from Da
  • Loading branch information
junrushao authored and haohuw committed Jun 23, 2019
1 parent 3d2f62a commit fce7baf
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 238 deletions.
4 changes: 3 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ class NDArray {
* \brief set the correct shape of NDArray directly from the storage_shape of its own chunk.
*/
void SetShapeFromChunk() {
shape_ = ptr_->storage_shape;
if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
shape_ = ptr_->storage_shape;
}
}
/*
* This indicates whether an array is a view of another array (created by
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def forward(self, is_train=False, **kwargs):
check_call(_LIB.MXExecutorForward(
self.handle,
ctypes.c_int(int(is_train))))

self.outputs = self._get_outputs()
return self.outputs

def backward(self, out_grads=None, is_train=True):
Expand Down
215 changes: 192 additions & 23 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using namespace mxnet::common;
GraphExecutor::GraphExecutor() {
log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
need_grad_ = false;
is_dynamic_ = false;
subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
engine_ref_ = Engine::_GetSharedRef();
}
Expand Down Expand Up @@ -76,20 +77,77 @@ void GraphExecutor::PartialForward(bool is_train, int step, int *step_left) {
}

void GraphExecutor::Backward(const std::vector<NDArray>& head_grads, bool is_train) {
const auto& idx = graph_.indexed_graph();
if (num_forward_inputs_ != idx.input_nodes().size()) {
for (size_t i = 0; i < head_grad_array_.size(); ++i) {
if (!head_grad_array_[i].is_none()) {
CHECK(i < head_grads.size() && !head_grads[i].is_none())
<< "Because the last operator is not Loss function, "
<< "head_gradient is required when calling backward. "
<< "If you are attempting to minimize the output as "
<< "an objective, please modify your network and "
<< "pass it through the make_loss symbol.";
CopyFromTo(head_grads[i], &(head_grad_array_[i]));
{
const auto& idx = graph_.indexed_graph();
if (num_forward_inputs_ != idx.input_nodes().size()) {
for (size_t i = 0; i < head_grad_array_.size(); ++i) {
if (!head_grad_array_[i].is_none()) {
CHECK(i < head_grads.size() && !head_grads[i].is_none())
<< "Because the last operator is not Loss function, "
<< "head_gradient is required when calling backward. "
<< "If you are attempting to minimize the output as "
<< "an objective, please modify your network and "
<< "pass it through the make_loss symbol.";
const NDArray &from = head_grads[i];
NDArray &to = head_grad_array_[i];
if (this->is_dynamic_) {
to.WaitToRead();
if (!shape_is_known(to.shape())) {
to.Init(from.shape());
}
}
CopyFromTo(from, &to);
}
}
}
}
if (this->is_dynamic_) {
graph_ = InferShape(std::move(graph_), {}, "");
mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
const auto& idx = graph_.indexed_graph();
for (size_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
OpNode& opnode = op_nodes_[nid];
if (opnode.skip_exec_node) continue;
for (NDArray &array : opnode.exec->in_array) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
array.SetShapeFromChunk();
}
}
int i = 0;
for (NDArray &array : opnode.exec->in_array) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
array.SetShapeFromChunk();
}
if (!shape_is_known(array.shape())) {
mxnet::TShape shape = rshape[idx.entry_id(inode.inputs[i])];
if (shape_is_known(shape)) {
array.ReshapeAndAlloc(shape);
}
}
++i;
}
i = 0;
for (NDArray &array : opnode.exec->out_array) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
array.SetShapeFromChunk();
}
if (!shape_is_known(array.shape())) {
mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
if (shape_is_known(shape)) {
array.ReshapeAndAlloc(shape);
}
}
++i;
}
}
graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
}
const auto& idx = graph_.indexed_graph();
RunOps(is_train, num_forward_nodes_, idx.num_nodes());
}

Expand Down Expand Up @@ -119,6 +177,14 @@ void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback, bool mon
}

const std::vector<NDArray>& GraphExecutor::outputs() const {
if (this->is_dynamic_) {
for (const NDArray &array : output_arrays_) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
const_cast<NDArray &>(array).SetShapeFromChunk();
}
}
}
return output_arrays_;
}

Expand Down Expand Up @@ -381,8 +447,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
arg_shapes.resize(idx.input_nodes().size(), mxnet::TShape());
g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<mxnet::ShapeVector>("shape"));
this->is_dynamic_ = true;
}

arg_dtypes.resize(idx.input_nodes().size(), -1);
Expand Down Expand Up @@ -821,8 +886,7 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping,
}
g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<mxnet::ShapeVector>("shape"));
this->is_dynamic_ = true;
}
const mxnet::ShapeVector& shape_vec = g.GetAttr<mxnet::ShapeVector>("shape");
std::vector<OpReqType> grad_req_types;
Expand Down Expand Up @@ -977,14 +1041,16 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
uint32_t oid = head_grad_map_.at(idx[nid].source);
uint32_t eid = idx.entry_id(idx.outputs()[oid]);
NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid];
CHECK(mxnet::shape_is_known(vshape[eid]));
bool unknown_shape = !shape_is_known(vshape[eid]);
CHECK_NE(vdtype[eid], -1);
auto data_eid = idx.entry_id(nid, 0);
// initialize based on storage_type
if (stype != kDefaultStorage) {
data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]);
} else {
} else if (!unknown_shape) {
data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]);
} else {
data_entry_[data_eid] = NDArray(data_context[eid], vdtype[eid]);
}
if (log_verbose_) {
LOG(INFO) << "\tinit head_grad entry\t" << data_eid << "\tas "
Expand All @@ -994,7 +1060,11 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
// get maximum bytes in each pool
for (size_t i = 0; i < vshape.size(); ++i) {
if (!data_entry_[i].is_none()) continue;
size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]);
size_t shape_size = 0;
if (shape_is_known(vshape[i])) {
shape_size = vshape[i].Size();
}
size_t bytes = shape_size * mshadow::mshadow_sizeof(vdtype[i]);
int storage_id = vstorage[i];
// skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID
if (storage_id < 0) continue;
Expand All @@ -1013,7 +1083,10 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
std::multimap<size_t, NDArray> free_pool;
if (shared_pool != nullptr) {
for (const NDArray& nd : *shared_pool) {
size_t bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
size_t bytes = 0;
if (shape_is_known(nd.shape())) {
bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
}
free_pool.insert(std::make_pair(bytes, nd));
}
}
Expand Down Expand Up @@ -1067,9 +1140,13 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
int storage_id = vstorage[i];
auto storage_type = (NDArrayStorageType) vstorage_type[i];
if (storage_type == kDefaultStorage) {
CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
const NDArray& src = data_pool_.at(storage_id);
data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
if (!shape_is_known(vshape[i])) {
data_entry_[i] = NDArray(data_context[i], vdtype[i]);
} else {
CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
const NDArray& src = data_pool_.at(storage_id);
data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
}
} else {
data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
true, vdtype[i]);
Expand Down Expand Up @@ -1209,7 +1286,10 @@ void GraphExecutor::InitOpSegs() {
const profiler::Profiler *prof = profiler::Profiler::Get();
bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
&& (!prof || !prof->AggregateEnabled());

if (this->is_dynamic_) {
prefer_bulk_exec_inference = false;
prefer_bulk_exec_train = false;
}
bool is_training = num_forward_nodes_ != total_num_nodes;

if (prefer_bulk_exec_train && is_training) {
Expand Down Expand Up @@ -1300,6 +1380,8 @@ void GraphExecutor::ExecuteMonOutputCallback(size_t nid) {
}

void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
static auto& finfer_shape = nnvm::Op::GetAttr<mxnet::FInferShape>("FInferShape");
static auto& is_backward = Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
// Update context
const auto& idx = graph_.indexed_graph();
for (size_t nid = topo_start; nid < topo_end; ++nid) {
Expand All @@ -1311,6 +1393,7 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
opnode.exec->op_ctx.need_grad = need_grad_;
}

mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
// Push Ops
for (size_t nid = topo_start; nid < topo_end; ++nid) {
auto seg_op = cached_seg_opr_[nid];
Expand All @@ -1323,13 +1406,78 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
}
// Normal mode
const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) continue;
OpNode& opnode = op_nodes_[nid];
if (op_nodes_[nid].skip_exec_node) continue;
// Monitor callbacks
if (monitor_callback_ && monitor_all_) {
ExecuteMonInputCallback(nid);
}
if (this->is_dynamic_) {
const auto &op = inode.source->op();
{
for (NDArray &array : opnode.exec->in_array) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
array.SetShapeFromChunk();
}
}
int i = 0;
for (NDArray &array : opnode.exec->out_array) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
array.SetShapeFromChunk();
}
if (!shape_is_known(array.shape())) {
mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
if (shape_is_known(shape)) {
array.ReshapeAndAlloc(shape);
}
}
++i;
}
}
if (finfer_shape.count(op)) {
mxnet::ShapeVector in_shapes;
mxnet::ShapeVector out_shapes;
for (NDArray &array : opnode.exec->in_array) {
in_shapes.push_back(array.shape());
}
for (NDArray &array : opnode.exec->out_array) {
out_shapes.push_back(array.shape());
}
auto finfer = finfer_shape[op];
try {
bool success = finfer(inode.source->attrs, &in_shapes, &out_shapes);
CHECK(success) << "InferShape failed in operator " << inode.source->attrs.name;
} catch (const std::exception& e) {
throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
}
int n_out = out_shapes.size();
for (int i = 0; i < n_out; ++i) {
NDArray &array = opnode.exec->out_array[i];
if (!shape_is_known(array.shape())) {
array.Init(out_shapes[i]);
}
}
} else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
CHECK_GE(inode.control_deps.size(), 1U) <<
"BackwardOp need to have control_deps to its forward op";
uint32_t fid = inode.control_deps[0];
const OpNode& fopnode = op_nodes_[fid];
CHECK_EQ(fopnode.exec->in_array.size(), opnode.exec->out_array.size());
int nelem = fopnode.exec->in_array.size();
std::vector<NDArray> &from = fopnode.exec->in_array;
std::vector<NDArray> &to = opnode.exec->out_array;
for (int i = 0; i < nelem; ++i) {
if (!shape_is_known(to[i].shape())) {
to[i].Init(from[i].shape());
}
}
}
}
opnode.exec->op_ctx.is_train = is_train;
opnode.exec->op_ctx.need_grad = need_grad_;
if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
Expand All @@ -1343,14 +1491,35 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
} else if (opnode.cached_opr != nullptr) {
bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
if (this->is_dynamic_) {
for (NDArray &array : opnode.exec->out_array) {
array.WaitToRead();
if (!shape_is_known(array.shape())) {
array.SetShapeFromChunk();
}
}
}
} else {
LOG(FATAL) << "Not accessed";
}
for (uint32_t i = 0; i < num_inputs; ++i) {
int eid = idx.entry_id(inode.inputs[i]);
if (!shape_is_known(rshape[eid])) {
rshape[eid] = opnode.exec->in_array[i].shape();
}
}
for (uint32_t i = 0; i < num_outputs; ++i) {
int eid = idx.entry_id(nid, i);
if (!shape_is_known(rshape[eid])) {
rshape[eid] = opnode.exec->out_array[i].shape();
}
}
// Monitor callbacks
if (monitor_callback_) {
ExecuteMonOutputCallback(nid);
}
}
graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
}

GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, size_t topo_end) {
Expand Down
2 changes: 2 additions & 0 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ class GraphExecutor : public Executor {
void ExecuteMonOutputCallback(size_t nid);
// peform bulking and segmentation on the region [from_node, up_to_node) of a graph
void BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max);
// When infer shape fails, fall back to ensure dynamic-shaped operators executed correctly.
bool is_dynamic_;
// indicate whether there is a backward graph for gradients.
bool need_grad_;
// internal graph
Expand Down
2 changes: 1 addition & 1 deletion src/nnvm/plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
// only request memory for kBadStorageID
if (storage[eid] == GraphAllocator::kBadStorageID) {
auto &eshape = shape_vec[eid];
size_t esize = eshape.Size();
size_t esize = ndim_is_known(shape_vec[eid]) ? eshape.Size() : 0;
eids.insert(std::make_pair(esize, eid));
}
}
Expand Down
Loading

0 comments on commit fce7baf

Please sign in to comment.