Skip to content

Commit

Permalink
[flang] Do not inline SUM with invalid DIM argument. (#118911)
Browse files Browse the repository at this point in the history
Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.
  • Loading branch information
vzakhari authored Dec 9, 2024
1 parent 1ca3927 commit 084451c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
16 changes: 13 additions & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
assert(dimVal > 0 && "DIM must be present and a positive constant");
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
Expand Down Expand Up @@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

Expand Down Expand Up @@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
// would avoid creating a temporary for the elemental array expression.
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (mlir::Value dim = sum.getDim()) {
if (fir::getIntIfConstant(dim)) {
if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
return false;
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(
sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
// CHECK: return
// CHECK: }

// negative: invalid dim==0
func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
%cst = arith.constant 0 : i32
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
return
}
// CHECK-LABEL: func.func @sum_invalid_dim0(
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

// negative: invalid dim>rank
func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
%cst = arith.constant 3 : i32
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
return
}
// CHECK-LABEL: func.func @sum_invalid_dim_big(
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

0 comments on commit 084451c

Please sign in to comment.