Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][PIR] Hold backward program in GradNode #63694

Merged
merged 25 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3223492
add log
gouzil Apr 18, 2024
c1398d1
add log
gouzil Apr 18, 2024
2a7f67d
add log
gouzil Apr 19, 2024
733ac03
add handle
gouzil Apr 19, 2024
5c90eac
remove breakpoint
gouzil Apr 19, 2024
64a56b0
use `std::shared_ptr<::pir::Program>>`
gouzil Apr 20, 2024
b48eaa0
Merge branch 'develop' of github.com:gouzil/Paddle into fix_backward_…
gouzil Apr 20, 2024
5534c70
Merge branch 'develop' of github.com:gouzil/Paddle into fix_backward_…
gouzil Apr 23, 2024
8aba5ef
copy #58180
gouzil Apr 24, 2024
23bbd58
clean log and open test
gouzil Apr 24, 2024
31be3b9
Rollback execution sequence
gouzil Apr 24, 2024
e005597
rm log
gouzil Apr 24, 2024
6277bc2
rm `class ProgramDesc`
gouzil Apr 24, 2024
4972090
rm `test[key]`
gouzil Apr 24, 2024
4bcbcc6
rm include
gouzil Apr 24, 2024
f5c6e0b
Merge branch 'develop' of github.com:gouzil/Paddle into fix_backward_…
gouzil Apr 25, 2024
414fbda
fix jit.load for pir
gouzil Apr 25, 2024
1e68e62
[ci][test] ignore test
gouzil Apr 25, 2024
cb6cdab
[ci][test] open `CompareOperantsTest`
gouzil Apr 25, 2024
a9946ee
[ci][test] ignore hook_utils
gouzil Apr 25, 2024
d83476b
[ci][test] ignore `dygraph_functions` and open `hook_utils`
gouzil Apr 25, 2024
0bf8a6a
Merge branch 'develop' of github.com:gouzil/Paddle into fix_backward_…
gouzil May 1, 2024
462cdf7
fix test_comp_eager
gouzil May 1, 2024
c7e9825
rollback `test_eager_prim.cc`
gouzil May 1, 2024
de20f12
fix
gouzil May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,7 @@ inline void pir_run_program_ad_func(

grad_node->SetStepScope(step_scope); // just for set useable.

// Set Grad out rank as same as fwd input and set stop gradient to bwd
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.

std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
x_require_grad.push_back(&x[i]);
}

grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0);
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);

