From 4cd5384698110d158d70121245fd8a1d6bfc2527 Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Mon, 6 May 2024 16:22:56 +0800 Subject: [PATCH] [Dy2St][PIR] Hold backward program in GradNode (#63694) Co-authored-by: xiongkun Co-authored-by: Nyakku Shigure --- .../eager/to_static/run_program_op_func.h | 11 +-- .../eager/to_static/run_program_op_node.h | 70 +++++++++---------- paddle/fluid/framework/op_desc.cc | 4 ++ paddle/fluid/framework/type_defs.cc | 3 +- paddle/fluid/framework/type_defs.h | 4 +- paddle/fluid/pybind/op_function_common.cc | 19 +++-- paddle/fluid/pybind/pir.cc | 2 +- .../jit/dy2static/pir_partial_program.py | 8 +-- test/cpp/prim/CMakeLists.txt | 2 +- test/dygraph_to_static/test_no_gradient.py | 3 +- 10 files changed, 68 insertions(+), 58 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index b6bdb28380736..c6c24ae47a7d2 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -296,16 +296,7 @@ inline void pir_run_program_ad_func( grad_node->SetStepScope(step_scope); // just for set useable. - // Set Grad out rank as same as fwd input and set stop gradient to bwd - // NOTE(@xiongkun): Not every tensor in x(list of tensor) is required - // gradient. for example: x[1] is not used for output, the x[1] is ignored. - - std::vector x_require_grad; - for (size_t i = 0; i < x.size(); ++i) { - x_require_grad.push_back(&x[i]); - } - - grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0); + grad_node->SetGradOutMeta(x, /*slot id*/ 0); grad_node->SetGradOutMeta(params, /*slot id*/ 1); // TODO(@xiongkun): rewrite by new ir representation. diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 75d812bf66e5e..853a0c445797c 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -467,21 +467,16 @@ inline void PirRunProgramAPI( auto param_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp")); - auto *forward_global_block = - PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block")); - auto *backward_global_block = - PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block")); - - auto *forward_program = - forward_global_block->GetParentOp()->GetParentProgram(); + std::shared_ptr<::pir::Program> forward_program = PADDLE_GET_CONST( + std::shared_ptr<::pir::Program>, attrs.at("forward_program")); + std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST( + std::shared_ptr<::pir::Program>, attrs.at("backward_program")); if (FLAGS_print_ir) { std::ostringstream print_stream; print_stream << "ForwardProgram is :\n"; forward_program->Print(print_stream); if (!is_test) { - auto *backward_program = - backward_global_block->GetParentOp()->GetParentProgram(); print_stream << "BackwardProgram is:\n"; backward_program->Print(print_stream); } else { @@ -509,12 +504,12 @@ inline void PirRunProgramAPI( << program_id; // Step 1. share input_vars & parameters into scope details::ShareTensorsIntoScopeByValue( - forward_global_block, x, input_values, global_inner_scope); + forward_program->block(), x, input_values, global_inner_scope); details::ShareTensorsIntoScopeByValue( - forward_global_block, params, param_values, global_inner_scope); + forward_program->block(), params, param_values, global_inner_scope); // Step 2. create new interpretercore auto passed_kernel_program = - paddle::framework::ApplyIrPass(forward_program, place); + paddle::framework::ApplyIrPass(forward_program.get(), place); if (FLAGS_print_ir) { std::ostringstream print_stream; print_stream << "LoweredProgram( AfterPass ) is :\n"; @@ -535,22 +530,22 @@ inline void PirRunProgramAPI( // update interpretercore skip_gc_var auto skip_names = details::GetNameFromValue( - forward_global_block, middle_values, false, true); + forward_program->block(), middle_values, false, true); auto skip_names_set = std::set(skip_names.begin(), skip_names.end()); auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("no_need_buffers")); auto no_need_buffer_names = details::GetNameFromValue( - forward_global_block, no_need_buffer_values, false, true); + forward_program->block(), no_need_buffer_values, false, true); for (auto &name : no_need_buffer_names) { VLOG(4) << "Find no need buffer vars with name:" << name; skip_names_set.erase(name); } skip_names = details::GetNameFromValue( - forward_global_block, output_values, false, true); + forward_program->block(), output_values, false, true); skip_names_set.insert(skip_names.begin(), skip_names.end()); skip_names = details::GetNameFromValue( - forward_global_block, input_values, true, false); + forward_program->block(), input_values, true, false); skip_names_set.insert(skip_names.begin(), skip_names.end()); details::print_collection(skip_names_set); interpreter_core->SetSkipGcVars(skip_names_set); @@ -576,9 +571,9 @@ inline void PirRunProgramAPI( interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScopeByValue( - forward_global_block, x, input_values, global_inner_scope); + forward_program->block(), x, input_values, global_inner_scope); details::ShareTensorsIntoScopeByValue( - forward_global_block, params, param_values, global_inner_scope); + forward_program->block(), params, param_values, global_inner_scope); // TODO(xiongkun): new ir how to build scope. // if (interpreter_core->GetVariableScope()->GetMutableScope() != // global_inner_scope) { @@ -589,7 +584,7 @@ inline void PirRunProgramAPI( } // interpretercore run - if (!forward_global_block->empty()) { + if (!forward_program->block()->empty()) { paddle::platform::RecordEvent record_event( "interpreter_core_run", paddle::platform::TracerEventType::UserDefined, @@ -602,7 +597,7 @@ inline void PirRunProgramAPI( "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); // Get Output, and Middle Outputs details::ShareTensorsFromScopeByValue( - forward_global_block, out, output_values, global_inner_scope); + forward_program->block(), out, output_values, global_inner_scope); VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); @@ -1041,10 +1036,8 @@ inline void PirRunProgramGradAPI( VLOG(4) << "global_inner_scope:" << global_inner_scope; - auto *backward_global_block = - PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block")); - auto *backward_program = - backward_global_block->GetParentOp()->GetParentProgram(); + std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST( + std::shared_ptr<::pir::Program>, attrs.at("backward_program")); auto output_grad_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g")); @@ -1064,8 +1057,10 @@ inline void PirRunProgramGradAPI( details::Trans2ContiguousTensorsInplace(out_grad); // share x, param, middles, output_grads, out into scope. - details::ShareTensorsIntoScopeByValue( - backward_global_block, out_grad, output_grad_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue(backward_program->block(), + out_grad, + output_grad_values, + global_inner_scope); auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = @@ -1082,7 +1077,7 @@ inline void PirRunProgramGradAPI( VLOG(2) << "No interpretercore cache, so create a new interpretercore"; // Step 1. share input_vars & parameters into scope auto passed_kernel_program = - paddle::framework::ApplyIrPass(backward_program, place); + paddle::framework::ApplyIrPass(backward_program.get(), place); const auto &new_block = passed_kernel_program->block(); passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass( @@ -1124,10 +1119,10 @@ inline void PirRunProgramGradAPI( // get all eager gc vars std::set skip_eager_delete_vars; auto skip_names = details::GetNameFromValue( - backward_global_block, x_grad_values, false, true); + backward_program->block(), x_grad_values, false, true); skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); skip_names = details::GetNameFromValue( - backward_global_block, p_grad_values, false, true); + backward_program->block(), p_grad_values, false, true); skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); interpreter_core->SetSkipGcVars(skip_eager_delete_vars); cache.UpdateSkipEagerDeleteVars(program_id, @@ -1160,7 +1155,7 @@ inline void PirRunProgramGradAPI( } } - if (!backward_global_block->empty()) { + if (!backward_program->block()->empty()) { paddle::platform::RecordEvent record_event( "interpreter_core_run", paddle::platform::TracerEventType::UserDefined, @@ -1175,9 +1170,11 @@ inline void PirRunProgramGradAPI( "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); // Step 4. get outputs details::ShareTensorsFromScopeByValue( - backward_global_block, x_grad, x_grad_values, global_inner_scope); - details::ShareTensorsFromScopeByValue( - backward_global_block, params_grad, p_grad_values, global_inner_scope); + backward_program->block(), x_grad, x_grad_values, global_inner_scope); + details::ShareTensorsFromScopeByValue(backward_program->block(), + params_grad, + p_grad_values, + global_inner_scope); VLOG(4) << "after backward gc all vars"; global_inner_scope->SetCanReused(true); details::GcScope(global_inner_scope); @@ -1316,8 +1313,7 @@ class GradNodeRunProgram : public egr::GradNodeBase { if (x[i].is_dense_tensor()) { x_grad->emplace_back(std::make_shared()); } else if (x[i].is_selected_rows()) { - auto selected_row = std::make_shared(); - x_grad->emplace_back(selected_row); + x_grad->emplace_back(std::make_shared()); } x_grad->back().set_name(x_grad_names[i]); } @@ -1446,6 +1442,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase { VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram"; *executed_ = true; + egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad, + this->OutputMeta()[0]); + egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(¶ms_grad, + this->OutputMeta()[1]); return {x_grad, params_grad}; } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 32c520711d978..c131c93909705 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/phi/common/complex.h" #include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" #include "paddle/utils/blank.h" @@ -977,6 +978,9 @@ struct SetAttrDescVisitor { void operator()(const std::vector &v) const { // just do nothing. } + void operator()(const std::shared_ptr &v) const { + // just do nothing. + } void operator()(const std::vector &v) const { std::vector var_names; for (auto var : v) { diff --git a/paddle/fluid/framework/type_defs.cc b/paddle/fluid/framework/type_defs.cc index d8a6546ea718d..6d350f1fe1e6c 100644 --- a/paddle/fluid/framework/type_defs.cc +++ b/paddle/fluid/framework/type_defs.cc @@ -39,7 +39,8 @@ template class variant, ::pir::Block*, - std::vector<::pir::Value>>; + std::vector<::pir::Value>, + std::shared_ptr<::pir::Program>>; } // namespace paddle REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap); REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute); diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 61f133ceb082a..919da60601555 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" #include "paddle/utils/blank.h" #include "paddle/utils/small_vector.h" @@ -67,7 +68,8 @@ using Attribute = paddle::variant, ::pir::Block*, - std::vector<::pir::Value>>; + std::vector<::pir::Value>, + std::shared_ptr<::pir::Program>>; using AttributeMap = std::unordered_map; using OpCreator = diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index f8f1424ded243..cc484f74ab22f 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -858,6 +858,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj, attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]); } +void CastPyArg2AttrIRProgram(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + VLOG(1) << "After Process pir::Program*"; + const std::shared_ptr<::pir::Program> program = + ::py::handle(obj).cast>(); + attrs[key] = program; +} + void CastPyArg2AttrValues(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, @@ -1020,11 +1031,11 @@ void ConstructAttrMapForRunProgram( if (std::set({"cuda_graph_capture_mode"}).count(key)) { CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); - } else if (std::set({"global_block", - "forward_global_block", - "backward_global_block"}) - .count(key)) { + } else if (std::set({"global_block"}).count(key)) { CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"forward_program", "backward_program"}) + .count(key)) { + CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos); } else if (std::set({"is_test", "use_interpretorcore"}) .count(key)) { CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 85ce4abcda94d..86c4f6539c3a5 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -287,7 +287,7 @@ void BindProgram(py::module *m) { )DOC"); program .def(py::init([]() { - return std::make_unique(pir::IrContext::Instance()); + return std::make_shared(pir::IrContext::Instance()); })) .def("__str__", [](const std::shared_ptr &self) { diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 8a2eaee72fd01..6cce148f13f1c 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -914,10 +914,10 @@ def _prune_unused_params(self, program): def _prepare_attributes(self): attrs = [ - 'forward_global_block', - self.program.forward_program.global_block(), - 'backward_global_block', - self.program.backward_program.global_block(), + 'forward_program', + self.program.forward_program, + 'backward_program', + self.program.backward_program, 'is_test', not self.training, 'program_id', diff --git a/test/cpp/prim/CMakeLists.txt b/test/cpp/prim/CMakeLists.txt index 7f5b3af052588..51b5bb70a6e22 100644 --- a/test/cpp/prim/CMakeLists.txt +++ b/test/cpp/prim/CMakeLists.txt @@ -23,7 +23,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_library(init_env_utils SRCS init_env_utils.cc) target_compile_definitions(init_env_utils PUBLIC PADDLE_DLL_EXPORT) - paddle_test(test_comp_eager SRCS test_eager_prim.cc DEPS init_env_utils) + paddle_test(test_comp_eager SRCS test_eager_prim.cc init_env_utils.cc) endif() # skip win32 since wget is not installed by default on windows machine. diff --git a/test/dygraph_to_static/test_no_gradient.py b/test/dygraph_to_static/test_no_gradient.py index 1bd3a02f54ede..84f7b032c2f4a 100644 --- a/test/dygraph_to_static/test_no_gradient.py +++ b/test/dygraph_to_static/test_no_gradient.py @@ -15,7 +15,7 @@ import unittest import numpy -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir import paddle @@ -33,6 +33,7 @@ def main_func(x, index): class TestNoGradientCase(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_no_gradient(self): paddle.disable_static() x = paddle.randn([10, 3])