diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 957621c16bf2b..f30678f8cd664 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -20,6 +20,8 @@ // terminator operands of region branch ops, and, // (D) Removes simple and region branch ops that have all non-live results and // don't affect memory in any way, +// (E) Replaces dead operands of branch ops with `ub.poison`, relying on the +// canonicalizer to remove the corresponding block arguments. // // iff // @@ -101,24 +103,11 @@ struct OperandsToCleanup { bool replaceWithPoison = false; }; -struct BlockArgsToCleanup { - Block *b; - BitVector nonLiveArgs; -}; - -struct SuccessorOperandsToCleanup { - BranchOpInterface branch; - unsigned successorIndex; - BitVector nonLiveOperands; -}; - struct RDVFinalCleanupList { SmallVector operations; SmallVector functions; SmallVector operands; SmallVector results; - SmallVector blocks; - SmallVector successorOperands; }; // Some helper functions... @@ -476,11 +465,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, /// /// Otherwise, iterate through each successor block of `branchOp`. /// (1) For each successor block, gather all operands from all successors. -/// (2) Fetch their associated liveness analysis data and collect for future -/// removal. -/// (3) Identify and collect the dead operands from the successor block -/// as well as their corresponding arguments. - +/// (2) Determine which operands are dead using liveness analysis. +/// (3) Replace dead successor operands with ub.poison instead of erasing them. +/// Block arguments are left intact — the canonicalizer will remove them +/// once it sees all incoming operands are poison. static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet &nonLiveSet, RDVFinalCleanupList &cl) { @@ -490,7 +478,8 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, BitVector deadNonForwardedOperands = markLives(branchOp->getOperands(), nonLiveSet, la).flip(); unsigned numSuccessors = branchOp->getNumSuccessors(); - for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { + + for (unsigned succIdx : llvm::seq(0, numSuccessors)) { SuccessorOperands successorOperands = branchOp.getSuccessorOperands(succIdx); // Remove all non-forwarded operands from the bit vector. @@ -502,28 +491,20 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, return; } - for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { - Block *successorBlock = branchOp->getSuccessor(succIdx); - - // Do (1) - SuccessorOperands successorOperands = - branchOp.getSuccessorOperands(succIdx); - SmallVector operandValues; - for (unsigned operandIdx = 0; operandIdx < successorOperands.size(); - ++operandIdx) { - operandValues.push_back(successorOperands[operandIdx]); + // For each successor, find dead forwarded operands and + // schedule them for replacement with ub.poison. + BitVector opNonLive(branchOp->getNumOperands(), false); + for (unsigned succIdx : llvm::seq(0, numSuccessors)) { + for (OpOperand &opOperand : + branchOp.getSuccessorOperands(succIdx).getMutableForwardedOperands()) { + if (!hasLive(opOperand.get(), nonLiveSet, la)) + opNonLive.set(opOperand.getOperandNumber()); } - - // Do (2) - BitVector successorNonLive = - markLives(operandValues, nonLiveSet, la).flip(); - collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), - successorNonLive); - - // Do (3) - cl.blocks.push_back({successorBlock, successorNonLive}); - cl.successorOperands.push_back({branchOp, succIdx, successorNonLive}); } + + if (opNonLive.any()) + cl.operands.push_back({branchOp.getOperation(), opNonLive, + /*callee=*/nullptr, /*replaceWithPoison=*/true}); } /// Create ub.poison ops for the given values. If a value has no uses, return @@ -564,56 +545,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { TrackingListener listener; IRRewriter rewriter(ctx, &listener); - // 1. Blocks, We must remove the block arguments and successor operands before - // deleting the operation, as they may reside in the region operation. - LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; - for (auto &b : list.blocks) { - // blocks that are accessed via multiple codepaths processed once - if (b.b->getNumArguments() != b.nonLiveArgs.size()) - continue; - LDBG_OS([&](raw_ostream &os) { - os << "Erasing non-live arguments ["; - llvm::interleaveComma(b.nonLiveArgs.set_bits(), os); - os << "] from block #" << b.b->computeBlockNumber() << " in region #" - << b.b->getParent()->getRegionNumber() << " of operation " - << OpWithFlags(b.b->getParent()->getParentOp(), - OpPrintingFlags().skipRegions().printGenericOpForm()); - }); - // Note: Iterate from the end to make sure that that indices of not yet - // processes arguments do not change. - for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { - if (!b.nonLiveArgs[i]) - continue; - b.b->getArgument(i).dropAllUses(); - b.b->eraseArgument(i); - } - } - - // 2. Successor Operands - LDBG() << "Cleaning up " << list.successorOperands.size() - << " successor operand lists"; - for (auto &op : list.successorOperands) { - SuccessorOperands successorOperands = - op.branch.getSuccessorOperands(op.successorIndex); - // blocks that are accessed via multiple codepaths processed once - if (successorOperands.size() != op.nonLiveOperands.size()) - continue; - LDBG_OS([&](raw_ostream &os) { - os << "Erasing non-live successor operands ["; - llvm::interleaveComma(op.nonLiveOperands.set_bits(), os); - os << "] from successor " << op.successorIndex << " of branch: " - << OpWithFlags(op.branch.getOperation(), - OpPrintingFlags().skipRegions().printGenericOpForm()); - }); - // it iterates backwards because erase invalidates all successor indexes - for (int i = successorOperands.size() - 1; i >= 0; --i) { - if (!op.nonLiveOperands[i]) - continue; - successorOperands.erase(i); - } - } - - // 3. Functions + // 1. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; // Record which function arguments were erased so we can shrink call-site // argument segments for CallOpInterface operations (e.g. ops using @@ -647,7 +579,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { (void)f.funcOp.eraseResults(f.nonLiveRets); } - // 4. Operands + // 2. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperandsToCleanup &o : list.operands) { // Handle call-specific cleanup only when we have a cached callee reference. @@ -705,7 +637,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { } } - // 5. Results + // 3. Results LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { LDBG_OS([&](raw_ostream &os) { @@ -718,7 +650,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { dropUsesAndEraseResults(rewriter, r.op, r.nonLive); } - // 6. Operations + // 4. Operations LDBG() << "Cleaning up " << list.operations.size() << " operations"; for (Operation *op : list.operations) { LDBG() << "Erasing operation: " @@ -755,7 +687,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { rewriter.eraseOp(op); } - // 7. Remove all dead poison ops. + // 5. Remove all dead poison ops. for (ub::PoisonOp poisonOp : listener.poisonOps) { if (poisonOp.use_empty()) poisonOp.erase(); @@ -808,20 +740,17 @@ void RemoveDeadValues::runOnOperation() { if (!canonicalize) return; - // Canonicalize all region branch ops. - SmallVector opsToCanonicalize; - module->walk([&](RegionBranchOpInterface regionBranchOp) { - opsToCanonicalize.push_back(regionBranchOp.getOperation()); - }); - // Collect all canonicalization patterns for region branch ops. + // Canonicalize all region branch ops and branch ops. RewritePatternSet owningPatterns(context); DenseSet populatedPatterns; - for (Operation *op : opsToCanonicalize) + module->walk([&](Operation *op) { + if (!isa(op)) + return; if (std::optional info = op->getRegisteredInfo()) if (populatedPatterns.insert(*info).second) info->getCanonicalizationPatterns(owningPatterns, context); - if (failed(applyOpPatternsGreedily(opsToCanonicalize, - std::move(owningPatterns)))) { + }); + if (failed(applyPatternsGreedily(module, std::move(owningPatterns)))) { module->emitError("greedy pattern rewrite failed to converge"); signalPassFailure(); } diff --git a/mlir/test/Dialect/SPIRV/IR/return-ops.mlir b/mlir/test/Dialect/SPIRV/IR/return-ops.mlir index 2f945b24d24fd..b12cba5f7a074 100644 --- a/mlir/test/Dialect/SPIRV/IR/return-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/return-ops.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s --remove-dead-values | FileCheck %s +// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE // Make sure that the return value op is considered as a return-like op and // remains live. @@ -8,6 +9,9 @@ // CHECK-NEXT: %[[BITCAST0:.*]] = spirv.Bitcast %[[ARG1]] : vector<2xi32> to vector<2xf32> // CHECK-NEXT: %[[BITCAST1:.*]] = spirv.Bitcast %[[BITCAST0]] : vector<2xf32> to vector<2xi32> // CHECK-NEXT: spirv.ReturnValue %[[BITCAST1]] : vector<2xi32> +// CHECK-CANONICALIZE-LABEL: @preserve_return_value +// CHECK-CANONICALIZE-SAME: (%[[ARG0:.*]]: vector<2xi32>, %[[ARG1:.*]]: vector<2xi32>) -> vector<2xi32> +// CHECK-CANONICALIZE-NEXT: spirv.ReturnValue %[[ARG1]] : vector<2xi32> spirv.func @preserve_return_value(%arg0: vector<2xi32>, %arg1: vector<2xi32>) -> vector<2xi32> "None" { %0 = spirv.Bitcast %arg0 : vector<2xi32> to vector<2xf32> diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 19bc6b2fddd66..22e4d66ef0ea5 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -35,20 +35,24 @@ module @named_module_acceptable { func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) { %non_live = arith.constant 0 : i32 // CHECK-NOT: arith.constant + // CHECK-CANONICALIZE-NOT: arith.constant cf.br ^bb1(%non_live : i32) - // CHECK: cf.br ^[[BB1:bb[0-9]+]] + // CHECK: cf.br ^[[BB1:bb[0-9]+]](%{{.*}} : i32) + // CHECK-CANONICALIZE: cf.br ^[[BB1:bb[0-9]+]] ^bb1(%non_live_1 : i32): - // CHECK: ^[[BB1]]: + // CHECK: ^[[BB1]](%{{.*}}: i32): + // CHECK-CANONICALIZE: ^[[BB1]]: %non_live_5 = arith.constant 1 : i32 cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32) - // CHECK: cf.br ^[[BB3:bb[0-9]+]] - // CHECK-NOT: i32 + // CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}}, %{{.*}} : i32, i32) + // CHECK-CANONICALIZE: cf.cond_br %arg0, ^[[BB1]], ^[[BB2:bb[0-9]+]] ^bb3(%non_live_2 : i32, %non_live_6 : i32): - // CHECK: ^[[BB3]]: + // CHECK: ^[[BB3]](%{{.*}}: i32, %{{.*}}: i32): cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32) - // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]] + // CHECK: cf.cond_br %arg0, ^[[BB1]](%{{.*}} : i32), ^[[BB4:bb[0-9]+]](%{{.*}} : i32) ^bb4(%non_live_4 : i32): - // CHECK: ^[[BB4]]: + // CHECK: ^[[BB4]](%{{.*}}: i32): + // CHECK-CANONICALIZE: ^[[BB2]]: return } @@ -345,9 +349,9 @@ func.func private @identity(%arg1 : i32) -> (i32) { // Note that this cleanup cannot be done by the `canonicalize` pass. // // CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref) { +// CHECK-CANONICALIZE: %[[c10:.*]] = arith.constant 10 // CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]] // CHECK-CANONICALIZE-NEXT: case 1 { -// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10 // CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][] // CHECK-CANONICALIZE: scf.yield // CHECK-CANONICALIZE-NEXT: } @@ -476,6 +480,47 @@ func.func @kernel(%arg0: memref<18xf32>) { // ----- +// Test that RemoveDeadValues does not crash when gpu.launch appears in a block +// with multiple predecessors. The dead branch operand (%c20) must be replaced +// with ub.poison, gpu.launch and its grid/block size operands must be +// preserved, and the live block argument must remain intact. +// +// CHECK-LABEL: func.func @gpu_launch_in_multi_predecessor_block +// CHECK: arith.constant true +// CHECK: cf.cond_br +// CHECK: arith.constant 10 +// CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}} : i64) +// CHECK: ub.poison +// CHECK: cf.br ^[[BB3]](%{{.*}} : i64) +// CHECK: ^[[BB3]](%{{.*}}: i64): +// CHECK: return +// CHECK-NOT: arith.constant 20 +// +// CHECK-CANONICALIZE-LABEL: func.func @gpu_launch_in_multi_predecessor_block +// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10 : i64 +// CHECK-CANONICALIZE-NEXT: return %[[c10]] +func.func @gpu_launch_in_multi_predecessor_block() -> i64 { + %cond = arith.constant true + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + %c10 = arith.constant 10 : i64 + cf.br ^bb3(%c10 : i64) +^bb2: + %c20 = arith.constant 20 : i64 + cf.br ^bb3(%c20 : i64) +^bb3(%arg0: i64): + %c1 = arith.constant 1 : index + gpu.launch + blocks(%bx, %by, %bz) in (%gx = %c1, %gy = %c1, %gz = %c1) + threads(%tx, %ty, %tz) in (%bsx = %c1, %bsy = %c1, %bsz = %c1) { + %blk_x = gpu.block_id x + %thr_x = gpu.thread_id x + gpu.terminator + } + func.return %arg0 : i64 +} + +// ----- // CHECK-LABEL: llvm_unreachable // CHECK-LABEL: @fn_with_llvm_unreachable @@ -768,6 +813,7 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() { // CHECK: return %[[while]]#0 // CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args() +// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32 // CHECK-CANONICALIZE: %[[c5:.*]] = arith.constant 5 : i32 // CHECK-CANONICALIZE: %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 { // CHECK-CANONICALIZE: vector.print %[[arg0]] @@ -775,7 +821,6 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() { // CHECK-CANONICALIZE: scf.condition(%[[cmpi]]) %[[arg0]] // CHECK-CANONICALIZE: } do { // CHECK-CANONICALIZE: ^bb0(%[[arg1:.*]]: i32): -// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32 // CHECK-CANONICALIZE: scf.yield %[[p0]] // CHECK-CANONICALIZE: } // CHECK-CANONICALIZE: return %[[while]] @@ -799,7 +844,13 @@ func.func @scf_while_dead_iter_args() -> i32 { // ----- -// CHECK-LABEL: func.func @replace_dead_operation_results_with_poison +// Check that this prevents a crash in the canonicalization phase which +// happens after the dead value removal phase. Also check that only used +// results of an erased op are replaced with ub.poison. + +// CHECK-CANONICALIZE-LABEL: func.func @replace_dead_operation_results_with_poison +// CHECK-CANONICALIZE-NEXT: %[[p:.*]] = ub.poison : vector<1xindex> +// CHECK-CANONICALIZE-NEXT: return %[[p]] func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> vector<1xindex> { %1 = scf.while (%arg0 = %0) : (vector<1xindex>) -> vector<1xindex> { %cond = arith.constant true @@ -809,15 +860,6 @@ func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> ve scf.yield %arg0 : vector<1xindex> } %2 = scf.while (%arg0 = %1) : (vector<1xindex>) -> vector<1xindex> { - // Check that the binary value in condition is replaced with poison, and - // the condition itself is well-formed IR. This prevents a crash in the - // canonicalization phase which happens after the dead value removal phase. - // Also check that only used results of an erased op are replaced with ub.poison. - // CHECK-CANONICALIZE: %[[COND:.*]] = ub.poison : i1 - // CHECK-CANONICALIZE-NEXT: %[[NEXT:.*]] = ub.poison : vector<1xindex> - // CHECK-CANONICALIZE-NEXT: scf.condition(%[[COND]]) %[[NEXT]] - // CHECK-CANONICALIZE-NOT: ub.poison : i32 - // CHECK-CANONICALIZE-NOT: "test.three" %cond, %unused, %next = "test.three"(%1) : (vector<1xindex>) -> (i1, i32, vector<1xindex>) scf.condition(%cond) %next : vector<1xindex> } do {