diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 9db235bca532..61dfb9c9423c 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -262,6 +262,35 @@ std::vector CachedOp::Gradient( return ret; } +bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, + const std::vector& inputs, + bool erase_result) { + using namespace nnvm; + using namespace imperative; + CHECK_EQ(inputs.size(), num_inputs()); + + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); + + nnvm::Graph& g = state.info.fwd_graph; + ShapeVector shape_inputs; + shape_inputs.reserve(inputs.size()); + for (auto input : inputs) { + shape_inputs.emplace_back(input->shape()); + } + // We leverage the shape inference pass to detect whether dynamic shape exists. + // If so, the pass will fail with `contain_dynamic_shape = true`, + // This method is only called once, so the overhead is negligible. + bool contain_dynamic_shape = false; + CheckAndInferShape(&g, std::move(shape_inputs), true, + {0, 0}, {0, 0}, + &contain_dynamic_shape); + if (erase_result) { + g.attrs.erase("shape"); + g.attrs.erase("shape_inputs"); + } + return contain_dynamic_shape; +} bool CachedOp::SetForwardGraph( GraphInfo* info, @@ -762,7 +791,8 @@ OpStatePtr CachedOp::StaticForward( OpStatePtr CachedOp::DynamicForward( const Context& default_ctx, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + bool use_naive_run) { using namespace nnvm; using namespace imperative; @@ -784,9 +814,8 @@ OpStatePtr CachedOp::DynamicForward( auto& states = runtime.op_states; // Allocate entries - states.resize(idx.num_nodes()); buff.resize(idx.num_node_entries()); - states.reserve(idx.num_nodes()); + states.resize(idx.num_nodes()); std::vector arrays; arrays.reserve(buff.size()); for (auto& buffered_array : buff) { @@ -809,33 +838,42 @@ OpStatePtr CachedOp::DynamicForward( for (size_t i = 0; i < idx.num_node_entries(); ++i) { if (ref_count[i] == 0) array_reqs[i] = kNullOp; } - - const auto& mem_plan = g.GetAttr( - recording ? "full_mem_plan" : "forward_mem_plan"); - AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), - mem_plan, arrays, &array_reqs); - - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); - const auto& stypes = g.GetAttr("storage_type"); - - for (size_t i = 0; i < outputs.size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - arrays[eid] = outputs[i]; - if (!outputs[i]->is_none()) continue; - *outputs[i] = NDArray(static_cast(stypes[eid]), - shapes[eid], default_ctx, true, dtypes[eid]); - } - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - - // If CachedOp is running in the inline mode, it uses RunGraph to record - // computation; otherwise, CachedOp records computation itself. - // So if it's not the inline mode, we disable recording. - RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes, - recording && inlining_); - + if (!use_naive_run) { + const auto& mem_plan = g.GetAttr( + recording ? "full_mem_plan" : "forward_mem_plan"); + AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), + mem_plan, arrays, &array_reqs); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shapes = g.GetAttr("shape"); + const auto& stypes = g.GetAttr("storage_type"); + for (size_t i = 0; i < outputs.size(); ++i) { + auto eid = idx.entry_id(idx.outputs()[i]); + arrays[eid] = outputs[i]; + if (!outputs[i]->is_none()) continue; + *outputs[i] = NDArray(static_cast(stypes[eid]), + shapes[eid], default_ctx, true, dtypes[eid]); + } + // If CachedOp is running in the inline mode, it uses RunGraph to record + // computation; otherwise, CachedOp records computation itself. + // So if it's not the inline mode, we disable recording. + RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), + std::move(ref_count), &states, dispatch_modes, + recording && inlining_); + } else { + mxnet::ShapeVector shapes = g.GetAttr("shape"); + NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(), + std::move(array_reqs), std::move(ref_count), &states, + dispatch_modes, recording && inlining_, &shapes); + { + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); + auto copied_shape = shapes; + std::lock_guard lock(state.mutex); + state.info.fwd_graph.attrs["shape"] = std::make_shared(std::move(copied_shape)); + } + g.attrs["shape"] = std::make_shared(std::move(shapes)); + } return op_state; } @@ -863,10 +901,14 @@ OpStatePtr CachedOp::Forward( OpStatePtr op_state; try { - if (config_.static_alloc) { + if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) { + config_.is_dynamic = true; + config_.static_alloc = false; + op_state = DynamicForward(default_ctx, inputs, outputs, true); + } else if (config_.static_alloc) { op_state = StaticForward(default_ctx, inputs, outputs); } else { - op_state = DynamicForward(default_ctx, inputs, outputs); + op_state = DynamicForward(default_ctx, inputs, outputs, false); } } catch (const dmlc::Error& e) { Engine::Get()->set_bulk_size(prev_bulk_size); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 3b173c8654a4..5a0351ad6fa7 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -153,7 +153,8 @@ class CachedOp { OpStatePtr DynamicForward( const Context& default_ctx, const std::vector& inputs, - const std::vector& outputs); + const std::vector& outputs, + bool use_naive_run = false); void DynamicBackward( const bool retain_graph, const OpStatePtr& op_state, @@ -185,6 +186,10 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); + bool CheckDynamicShapeExists( + const Context& default_ctx, + const std::vector& inputs, + bool erase_result); CachedOpConfig config_; nnvm::Graph fwd_graph_; diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 8d1f65518565..3e5b3987522c 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -442,9 +442,10 @@ std::vector Imperative::Backward( ShapeVector shapes; shapes.reserve(idx.num_node_entries()); + bool contain_unknown = false; for (const auto& i : arrays) shapes.emplace_back(i->shape()); CheckAndInferShape(&graph, std::move(shapes), false, - node_range, entry_range); + node_range, entry_range, &contain_unknown); DTypeVector dtypes; dtypes.reserve(idx.num_node_entries()); diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 1a676e62000d..6cb4a70324b5 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -22,6 +22,106 @@ namespace mxnet { namespace imperative { + +inline std::vector NodeInputs(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector arrays) { + const nnvm::IndexedGraph::Node& node = idx[node_idx]; + const size_t num_inputs = node.inputs.size(); + std::vector ndinputs; + ndinputs.reserve(num_inputs); + for (const auto& j : node.inputs) { + size_t eid = idx.entry_id(j); + ndinputs.emplace_back(arrays[eid]); + } + return ndinputs; +} + +inline std::vector NodeOutputs(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector arrays) { + const nnvm::IndexedGraph::Node& node = idx[node_idx]; + const size_t num_outputs = node.source->num_outputs(); + std::vector ndoutputs; + ndoutputs.reserve(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + size_t eid = idx.entry_id(node_idx, j); + ndoutputs.emplace_back(arrays[eid]); + } + return ndoutputs; +} + +inline std::vector NodeReq(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector array_reqs) { + const nnvm::IndexedGraph::Node& node = idx[node_idx]; + const size_t num_outputs = node.source->num_outputs(); + std::vector req; + req.reserve(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + size_t eid = idx.entry_id(node_idx, j); + req.push_back(array_reqs[eid]); + } + return req; +} + +inline void InvokeOperator(const nnvm::IndexedGraph& idx, + const int node_idx, + const bool retain_graph, + const std::vector arrays, + Context ctx, + std::vector* p_states, + std::vector ndinputs, + std::vector ndoutputs, + std::vector *p_req, + std::vector *p_ref_count, + std::function invoke) { + static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); + std::vector& states = *p_states; + std::vector &req = *p_req; + std::vector &ref_count = *p_ref_count; + + const nnvm::IndexedGraph::Node& node = idx[node_idx]; + if (node.source->op() == bwd_cached_op) { + const auto& cached_op = dmlc::get(node.source->attrs.parsed); + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); + } else if (createop.count(node.source->op())) { + mxnet::ShapeVector arg_shapes; + nnvm::DTypeVector arg_dtypes; + arg_shapes.reserve(ndinputs.size()); + arg_dtypes.reserve(ndinputs.size()); + for (auto& ndinput : ndinputs) { + arg_shapes.emplace_back(ndinput->shape()); + arg_dtypes.emplace_back(ndinput->dtype()); + } + states[node_idx] = createop[node.source->op()](node.source->attrs, ctx, arg_shapes, arg_dtypes); + invoke(states[node_idx]); + } else if (is_layer_backward.get(node.source->op(), false)) { + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + invoke(states[fwd_node_id]); + } else { + invoke(OpStatePtr()); + } + for (const auto& j : node.inputs) { + size_t eid = idx.entry_id(j); + --ref_count[eid]; + if (ref_count[eid] == 0) { + *arrays[eid] = NDArray(); + } + } + for (size_t j = 0; j < ndoutputs.size(); ++j) { + size_t eid = idx.entry_id(node_idx, j); + if (ref_count[eid] == 0) { + *arrays[eid] = NDArray(); + } + } +} + void RunGraph( const bool retain_graph, const nnvm::IndexedGraph& idx, @@ -31,88 +131,75 @@ void RunGraph( std::vector&& ref_count, std::vector *p_states, const DispatchModeVector &dispatch_modes, - bool recording) { - using namespace nnvm; - using namespace imperative; - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); - static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); - - const auto imp = Imperative::Get(); - - std::vector& states = *p_states; - - std::vector ndinputs, ndoutputs; - ShapeVector arg_shapes; - DTypeVector arg_dtypes; - std::vector req; - + bool recording, + mxnet::ShapeVector *shapes) { + CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; - if (node.source->op() == nullptr) continue; - auto num_outputs = node.source->num_outputs(); - ndinputs.clear(); - ndinputs.reserve(node.inputs.size()); - for (const auto& j : node.inputs) { - ndinputs.emplace_back(arrays[idx.entry_id(j)]); - CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index; - } - ndoutputs.clear(); - ndoutputs.reserve(num_outputs); - req.clear(); - req.reserve(num_outputs); - for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(i, j); - ndoutputs.emplace_back(arrays[eid]); - req.push_back(array_reqs[eid]); - CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none()); + if (node.source->op() == nullptr) { + continue; } - const Context& ctx = ndoutputs[0]->ctx(); - const DispatchMode dispatch_mode = dispatch_modes[i]; - if (node.source->op() == bwd_cached_op) { - const auto& cached_op = dmlc::get(node.source->attrs.parsed); - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); - } else if (createop.count(node.source->op())) { - arg_shapes.clear(); - arg_dtypes.clear(); - arg_shapes.reserve(ndinputs.size()); - arg_dtypes.reserve(ndinputs.size()); - for (auto& ndinput : ndinputs) { - arg_shapes.emplace_back(ndinput->shape()); - arg_dtypes.emplace_back(ndinput->dtype()); - } - states[i] = createop[node.source->op()]( - node.source->attrs, ctx, arg_shapes, arg_dtypes); - imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); + std::vector ndinputs = NodeInputs(idx, i, arrays); + std::vector ndoutputs = NodeOutputs(idx, i, arrays); + std::vector req = NodeReq(idx, i, array_reqs); + Context ctx = ndoutputs[0]->ctx(); + auto invoke = [&](const OpStatePtr &state) { + const nnvm::IndexedGraph::Node& node = idx[i]; + DispatchMode dispatch_mode = dispatch_modes[i]; + Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, + req, dispatch_mode, state); if (recording) { - imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]); + Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state); } - } else if (is_layer_backward.get(node.source->op(), false)) { - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, states[fwd_node_id]); - if (recording) { - imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]); + }; + InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, + &req, &ref_count, invoke); + } +} + +void NaiveRunGraph( + const bool retain_graph, + const Context& default_ctx, + const nnvm::IndexedGraph& idx, + const std::vector arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes, + bool recording, + mxnet::ShapeVector *shapes) { + for (size_t i = node_start; i < node_end; ++i) { + const nnvm::IndexedGraph::Node& node = idx[i]; + if (node.source->op() == nullptr) { + continue; + } + std::vector ndinputs = NodeInputs(idx, i, arrays); + std::vector ndoutputs = NodeOutputs(idx, i, arrays); + std::vector req; + Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx); + auto invoke = [&](const OpStatePtr &state) { + const nnvm::IndexedGraph::Node& node = idx[i]; + DispatchMode dispatch_mode = DispatchMode::kUndefined; + SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode); + SetWriteInplaceReq(ndinputs, ndoutputs, &req); + Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, + req, dispatch_mode, state); + for (size_t j = 0; j < ndoutputs.size(); ++j) { + if (ndoutputs[j]->shape().ndim() == 0) { + ndoutputs[j]->WaitToRead(); + ndoutputs[j]->SetShapeFromChunk(); + } + size_t eid = idx.entry_id(i, j); + auto shape = ndoutputs[j]->shape(); + (*shapes)[eid] = shape; } - } else { - imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); if (recording) { - imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs); + Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state); } - } - - for (const auto& j : node.inputs) { - size_t eid = idx.entry_id(j); - --ref_count[eid]; - if (ref_count[eid] == 0) *arrays[eid] = NDArray(); - } - for (size_t j = 0; j < ndoutputs.size(); ++j) { - size_t eid = idx.entry_id(i, j); - if (ref_count[eid] == 0) *arrays[eid] = NDArray(); - } + }; + InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, + &req, &ref_count, invoke); } } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 5eecfe8c6f23..071f4fa9dd0b 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -1005,7 +1005,20 @@ void RunGraph(const bool retain_graph, std::vector&& ref_count, std::vector *p_states, const DispatchModeVector &dispatch_modes, - bool recording); + bool recording, + mxnet::ShapeVector *shapes = nullptr); + +void NaiveRunGraph(const bool retain_graph, + const Context& default_ctx, + const nnvm::IndexedGraph& idx, + const std::vector arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes, + bool recording, + mxnet::ShapeVector *shapes); } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_dynamic_shape.py b/tests/python/unittest/test_dynamic_shape.py new file mode 100644 index 000000000000..1b043c73256d --- /dev/null +++ b/tests/python/unittest/test_dynamic_shape.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import mxnet as mx +from mxnet import gluon +from numpy.testing import assert_allclose, assert_array_equal +from mxnet.test_utils import * +from mxnet.base import _as_list +from mxnet.attribute import AttrScope +from common import with_seed + + +def test_dynamic_shape(): + + class _TestBlock(gluon.HybridBlock): + + def __init__(self): + super(_TestBlock, self).__init__() + + def hybrid_forward(self, F, data, index): + return F.contrib.boolean_mask(data, index) + + block = _TestBlock() + block.hybridize() + data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) + index = mx.nd.array([0, 1, 1]) + data.attach_grad() + with mx.autograd.record(): + result = block(data, index) + result.backward() + result_nd = np.array([[4, 5, 6], [7, 8, 9]]) + data_grad_nd = np.array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.]]) + assert_almost_equal(result.asnumpy(), result_nd) + assert_almost_equal(data.grad.asnumpy(), data_grad_nd) + + +if __name__ == '__main__': + import nose + nose.runmodule()