Skip to content

Commit

Permalink
Improve bulking in Gluon (apache#13890)
Browse files Browse the repository at this point in the history
* Improve bulking in Gluon

* Trigger CI
  • Loading branch information
ptrendx authored and stephenrawls committed Feb 16, 2019
1 parent a433e8f commit 3cb98a7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
4 changes: 4 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ struct RunContext {
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
/*!
* \brief indicator of whether this execution is run in bulk mode
*/
bool is_bulk;
/*!
* \brief get mshadow stream from Context
* \return the mshadow stream
Expand Down
10 changes: 6 additions & 4 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
ret = RunContext{ctx, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
Expand All @@ -85,7 +85,9 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
use_counter = counter;
counter = (counter + 1) % kStreams;
}
ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter)};
ret = RunContext{ctx,
gpu_streams_.at(ctx.dev_id).at(use_counter),
false};
break;
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
Expand All @@ -103,7 +105,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
ret = RunContext{ctx, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
Expand All @@ -114,7 +116,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
}
}
ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id)};
ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), false};
break;
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
Expand Down
6 changes: 6 additions & 0 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,13 @@ class ThreadedEngine : public Engine {
DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
SyncFn fn = std::move(bulk_status.fn);
this->PushAsync([fn](RunContext ctx, CallbackOnComplete on_complete) {
ctx.is_bulk = true;
fn(ctx);
ctx.is_bulk = false;
bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask;
if (is_gpu) {
ctx.get_stream<gpu>()->Wait();
}
on_complete();
}, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars,
FnProperty::kNormal, 0, "ImperativeBulk");
Expand Down
5 changes: 1 addition & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,14 +583,11 @@ void CachedOp::StaticInitExec(
}

size_t bulk_size = idx.num_nodes();
std::unordered_set<uint32_t> excludes;
if (recording || keep_fwd) {
bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0));
}

CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes,
CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size,
state.execs, skip_plus_node, &state.opr_segs);
}

Expand Down
10 changes: 1 addition & 9 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ inline void PushFCompute(const FCompute& fn,
fn(attrs, opctx, input_blobs, tmp_req, output_blobs);
// post-fcompute fallback, cast to original storage type
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
if (is_gpu) {
if (is_gpu && !rctx.is_bulk) {
rctx.get_stream<gpu>()->Wait();
}
}, ctx, read_vars, write_vars, FnProperty::kNormal,
Expand Down Expand Up @@ -928,7 +928,6 @@ inline void CreateEngineOpSeg(
const size_t start_nid,
const size_t end_nid,
const size_t bulk_size,
const std::unordered_set<uint32_t>& excludes,
const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
const std::vector<int> skip_plus_node,
std::vector<EngineOprSeg> *opr_segs) {
Expand All @@ -944,13 +943,6 @@ inline void CreateEngineOpSeg(

// Stop at async nodes and invalid node (due to input/output is not allocated)
bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
}
auto num_outputs = node.source->num_outputs();
for (size_t i = 0; i < num_outputs && !stop; ++i) {
if (excludes.count(idx.entry_id(nid, i))) stop = true;
}

// Create opr segment for previous nodes.
if (stop && nid > seg_start) {
Expand Down

0 comments on commit 3cb98a7

Please sign in to comment.