Skip to content

Commit

Permalink
Refine colreduction fusion strategy of kStitch. (alibaba#1257)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunzhongOvO authored Oct 9, 2023
1 parent da990f4 commit c46084d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 76 deletions.
27 changes: 0 additions & 27 deletions .github/workflows/pytorch_pre_cpu.yml

This file was deleted.

28 changes: 0 additions & 28 deletions .github/workflows/pytorch_pre_gpu.yml

This file was deleted.

28 changes: 20 additions & 8 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1475,16 +1475,28 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
FusionPattern& lhs, FusionPattern& rhs,
FusionPattern& target) {
// TODO(Yancey): support fusion with different reduction type
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2RowReduction(op);
});
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2ColReduction(op);
});

if (has_row_reduction && has_col_reduciton) {
bool has_rank2_row_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2RowReduction(op); });
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
return false;
}

if (has_rank2_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (llvm::any_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
return isa<lmhlo::TransposeOp>(op);
})) {
return false;
}
}

return BaseFusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down
33 changes: 22 additions & 11 deletions tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,28 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
FusionPattern& lhs, FusionPattern& rhs,
FusionPattern& target) {
// TODO(Yancey): support fusion with different reduction type
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2RowReduction(op);
});
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2ColReduction(op);
});

if (has_row_reduction && has_col_reduciton) {
bool has_rank2_row_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2RowReduction(op); });
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
return false;
}

if (has_rank2_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (llvm::any_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
return isa<lmhlo::TransposeOp>(op);
})) {
return false;
}
}

return FusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down Expand Up @@ -428,9 +440,8 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return isRank2ColReduction(op)
? shapeAnalysis.isShapeEqual(ref_shape, shape)
: shapeAnalysis.isSameNumElements(ref_shape, shape);
return isRank2ColReduction(op) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ LogicalResult lowerWithScheduleColReduction(

SmallVector<Value, 4> yield_values_for_if;

ValueRange load_index({row_index, col_index});
SmallVector<Value, 2> load_index({row_index, col_index});
b.setInsertionPointToStart(&if_row_valid_op.getThenRegion().front());
int col_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
Expand Down Expand Up @@ -924,7 +924,7 @@ LogicalResult lowerWithScheduleColReductionTileH(

SmallVector<Value, 4> yield_values_for_if;

ValueRange load_index({row_index, col_index});
SmallVector<Value, 2> load_index({row_index, col_index});
b.setInsertionPointToStart(&if_row_valid_op.getThenRegion().front());
int col_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
Expand Down

0 comments on commit c46084d

Please sign in to comment.