Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Better Exception Handling for Operators #9681

Merged
merged 22 commits into from
Feb 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ class MXNET_API Engine {
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation
* \return The new operator allocated.
*/
virtual OprHandle NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Delete the given operator.
* \param op The operator to delete.
Expand Down Expand Up @@ -176,13 +178,15 @@ class MXNET_API Engine {
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation
*/
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Schedule the deletion of a variable.
*
Expand Down
6 changes: 4 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr,
bool wait = false) override {
NaiveOpr *opr = new NaiveOpr();
opr->fn = fn;
opr->const_vars = const_vars;
Expand Down Expand Up @@ -125,7 +126,8 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr,
bool wait = false) override {
CallbackOnComplete callback = CreateCallback(
NaiveEngine::OnComplete, nullptr);
this->req_completed_ = false;
Expand Down
39 changes: 29 additions & 10 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ ThreadedOpr* ThreadedEngine::NewOperator(
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
const char* opr_name) {
const char* opr_name,
bool wait) {
auto ret = ThreadedOpr::New();
ret->opr_name = opr_name;
ret->fn = std::move(fn);
ret->prop = prop;
ret->const_vars.resize(const_vars.size());
ret->mutable_vars.resize(mutable_vars.size());
ret->wait = wait;
std::transform(const_vars.begin(), const_vars.end(),
ret->const_vars.begin(), ThreadedVar::CastFromBase);
std::transform(mutable_vars.begin(), mutable_vars.end(),
Expand Down Expand Up @@ -305,9 +307,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
int priority,
const char* opr_name) {
const char* opr_name,
bool wait) {
BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name);
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
#if MXNET_USE_PROFILER
Profiler *profiler = Profiler::Get();
Expand Down Expand Up @@ -356,7 +359,10 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
void ThreadedEngine::WaitForVar(VarHandle var) {
BulkFlush();
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) return;
if (threaded_var->ready_to_read()) {
ThrowException(threaded_var);
return;
}
if (engine_info_) {
LOG(INFO) << "Wait for " << threaded_var;
debug_wait_var_ = threaded_var;
Expand All @@ -376,13 +382,15 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
}
on_complete();
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("WaitForVar"));
PROFILER_MESSAGE("WaitForVar"), true);
{
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this, &done]() {
return done.load() || kill_.load();
});
}

ThrowException(threaded_var);
}

void ThreadedEngine::WaitForAll() {
Expand All @@ -397,18 +405,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
bool is_temporary_opr = threaded_opr->temporary;
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
i->CompleteReadDependency([this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
i->CompleteReadDependency(
[this](OprBlock* opr) { this->PushToExecute(opr, false); });
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
LOG(INFO) << "Complete write dep for " << i;
}
const bool to_delete = i->CompleteWriteDependency(
[this, debug_info](OprBlock* opr) {
const bool to_delete =
i->CompleteWriteDependency([this, debug_info](OprBlock* opr) {
if (debug_info) {
LOG(INFO) << "PushToExecute " << opr;
debug_push_opr_ = opr;
Expand Down Expand Up @@ -443,6 +453,15 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
}

inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
if (threaded_var->var_exception && *threaded_var->var_exception) {
std::exception_ptr tmp = *threaded_var->var_exception;
*threaded_var->var_exception = nullptr;
std::rethrow_exception(tmp);
}
return;
}

void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *opr_block_) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
Expand Down
94 changes: 72 additions & 22 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::shared_ptr<std::exception_ptr> var_exception;

private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr,
* that can be deleted right after the operation completed.
*/
bool temporary{false};
/*!
* \brief Whether this is a WaitForVar operation
*/
bool wait{false};
/*!
* \brief Cast a Opr pointer to ThreadedOpr pointer
* \param ptr pointer from base.
Expand All @@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::shared_ptr<std::exception_ptr> opr_exception;
}; // struct ThreadedOpr

/*!
Expand All @@ -265,15 +273,17 @@ class ThreadedEngine : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr,
bool wait = false) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr,
bool wait = false) override;
void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
Expand Down Expand Up @@ -321,50 +331,63 @@ class ThreadedEngine : public Engine {
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
#if MXNET_USE_PROFILER
if (opr_block->profiling && threaded_opr->opr_name) {
const Context& ctx = opr_block->ctx;
opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
opr_block->opr_stat =
Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id());
opr_block->opr_stat->thread_id = id;
strncpy(opr_block->opr_stat->opr_name,
threaded_opr->opr_name,
sizeof(opr_block->opr_stat->opr_name) - 1);
strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name,
sizeof(opr_block->opr_stat->opr_name) - 1);
// record operator start timestamp
SetOprStart(opr_block->opr_stat);
}
#endif
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, opr_block);
bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
CallbackOnComplete callback =
this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
if (debug_info) {
LOG(INFO) << "ExecuteOprBlock " << opr_block
<< "shutdown_phase=" << shutdown_phase_;
}
if (!shutdown_phase_) {
try {
OnStart(threaded_opr);
if (debug_info) {
LOG(INFO) << "ExecuteOprFn ";
}
threaded_opr->fn(run_ctx, callback);
try {
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
}
} catch (dmlc::Error& e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only catch dmlc error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I catch only dmlc::Error is that The guards in the c_api API_BEGIN and API_END/API_END_HANDLE_ERROR only catch dmlc::Error currently and propagate to frontend.

Copy link
Member Author

@anirudh2290 anirudh2290 Feb 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have used dmlc::Error in the inner block and used std::exception to catch other stdlib exceptions thrown in the outer block. Currently, it will catch only dmlc::Error and for other exceptions(std::exception) the process will be terminated. I will open another PR to handle std::exception and change the c_api guards and frontend code.

threaded_opr->opr_exception =
std::make_shared<std::exception_ptr>(std::current_exception());
callback();
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch(dmlc::Error &e) {
} catch (std::exception& e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos &&
!shutdown_phase_) {
LOG(FATAL) << e.what() << "\n" <<
"A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
LOG(FATAL)
<< e.what() << "\n"
<< "A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
}
}
} else {
Expand Down Expand Up @@ -414,7 +437,34 @@ class ThreadedEngine : public Engine {
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
// callback to the threaded engine
/*!
* \brief rethrow caught exception in WaitForVar
* \param threaded_var the var that we are waiting to read
*/
inline void ThrowException(ThreadedVar* threaded_var);
/*!
* \brief Mark exceptions before operation execution.
*
* Will mark the operator as a failure and associate exception_ptr
* if any of the read dependencies have exception associated.
*/
inline void OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
}
}

static void OnCompleteStatic(Engine *engine, void *threaded_opr);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
Expand Down
4 changes: 2 additions & 2 deletions src/storage/cpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) throw std::bad_alloc();
if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
#else
int ret = posix_memalign(&ptr, alignment_, size);
if (ret != 0) throw std::bad_alloc();
if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
#endif
return ptr;
}
Expand Down
2 changes: 1 addition & 1 deletion src/storage/gpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
throw std::bad_alloc();
LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from test_random import *
from test_gluon import *
from test_loss import *
from test_exc_handling import *
#from test_rnn import *
from test_gluon_rnn import *
from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
Expand Down
Loading