Skip to content

Commit

Permalink
[PIR] Add unittest for Operation::Clone and Group::Clone (#60577)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored Jan 5, 2024
1 parent a9712d1 commit 75e62a2
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 8 deletions.
12 changes: 6 additions & 6 deletions paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,22 @@ struct Group {
// Mapper from original to new ops.
std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper;
::pir::CloneOptions clone_options(false, true);
for (auto* op : this->ops_set) {
for (auto* op : ops) {
VLOG(4) << "clone op :" << op->name();
auto* new_op = op->Clone(ir_mapping, clone_options);
// NOTE(dev): Must call MoveTo to deal with ownership, otherwise it
// NOTE(dev): Must call block.insert to deal with ownership, otherwise it
// will lead memory-leak.
new_op->MoveTo(target_block, target_block->end());
target_block->insert(target_block->end(), new_op);
new_ops.push_back(new_op);
ops_mapper[op] = new_op;
}
// Construct Base information for new Group
auto new_group = std::make_shared<Group>(new_ops);
this->CollectOps();
for (auto& iter : this->input_ops) {
new_group->input_ops[ops_mapper[iter.first]] = iter.second;
new_group->input_ops[ops_mapper.at(iter.first)] = iter.second;
}
for (auto* op : this->output_ops) {
new_group->output_ops.insert(ops_mapper[op]);
new_group->output_ops.insert(ops_mapper.at(op));
}

return new_group;
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
IR_ENFORCE(options.IsCloneRegions() || num_regions_ > 0,
"Operation CloneOperands is unimplemented currently.");
IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0,
"Operation CloneRegions is unimplemented currently.");
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Expand Down
110 changes: 110 additions & 0 deletions test/cpp/pir/cinn/group_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h"
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand Down Expand Up @@ -243,3 +244,112 @@ TEST(GroupOp, CINNLowering) {
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}

class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
public:
using pir::OpRewritePattern<cinn::dialect::GroupOp>::OpRewritePattern;
using Group = cinn::hlir::framework::pir::Group;

bool MatchAndRewrite(cinn::dialect::GroupOp group_op,
pir::PatternRewriter& rewriter) const override {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* program = group_op->GetParentProgram();
::pir::Builder builder = ::pir::Builder(ctx, program->block());
VLOG(4) << "Before GroupOpPattern: " << *program;
std::vector<::pir::Operation*> group_ops = group_op.ops();
auto yeild_op = group_ops.back();
std::vector<::pir::Type> output_type{yeild_op->operand_source(0).type()};

// construct hlir::Group
Group group({group_ops.begin(), group_ops.end() - 1});
group.input_ops[group_ops[0]] = 0; // first tan
auto last_op_idx = group_ops.size() - 2;
group.output_ops.insert(group_ops[last_op_idx]); // last relu

// clone group and sync their op into new GroupOp
builder.SetInsertionPointAfter(group_op.operation());
auto new_group_op = builder.Build<cinn::dialect::GroupOp>(output_type);

// prepare IrMapping
::pir::IrMapping ir_mapping;
auto depend_value = group_ops[0]->operand_source(0);
ir_mapping.Add(depend_value, depend_value);
std::shared_ptr<Group> new_group =
group.Clone(new_group_op.block(), ir_mapping);

EXPECT_EQ(new_group->ops.size(), group.ops.size());
EXPECT_EQ(new_group->input_ops.size(), group.input_ops.size());
EXPECT_EQ(new_group->output_ops.size(), group.output_ops.size());

// Add yield op
builder.SetInsertionPointToBlockEnd(new_group_op.block());
std::vector<::pir::Value> yield_inputs{
new_group_op.ops().back()->result(0)};
builder.Build<::pir::YieldOp>(yield_inputs);
EXPECT_EQ(new_group_op.ops().size(), group_ops.size());

// replace result UD between GroupOp
rewriter.ReplaceAllUsesWith(group_op->result(0), new_group_op->result(0));
rewriter.EraseOp(group_op);
VLOG(4) << "After GroupOpPattern.EraseOp: " << *program;
return true;
}
};

class TestGroupClonePass : public pir::PatternRewritePass {
public:
TestGroupClonePass() : pir::PatternRewritePass("test_group_clone", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<GroupOpPattern>(context);

return ps;
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

std::shared_ptr<::pir::Program> BuildSingleGroupProgram() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>();

auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());
const std::vector<int64_t> shape = {64, 128};
// full op
auto full_x = builder.Build<paddle::dialect::FullOp>(
shape, 0.5, phi::DataType::FLOAT32, phi::GPUPlace());

// group op
auto group_op = builder.Build<cinn::dialect::GroupOp>(
CreateDenseTensorTypes(common::make_ddim(shape)));
pir::Block* block = group_op.block();
builder.SetInsertionPointToBlockEnd(block);

auto tan_op_x = builder.Build<paddle::dialect::TanOp>(full_x->result(0));
auto relu_op_x = builder.Build<paddle::dialect::ReluOp>(tan_op_x->result(0));
auto tan_op_y = builder.Build<paddle::dialect::TanOp>(relu_op_x->result(0));
auto relu_op_y = builder.Build<paddle::dialect::ReluOp>(tan_op_y->result(0));
builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()});

// tan op
builder.SetInsertionPointToBlockEnd(program->block());
auto final_op = builder.Build<paddle::dialect::TanOp>(group_op->result(0));

return program;
}

TEST(Group, Clone) {
// Step 1: Construct pir::Program
std::shared_ptr<::pir::Program> program = BuildSingleGroupProgram();
::pir::IrContext* ctx = ::pir::IrContext::Instance();
::pir::PassManager pm(ctx);
// Step 2: Run TestGroupClonePass
pm.AddPass(std::make_unique<TestGroupClonePass>());
pm.Run(program.get());
}

0 comments on commit 75e62a2

Please sign in to comment.