Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

share MemOptVarInfos of external variables into cinn_launch subgraph #39209

Merged
merged 14 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions paddle/fluid/framework/details/eager_deletion_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ void EagerDeletionOpHandle::CallOnce() {

std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }

static bool CanBeErased(ir::MemOptVarInfo *var_info) {
if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) {
return false;
}
#ifdef PADDLE_WITH_CINN
// if parent_holder exists, it should meet deletion condition too.
thisjiang marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<ir::MemOptVarInfo> parent_holder = var_info->ParentHolder();
if (parent_holder && !CanBeErased(parent_holder.get())) {
return false;
}
#endif
return true;
}

void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size() || is_variant_scope_) {
vars_.clear();
Expand All @@ -117,8 +132,7 @@ void EagerDeletionOpHandle::RunImpl() {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (size_t i = 0; i < var_infos_.size(); ++i) {
auto *var_info = var_infos_[i];
if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) {
if (!CanBeErased(var_info)) {
VLOG(4) << "skip memory optimization with var: " << var_info->Name();
continue;
}
Expand Down
14 changes: 10 additions & 4 deletions paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pas
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)

cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
if (WITH_CINN)
cc_library(share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler)
cc_test(share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc DEPS share_varinfo_into_cinn_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op)
list(APPEND EAGER_DELETETION_PASS_DEPS share_varinfo_into_cinn_pass)
endif()

cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS ${EAGER_DELETETION_PASS_DEPS})

cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)

cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)

cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto recurrent_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass");
recurrent_op_eager_deletion_pass->Apply(graph);

#ifdef PADDLE_WITH_CINN
auto share_varinfo_into_cinn_pass =
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
share_varinfo_into_cinn_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos);
share_varinfo_into_cinn_pass->Apply(graph);
#endif
}

} // namespace ir
Expand All @@ -300,3 +307,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
USE_PASS(conditional_block_op_eager_deletion_pass);
USE_PASS(while_op_eager_deletion_pass);
USE_PASS(recurrent_op_eager_deletion_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(share_varinfo_into_cinn_pass);
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class MemOptVarInfo {
return skip_memory_reuse_ || skip_all_memory_optimization_;
}

void SetParentHolder(std::shared_ptr<MemOptVarInfo> parent) {
parent_holder_ = parent;
}

std::shared_ptr<MemOptVarInfo> ParentHolder() const { return parent_holder_; }

const std::string &Name() const { return name_; }

private:
Expand All @@ -88,6 +94,9 @@ class MemOptVarInfo {
std::atomic<size_t> runtime_ref_cnt_;
bool skip_memory_reuse_{false};
bool skip_all_memory_optimization_{false};
// point to var info of the same variable in the main graph,
// used in external(input/output) variables of a subgraph
std::shared_ptr<MemOptVarInfo> parent_holder_{nullptr};
};

using MemOptVarInfoMapList = std::vector<
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle::framework::ir {

using Name2VarInfoMap =
std::unordered_map<std::string, std::shared_ptr<MemOptVarInfo>>;

static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp(
details::ComputationOpHandle* compute_op) {
for (details::VarHandleBase* var : compute_op->Outputs()) {
if (!var->Node()->IsCtrlVar()) {
continue;
}
for (details::OpHandleBase* op : var->PendingOps()) {
auto* eager_deletion_op =
dynamic_cast<details::EagerDeletionOpHandle*>(op);
if (eager_deletion_op) {
return eager_deletion_op;
}
}
}
return nullptr;
}

static void ShareVarInfoToCinnLaunch(
const MemOptVarInfoMapList& varinfo_maps,
details::ComputationOpHandle* cinn_launch_op) {
details::EagerDeletionOpHandle* followed_eager_deletion_op =
FindFollowedEagerDeletionOp(cinn_launch_op);
if (!followed_eager_deletion_op) {
VLOG(4) << "No eager_deletion op found after this cinn_launch op";
return;
}

std::vector<std::string> vars_to_delete =
followed_eager_deletion_op->VarsToDelete();
if (vars_to_delete.empty()) {
VLOG(4) << "No var to be deleted after this cinn_launch op";
return;
}
VLOG(4) << "Variables would be deleted by the eager_deletion_op"
<< " following the cinn_launch:"
<< paddle::string::join_strings(vars_to_delete, ',');

const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey));
auto& dst_varinfo_map =
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
const Name2VarInfoMap& src_varinfo_map =
varinfo_maps.at(cinn_launch_op->GetScopeIdx());

// collect all MemOptVarInfos of external variables
// that would be eager deleted after the cinn_launch subgraph executed,
// and store them as attribute of the subgraph
for (const auto& var_name : vars_to_delete) {
auto it = src_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(),
platform::errors::NotFound(
"MemOptVarInfo of var[%s] not found", var_name));
dst_varinfo_map.emplace(var_name, it->second);
}
}

