Skip to content
This repository has been archived by the owner on Feb 9, 2021. It is now read-only.

Commit

Permalink
Support populating errors back to MXNet engine in callback (apache#13922
Browse files Browse the repository at this point in the history
)

* add an optional error_msg in engine on_complete callbcak

* use dmlc::Error struct to make error population extendable
  • Loading branch information
yuxihu authored and Gordon Reid committed Jan 27, 2019
1 parent e6e62f9 commit e2ff96d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
8 changes: 4 additions & 4 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ class CallbackOnComplete {
public:
// use implicit copy and assign
/*! \brief involve the callback */
inline void operator()() const {
(*callback_)(engine_, param_);
inline void operator()(const dmlc::Error* error = nullptr) const {
(*callback_)(engine_, param_, error);
}

private:
/*! \brief engine can see content of callback */
friend class ::mxnet::Engine;
/*! \brief the real callback */
void (*callback_)(Engine *, void *);
void (*callback_)(Engine *, void *, const dmlc::Error *);
/*! \brief the engine class passed to callback */
Engine* engine_;
/*! \brief the parameter set on callback */
Expand Down Expand Up @@ -275,7 +275,7 @@ class MXNET_API Engine {
* \param param the paramter passed to callback.
*/
inline CallbackOnComplete CreateCallback(
void (*callback)(Engine *, void *), void *param) {
void (*callback)(Engine *, void *, const dmlc::Error *), void *param) {
CallbackOnComplete ret;
ret.callback_ = callback;
ret.engine_ = this;
Expand Down
3 changes: 2 additions & 1 deletion src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ class NaiveEngine final : public Engine {

private:
// callback to oncomplete
static void OnComplete(Engine *engine, void *param) {
static void OnComplete(Engine *engine, void *param,
const dmlc::Error* error) {
static_cast<NaiveEngine*>(engine)->req_completed_ = true;
}
// whether action is completed
Expand Down
8 changes: 6 additions & 2 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,14 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
return;
}

void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *opr_block_) {
void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
const dmlc::Error* error) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
ThreadedOpr *threaded_opr = opr_block->opr;
if (error != nullptr) {
auto ex_p = std::make_exception_ptr(*error);
threaded_opr->opr_exception = std::make_shared<std::exception_ptr>(ex_p);
}
if (opr_block->profiling && threaded_opr->opr_name) {
// record operator end timestamp
opr_block->opr_profile->stop();
Expand Down
3 changes: 2 additions & 1 deletion src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ class ThreadedEngine : public Engine {
}
}

static void OnCompleteStatic(Engine *engine, void *threaded_opr);
static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down

0 comments on commit e2ff96d

Please sign in to comment.