Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Remove redundant shared_ptr for exception_ptr #13933

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
if (threaded_opr->opr_exception != nullptr) {
i->var_exception = threaded_opr->opr_exception;
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
Expand Down Expand Up @@ -470,21 +470,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}

inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
if (threaded_var->var_exception && *threaded_var->var_exception) {
std::exception_ptr tmp = *threaded_var->var_exception;
*threaded_var->var_exception = nullptr;
if (threaded_var->var_exception != nullptr) {
auto tmp = threaded_var->var_exception;
threaded_var->var_exception = nullptr;
std::rethrow_exception(tmp);
}
return;
}

void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
const dmlc::Error* error) {
const dmlc::Error *error) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
ThreadedOpr *threaded_opr = opr_block->opr;
if (error != nullptr) {
auto ex_p = std::make_exception_ptr(*error);
threaded_opr->opr_exception = std::make_shared<std::exception_ptr>(ex_p);
threaded_opr->opr_exception = std::make_exception_ptr(*error);
}
if (opr_block->profiling && threaded_opr->opr_name) {
// record operator end timestamp
Expand Down
20 changes: 9 additions & 11 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class ThreadedVar final
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::shared_ptr<std::exception_ptr> var_exception;
std::exception_ptr var_exception;

private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -254,7 +254,7 @@ struct ThreadedOpr final : public Opr,
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::shared_ptr<std::exception_ptr> opr_exception;
std::exception_ptr opr_exception;
}; // struct ThreadedOpr

/*!
Expand Down Expand Up @@ -359,16 +359,14 @@ class ThreadedEngine : public Engine {
LOG(INFO) << "ExecuteOprFn ";
}
try {
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
if (threaded_opr->opr_exception == nullptr ||
threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
}
} catch (dmlc::Error& e) {
threaded_opr->opr_exception =
std::make_shared<std::exception_ptr>(std::current_exception());
callback();
} catch (const dmlc::Error& e) {
callback(&e);
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
Expand Down Expand Up @@ -450,14 +448,14 @@ class ThreadedEngine : public Engine {
*/
inline void OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
if (i->var_exception != nullptr) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
if (threaded_opr->opr_exception == nullptr) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
if (i->var_exception != nullptr) {
threaded_opr->opr_exception = i->var_exception;
break;
}
Expand All @@ -466,7 +464,7 @@ class ThreadedEngine : public Engine {
}

static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
const dmlc::Error *error);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down