Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Bulked op segments to allow Variable nodes #14200

Merged
merged 15 commits into from
Mar 7, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
- If set to `1`, during training MXNet executes the computation graph as several subgraphs in bulk mode.
* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN
- Values: Int ```(default=15)```
- The maximum number of nodes in the subgraph executed in bulk during training(not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training.
- The maximum number of nodes in the subgraph executed in bulk during training (not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training.
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD
- Values: Int ```(default=<value of MXNET_EXEC_BULK_MAX_NODE_TRAIN>)```
- The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the forward pass.
* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD
- Values: Int ```(default=<value of MXNET_EXEC_BULK_MAX_NODE_TRAIN>)```
- The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass.

## Control the Data Communication

Expand Down
23 changes: 20 additions & 3 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,31 @@ class Imperative {
bool create_graph);
/*! \return AutogradRuntime singleton */
static Imperative* Get();
/*! \brief Should op execution bulking be employed during inference. */
static bool PreferBulkExecInference() {
return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
}
/*! \brief Should op execution bulking be employed during training. */
static bool PreferBulkExecTrain() {
return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", true);
}
/*! \brief The max number of op nodes in a bulk during forward pass of training. */
static int BulkExecMaxNodeTrainFwd() {
return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD",
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
}
/*! \brief The max number of op nodes in a bulk during backward pass of training. */
static int BulkExecMaxNodeTrainBwd() {
return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD",
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
}

private:
friend class NDArray;
/*! \brief make constructor protected. */
Imperative() {
if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
}
if (PreferBulkExecTrain())
backward_bulk_size_ = BulkExecMaxNodeTrainBwd();
}
/*! \brief find the input/output ndarrays that are needed for backward */
void GetBackwardDependency(
Expand Down
79 changes: 35 additions & 44 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,16 +1191,17 @@ void GraphExecutor::InitOpSegs() {
cached_seg_opr_.resize(total_num_nodes, p);
if (monitor_callback_) return;

// Symbolic bulking is set by the same environment variables as Imperative bulking.
// Generate segments based on the graph structure
bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
bool prefer_bulk_exec_inference = Imperative::PreferBulkExecInference();
// Whether to perform bulk exec for training
const profiler::Profiler *prof = profiler::Profiler::Get();
bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)
&& (!prof || !prof->AggregateEnabled());
bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
&& (!prof || !prof->AggregateEnabled());

bool is_training = num_forward_nodes_ != total_num_nodes;

if (prefer_bulk_exec && is_training) {
if (prefer_bulk_exec_train && is_training) {
this->BulkTrainingOpSegs(total_num_nodes);
}

Expand All @@ -1211,63 +1212,53 @@ void GraphExecutor::InitOpSegs() {


void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) {
// The maximum number of node in a segment executed in bulk
size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
// The maximum number of nodes in a segment executed in bulk (excluding variables) in fwd pass.
size_t segment_num_nodes_threshold_fwd = Imperative::BulkExecMaxNodeTrainFwd();
// The maximum number of nodes in a segment executed in bulk (excluding variables) in bwd pass.
size_t segment_num_nodes_threshold_bwd = Imperative::BulkExecMaxNodeTrainBwd();

// create forward segments for training
size_t topo_start = 0;
size_t segment_node_count = 0;
for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
auto &node = graph_.indexed_graph()[nid].source;
auto &op_node = op_nodes_[nid];
// check if the segment relies on external input, or exceeds maxinum number of node,
// or requires async ops
if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
op_node.exec->exec_type() != ExecType::kSync) {
// create a new segment for the previous nodes if the current one cannot be bulked
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
// Variables, such as learned weights, are ignored in the segment_node_count
bool ignore_node = node->is_variable();
if (!ignore_node)
segment_node_count++;
bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync;
// check if we need to create the segment based on properties of this node
if (!can_bulk || nid == num_forward_nodes_ - 1 ||
segment_node_count >= segment_num_nodes_threshold_fwd) {
// Create a new segment for the previous nodes- include also this node if it's bulkable
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid);
topo_start = nid + 1;
segment_node_count = 0;
}
}
// the last segment
if (topo_start != num_forward_nodes_) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_);
}

// create backward segments for training
// get all gradient variables
std::unordered_set<engine::VarHandle> grad_vars;
for (auto &kv : grad_store_) {
grad_vars.insert(kv.second.var());
}
auto &idx = graph_.indexed_graph();
topo_start = num_forward_nodes_;
segment_node_count = 0;
for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) {
auto &node = graph_.indexed_graph()[nid].source;
auto &op_node = op_nodes_[nid];
if (op_node.skip_exec_node || op_node.exec == nullptr) {
continue;
}
if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold ||
op_node.exec->exec_type() != ExecType::kSync) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
// Variables, such as learned weights, are ignored in the segment_node_count and
// nodes that are not executed for various reasons.
bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr;
if (!ignore_node)
segment_node_count++;
bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync;
// check if we need to create the segment based on properties of this node
if (!can_bulk || nid == total_num_nodes - 1 ||
segment_node_count >= segment_num_nodes_threshold_bwd) {
// Create a new segment for the previous nodes- include also this node if it's bulkable
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid);
topo_start = nid + 1;
} else {
// If it produces output gradient, don't include it in the segment
bool output_gradient = false;
for (auto &out_arr : op_node.exec->out_array) {
if (grad_vars.find(out_arr.var()) != grad_vars.end()) {
output_gradient = true;
}
}
if (output_gradient) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
}
segment_node_count = 0;
}
}
// last segment for backward
if (topo_start < total_num_nodes) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes);
}
}