// TODO(@xiongkun): rewrite by new ir representation.
Expand Down
64 changes: 35 additions & 29 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,21 +467,18 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto forward_program = PADDLE_GET_CONST(std::shared_ptr<::pir::Program>,
attrs.at("forward_program"));
auto backward_program = PADDLE_GET_CONST(std::shared_ptr<::pir::Program>,
attrs.at("backward_program"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
::pir::Block *forward_global_block = forward_program->block();

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
Expand Down Expand Up @@ -514,7 +511,7 @@ inline void PirRunProgramAPI(
forward_global_block, params, param_values, global_inner_scope);
// Step 2. create new interpretercore
auto passed_kernel_program =
paddle::framework::ApplyIrPass(forward_program, place);
paddle::framework::ApplyIrPass(forward_program.get(), place);
if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "LoweredProgram( AfterPass ) is :\n";
Expand Down Expand Up @@ -1046,10 +1043,8 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
auto backward_program = PADDLE_GET_CONST(std::shared_ptr<::pir::Program>,
attrs.at("backward_program"));

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
Expand All @@ -1069,26 +1064,32 @@ inline void PirRunProgramGradAPI(
details::Trans2ContiguousTensorsInplace(out_grad);

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
backward_program->block(), x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
middles,
forward_middle_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out,
forward_output_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, params, parameter_values, global_inner_scope);
backward_program->block(), params, parameter_values, global_inner_scope);

// Clear out and middles to avoid hold memory until backward finish.
out.clear();
middles.clear();
VLOG(1) << "out and middles clear end";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个忘清了?


auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;

if (!cache.Has(program_id,
global_inner_scope,
place_hash_key,
Expand All @@ -1101,7 +1102,7 @@ inline void PirRunProgramGradAPI(
VLOG(2) << "No interpretercore cache, so create a new interpretercore";
// Step 1. share input_vars & parameters into scope
auto passed_kernel_program =
paddle::framework::ApplyIrPass(backward_program, place);
paddle::framework::ApplyIrPass(backward_program.get(), place);

const auto &new_block = passed_kernel_program->block();
passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass(
Expand Down Expand Up @@ -1143,10 +1144,10 @@ inline void PirRunProgramGradAPI(
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names = details::GetNameFromValue(
backward_global_block, x_grad_values, false, true);
backward_program->block(), x_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
backward_global_block, p_grad_values, false, true);
backward_program->block(), p_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
cache.UpdateSkipEagerDeleteVars(program_id,
Expand Down Expand Up @@ -1179,7 +1180,7 @@ inline void PirRunProgramGradAPI(
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -1194,9 +1195,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
Expand Down Expand Up @@ -1335,8 +1338,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
if (x[i].is_dense_tensor()) {
x_grad->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (x[i].is_selected_rows()) {
auto selected_row = std::make_shared<phi::SelectedRows>();
x_grad->emplace_back(selected_row);
x_grad->emplace_back(std::make_shared<phi::SelectedRows>());
}
x_grad->back().set_name(x_grad_names[i]);
}
Expand Down Expand Up @@ -1471,6 +1473,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram";

*executed_ = true;
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad,
this->OutputMeta()[0]);
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&params_grad,
this->OutputMeta()[1]);
return {x_grad, params_grad};
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"

Expand Down Expand Up @@ -977,6 +978,9 @@ struct SetAttrDescVisitor {
void operator()(const std::vector<pir::Block *> &v) const {
// just do nothing.
}
void operator()(const std::shared_ptr<pir::Program> &v) const {
// just do nothing.
}
void operator()(const std::vector<VarDesc *> &v) const {
std::vector<std::string> var_names;
for (auto var : v) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/type_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ template class variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
} // namespace paddle
REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap);
REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute);
5 changes: 4 additions & 1 deletion paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"
#include "paddle/utils/small_vector.h"
Expand All @@ -40,6 +41,7 @@ class InferShapeContext;
class InferVarTypeContext;
class VarDesc;
class BlockDesc;
class ProgramDesc;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有必要嘛?下面只添加了 pir::Program,为啥这里要前向声明老 IR ProgramDesc

class Variable;
class InferNoNeedBufferVarsFN;

Expand Down Expand Up @@ -67,7 +69,8 @@ using Attribute = paddle::variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;

using OpCreator =
Expand Down
21 changes: 17 additions & 4 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/op_result.h"
#include "paddle/pir/include/core/region.h"
#include "paddle/pir/include/core/value.h"

namespace paddle {
Expand Down Expand Up @@ -858,6 +859,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj,
attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]);
}

void CastPyArg2AttrIRProgram(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
VLOG(1) << "After Process pir::Program*";
const std::shared_ptr<::pir::Program> program =
::py::handle(obj).cast<std::shared_ptr<::pir::Program>>();
attrs[key] = program;
}

void CastPyArg2AttrValues(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
Expand Down Expand Up @@ -998,6 +1010,7 @@ void ConstructAttrMapForRunProgram(
attr_end));

PyObject* obj = nullptr;
attrs["testkey"] = std::string("testvalue");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

忘清了?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
VLOG(1) << "Start Process " << arg_pos;
Py_ssize_t key_len = 0;
Expand All @@ -1020,11 +1033,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
.count(key)) {
} else if (std::set<std::string>({"global_block"}).count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void BindProgram(py::module *m) {
)DOC");
program
.def(py::init([]() {
return std::make_unique<Program>(pir::IrContext::Instance());
return std::make_shared<Program>(pir::IrContext::Instance());
}))
.def("__str__",
[](const std::shared_ptr<Program> &self) {
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir

import paddle

Expand All @@ -33,6 +33,7 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])
Expand Down