diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index e549a56a6f960..9820a91291fdb 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -66,11 +66,10 @@ class IntegerRangeAnalysis /// function calls `InferIntRangeInterface` to provide values for block /// arguments or tries to reduce the range on loop induction variables with /// known bounds. - void - visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, - ValueRange successorInputs, - ArrayRef argLattices, - unsigned firstIndex) override; + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ValueRange nonSuccessorInputs, + ArrayRef nonSuccessorInputLattices) override; }; /// Succeeds if an op can be converted to its unsigned equivalent without diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 02f699de06f99..df50d8d193aeb 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -215,8 +215,8 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { /// of loops). virtual void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, - ValueRange successorInputs, ArrayRef argLattices, - unsigned firstIndex) = 0; + ValueRange nonSuccessorInputs, + ArrayRef nonSuccessorInputLattices) = 0; /// Get the lattice element of a value. virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; @@ -322,19 +322,17 @@ class SparseForwardDataFlowAnalysis } /// Given an operation with possible region control-flow, the lattices of the - /// operands, and a region successor, compute the lattice values for block - /// arguments that are not accounted for by the branching control flow (ex. - /// the bounds of loops). By default, this method marks all such lattice - /// elements as having reached a pessimistic fixpoint. `firstIndex` is the - /// index of the first element of `argLattices` that is set by control-flow. - virtual void visitNonControlFlowArguments(Operation *op, - const RegionSuccessor &successor, - ValueRange successorInputs, - ArrayRef argLattices, - unsigned firstIndex) { - setAllToEntryStates(argLattices.take_front(firstIndex)); - setAllToEntryStates( - argLattices.drop_front(firstIndex + successorInputs.size())); + /// operands, and a region successor, compute the lattice values for + /// non-successor-inputs (ex. loop induction variables) of a given region + /// successor. By default, this method marks all lattice elements as having + /// reached a pessimistic fixpoint. + virtual void + visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, + ValueRange nonSuccessorInputs, + ArrayRef nonSuccessorInputLattices) { + assert(nonSuccessorInputs.size() == nonSuccessorInputLattices.size() && + "size mismatch"); + setAllToEntryStates(nonSuccessorInputLattices); } protected: @@ -385,14 +383,14 @@ class SparseForwardDataFlowAnalysis } void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, - ValueRange successorInputs, ArrayRef argLattices, - unsigned firstIndex) override { + ValueRange nonSuccessorInputs, + ArrayRef nonSuccessorInputLattices) override { visitNonControlFlowArguments( - op, successor, successorInputs, - {reinterpret_cast(argLattices.begin()), - argLattices.size()}, - firstIndex); + op, successor, nonSuccessorInputs, + {reinterpret_cast(nonSuccessorInputLattices.begin()), + nonSuccessorInputLattices.size()}); } + void setToEntryState(AbstractSparseLattice *lattice) override { return setToEntryState(reinterpret_cast(lattice)); } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 012d8384d3098..7b567f043577a 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -138,8 +138,11 @@ LogicalResult IntegerRangeAnalysis::visitOperation( } void IntegerRangeAnalysis::visitNonControlFlowArguments( - Operation *op, const RegionSuccessor &successor, ValueRange successorInputs, - ArrayRef argLattices, unsigned firstIndex) { + Operation *op, const RegionSuccessor &successor, + ValueRange nonSuccessorInputs, + ArrayRef nonSuccessorInputLattices) { + assert(nonSuccessorInputs.size() == nonSuccessorInputLattices.size() && + "size mismatch"); if (auto inferrable = dyn_cast(op)) { LDBG() << "Inferring ranges for " << OpWithFlags(op, OpPrintingFlags().skipRegions()); @@ -156,7 +159,11 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return; LDBG() << "Inferred range " << attrs; - IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; + auto it = llvm::find(successor.getSuccessor()->getArguments(), arg); + unsigned nonSuccessorInputIdx = + std::distance(successor.getSuccessor()->getArguments().begin(), it); + IntegerValueRangeLattice *lattice = + nonSuccessorInputLattices[nonSuccessorInputIdx]; IntegerValueRange oldRange = lattice->getValue(); ChangeResult changed = lattice->join(attrs); @@ -208,7 +215,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( loop.getLoopInductionVars(); if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( - op, successor, successorInputs, argLattices, firstIndex); + op, successor, nonSuccessorInputs, nonSuccessorInputLattices); } // This shouldn't be returning nullopt if there are indunction variables. SmallVector lowerBounds = *loop.getLoopLowerBounds(); @@ -246,5 +253,5 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( } return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( - op, successor, successorInputs, argLattices, firstIndex); + op, successor, nonSuccessorInputs, nonSuccessorInputLattices); } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 7dad9676e7e53..90f2a588d1ca4 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -185,10 +185,10 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { block->getParent(), argLattices); } - // Otherwise, we can't reason about the data-flow. - return visitNonControlFlowArgumentsImpl( - block->getParentOp(), RegionSuccessor(block->getParent()), ValueRange(), - argLattices, /*firstIndex=*/0); + // All block arguments are non-successor-inputs. + return visitNonControlFlowArgumentsImpl(block->getParentOp(), + RegionSuccessor(block->getParent()), + block->getArguments(), argLattices); } // Iterate over the predecessors of the non-entry block. @@ -309,23 +309,30 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( assert(inputs.size() == operands->size() && "expected the same number of successor inputs as operands"); + auto valueToLattices = [&](Value v) { return getLatticeElement(v); }; unsigned firstIndex = 0; if (inputs.size() != lattices.size()) { if (!point->isBlockStart()) { if (!inputs.empty()) firstIndex = cast(inputs.front()).getResultNumber(); - visitNonControlFlowArgumentsImpl( - branch, RegionSuccessor::parent(), - branch->getResults().slice(firstIndex, inputs.size()), lattices, - firstIndex); + SmallVector nonSuccessorInputs = + branch.getNonSuccessorInputs(RegionSuccessor::parent()); + SmallVector nonSuccessorInputLattices = + llvm::map_to_vector(nonSuccessorInputs, valueToLattices); + visitNonControlFlowArgumentsImpl(branch, RegionSuccessor::parent(), + nonSuccessorInputs, + nonSuccessorInputLattices); } else { if (!inputs.empty()) firstIndex = cast(inputs.front()).getArgNumber(); Region *region = point->getBlock()->getParent(); - visitNonControlFlowArgumentsImpl( - branch, RegionSuccessor(region), - region->getArguments().slice(firstIndex, inputs.size()), lattices, - firstIndex); + SmallVector nonSuccessorInputs = + branch.getNonSuccessorInputs(RegionSuccessor(region)); + SmallVector nonSuccessorInputLattices = + llvm::map_to_vector(nonSuccessorInputs, valueToLattices); + visitNonControlFlowArgumentsImpl(branch, RegionSuccessor(region), + nonSuccessorInputs, + nonSuccessorInputLattices); } }