Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,14 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {

} // namespace

static bool isUniformShape(Value *V) {
static bool isShapePreserving(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I)
return true;

if (isa<SelectInst>(I))
return true;

if (I->isBinaryOp())
return true;

Expand Down Expand Up @@ -300,6 +303,16 @@ static bool isUniformShape(Value *V) {
}
}

/// Return an iterator over the operands of \p I that should share shape
/// information with \p I.
static iterator_range<Use *> getShapedOperandsForInst(Instruction *I) {
assert(isShapePreserving(I) &&
"Can't retrieve shaped operands for an instruction that does not "
"preserve shape information");
auto Ops = I->operands();
return isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
}

/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
Expand Down Expand Up @@ -329,9 +342,8 @@ computeShapeInfoForInst(Instruction *I,
return OpShape->second;
}

if (isUniformShape(I) || isa<SelectInst>(I)) {
auto Ops = I->operands();
auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
if (isShapePreserving(I)) {
auto ShapedOps = getShapedOperandsForInst(I);
// Find the first operand that has a known shape and use that.
for (auto &Op : ShapedOps) {
auto OpShape = ShapeMap.find(Op.get());
Expand Down Expand Up @@ -710,10 +722,9 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
return true;
default:
return isUniformShape(II);
break;
}
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
isa<SelectInst>(V);
return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}

/// Propagate the shape information of instructions to their users.
Expand Down Expand Up @@ -800,9 +811,8 @@ class LowerMatrixIntrinsics {
} else if (isa<StoreInst>(V)) {
// Nothing to do. We forward-propagated to this so we would just
// backward propagate to an instruction with an already known shape.
} else if (isUniformShape(V) || isa<SelectInst>(V)) {
auto Ops = cast<Instruction>(V)->operands();
auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
} else if (isShapePreserving(V)) {
auto ShapedOps = getShapedOperandsForInst(cast<Instruction>(V));
// Propagate to all operands.
ShapeInfo Shape = ShapeMap[V];
for (Use &U : ShapedOps) {
Expand Down