Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.4.x] Support populating errors back to MXNet engine in callback #13932

Merged
merged 2 commits into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ class CallbackOnComplete {
public:
// use implicit copy and assign
/*! \brief involve the callback */
inline void operator()() const {
(*callback_)(engine_, param_);
inline void operator()(const dmlc::Error* error = nullptr) const {
(*callback_)(engine_, param_, error);
}

private:
/*! \brief engine can see content of callback */
friend class ::mxnet::Engine;
/*! \brief the real callback */
void (*callback_)(Engine *, void *);
void (*callback_)(Engine *, void *, const dmlc::Error *);
/*! \brief the engine class passed to callback */
Engine* engine_;
/*! \brief the parameter set on callback */
Expand Down Expand Up @@ -275,7 +275,7 @@ class MXNET_API Engine {
* \param param the paramter passed to callback.
*/
inline CallbackOnComplete CreateCallback(
void (*callback)(Engine *, void *), void *param) {
void (*callback)(Engine *, void *, const dmlc::Error *), void *param) {
CallbackOnComplete ret;
ret.callback_ = callback;
ret.engine_ = this;
Expand Down
3 changes: 2 additions & 1 deletion src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ class NaiveEngine final : public Engine {

private:
// callback to oncomplete
static void OnComplete(Engine *engine, void *param) {
static void OnComplete(Engine *engine, void *param,
const dmlc::Error* error) {
static_cast<NaiveEngine*>(engine)->req_completed_ = true;
}
// whether action is completed
Expand Down
8 changes: 6 additions & 2 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,14 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
return;
}

void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *opr_block_) {
void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
const dmlc::Error* error) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
ThreadedOpr *threaded_opr = opr_block->opr;
if (error != nullptr) {
auto ex_p = std::make_exception_ptr(*error);
threaded_opr->opr_exception = std::make_shared<std::exception_ptr>(ex_p);
}
if (opr_block->profiling && threaded_opr->opr_name) {
// record operator end timestamp
opr_block->opr_profile->stop();
Expand Down
3 changes: 2 additions & 1 deletion src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ class ThreadedEngine : public Engine {
}
}

static void OnCompleteStatic(Engine *engine, void *threaded_opr);
static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down