Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move Async engine tag to MXNET_ENGINE_TYPE
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Sep 8, 2021
1 parent 843483c commit ee5a6bb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ inline Engine* CreateEngine() {
if (type == nullptr) type = "ThreadedEnginePerDevice";
std::string stype = type;

// The async tag is used later to determine if we use the GPU dependecy engine
std::string async_engine_tag = "Async";
auto tag_pos = stype.find(async_engine_tag);
if (tag_pos != std::string::npos
&& tag_pos + async_engine_tag.length() == stype.length()) {
stype = stype.substr(0, tag_pos);
}

Engine *ret = nullptr;
#if MXNET_PREDICT_ONLY == 0
if (stype == "NaiveEngine") {
Expand Down
13 changes: 10 additions & 3 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,16 @@ static inline void AddEventHelper(
}
}

static inline bool IsEngineAsync() {
std::string type = dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string(""));
std::string async_engine_tag("Async");
auto tag_pos = type.find(async_engine_tag);
return tag_pos != std::string::npos;
}

void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block,
const dmlc::Error* error) {
static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false);
static bool use_new_dep_engine = IsEngineAsync();
if (!use_new_dep_engine) {
return;
}
Expand Down Expand Up @@ -619,7 +626,7 @@ void ThreadedEngine::OnStartCPU(Engine *engine, void *opr_block,

void ThreadedEngine::OnStartGPU(Engine *engine, void *sync_info,
const dmlc::Error* error) {
static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false);
static bool use_new_dep_engine = IsEngineAsync();
if (!use_new_dep_engine) {
return;
}
Expand Down Expand Up @@ -697,7 +704,7 @@ void ThreadedEngine::OnCompleteGPU(Engine *engine, void *sync_info,
CHECK(info->stream != nullptr);

auto *worker_stream = reinterpret_cast<mshadow::Stream<gpu> *>(info->stream);
static bool use_new_dep_engine = dmlc::GetEnv("MXNET_ASYNC_GPU_ENGINE", false);
static bool use_new_dep_engine = IsEngineAsync();

if (!use_new_dep_engine) {
worker_stream->Wait();
Expand Down
3 changes: 1 addition & 2 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1189,10 +1189,9 @@ void SetValueOp(const real_t& rhs, NDArray* out) {
} else {
ndarray::Eval(ctx.get_stream<gpu>(), rhs, ret);
}
// Wait GPU kernel to complete
ctx.get_stream<gpu>()->Wait();
break;
}
#endif
default:
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
Expand Down

0 comments on commit ee5a6bb

Please sign in to comment.