diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index e02b995d6857..408a70a5feed 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -74,15 +74,15 @@ class CallbackOnComplete { public: // use implicit copy and assign /*! \brief involve the callback */ - inline void operator()() const { - (*callback_)(engine_, param_); + inline void operator()(const dmlc::Error* error = nullptr) const { + (*callback_)(engine_, param_, error); } private: /*! \brief engine can see content of callback */ friend class ::mxnet::Engine; /*! \brief the real callback */ - void (*callback_)(Engine *, void *); + void (*callback_)(Engine *, void *, const dmlc::Error *); /*! \brief the engine class passed to callback */ Engine* engine_; /*! \brief the parameter set on callback */ @@ -275,7 +275,7 @@ class MXNET_API Engine { * \param param the paramter passed to callback. */ inline CallbackOnComplete CreateCallback( - void (*callback)(Engine *, void *), void *param) { + void (*callback)(Engine *, void *, const dmlc::Error *), void *param) { CallbackOnComplete ret; ret.callback_ = callback; ret.engine_ = this; diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index daff53066949..05b72d2a6fde 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -208,7 +208,8 @@ class NaiveEngine final : public Engine { private: // callback to oncomplete - static void OnComplete(Engine *engine, void *param) { + static void OnComplete(Engine *engine, void *param, + const dmlc::Error* error) { static_cast(engine)->req_completed_ = true; } // whether action is completed diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 3a7587fef13f..6a6004011db1 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -478,10 +478,14 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) { return; } -void ThreadedEngine::OnCompleteStatic( - Engine *engine, void *opr_block_) { +void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_, + const dmlc::Error* error) { OprBlock *opr_block = static_cast(opr_block_); ThreadedOpr *threaded_opr = opr_block->opr; + if (error != nullptr) { + auto ex_p = std::make_exception_ptr(*error); + threaded_opr->opr_exception = std::make_shared(ex_p); + } if (opr_block->profiling && threaded_opr->opr_name) { // record operator end timestamp opr_block->opr_profile->stop(); diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index ccfd09d64ea8..fae120db915b 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -465,7 +465,8 @@ class ThreadedEngine : public Engine { } } - static void OnCompleteStatic(Engine *engine, void *threaded_opr); + static void OnCompleteStatic(Engine *engine, void *threaded_opr, + const dmlc::Error* error); /*! \brief append an operator to bulk */ inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, std::vector const& const_vars,