Skip to content

Commit

Permalink
[InstCombine] Fix miscompilation in PR83947 (llvm#83993)
Browse files Browse the repository at this point in the history
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.

(cherry picked from commit a1a590e)
  • Loading branch information
dtcxzyw authored and llvmbot committed Mar 5, 2024
1 parent 461274b commit 5645cfa
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 5 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
/// lanes can be assumed active.
bool maskIsAllOneOrUndef(Value *Mask);

/// Given a mask vector of i1, Return true if any of the elements of this
/// predicate mask are known to be true or undef. That is, return true if at
/// least one lane can be assumed active.
bool maskContainsAllOneOrUndef(Value *Mask);

/// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
/// for each lane which may be active.
APInt possiblyDemandedEltsInMask(Value *Mask);
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
return true;
}

bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
assert(isa<VectorType>(Mask->getType()) &&
isa<IntegerType>(Mask->getType()->getScalarType()) &&
cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
1 &&
"Mask must be a vector of i1");

auto *ConstMask = dyn_cast<Constant>(Mask);
if (!ConstMask)
return false;
if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
return true;
if (isa<ScalableVectorType>(ConstMask->getType()))
return false;
for (unsigned
I = 0,
E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
I != E; ++I) {
if (auto *MaskElt = ConstMask->getAggregateElement(I))
if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
return true;
}
return false;
}

/// TODO: This is a lot like known bits, but for
/// vectors. Is there something we can common this with?
APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
Expand Down
13 changes: 8 additions & 5 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
if (maskContainsAllOneOrUndef(ConstMask)) {
Align Alignment =
cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
Alignment);
S->copyMetadata(II);
return S;
}
}
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
// lastlane), ptr
Expand Down
67 changes: 67 additions & 0 deletions llvm/test/Transforms/InstCombine/pr83947.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt -S -passes=instcombine < %s | FileCheck %s

@c = global i32 0, align 4
@b = global i32 0, align 4

define void @masked_scatter1() {
; CHECK-LABEL: define void @masked_scatter1() {
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
ret void
}

define void @masked_scatter2() {
; CHECK-LABEL: define void @masked_scatter2() {
; CHECK-NEXT: store i32 0, ptr @c, align 4
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
ret void
}

define void @masked_scatter3() {
; CHECK-LABEL: define void @masked_scatter3() {
; CHECK-NEXT: store i32 0, ptr @c, align 4
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
ret void
}

define void @masked_scatter4() {
; CHECK-LABEL: define void @masked_scatter4() {
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
ret void
}

define void @masked_scatter5() {
; CHECK-LABEL: define void @masked_scatter5() {
; CHECK-NEXT: store i32 0, ptr @c, align 4
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
ret void
}

define void @masked_scatter6() {
; CHECK-LABEL: define void @masked_scatter6() {
; CHECK-NEXT: store i32 0, ptr @c, align 4
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
ret void
}

define void @masked_scatter7() {
; CHECK-LABEL: define void @masked_scatter7() {
; CHECK-NEXT: call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
; CHECK-NEXT: ret void
;
call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
ret void
}

0 comments on commit 5645cfa

Please sign in to comment.