Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #80 from tqchen/master
Browse files Browse the repository at this point in the history
fix threaed engine per device
  • Loading branch information
tqchen committed Sep 15, 2015
2 parents dcacd67 + 7c6eb9a commit f4207b5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 0 additions & 6 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,6 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
API_END();
}

int MXEngineWaitAll() {
API_BEGIN();
Engine::Get()->WaitForAll();
API_END();
}

// NOTE: return value is added in API_END
int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
Expand Down
8 changes: 6 additions & 2 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ class ThreadedEnginePerDevice : public ThreadedEngine {

protected:
void PushToExecute(OprBlock *opr_block, bool pusher_thread) override {
const Context& ctx = opr_block->ctx;
if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) {
CHECK_EQ(opr_block->ctx.dev_mask, cpu::kDevMask);
if (ctx.dev_mask == gpu::kDevMask) {
#if MXNET_USE_CUDA
mshadow::SetDevice<gpu>(ctx.dev_id);
#endif
}
RunContext run_ctx;
run_ctx.stream = nullptr;
this->ExecuteOprBlock(run_ctx, opr_block);
} else {
const Context& ctx = opr_block->ctx;
if (ctx.dev_mask == cpu::kDevMask) {
cpu_worker_.task_queue.Push(opr_block);
} else {
Expand Down

0 comments on commit f4207b5

Please sign in to comment.