From cabb7c49320c0fe637cd4d22783fae1a4db1d102 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Tue, 25 Jan 2022 08:24:32 +0000 Subject: [PATCH 1/7] add a graph pass to share MemOptVarInfos of external variables into subgraph --- .../details/eager_deletion_op_handle.cc | 18 ++- .../ir/memory_optimize_pass/CMakeLists.txt | 10 +- .../eager_deletion_pass.cc | 7 + .../memory_optimization_var_info.h | 9 ++ .../share_mem_opt_info_to_subgraph.cc | 145 ++++++++++++++++++ .../framework/paddle2cinn/build_cinn_pass.cc | 10 ++ .../framework/paddle2cinn/build_cinn_pass.h | 2 + 7 files changed, 195 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph.cc diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index 59614e89c1344..9db43ce2ec87b 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -107,6 +107,21 @@ void EagerDeletionOpHandle::CallOnce() { std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } +static bool CanBeErased(ir::MemOptVarInfo *var_info) { + if (var_info->IsSkippedAllMemoryOptimization() || + !var_info->DecreaseRefCnt()) { + return false; + } + auto parent_info = var_info->ParentHolder(); + if (parent_info && (parent_info->IsSkippedAllMemoryOptimization() || + !parent_info->DecreaseRefCnt())) { + VLOG(4) << "Skip eager_deletion on var:" << var_info->Name() + << " due to parent_info"; + return false; + } + return true; +} + void EagerDeletionOpHandle::RunImpl() { if (vars_.size() != var_infos_.size() || is_variant_scope_) { vars_.clear(); @@ -117,8 +132,7 @@ void EagerDeletionOpHandle::RunImpl() { std::deque> garbages; for (size_t i = 0; i < var_infos_.size(); ++i) { auto *var_info = var_infos_[i]; - if (var_info->IsSkippedAllMemoryOptimization() || - !var_info->DecreaseRefCnt()) { + if (!CanBeErased(var_info)) { VLOG(4) << "skip memory optimization with var: " << var_info->Name(); continue; } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index ee63b314adedb..f7487adeed75c 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -2,16 +2,18 @@ cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) +cc_library(share_mem_opt_info_to_subgraph_pass SRCS share_mem_opt_info_to_subgraph.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) + cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) -cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) +cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper ) cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle - eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) + eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper share_mem_opt_info_to_subgraph_pass) -cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper) +cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper) cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper) -cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass) +cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass) cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index 7b9b5aa623074..aa75b8ec0511f 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -285,6 +285,12 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { auto recurrent_op_eager_deletion_pass = ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass"); recurrent_op_eager_deletion_pass->Apply(graph); + + auto share_mem_opt_info_to_subgraph_pass = + ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); + share_mem_opt_info_to_subgraph_pass->SetNotOwned(kMemOptVarInfoMapList, + &var_infos); + share_mem_opt_info_to_subgraph_pass->Apply(graph); } } // namespace ir @@ -300,3 +306,4 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) USE_PASS(conditional_block_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass); +USE_PASS(share_mem_opt_info_to_subgraph_pass); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h b/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h index 94842485440bd..e89734bacec36 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h @@ -66,6 +66,12 @@ class MemOptVarInfo { return skip_memory_reuse_ || skip_all_memory_optimization_; } + void SetParentHolder(std::shared_ptr parent) { + parent_holder_ = parent; + } + + std::shared_ptr ParentHolder() const { return parent_holder_; } + const std::string &Name() const { return name_; } private: @@ -88,6 +94,9 @@ class MemOptVarInfo { std::atomic runtime_ref_cnt_; bool skip_memory_reuse_{false}; bool skip_all_memory_optimization_{false}; + // point to var info of the same variable in the main graph, + // used in external(input/output) variables of a subgraph + std::shared_ptr parent_holder_{nullptr}; }; using MemOptVarInfoMapList = std::vector< diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph.cc new file mode 100644 index 0000000000000..69b7d8f37a11d --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +using Name2VarInfoMap = + std::unordered_map>; + +static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp( + details::ComputationOpHandle* compute_op) { + for (auto* var : compute_op->Outputs()) { + if (!var->Node()->IsCtrlVar()) { + continue; + } + for (auto* op : var->PendingOps()) { + auto* eager_deletion_op = + dynamic_cast(op); + if (eager_deletion_op) { + return eager_deletion_op; + } + } + } + return nullptr; +} + +static void ShareVarInfoToCinnLaunch( + const MemOptVarInfoMapList& var_infos, + details::ComputationOpHandle* cinn_launch_op) { + auto* followed_eager_deletion_op = + FindFollowedEagerDeletionOp(cinn_launch_op); + if (!followed_eager_deletion_op) { + VLOG(4) << "No eager_deletion op found after this cinn_launch op"; + return; + } + auto vars_to_delete = followed_eager_deletion_op->VarsToDelete(); + if (vars_to_delete.empty()) { + VLOG(4) << "No var to be deleted after this cinn_launch op"; + return; + } + + const auto& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( + cinn_launch_op->GetOp()->Attr(operators::kCompilationKey)); + auto& varinfo_from_maingraph = + subgraph.Get(paddle2cinn::kMemOptVarInfoFromMainGraph); + const auto& cur_place_var_infos = var_infos.at(cinn_launch_op->GetScopeIdx()); + + // collect all MemOptVarInfos of external variables that would be eager + // deleted + // after the cinn_launch subgraph executed, and store them as attribute of the + // subgraph + for (const auto& var_name : vars_to_delete) { + auto it = cur_place_var_infos.find(var_name); + PADDLE_ENFORCE_NE(it, cur_place_var_infos.end(), + platform::errors::NotFound( + "MemOptVarInfo of Var[%s] not found", var_name)); + varinfo_from_maingraph.emplace(var_name, it->second); + } +} + +static void TakeVarInfoFromMainGraph( + const Name2VarInfoMap& parent_var_infos, + const MemOptVarInfoMapList& var_infos, + details::EagerDeletionOpHandle* eager_deletion_op) { + const auto& cur_place_var_infos = + var_infos.at(eager_deletion_op->GetScopeIdx()); + for (auto&& var_name : eager_deletion_op->VarsToDelete()) { + auto cur_it = cur_place_var_infos.find(var_name); + PADDLE_ENFORCE_NE(cur_it, cur_place_var_infos.end(), + platform::errors::NotFound( + "MemOptVarInfo of Var[%s] not found", var_name)); + auto parent_it = parent_var_infos.find(var_name); + if (parent_it != parent_var_infos.end()) { + VLOG(4) << "Var[" << var_name << "] set parent holder"; + cur_it->second->SetParentHolder(parent_it->second); + } + } +} + +// This pass will be applied on both the main graph and all subgraphs, and +// it distinguishs them according to whether the graph has the +// kMemOptVarInfoFromMainGraph attribute or not. +// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos +// to their subgraphs. +// On a subgraph, it iterates each variable that will be deleted by a +// eager_deletion op, +// and take the MemOptVarInfo from the main graph if such one found. +class ShareMemOptInfoToSubGraphPass : public ir::Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override { + auto all_ops = ir::FilterByNodeWrapper(*graph); + const auto& var_infos = Get(kMemOptVarInfoMapList); + + // the main graph + if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) { + for (auto* op : all_ops) { + auto compute_op = dynamic_cast(op); + if (compute_op && compute_op->Name() == "cinn_launch") { + ShareVarInfoToCinnLaunch(var_infos, compute_op); + } + } + } else { // a subgraph + const auto& parent_var_infos = + graph->Get(paddle2cinn::kMemOptVarInfoFromMainGraph); + for (auto* op : all_ops) { + auto eager_deletion_op = + dynamic_cast(op); + if (eager_deletion_op) { + TakeVarInfoFromMainGraph(parent_var_infos, var_infos, + eager_deletion_op); + } + } + } + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(share_mem_opt_info_to_subgraph, + paddle::framework::ir::ShareMemOptInfoToSubGraphPass) + .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 4abe3a55b298f..ab259a0fc85ab 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -44,6 +44,11 @@ DECLARE_string(deny_cinn_ops); namespace paddle { namespace framework { + +namespace ir { +class MemOptVarInfo; +} // namespace ir + namespace paddle2cinn { using framework::ir::Graph; @@ -369,6 +374,11 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, ExtractNoNeedBufferFeeds(cluster, cluster_inputs)); subgraph->Set>( kNoNeedBufferFeeds, no_need_buffer_feeds.release()); + // initialize empty map for kMemOptVarInfoFromMainGraph attribute, + // it will be filled on the share_mem_opt_info_to_subgraph pass + subgraph->GetOrInit>>( + kMemOptVarInfoFromMainGraph); return subgraph; } diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 10d12f93f8bd8..9bb25b6b52e54 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -22,6 +22,8 @@ namespace paddle2cinn { constexpr char kCinnLaunchOp[] = "cinn_launch"; constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds"; +constexpr char kMemOptVarInfoFromMainGraph[] = + "mem_opt_var_info_from_main_graph"; // A pass named BuildCinnPass, the function of this pass is: // From 5ea672378640f8420e828c5fd8e54fff80675e12 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Tue, 25 Jan 2022 10:48:36 +0000 Subject: [PATCH 2/7] update pass name --- paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt | 2 +- ...fo_to_subgraph.cc => share_mem_opt_info_to_subgraph_pass.cc} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename paddle/fluid/framework/ir/memory_optimize_pass/{share_mem_opt_info_to_subgraph.cc => share_mem_opt_info_to_subgraph_pass.cc} (100%) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index f7487adeed75c..a4f8e6f6475c6 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -2,7 +2,7 @@ cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) -cc_library(share_mem_opt_info_to_subgraph_pass SRCS share_mem_opt_info_to_subgraph.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) +cc_library(share_mem_opt_info_to_subgraph_pass SRCS share_mem_opt_info_to_subgraph_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper ) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc similarity index 100% rename from paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph.cc rename to paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc From a4d9234f7cf55e86ef892c5f883912b81ad313e0 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Thu, 27 Jan 2022 04:48:46 +0000 Subject: [PATCH 3/7] fix compile failed --- .../memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc index 69b7d8f37a11d..c5355b8dcd18e 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc @@ -140,6 +140,6 @@ class ShareMemOptInfoToSubGraphPass : public ir::Pass { } // namespace framework } // namespace paddle -REGISTER_PASS(share_mem_opt_info_to_subgraph, +REGISTER_PASS(share_mem_opt_info_to_subgraph_pass, paddle::framework::ir::ShareMemOptInfoToSubGraphPass) .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList); From cf457653738e999fb346b1600ec07048852fdd9e Mon Sep 17 00:00:00 2001 From: CtfGo Date: Fri, 28 Jan 2022 10:28:39 +0000 Subject: [PATCH 4/7] add share_mem_opt_info_to_subgraph_pass test --- .../ir/memory_optimize_pass/CMakeLists.txt | 15 +- .../eager_deletion_pass.cc | 4 + .../share_mem_opt_info_to_subgraph_pass.cc | 15 +- ...hare_mem_opt_info_to_subgraph_pass_test.cc | 157 ++++++++++++++++++ .../framework/paddle2cinn/cinn_compiler.h | 22 ++- 5 files changed, 193 insertions(+), 20 deletions(-) create mode 100644 paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index a4f8e6f6475c6..294e93ef051b1 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -2,14 +2,21 @@ cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) -cc_library(share_mem_opt_info_to_subgraph_pass SRCS share_mem_opt_info_to_subgraph_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) + +if (WITH_CINN) + cc_library(share_mem_opt_info_to_subgraph_pass SRCS share_mem_opt_info_to_subgraph_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) + cc_test(share_mem_opt_info_to_subgraph_pass_test SRCS share_mem_opt_info_to_subgraph_pass_test.cc DEPS share_mem_opt_info_to_subgraph_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op) + cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle + eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper share_mem_opt_info_to_subgraph_pass) +else() + cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle + eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) +endif() + cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper ) -cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle - eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper share_mem_opt_info_to_subgraph_pass) - cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper) cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index aa75b8ec0511f..52325e7c5610a 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -286,11 +286,13 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass"); recurrent_op_eager_deletion_pass->Apply(graph); +#ifdef PADDLE_WITH_CINN auto share_mem_opt_info_to_subgraph_pass = ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); share_mem_opt_info_to_subgraph_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos); share_mem_opt_info_to_subgraph_pass->Apply(graph); +#endif } } // namespace ir @@ -306,4 +308,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) USE_PASS(conditional_block_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass); +#ifdef PADDLE_WITH_CINN USE_PASS(share_mem_opt_info_to_subgraph_pass); +#endif diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc index c5355b8dcd18e..ec70bdaf32cfd 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc @@ -67,10 +67,9 @@ static void ShareVarInfoToCinnLaunch( subgraph.Get(paddle2cinn::kMemOptVarInfoFromMainGraph); const auto& cur_place_var_infos = var_infos.at(cinn_launch_op->GetScopeIdx()); - // collect all MemOptVarInfos of external variables that would be eager - // deleted - // after the cinn_launch subgraph executed, and store them as attribute of the - // subgraph + // collect all MemOptVarInfos of external variables + // that would be eager deleted after the cinn_launch subgraph executed, + // and store them as attribute of the subgraph for (const auto& var_name : vars_to_delete) { auto it = cur_place_var_infos.find(var_name); PADDLE_ENFORCE_NE(it, cur_place_var_infos.end(), @@ -99,14 +98,14 @@ static void TakeVarInfoFromMainGraph( } } -// This pass will be applied on both the main graph and all subgraphs, and -// it distinguishs them according to whether the graph has the +// This pass will be applied on both the main graph and all subgraphs, +// and it distinguishs them according to whether the graph has the // kMemOptVarInfoFromMainGraph attribute or not. // On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos // to their subgraphs. // On a subgraph, it iterates each variable that will be deleted by a -// eager_deletion op, -// and take the MemOptVarInfo from the main graph if such one found. +// eager_deletion op, and take the MemOptVarInfo from the main graph +// if such one found. class ShareMemOptInfoToSubGraphPass : public ir::Pass { protected: void ApplyImpl(ir::Graph* graph) const override { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc new file mode 100644 index 0000000000000..8ba55908c8552 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/framework/program_desc.h" + +USE_OP(mul); +USE_OP(cinn_launch); +USE_OP(elementwise_add); +namespace paddle { +namespace framework { + +using Name2VarInfoMap = + std::unordered_map>; + +static ProgramDesc BuildProgramInsideCinnLaunchOp() { + ProgramDesc program; + auto* block = program.MutableBlock(0); + block->Var("var1"); + block->Var("var2"); + block->Var("var3"); + block->Var("var4"); + block->Var("var5"); + + auto add_op = std::unique_ptr( + new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}}, + {{"Out", {"var3"}}}, {})); + block->AppendAllocatedOp(std::move(add_op)); + auto mul_op = std::unique_ptr(new OpDesc( + "mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {})); + block->AppendAllocatedOp(std::move(mul_op)); + return program; +} + +static ProgramDesc BuildProgramWithCinnLaunchOp( + const std::string& compilation_key) { + // create a cinn_launch op + ProgramDesc program; + auto* block = program.MutableBlock(0); + block->Var("var1"); + block->Var("var2"); + block->Var("var4"); + block->Var("var5"); + + auto cinn_launch_op = std::unique_ptr( + new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}}, + {{"Out", {"var5"}}}, {{"compilation_key", compilation_key}})); + block->AppendAllocatedOp(std::move(cinn_launch_op)); + return program; +} + +struct TestPassContext { + explicit TestPassContext(const ProgramDesc& program) { + graph = std::make_unique(program); + details::BuildStrategy build_strategy; + details::ExecutionStrategy exec_strategy; + exec_strategy.use_device_ = paddle::platform::kCUDA; + executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope, + exec_strategy, build_strategy, + graph.get())); + } + + Scope scope; + std::unique_ptr graph; + std::unique_ptr executor; +}; + +TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_var_info) { + // add a subgraph to CinnCompiler + auto subgraph = std::make_unique(BuildProgramInsideCinnLaunchOp()); + subgraph->GetOrInit( + paddle2cinn::kMemOptVarInfoFromMainGraph); + auto compilation_key = + paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph)); + + // build test data and apply pass + auto context = std::make_unique( + BuildProgramWithCinnLaunchOp(compilation_key)); + LOG(INFO) << "context constructed"; + ir::MemOptVarInfoMapList mem_opt_var_infos(1); + auto& var_infos = mem_opt_var_infos.front(); + var_infos = { + {"var1", std::make_shared("var1", 1)}, + {"var2", std::make_shared("var2", 2)}, + {"var4", std::make_shared("var4", 2)}, + {"var5", std::make_shared("var5", 1)}, + }; + auto share_pass = + ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); + share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos); + share_pass->Apply(context->graph.get()); + + // check result + const auto& result_subgraph = + paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key); + const auto& shared_var_infos = result_subgraph.Get( + paddle2cinn::kMemOptVarInfoFromMainGraph); + ASSERT_EQ(shared_var_infos.size(), 4); + EXPECT_EQ(shared_var_infos.count("var1"), 1); + EXPECT_EQ(shared_var_infos.count("var5"), 1); + EXPECT_EQ(var_infos.at("var1").get(), shared_var_infos.at("var1").get()); + EXPECT_EQ(var_infos.at("var5").get(), shared_var_infos.at("var5").get()); +} + +TEST(ShareMemInfoToSubGraphPassTest, test_sub_graph_take_var_info) { + // build test data and apply pass + auto context = + std::make_unique(BuildProgramInsideCinnLaunchOp()); + auto& var_infos_shared = context->graph->GetOrInit( + paddle2cinn::kMemOptVarInfoFromMainGraph); + var_infos_shared = { + {"var1", std::make_shared("var1", 1)}, + {"var2", std::make_shared("var2", 2)}, + }; + + ir::MemOptVarInfoMapList mem_opt_var_infos(1); + auto& var_infos = mem_opt_var_infos.front(); + var_infos = {{"var1", std::make_shared("var1", 1)}, + {"var2", std::make_shared("var2", 1)}, + {"var3", std::make_shared("var3", 1)}, + {"var4", std::make_shared("var4", 1)}, + {"var5", std::make_shared("var5", 1)}}; + auto share_pass = + ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); + share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos); + share_pass->Apply(context->graph.get()); + + // check result + ASSERT_NE(var_infos.at("var1")->ParentHolder(), nullptr); + ASSERT_NE(var_infos.at("var2")->ParentHolder(), nullptr); + ASSERT_EQ(var_infos.at("var3")->ParentHolder(), nullptr); + ASSERT_EQ(var_infos.at("var4")->ParentHolder(), nullptr); + ASSERT_EQ(var_infos.at("var5")->ParentHolder(), nullptr); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index 024dd26747b8e..91a7b4e5a11f0 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -20,9 +20,6 @@ #include #include #include - -#include "cinn/common/target.h" -#include "cinn/hlir/framework/graph_compiler.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" @@ -30,13 +27,22 @@ #include "paddle/fluid/platform/macros.h" #include "paddle/pten/core/utils/rw_lock.h" -namespace paddle { +namespace cinn { +namespace common { +class Target; +} // namespace common -namespace operators { -namespace details { +namespace hlir::framework { +class GraphCompiler; +class Program; +class Scope; +} // namespace hlir::framework +} // namespace cinn + +namespace paddle { +namespace operators::details { class CinnLaunchContext; -} // namespace details -} // namespace operators +} // namespace operators::details namespace framework { namespace paddle2cinn { From 11aaa61d5adaa42b692238269b5cdfb28c080f75 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Mon, 7 Feb 2022 10:53:53 +0000 Subject: [PATCH 5/7] share_mem_opt_info_to_subgraph_pass_test pass --- .../details/eager_deletion_op_handle.cc | 12 +++++------- .../share_mem_opt_info_to_subgraph_pass.cc | 10 +++++++--- .../share_mem_opt_info_to_subgraph_pass_test.cc | 17 ++--------------- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index e638c571263a6..9a619a67e41a8 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -112,14 +112,12 @@ static bool CanBeErased(ir::MemOptVarInfo *var_info) { !var_info->DecreaseRefCnt()) { return false; } - auto parent_info = var_info->ParentHolder(); - if (parent_info && (parent_info->IsSkippedAllMemoryOptimization() || - !parent_info->DecreaseRefCnt())) { - VLOG(4) << "Skip eager_deletion on var:" << var_info->Name() - << " due to parent_info"; - return false; + auto parent_holder = var_info->ParentHolder(); + // if parent_holder exists, it should meet deletion condition too. + if (!parent_holder || CanBeErased(parent_holder.get())) { + return true; } - return true; + return false; } void EagerDeletionOpHandle::RunImpl() { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc index ec70bdaf32cfd..fb5a7c668c450 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { @@ -60,6 +61,9 @@ static void ShareVarInfoToCinnLaunch( VLOG(4) << "No var to be deleted after this cinn_launch op"; return; } + VLOG(4) << "Variables would be deleted by eager_deletion_op" + << " following the cinn_launch:" + << paddle::string::join_strings(vars_to_delete, ','); const auto& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( cinn_launch_op->GetOp()->Attr(operators::kCompilationKey)); @@ -74,7 +78,7 @@ static void ShareVarInfoToCinnLaunch( auto it = cur_place_var_infos.find(var_name); PADDLE_ENFORCE_NE(it, cur_place_var_infos.end(), platform::errors::NotFound( - "MemOptVarInfo of Var[%s] not found", var_name)); + "MemOptVarInfo of var[%s] not found", var_name)); varinfo_from_maingraph.emplace(var_name, it->second); } } @@ -89,10 +93,10 @@ static void TakeVarInfoFromMainGraph( auto cur_it = cur_place_var_infos.find(var_name); PADDLE_ENFORCE_NE(cur_it, cur_place_var_infos.end(), platform::errors::NotFound( - "MemOptVarInfo of Var[%s] not found", var_name)); + "MemOptVarInfo of var[%s] not found", var_name)); auto parent_it = parent_var_infos.find(var_name); if (parent_it != parent_var_infos.end()) { - VLOG(4) << "Var[" << var_name << "] set parent holder"; + VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder"; cur_it->second->SetParentHolder(parent_it->second); } } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc index 8ba55908c8552..8b9a2a9f00ef4 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc @@ -96,19 +96,6 @@ TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_var_info) { // build test data and apply pass auto context = std::make_unique( BuildProgramWithCinnLaunchOp(compilation_key)); - LOG(INFO) << "context constructed"; - ir::MemOptVarInfoMapList mem_opt_var_infos(1); - auto& var_infos = mem_opt_var_infos.front(); - var_infos = { - {"var1", std::make_shared("var1", 1)}, - {"var2", std::make_shared("var2", 2)}, - {"var4", std::make_shared("var4", 2)}, - {"var5", std::make_shared("var5", 1)}, - }; - auto share_pass = - ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); - share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos); - share_pass->Apply(context->graph.get()); // check result const auto& result_subgraph = @@ -118,8 +105,8 @@ TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_var_info) { ASSERT_EQ(shared_var_infos.size(), 4); EXPECT_EQ(shared_var_infos.count("var1"), 1); EXPECT_EQ(shared_var_infos.count("var5"), 1); - EXPECT_EQ(var_infos.at("var1").get(), shared_var_infos.at("var1").get()); - EXPECT_EQ(var_infos.at("var5").get(), shared_var_infos.at("var5").get()); + EXPECT_EQ(shared_var_infos.at("var1").use_count(), 2); + EXPECT_EQ(shared_var_infos.at("var5").use_count(), 2); } TEST(ShareMemInfoToSubGraphPassTest, test_sub_graph_take_var_info) { From 792ffed19a73b4af4014d18cae000e612ca2f104 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Wed, 9 Feb 2022 09:34:17 +0000 Subject: [PATCH 6/7] modify some codes for better style and more robust --- .../details/eager_deletion_op_handle.cc | 10 ++- .../ir/memory_optimize_pass/CMakeLists.txt | 18 ++--- .../eager_deletion_pass.cc | 11 ++- ...ass.cc => share_varinfo_into_cinn_pass.cc} | 73 +++++++++---------- ...c => share_varinfo_into_cinn_pass_test.cc} | 58 +++++++-------- 5 files changed, 83 insertions(+), 87 deletions(-) rename paddle/fluid/framework/ir/memory_optimize_pass/{share_mem_opt_info_to_subgraph_pass.cc => share_varinfo_into_cinn_pass.cc} (69%) rename paddle/fluid/framework/ir/memory_optimize_pass/{share_mem_opt_info_to_subgraph_pass_test.cc => share_varinfo_into_cinn_pass_test.cc} (71%) diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index 9a619a67e41a8..c760e7a98614c 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -112,12 +112,14 @@ static bool CanBeErased(ir::MemOptVarInfo *var_info) { !var_info->DecreaseRefCnt()) { return false; } - auto parent_holder = var_info->ParentHolder(); +#ifdef PADDLE_WITH_CINN // if parent_holder exists, it should meet deletion condition too. - if (!parent_holder || CanBeErased(parent_holder.get())) { - return true; + std::shared_ptr parent_holder = var_info->ParentHolder(); + if (parent_holder && !CanBeErased(parent_holder.get())) { + return false; } - return false; +#endif + return true; } void EagerDeletionOpHandle::RunImpl() { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index 294e93ef051b1..8343291816396 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -3,19 +3,17 @@ cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_ea cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) +cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) +cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper ) + +SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) if (WITH_CINN) - cc_library(share_mem_opt_info_to_subgraph_pass SRCS share_mem_opt_info_to_subgraph_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) - cc_test(share_mem_opt_info_to_subgraph_pass_test SRCS share_mem_opt_info_to_subgraph_pass_test.cc DEPS share_mem_opt_info_to_subgraph_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op) - cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle - eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper share_mem_opt_info_to_subgraph_pass) -else() - cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle - eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) + cc_library(share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler) + cc_test(share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc DEPS share_varinfo_into_cinn_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op) + list(APPEND EAGER_DELETETION_PASS_DEPS share_varinfo_into_cinn_pass) endif() - -cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) -cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper ) +cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS ${EAGER_DELETETION_PASS_DEPS}) cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index 52325e7c5610a..af1a65f7a6c3b 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -287,11 +287,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { recurrent_op_eager_deletion_pass->Apply(graph); #ifdef PADDLE_WITH_CINN - auto share_mem_opt_info_to_subgraph_pass = - ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); - share_mem_opt_info_to_subgraph_pass->SetNotOwned(kMemOptVarInfoMapList, - &var_infos); - share_mem_opt_info_to_subgraph_pass->Apply(graph); + auto share_varinfo_into_cinn_pass = + ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass"); + share_varinfo_into_cinn_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos); + share_varinfo_into_cinn_pass->Apply(graph); #endif } @@ -309,5 +308,5 @@ USE_PASS(conditional_block_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass); #ifdef PADDLE_WITH_CINN -USE_PASS(share_mem_opt_info_to_subgraph_pass); +USE_PASS(share_varinfo_into_cinn_pass); #endif diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc similarity index 69% rename from paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc rename to paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc index fb5a7c668c450..1b2a62695fb13 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc @@ -23,20 +23,18 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/string_helper.h" -namespace paddle { -namespace framework { -namespace ir { +namespace paddle::framework::ir { using Name2VarInfoMap = std::unordered_map>; static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp( details::ComputationOpHandle* compute_op) { - for (auto* var : compute_op->Outputs()) { + for (details::VarHandleBase* var : compute_op->Outputs()) { if (!var->Node()->IsCtrlVar()) { continue; } - for (auto* op : var->PendingOps()) { + for (details::OpHandleBase* op : var->PendingOps()) { auto* eager_deletion_op = dynamic_cast(op); if (eager_deletion_op) { @@ -48,90 +46,93 @@ static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp( } static void ShareVarInfoToCinnLaunch( - const MemOptVarInfoMapList& var_infos, + const MemOptVarInfoMapList& varinfo_maps, details::ComputationOpHandle* cinn_launch_op) { - auto* followed_eager_deletion_op = + details::EagerDeletionOpHandle* followed_eager_deletion_op = FindFollowedEagerDeletionOp(cinn_launch_op); if (!followed_eager_deletion_op) { VLOG(4) << "No eager_deletion op found after this cinn_launch op"; return; } - auto vars_to_delete = followed_eager_deletion_op->VarsToDelete(); + + std::vector vars_to_delete = + followed_eager_deletion_op->VarsToDelete(); if (vars_to_delete.empty()) { VLOG(4) << "No var to be deleted after this cinn_launch op"; return; } - VLOG(4) << "Variables would be deleted by eager_deletion_op" + VLOG(4) << "Variables would be deleted by the eager_deletion_op" << " following the cinn_launch:" << paddle::string::join_strings(vars_to_delete, ','); - const auto& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( + const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( cinn_launch_op->GetOp()->Attr(operators::kCompilationKey)); - auto& varinfo_from_maingraph = + auto& dst_varinfo_map = subgraph.Get(paddle2cinn::kMemOptVarInfoFromMainGraph); - const auto& cur_place_var_infos = var_infos.at(cinn_launch_op->GetScopeIdx()); + const Name2VarInfoMap& src_varinfo_map = + varinfo_maps.at(cinn_launch_op->GetScopeIdx()); // collect all MemOptVarInfos of external variables // that would be eager deleted after the cinn_launch subgraph executed, // and store them as attribute of the subgraph for (const auto& var_name : vars_to_delete) { - auto it = cur_place_var_infos.find(var_name); - PADDLE_ENFORCE_NE(it, cur_place_var_infos.end(), + auto it = src_varinfo_map.find(var_name); + PADDLE_ENFORCE_NE(it, src_varinfo_map.end(), platform::errors::NotFound( "MemOptVarInfo of var[%s] not found", var_name)); - varinfo_from_maingraph.emplace(var_name, it->second); + dst_varinfo_map.emplace(var_name, it->second); } } static void TakeVarInfoFromMainGraph( - const Name2VarInfoMap& parent_var_infos, - const MemOptVarInfoMapList& var_infos, + const Name2VarInfoMap& src_varinfo_map, + const MemOptVarInfoMapList& varinfo_maps, details::EagerDeletionOpHandle* eager_deletion_op) { - const auto& cur_place_var_infos = - var_infos.at(eager_deletion_op->GetScopeIdx()); + const Name2VarInfoMap& dst_varinfo_map = + varinfo_maps.at(eager_deletion_op->GetScopeIdx()); for (auto&& var_name : eager_deletion_op->VarsToDelete()) { - auto cur_it = cur_place_var_infos.find(var_name); - PADDLE_ENFORCE_NE(cur_it, cur_place_var_infos.end(), + auto dst_it = dst_varinfo_map.find(var_name); + PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(), platform::errors::NotFound( "MemOptVarInfo of var[%s] not found", var_name)); - auto parent_it = parent_var_infos.find(var_name); - if (parent_it != parent_var_infos.end()) { + auto src_it = src_varinfo_map.find(var_name); + if (src_it != src_varinfo_map.end()) { VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder"; - cur_it->second->SetParentHolder(parent_it->second); + dst_it->second->SetParentHolder(src_it->second); } } } -// This pass will be applied on both the main graph and all subgraphs, +// This pass will be applied on both the main graph and all cinn subgraphs, // and it distinguishs them according to whether the graph has the // kMemOptVarInfoFromMainGraph attribute or not. // On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos // to their subgraphs. -// On a subgraph, it iterates each variable that will be deleted by a +// On a cinn subgraph, it iterates each variable that will be deleted by a // eager_deletion op, and take the MemOptVarInfo from the main graph // if such one found. class ShareMemOptInfoToSubGraphPass : public ir::Pass { protected: void ApplyImpl(ir::Graph* graph) const override { auto all_ops = ir::FilterByNodeWrapper(*graph); - const auto& var_infos = Get(kMemOptVarInfoMapList); + const auto& varinfo_maps = Get(kMemOptVarInfoMapList); // the main graph if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) { - for (auto* op : all_ops) { + for (details::OpHandleBase* op : all_ops) { auto compute_op = dynamic_cast(op); if (compute_op && compute_op->Name() == "cinn_launch") { - ShareVarInfoToCinnLaunch(var_infos, compute_op); + ShareVarInfoToCinnLaunch(varinfo_maps, compute_op); } } - } else { // a subgraph - const auto& parent_var_infos = + } else { // a cinn subgraph + const auto& parent_varinfo_map = graph->Get(paddle2cinn::kMemOptVarInfoFromMainGraph); - for (auto* op : all_ops) { + for (details::OpHandleBase* op : all_ops) { auto eager_deletion_op = dynamic_cast(op); if (eager_deletion_op) { - TakeVarInfoFromMainGraph(parent_var_infos, var_infos, + TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps, eager_deletion_op); } } @@ -139,10 +140,8 @@ class ShareMemOptInfoToSubGraphPass : public ir::Pass { } }; -} // namespace ir -} // namespace framework -} // namespace paddle +} // namespace paddle::framework::ir -REGISTER_PASS(share_mem_opt_info_to_subgraph_pass, +REGISTER_PASS(share_varinfo_into_cinn_pass, paddle::framework::ir::ShareMemOptInfoToSubGraphPass) .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc similarity index 71% rename from paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc rename to paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc index 8b9a2a9f00ef4..abed6a5bd4bc4 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_mem_opt_info_to_subgraph_pass_test.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc @@ -27,8 +27,7 @@ USE_OP(mul); USE_OP(cinn_launch); USE_OP(elementwise_add); -namespace paddle { -namespace framework { +namespace paddle::framework { using Name2VarInfoMap = std::unordered_map>; @@ -85,12 +84,12 @@ struct TestPassContext { std::unique_ptr executor; }; -TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_var_info) { +TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) { // add a subgraph to CinnCompiler auto subgraph = std::make_unique(BuildProgramInsideCinnLaunchOp()); subgraph->GetOrInit( paddle2cinn::kMemOptVarInfoFromMainGraph); - auto compilation_key = + std::string compilation_key = paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph)); // build test data and apply pass @@ -98,47 +97,46 @@ TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_var_info) { BuildProgramWithCinnLaunchOp(compilation_key)); // check result - const auto& result_subgraph = + const ir::Graph& result_subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key); - const auto& shared_var_infos = result_subgraph.Get( + const auto& dst_varinfo_map = result_subgraph.Get( paddle2cinn::kMemOptVarInfoFromMainGraph); - ASSERT_EQ(shared_var_infos.size(), 4); - EXPECT_EQ(shared_var_infos.count("var1"), 1); - EXPECT_EQ(shared_var_infos.count("var5"), 1); - EXPECT_EQ(shared_var_infos.at("var1").use_count(), 2); - EXPECT_EQ(shared_var_infos.at("var5").use_count(), 2); + ASSERT_EQ(dst_varinfo_map.size(), 4); + EXPECT_EQ(dst_varinfo_map.count("var1"), 1); + EXPECT_EQ(dst_varinfo_map.count("var5"), 1); + EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2); + EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2); } -TEST(ShareMemInfoToSubGraphPassTest, test_sub_graph_take_var_info) { +TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) { // build test data and apply pass auto context = std::make_unique(BuildProgramInsideCinnLaunchOp()); - auto& var_infos_shared = context->graph->GetOrInit( + auto& varinfo_map_shared = context->graph->GetOrInit( paddle2cinn::kMemOptVarInfoFromMainGraph); - var_infos_shared = { + varinfo_map_shared = { {"var1", std::make_shared("var1", 1)}, {"var2", std::make_shared("var2", 2)}, }; - ir::MemOptVarInfoMapList mem_opt_var_infos(1); - auto& var_infos = mem_opt_var_infos.front(); - var_infos = {{"var1", std::make_shared("var1", 1)}, - {"var2", std::make_shared("var2", 1)}, - {"var3", std::make_shared("var3", 1)}, - {"var4", std::make_shared("var4", 1)}, - {"var5", std::make_shared("var5", 1)}}; + ir::MemOptVarInfoMapList varinfo_maps(1); + auto& dst_varinfo_map = varinfo_maps.front(); + dst_varinfo_map = {{"var1", std::make_shared("var1", 1)}, + {"var2", std::make_shared("var2", 1)}, + {"var3", std::make_shared("var3", 1)}, + {"var4", std::make_shared("var4", 1)}, + {"var5", std::make_shared("var5", 1)}}; auto share_pass = - ir::PassRegistry::Instance().Get("share_mem_opt_info_to_subgraph_pass"); - share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos); + ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass"); + share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps); share_pass->Apply(context->graph.get()); // check result - ASSERT_NE(var_infos.at("var1")->ParentHolder(), nullptr); - ASSERT_NE(var_infos.at("var2")->ParentHolder(), nullptr); - ASSERT_EQ(var_infos.at("var3")->ParentHolder(), nullptr); - ASSERT_EQ(var_infos.at("var4")->ParentHolder(), nullptr); - ASSERT_EQ(var_infos.at("var5")->ParentHolder(), nullptr); + ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr); + ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr); + ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr); + ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr); + ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr); } -} // namespace framework -} // namespace paddle +} // namespace paddle::framework From f5e975b0672e08946e72a48ea8c57039f918f611 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Wed, 9 Feb 2022 09:40:47 +0000 Subject: [PATCH 7/7] update cmake --- paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index 8343291816396..25b07ddf41414 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -2,9 +2,8 @@ cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) - cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) -cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper ) +cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) if (WITH_CINN)