Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1324] Add NaiveRunGraph to imperative utils #14192

Merged
merged 16 commits into from
Mar 6, 2019
102 changes: 99 additions & 3 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,35 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient(
return ret;
}

bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
const std::vector<NDArray*>& inputs,
bool erase_result) {
using namespace nnvm;
using namespace imperative;
CHECK_EQ(inputs.size(), num_inputs());

auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();

nnvm::Graph& g = state.info.fwd_graph;
ShapeVector shape_inputs;
shape_inputs.reserve(inputs.size());
for (auto input : inputs) {
shape_inputs.emplace_back(input->shape());
}
// We leverage the shape inference pass to detect whether dynamic shape exists.
// If so, the pass will fail with `contain_dynamic_shape = true`,
// This method is only called once, so the overhead is negligible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you able to verify this? I'm afraid this could cause slowdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If infer shape fails, it means there is probably an op of dynamic shape (np.unique, boolean mask).

MXNet didn't support this kind of op before, because of the limitation of our system. Because of these lines, we now can support it in Gluon blocks.

I would suggest to clearly mention the behavior in our docs that if infer shape fails, the code will go to the slow path (naive run graph, etc).

bool contain_dynamic_shape = false;
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
junrushao marked this conversation as resolved.
Show resolved Hide resolved
if (erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
return contain_dynamic_shape;
}

bool CachedOp::SetForwardGraph(
GraphInfo* info,
Expand Down Expand Up @@ -784,9 +813,8 @@ OpStatePtr CachedOp::DynamicForward(
auto& states = runtime.op_states;

// Allocate entries
states.resize(idx.num_nodes());
buff.resize(idx.num_node_entries());
states.reserve(idx.num_nodes());
states.resize(idx.num_nodes());
std::vector<NDArray*> arrays;
arrays.reserve(buff.size());
for (auto& buffered_array : buff) {
Expand Down Expand Up @@ -839,6 +867,70 @@ OpStatePtr CachedOp::DynamicForward(
return op_state;
}

OpStatePtr CachedOp::NaiveForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs) {
using namespace nnvm;
using namespace imperative;
// Initialize
bool recording = Imperative::Get()->is_recording();
auto op_state = OpStatePtr::Create<DynamicRuntime>();
auto& runtime = op_state.get_state<DynamicRuntime>();
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
std::lock_guard<std::mutex> lock(state.mutex);
SetForwardGraph(&state.info, recording, inputs);
runtime.info.fwd_graph = state.info.fwd_graph;
}
// build the indexed graph
nnvm::Graph& g = runtime.info.fwd_graph;
const auto& idx = g.indexed_graph();
const size_t num_inputs = idx.input_nodes().size();
const size_t num_entries = idx.num_node_entries();
std::vector<uint32_t> ref_count = g.GetAttr<std::vector<uint32_t> >(
recording ? "full_ref_count" : "forward_ref_count");
// construct `arrays`
runtime.buff.resize(num_entries);
std::vector<NDArray*> arrays;
junrushao marked this conversation as resolved.
Show resolved Hide resolved
arrays.reserve(num_entries);
for (auto& item : runtime.buff) {
arrays.push_back(&item);
}
for (size_t i = 0; i < num_inputs; ++i) {
arrays[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i];
}
for (size_t i = 0; i < idx.outputs().size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (!arrays[eid]->is_none()) *outputs[i] = arrays[eid]->Detach();
arrays[eid] = outputs[i];
}
// construct `array_reqs`
std::vector<OpReqType> array_reqs;
array_reqs.reserve(num_entries);
for (size_t i = 0; i < num_entries; ++i) {
array_reqs.push_back(ref_count[i] == 0 ? kNullOp : kWriteTo);
}
// other stuff
auto& states = runtime.op_states;
states.resize(idx.num_nodes());
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
ShapeVector shapes = g.GetAttr<ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, &shapes, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
auto copied_shape = shapes;
std::lock_guard<std::mutex> lock(state.mutex);
state.info.fwd_graph.attrs["shape"] = std::make_shared<dmlc::any>(std::move(copied_shape));
}
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
return op_state;
}

OpStatePtr CachedOp::Forward(
const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
Expand All @@ -863,7 +955,11 @@ OpStatePtr CachedOp::Forward(

OpStatePtr op_state;
try {
if (config_.static_alloc) {
if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) {
config_.is_dynamic = true;
config_.static_alloc = false;
op_state = NaiveForward(default_ctx, inputs, outputs);
} else if (config_.static_alloc) {
op_state = StaticForward(default_ctx, inputs, outputs);
} else {
op_state = DynamicForward(default_ctx, inputs, outputs);
Expand Down
8 changes: 8 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
bool CheckDynamicShapeExists(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
bool erase_result);
OpStatePtr NaiveForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);

CachedOpConfig config_;
nnvm::Graph fwd_graph_;
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ std::vector<NDArray*> Imperative::Backward(

ShapeVector shapes;
shapes.reserve(idx.num_node_entries());
bool contain_unknown = false;
for (const auto& i : arrays) shapes.emplace_back(i->shape());
CheckAndInferShape(&graph, std::move(shapes), false,
node_range, entry_range);
node_range, entry_range, &contain_unknown);

DTypeVector dtypes;
dtypes.reserve(idx.num_node_entries());
Expand Down
111 changes: 111 additions & 0 deletions src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,116 @@ void RunGraph(
}
}


void NaiveRunGraph(
const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
mxnet::ShapeVector *shapes,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes,
bool recording) {
using namespace nnvm;
using namespace imperative;
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
static const auto bwd_cached_op = Op::Get("_backward_CachedOp");

const auto imp = Imperative::Get();

std::vector<OpStatePtr>& states = *p_states;

for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) {
continue;
}
size_t num_outputs = node.source->num_outputs();
// construct `ndinputs`
std::vector<NDArray*> ndinputs;
ndinputs.reserve(node.inputs.size());
for (const auto& j : node.inputs) {
ndinputs.emplace_back(arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
}
// construct `ndoutputs` and `req`
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(i, j);
ndoutputs.emplace_back(arrays[eid]);
}
// other auxiliary data
Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx);
auto invoke = [&](const OpStatePtr &state) {
DispatchMode dispatch_mode = DispatchMode::kUndefined;
SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode);
std::vector<OpReqType> req;
SetWriteInplaceReq(ndinputs, ndoutputs, &req);
imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state);
for (size_t j = 0; j < ndoutputs.size(); ++j) {
if (ndoutputs[j]->shape().ndim() == 0) {
ndoutputs[j]->WaitToRead();
ndoutputs[j]->SetShapeFromChunk();
}
size_t eid = idx.entry_id(i, j);
auto shape = ndoutputs[j]->shape();
(*shapes)[eid] = shape;
}
if (recording) {
imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
}
};
if (node.source->op() == bwd_cached_op) {
// case 1: backward cached op
std::vector<OpReqType> req;
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(i, j);
req.push_back(array_reqs[eid]);
CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
}
const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
} else if (createop.count(node.source->op())) {
// case 2: node is in createop
// construct `arg_shapes` and `arg_dtypes`
ShapeVector arg_shapes;
DTypeVector arg_dtypes;
arg_shapes.reserve(ndinputs.size());
arg_dtypes.reserve(ndinputs.size());
for (auto& ndinput : ndinputs) {
arg_shapes.emplace_back(ndinput->shape());
arg_dtypes.emplace_back(ndinput->dtype());
}
states[i] = createop[node.source->op()](node.source->attrs, ctx, arg_shapes, arg_dtypes);
invoke(states[i]);
} else if (is_layer_backward.get(node.source->op(), false)) {
// case 3: backward layer
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
invoke(states[fwd_node_id]);
} else {
// case default
invoke(OpStatePtr());
}
for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
--ref_count[eid];
if (ref_count[eid] == 0) *arrays[eid] = NDArray();
}
for (size_t j = 0; j < ndoutputs.size(); ++j) {
size_t eid = idx.entry_id(i, j);
if (ref_count[eid] == 0) *arrays[eid] = NDArray();
}
}
}

} // namespace imperative
} // namespace mxnet
12 changes: 12 additions & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,18 @@ void RunGraph(const bool retain_graph,
const DispatchModeVector &dispatch_modes,
bool recording);

void NaiveRunGraph(const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
mxnet::ShapeVector *shapes,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes,
bool recording);

} // namespace imperative
} // namespace mxnet

Expand Down
54 changes: 54 additions & 0 deletions tests/python/unittest/test_dynamic_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import mxnet as mx
from mxnet import gluon
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from mxnet.base import _as_list
from mxnet.attribute import AttrScope
from common import with_seed


def test_dynamic_shape():

class _TestBlock(gluon.HybridBlock):

def __init__(self):
super(_TestBlock, self).__init__()

def hybrid_forward(self, F, data, index):
return F.contrib.boolean_mask(data, index)

block = _TestBlock()
block.hybridize()
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 1, 1])
data.attach_grad()
with mx.autograd.record():
result = block(data, index)
result.backward()
result_nd = np.array([[4, 5, 6], [7, 8, 9]])
data_grad_nd = np.array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.]])
assert_almost_equal(result.asnumpy(), result_nd)
assert_almost_equal(data.grad.asnumpy(), data_grad_nd)


if __name__ == '__main__':
import nose
nose.runmodule()