static void TakeVarInfoFromMainGraph(
const Name2VarInfoMap& src_varinfo_map,
const MemOptVarInfoMapList& varinfo_maps,
details::EagerDeletionOpHandle* eager_deletion_op) {
const Name2VarInfoMap& dst_varinfo_map =
varinfo_maps.at(eager_deletion_op->GetScopeIdx());
for (auto&& var_name : eager_deletion_op->VarsToDelete()) {
auto dst_it = dst_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(),
platform::errors::NotFound(
"MemOptVarInfo of var[%s] not found", var_name));
auto src_it = src_varinfo_map.find(var_name);
if (src_it != src_varinfo_map.end()) {
VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder";
dst_it->second->SetParentHolder(src_it->second);
}
}
}

// This pass will be applied on both the main graph and all cinn subgraphs,
// and it distinguishs them according to whether the graph has the
// kMemOptVarInfoFromMainGraph attribute or not.
// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos
// to their subgraphs.
// On a cinn subgraph, it iterates each variable that will be deleted by a
// eager_deletion op, and take the MemOptVarInfo from the main graph
// if such one found.
class ShareMemOptInfoToSubGraphPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
const auto& varinfo_maps = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);

// the main graph
if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) {
for (details::OpHandleBase* op : all_ops) {
auto compute_op = dynamic_cast<details::ComputationOpHandle*>(op);
if (compute_op && compute_op->Name() == "cinn_launch") {
ShareVarInfoToCinnLaunch(varinfo_maps, compute_op);
}
}
} else { // a cinn subgraph
const auto& parent_varinfo_map =
graph->Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
for (details::OpHandleBase* op : all_ops) {
auto eager_deletion_op =
dynamic_cast<details::EagerDeletionOpHandle*>(op);
if (eager_deletion_op) {
TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps,
eager_deletion_op);
}
}
}
}
};

} // namespace paddle::framework::ir

REGISTER_PASS(share_varinfo_into_cinn_pass,
paddle::framework::ir::ShareMemOptInfoToSubGraphPass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList);
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h"

USE_OP(mul);
USE_OP(cinn_launch);
USE_OP(elementwise_add);
namespace paddle::framework {

using Name2VarInfoMap =
std::unordered_map<std::string, std::shared_ptr<ir::MemOptVarInfo>>;

static ProgramDesc BuildProgramInsideCinnLaunchOp() {
ProgramDesc program;
auto* block = program.MutableBlock(0);
block->Var("var1");
block->Var("var2");
block->Var("var3");
block->Var("var4");
block->Var("var5");

auto add_op = std::unique_ptr<OpDesc>(
new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}},
{{"Out", {"var3"}}}, {}));
block->AppendAllocatedOp(std::move(add_op));
auto mul_op = std::unique_ptr<OpDesc>(new OpDesc(
"mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {}));
block->AppendAllocatedOp(std::move(mul_op));
return program;
}

static ProgramDesc BuildProgramWithCinnLaunchOp(
const std::string& compilation_key) {
// create a cinn_launch op
ProgramDesc program;
auto* block = program.MutableBlock(0);
block->Var("var1");
block->Var("var2");
block->Var("var4");
block->Var("var5");

auto cinn_launch_op = std::unique_ptr<OpDesc>(
new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}},
{{"Out", {"var5"}}}, {{"compilation_key", compilation_key}}));
block->AppendAllocatedOp(std::move(cinn_launch_op));
return program;
}

struct TestPassContext {
explicit TestPassContext(const ProgramDesc& program) {
graph = std::make_unique<ir::Graph>(program);
details::BuildStrategy build_strategy;
details::ExecutionStrategy exec_strategy;
exec_strategy.use_device_ = paddle::platform::kCUDA;
executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope,
exec_strategy, build_strategy,
graph.get()));
}

Scope scope;
std::unique_ptr<ir::Graph> graph;
std::unique_ptr<ParallelExecutor> executor;
};

TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) {
// add a subgraph to CinnCompiler
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp());
subgraph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
std::string compilation_key =
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph));

// build test data and apply pass
auto context = std::make_unique<TestPassContext>(
BuildProgramWithCinnLaunchOp(compilation_key));

// check result
const ir::Graph& result_subgraph =
paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key);
const auto& dst_varinfo_map = result_subgraph.Get<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
ASSERT_EQ(dst_varinfo_map.size(), 4);
EXPECT_EQ(dst_varinfo_map.count("var1"), 1);
EXPECT_EQ(dst_varinfo_map.count("var5"), 1);
EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2);
EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2);
}

TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) {
// build test data and apply pass
auto context =
std::make_unique<TestPassContext>(BuildProgramInsideCinnLaunchOp());
auto& varinfo_map_shared = context->graph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
varinfo_map_shared = {
{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 2)},
};

ir::MemOptVarInfoMapList varinfo_maps(1);
auto& dst_varinfo_map = varinfo_maps.front();
dst_varinfo_map = {{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 1)},
{"var3", std::make_shared<ir::MemOptVarInfo>("var3", 1)},
{"var4", std::make_shared<ir::MemOptVarInfo>("var4", 1)},
{"var5", std::make_shared<ir::MemOptVarInfo>("var5", 1)}};
auto share_pass =
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps);
share_pass->Apply(context->graph.get());

// check result
ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr);
ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr);
}

} // namespace paddle::framework
Loading