Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Add unittest for Operation::Clone and Group::Clone #60577

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,22 @@ struct Group {
// Mapper from original to new ops.
std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper;
::pir::CloneOptions clone_options(false, true);
for (auto* op : this->ops_set) {
for (auto* op : ops) {
VLOG(4) << "clone op :" << op->name();
auto* new_op = op->Clone(ir_mapping, clone_options);
// NOTE(dev): Must call MoveTo to deal with ownership, otherwise it
// NOTE(dev): Must call block.insert to deal with ownership, otherwise it
// will lead memory-leak.
new_op->MoveTo(target_block, target_block->end());
target_block->insert(target_block->end(), new_op);
new_ops.push_back(new_op);
ops_mapper[op] = new_op;
}
// Construct Base information for new Group
auto new_group = std::make_shared<Group>(new_ops);
this->CollectOps();
for (auto& iter : this->input_ops) {
new_group->input_ops[ops_mapper[iter.first]] = iter.second;
new_group->input_ops[ops_mapper.at(iter.first)] = iter.second;
}
for (auto* op : this->output_ops) {
new_group->output_ops.insert(ops_mapper[op]);
new_group->output_ops.insert(ops_mapper.at(op));
}

return new_group;
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
IR_ENFORCE(options.IsCloneRegions() || num_regions_ > 0,
"Operation CloneOperands is unimplemented currently.");
IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0,
"Operation CloneRegions is unimplemented currently.");
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Expand Down
110 changes: 110 additions & 0 deletions test/cpp/pir/cinn/group_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h"
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand Down Expand Up @@ -243,3 +244,112 @@ TEST(GroupOp, CINNLowering) {
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}

class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
public:
using pir::OpRewritePattern<cinn::dialect::GroupOp>::OpRewritePattern;
using Group = cinn::hlir::framework::pir::Group;

bool MatchAndRewrite(cinn::dialect::GroupOp group_op,
pir::PatternRewriter& rewriter) const override {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* program = group_op->GetParentProgram();
::pir::Builder builder = ::pir::Builder(ctx, program->block());
VLOG(4) << "Before GroupOpPattern: " << *program;
std::vector<::pir::Operation*> group_ops = group_op.ops();
auto yeild_op = group_ops.back();
std::vector<::pir::Type> output_type{yeild_op->operand_source(0).type()};

// construct hlir::Group
Group group({group_ops.begin(), group_ops.end() - 1});
group.input_ops[group_ops[0]] = 0; // first tan
auto last_op_idx = group_ops.size() - 2;
group.output_ops.insert(group_ops[last_op_idx]); // last relu

// clone group and sync their op into new GroupOp
builder.SetInsertionPointAfter(group_op.operation());
auto new_group_op = builder.Build<cinn::dialect::GroupOp>(output_type);

// prepare IrMapping
::pir::IrMapping ir_mapping;
auto depend_value = group_ops[0]->operand_source(0);
ir_mapping.Add(depend_value, depend_value);
std::shared_ptr<Group> new_group =
group.Clone(new_group_op.block(), ir_mapping);

EXPECT_EQ(new_group->ops.size(), group.ops.size());
EXPECT_EQ(new_group->input_ops.size(), group.input_ops.size());
EXPECT_EQ(new_group->output_ops.size(), group.output_ops.size());

// Add yield op
builder.SetInsertionPointToBlockEnd(new_group_op.block());
std::vector<::pir::Value> yield_inputs{
new_group_op.ops().back()->result(0)};
builder.Build<::pir::YieldOp>(yield_inputs);
EXPECT_EQ(new_group_op.ops().size(), group_ops.size());

// replace result UD between GroupOp
rewriter.ReplaceAllUsesWith(group_op->result(0), new_group_op->result(0));
rewriter.EraseOp(group_op);
VLOG(4) << "After GroupOpPattern.EraseOp: " << *program;
return true;
}
};

class TestGroupClonePass : public pir::PatternRewritePass {
public:
TestGroupClonePass() : pir::PatternRewritePass("test_group_clone", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<GroupOpPattern>(context);

return ps;
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

std::shared_ptr<::pir::Program> BuildSingleGroupProgram() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>();

auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());
const std::vector<int64_t> shape = {64, 128};
// full op
auto full_x = builder.Build<paddle::dialect::FullOp>(
shape, 0.5, phi::DataType::FLOAT32, phi::GPUPlace());

// group op
auto group_op = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(common::make_ddim(shape)));
pir::Block* block = group_op.block();
builder.SetInsertionPointToBlockEnd(block);

auto tan_op_x = builder.Build<paddle::dialect::TanOp>(full_x->result(0));
auto relu_op_x = builder.Build<paddle::dialect::ReluOp>(tan_op_x->result(0));
auto tan_op_y = builder.Build<paddle::dialect::TanOp>(relu_op_x->result(0));
auto relu_op_y = builder.Build<paddle::dialect::ReluOp>(tan_op_y->result(0));
builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()});

// tan op
builder.SetInsertionPointToBlockEnd(program->block());
auto final_op = builder.Build<paddle::dialect::TanOp>(group_op->result(0));

return program;
}

TEST(Group, Clone) {
// Step 1: Construct pir::Program
std::shared_ptr<::pir::Program> program = BuildSingleGroupProgram();
::pir::IrContext* ctx = ::pir::IrContext::Instance();
::pir::PassManager pm(ctx);
// Step 2: Run TestGroupClonePass
pm.AddPass(std::make_unique<TestGroupClonePass>());
pm.Run(program.get());
}