Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanyonggong committed Nov 13, 2024
2 parents 57c3bbf + e97b2e1 commit da03c88
Show file tree
Hide file tree
Showing 218 changed files with 4,347 additions and 1,013 deletions.
3 changes: 3 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/reduce_as_to_sum_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_assign_out_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h"
Expand Down Expand Up @@ -120,6 +121,7 @@ void ApplyPdToCinnPass(
const std::function<std::shared_ptr<::pir::PassManager>()>&
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
pass_manager->AddPass(cinn::dialect::ir::CreateReduceAsToSumPass());
pass_manager->AddPass(pir::CreateFusedGemmEpiloguePass());
if (FLAGS_enable_fuse_parallel_matmul_pass) {
pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
Expand All @@ -132,6 +134,7 @@ void ApplyPdToCinnPass(

pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());

// pass_manager->EnableIRPrinting();
pass_manager->Run(program);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
return false;
}

// Results num of FusionOp may be more than the signal op when the signal
// has multiple downstream ops including yieldstore op.
auto num_results = paddle_op.value()->num_results();
for (size_t i = 0; i * num_results < fusion_op.num_results(); i++) {
for (size_t j = 0; j < num_results; ++j) {
rewriter.ReplaceAllUsesWith(fusion_op.result(i * num_results + j),
paddle_op.value()->result(j));
}
// TODO(phlrain): support multi output
PADDLE_ENFORCE_EQ(
paddle_op.value()->num_results(),
1u,
::common::errors::PreconditionNotMet("Only support ONE output op"));

for (size_t i = 0; i < fusion_op.num_results(); ++i) {
rewriter.ReplaceAllUsesWith(fusion_op.result(i),
paddle_op.value()->result(0));
}

rewriter.EraseOp(fusion_op);
Expand Down Expand Up @@ -119,17 +120,48 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
return paddle_reshape;
}

pir::Operation* AssignOutOpPattern(
pir::Operation* op,
pir::PatternRewriter& rewriter) const { // NOLINT
PADDLE_ENFORCE(
op->isa<paddle::dialect::AssignOut_Op>(),
::common::errors::InvalidArgument(
"Input should be paddle::dialect::AssignOut_Op, but got %s",
op->name()));
auto assign_out_op = op->dyn_cast<paddle::dialect::AssignOut_Op>();

auto paddle_assign_out_ = rewriter.Build<paddle::dialect::AssignOut_Op>(
assign_out_op->operand_source(0), assign_out_op->operand_source(1));
return paddle_assign_out_;
}

pir::Operation* CastOpPattern(
pir::Operation* op,
pir::PatternRewriter& rewriter) const { // NOLINT
PADDLE_ENFORCE(
op->isa<paddle::dialect::CastOp>(),
::common::errors::InvalidArgument(
"Input should be paddle::dialect::CastOp, but got %s", op->name()));
auto cast_op = op->dyn_cast<paddle::dialect::CastOp>();

auto paddle_cast_op = rewriter.Build<paddle::dialect::CastOp>(
cast_op->operand_source(0), cast_op->attributes());
return paddle_cast_op;
}

const std::unordered_map<std::string, CinnOpHandler>& op_handler_map() const {
static std::unordered_map<std::string, CinnOpHandler> handler_map = {
{cinn::dialect::ReshapeOp::name(), &FusionOpPattern::ReshapeOpPattern},
{paddle::dialect::AssignOut_Op::name(),
&FusionOpPattern::AssignOutOpPattern},
{paddle::dialect::CastOp::name(), &FusionOpPattern::CastOpPattern},
};
return handler_map;
}

std::optional<pir::Operation*> FallBackOp(
pir::Operation* op,
pir::PatternRewriter& rewriter) const { // NOLINT
std::cerr << "fall back op " << op->name() << std::endl;
auto it = op_handler_map().find(op->name());
if (it == op_handler_map().end()) {
VLOG(4) << "No fallback handler for op: " << op->name();
Expand Down
157 changes: 0 additions & 157 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1212,162 +1212,6 @@ class GatherOpPattern
}
};

class ReduceAsOpPattern
: public pir::OpRewritePattern<paddle::dialect::ReduceAsOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ReduceAsOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::ReduceAsOp op,
pir::PatternRewriter &rewriter) const override {
auto x_shape =
phi::vectorize(op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims());

auto y_shape =
phi::vectorize(op->operand_source(1)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims());

size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
int64_t compare_offset = x_rank - y_rank;
std::vector<int64_t> reduce_axis;

std::vector<int64_t> squeeze_axis;
for (int64_t i = 0; i < compare_offset; ++i) {
reduce_axis.push_back(i);
squeeze_axis.push_back(i);
}

bool x_y_shape_equal = false;
std::vector<symbol::DimExpr> output_dims;
bool is_static_shape = IsStaicShape(x_shape, y_shape);
if (is_static_shape) {
x_y_shape_equal = (x_shape == y_shape);
ProcessStaticShape(x_shape, y_shape, &reduce_axis);
} else {
bool can_repalce =
ProcessDynamicShape(op, &reduce_axis, &output_dims, &x_y_shape_equal);
if (!can_repalce) {
return true;
}
}
if (x_y_shape_equal) {
rewriter.ReplaceAllUsesWith(op.result(0), op.operand_source(0));
return true;
}

auto pir_dtype =
op->operand_source(0).type().dyn_cast<pir::DenseTensorType>().dtype();
auto phi_dtype = paddle::dialect::TransToPhiDataType(pir_dtype);
auto sum_op = rewriter.Build<paddle::dialect::SumOp>(
op.operand_source(0), reduce_axis, phi_dtype, true);

auto new_output = sum_op.result(0);

if (!is_static_shape) {
auto &shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
shape_analysis.SetShapeOrDataForValue(
new_output,
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_dims)});
}