void GraphExecutor::BulkInferenceOpSegs() {
Expand Down
11 changes: 10 additions & 1 deletion src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,18 @@ void CachedOp::StaticInitExec(
SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
}

// Init bulk_size for Inference mode with bulking enabled (= entire forward graph).
size_t bulk_size = idx.num_nodes();
if (recording || keep_fwd) {
bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
// Training mode
if (!Imperative::PreferBulkExecTrain())
bulk_size = 0;
else
bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
} else {
// Inference mode
if (!Imperative::PreferBulkExecInference())
bulk_size = 0;
}

CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size,
Expand Down
4 changes: 2 additions & 2 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
.set_default(2)
.describe("Maximum number of operators that can be inlined.");
DMLC_DECLARE_FIELD(forward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that env_vars like MXNET_EXEC_BULK_EXEC_TRAIN/MXNET_EXEC_BULK_EXEC_INFER=0 are not respected by the cached_op. Would you have time to kindly fix it for cached op?
https://github.com/apache/incubator-mxnet/blob/54fd288c7a4bf59d37f793c26ef9a98ed40b0c40/src/imperative/cached_op.cc#L593-L596

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please review the latest commits, which address this valid concern about consistency. I also consolidated all op-bulking env var references to a central place and added timing-based tests for the perf impact of bulking. I'm happy with the PR now (assuming it passes CI). Anyone else you want to pull into the review @eric-haibin-lin?

.set_default(Imperative::BulkExecMaxNodeTrainFwd())
.describe("Segment size of bulk execution during forward pass.");
DMLC_DECLARE_FIELD(backward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.set_default(Imperative::BulkExecMaxNodeTrainBwd())
.describe("Segment size of bulk execution during backward pass.");
DMLC_DECLARE_FIELD(data_indices)
.set_default(nnvm::Tuple<uint32_t>())
Expand Down
78 changes: 78 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied
from common import run_in_spawned_process
from test_gluon import *
from test_loss import *
from test_gluon_rnn import *
Expand Down Expand Up @@ -408,6 +409,83 @@ def tensor_size(big_tensor_bytes):
# Evaluate model
net(data_in).asnumpy()

# isolated execution bulking test function to be invoked with different env var settings
def _test_bulking_in_process(seed, time_per_iteration):
# Use flip since it's a simple function with same-sized I/O unlikely to ever be fused.
class Flip(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Flip, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
return F.flip(x, axis=0)

def get_net(num_ops):
net = nn.HybridSequential()
with net.name_scope():
for _ in range(num_ops):
net.add(Flip())
return net

data_shape = (10,)
num_ops = 1000
num_iterations = 20

# build model
x = mx.ndarray.zeros(data_shape)
x.attach_grad()
dy = mx.ndarray.ones(data_shape)
net = get_net(num_ops)
net.hybridize(static_alloc=True, static_shape=True)

# time a number of forward() and backward() executions after some warm-up iterations
warmups = 1
for i in range(num_iterations+warmups):
with autograd.record():
if i == warmups:
start = time.time()
y = net(x)
y.backward(dy)
x.grad.wait_to_read()

time_per_iteration.value = (time.time() - start) / num_iterations

@with_seed()
def test_bulking():
# test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training)
test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)]
times = {}
times_str = ''
for seg_sizes in test_cases:
# Create shared variable to return measured time from test process
time_per_iteration = mp.Manager().Value('d', 0.0)
if not run_in_spawned_process(_test_bulking_in_process,
{'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0],
'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1],
'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]},
time_per_iteration):
# skip test since the python version can't run it properly. Warning msg was logged.
return
times[seg_sizes] = time_per_iteration.value
times_str += \
'\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format(
seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes])

fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)])
slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)])
fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)])
fully_bulked_time = times[(15,15,True)]

print(times_str)
# Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same,
# slower than both half-bulked times[0,15,True] and times[15,0,True]
assert slowest_half_bulked_time < fastest_non_bulked_time, \
'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \
.format(slowest_half_bulked_time - fastest_non_bulked_time, times_str)
# The fully bulked times[15,15,True] should be faster than both half-bulked runs
assert fully_bulked_time < fastest_half_bulked_time, \
'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \
.format(fully_bulked_time - fastest_half_bulked_time, times_str)


if __name__ == '__main__':
import nose
Expand Down
Loading