Skip to content

Commit

Permalink
[IR] Support GC and TraceRun for NewIr InterpreterCore (#55772)
Browse files Browse the repository at this point in the history
* add interface

* add code

* add code

* add code

* add code

* fix bug

* fix bug

* add var prefix

* add code

* add code

* add code

* fix compile bug

* fix bug

* refine code

* refine code

* refine code

* refine code

* fix bug

* add code

* add code

* fix bug

* add code

* add code

* refine code

* refine code

* fix bug
  • Loading branch information
zhangbo9674 authored Jul 31, 2023
1 parent 6f53d3b commit dc96ebc
Show file tree
Hide file tree
Showing 12 changed files with 621 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ InterpreterCoreEventGarbageCollector::InterpreterCoreEventGarbageCollector(
}
}

InterpreterCoreEventGarbageCollector::InterpreterCoreEventGarbageCollector(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction) {
WorkQueueOptions options(/*name*/ "GarbageCollector",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ false);
queue_ = CreateSingleThreadedWorkQueue(options);
for (auto& instruc : vec_instruction) {
gc_event_.emplace_back(instruc->DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
}
}

InterpreterCoreEventGarbageCollector::~InterpreterCoreEventGarbageCollector() {
queue_.reset(nullptr);
}
Expand All @@ -53,6 +66,18 @@ void InterpreterCoreEventGarbageCollector::Add(Variable* var,
Add(var, &gc_event_.at(instr.Id()), &instr.DeviceContext());
}

void InterpreterCoreEventGarbageCollector::Add(Variable* var,
const InstructionBase* instr) {
PADDLE_ENFORCE_LT(instr->Id(),
gc_event_.size(),
platform::errors::OutOfRange(
"The index should be less than the size of gc event "
", but got index is %d and size is %d",
instr->Id(),
gc_event_.size()));
Add(var, &gc_event_.at(instr->Id()), &instr->DeviceContext());
}

void InterpreterCoreEventGarbageCollector::Add(
Variable* var,
platform::DeviceEvent* event,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ class InterpreterCoreEventGarbageCollector
public:
InterpreterCoreEventGarbageCollector(
const std::vector<Instruction>& vec_instruction);

InterpreterCoreEventGarbageCollector(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction);

~InterpreterCoreEventGarbageCollector();

void Add(Variable* var, const Instruction& instruction) override;

void Add(Variable* var, const InstructionBase* instruction) override;

private:
void Add(Variable* var,
platform::DeviceEvent* event,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ void InterpreterCoreFastGarbageCollector::Add(Variable* var,
Add(var);
}

void InterpreterCoreFastGarbageCollector::Add(Variable* var,
const InstructionBase*) {
Add(var);
}

void InterpreterCoreFastGarbageCollector::Add(Variable* var) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class InterpreterCoreFastGarbageCollector
public:
void Add(Variable* var, const Instruction& instr) override;

void Add(Variable* var, const InstructionBase* instr) override;

private:
void Add(Variable* var);
void Add(Garbage garbage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,36 @@ InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() {
cur_memory_size_ = 0;
}

std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction) {
if (platform::is_gpu_place(place)) {
if (IsInterpretercoreFastGCEnabled()) {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreFastGarbageCollector());
} else {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreEventGarbageCollector(vec_instruction));
}
} else if (platform::is_xpu_place(place)) {
// Because there is no multi-stream on XPU device, fast GC can
// be used.
// Previously, XPU used no_event GC. But `Wait` in no_event GC
// may cause GC delayed, causing no enough memory problem.
// TODO(pangyoki): Multi-stream allocator and multi-stream GC
// are needed to be adapted for XPU.
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreFastGarbageCollector());
} else if (platform::is_ipu_place(place)) {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreNoEventGarbageCollector());
} else {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreEventGarbageCollector(vec_instruction));
}
}

std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <queue>

#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h"
Expand All @@ -34,6 +35,8 @@ class InterpreterCoreGarbageCollector {

virtual void Add(Variable* var, const Instruction& instruction) = 0;

virtual void Add(Variable* var, const InstructionBase* instruction) = 0;

DISABLE_COPY_AND_ASSIGN(InterpreterCoreGarbageCollector);

protected:
Expand All @@ -50,5 +53,10 @@ CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<Instruction>& vec_instruction);

std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<std::unique_ptr<InstructionBase>>& vec_instruction);

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ void InterpreterCoreNoEventGarbageCollector::Add(Variable* var,
Add(var, &instr.DeviceContext());
}

void InterpreterCoreNoEventGarbageCollector::Add(Variable* var,
const InstructionBase* instr) {
Add(var, &instr->DeviceContext());
}

void InterpreterCoreNoEventGarbageCollector::Add(
Variable* var, const platform::DeviceContext* ctx) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class InterpreterCoreNoEventGarbageCollector
~InterpreterCoreNoEventGarbageCollector();
void Add(Variable* var, const Instruction& instr) override;

void Add(Variable* var, const InstructionBase* instr) override;

private:
void Add(Variable* var, const platform::DeviceContext* ctx);
void Add(Garbage garbage, const platform::DeviceContext* ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ PhiKernelInstruction::PhiKernelInstruction(
kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
VLOG(6) << "finish process kernel context";

SetDeviceContext(
ParseDeviceContext(op,
phi::DeviceContextPool::Instance().Get(
Expand Down
Loading

0 comments on commit dc96ebc

Please sign in to comment.