if (squeeze_axis.size() > 0) {
new_output =
rewriter.Build<paddle::dialect::SqueezeOp>(new_output, squeeze_axis)
.result(0);
if (!is_static_shape) {
auto &shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
shape_analysis.SetShapeOrDataForValue(
new_output,
shape_analysis.GetShapeOrDataForValue(op->operand_source(1)));
}
}

rewriter.ReplaceAllUsesWith(op.result(0), new_output);

rewriter.EraseOp(op);

return true;
}

private:
bool IsStaicShape(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &y_shape) const {
bool x_has_dynamic_shape =
std::find(x_shape.begin(), x_shape.end(), -1) != x_shape.end();
bool y_has_dynamic_shape =
std::find(y_shape.begin(), y_shape.end(), -1) != y_shape.end();

return (!x_has_dynamic_shape) && (!y_has_dynamic_shape);
}

void ProcessStaticShape(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &y_shape,
std::vector<int64_t> *reduce_axis) const {
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();

// Get reduc aixs and
int64_t compare_offset = x_rank - y_rank;

for (size_t i = 0; i < y_rank; ++i) {
if (y_shape[i] == 1 && x_shape[i + compare_offset] != 1) {
reduce_axis->push_back(compare_offset + i);
}
}
}
bool ProcessDynamicShape(paddle::dialect::ReduceAsOp op,
std::vector<int64_t> *reduce_axis,
std::vector<symbol::DimExpr> *output_dims,
bool *x_y_shape_equal) const {
auto &shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

const auto &x_shape =
shape_analysis.GetShapeOrDataForValue(op->operand_source(0)).shape();
const auto &y_shape =
shape_analysis.GetShapeOrDataForValue(op->operand_source(1)).shape();

if (x_shape == y_shape) {
*x_y_shape_equal = true;
return true;
} else {
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();

int64_t compare_offset = x_rank - y_rank;
bool can_replace_with_sum = true;
for (int64_t i = 0; i < compare_offset; ++i) {
output_dims->push_back(symbol::DimExpr(1));
}

for (size_t i = 0; i < y_rank; ++i) {
bool x_dim_i_eq_one = x_shape[i + compare_offset].isa<int64_t>() &&
x_shape[i + compare_offset].Get<int64_t>() == 1;
bool y_dim_i_eq_one =
y_shape[i].isa<int64_t>() && y_shape[i].Get<int64_t>() == 1;
if (y_dim_i_eq_one && (!x_dim_i_eq_one)) {
reduce_axis->push_back(compare_offset + i);
output_dims->push_back(symbol::DimExpr(1));
} else if (x_shape[i + compare_offset] != y_shape[i]) {
can_replace_with_sum = false;
break;
} else {
output_dims->push_back(y_shape[i]);
}
}
return can_replace_with_sum;
}
}
};

PdOpToCinnOpPass::PdOpToCinnOpPass()
: pir::PatternRewritePass("pd_to_cinn_pass", 1) {}

Expand Down Expand Up @@ -1400,7 +1244,6 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<SigmoidOpPattern>(context);
ps.Add<GatherOpPattern>(context);
ps.Add<FlattenOpPattern>(context);
ps.Add<ReduceAsOpPattern>(context);

return ps;
}
Expand Down
Loading

0 comments on commit da03c88

Please sign in to comment.