Skip to content

Commit

Permalink
[new-exec] add completion_nofifier (#38447)
Browse files Browse the repository at this point in the history
* add completion_nofifier

* fix bug

* unregist event waiter
  • Loading branch information
zhiqiu authored Dec 28, 2021
1 parent 1db61c3 commit 404a4a6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
22 changes: 14 additions & 8 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);

constexpr const char* kExceptionCaught = "ExceptionCaught";
constexpr const char* kTaskCompletion = "TaskCompletion";

namespace paddle {
namespace framework {
Expand All @@ -49,6 +50,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
gc_.reset(new InterpreterCoreGarbageCollector());

exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);

create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) {
Expand All @@ -69,6 +71,9 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);

exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();

async_work_queue_.reset(nullptr);
}

Expand Down Expand Up @@ -417,7 +422,7 @@ void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) {
async_work_queue_->PrepareAtomicDeps(dependecy_count_);
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
op_run_number_ = 0;
unfinished_op_numer_ = vec_instr.size();

exception_holder_.Clear();

Expand All @@ -436,12 +441,6 @@ void InterpreterCore::ExecuteInstructionList(
async_work_queue_->Cancel();
exception_holder_.ReThrow();
}

PADDLE_ENFORCE_EQ(
op_run_number_.load(), vec_instr.size(),
platform::errors::Fatal(
"Required op_run_number == %d, but received op_run_number = %d.",
vec_instr.size(), op_run_number_.load()));
}

void InterpreterCore::RunNextInstructions(
Expand Down Expand Up @@ -539,8 +538,15 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
return;
}

VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_;
if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) ==
1)) {
if (completion_notifier_ != nullptr) {
completion_notifier_->NotifyEvent();
}
}

interpreter::RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed);

RunNextInstructions(instr_node, &ready_ops);
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode

std::vector<size_t> dependecy_count_;
std::atomic<size_t> op_run_number_{0};
std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_;

StreamAnalyzer stream_analyzer_;
EventsWaiter main_thread_blocker_;
std::unique_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};

std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ class AsyncWorkQueue {
// for execute host Kernel
group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true,
/*track_task*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options);
Expand Down

0 comments on commit 404a4a6

Please sign in to comment.