Skip to content

Commit 03fc921

Browse files
committed
[InstCombine] Minor Tweaks
1 parent 96ed7cb commit 03fc921

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -804,25 +804,16 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
804804

805805
bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
806806
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
807-
auto *BinOp = dyn_cast<BinaryOperator>(&I);
808-
if (!BinOp || !BinOp->isBitwiseLogicOp())
807+
Value *LHSSrc, *RHSSrc;
808+
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
809+
m_BitCast(m_Value(RHSSrc)))))
809810
return false;
810811

811-
Value *LHS = BinOp->getOperand(0);
812-
Value *RHS = BinOp->getOperand(1);
813-
814-
// Both operands must be bitcasts
815-
auto *LHSCast = dyn_cast<BitCastInst>(LHS);
816-
auto *RHSCast = dyn_cast<BitCastInst>(RHS);
817-
if (!LHSCast || !RHSCast)
818-
return false;
819-
820-
Value *LHSSrc = LHSCast->getOperand(0);
821-
Value *RHSSrc = RHSCast->getOperand(0);
822-
823812
// Source types must match
824813
if (LHSSrc->getType() != RHSSrc->getType())
825814
return false;
815+
if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
816+
return false;
826817

827818
// Only handle vector types
828819
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
@@ -831,15 +822,30 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
831822
return false;
832823

833824
// Same total bit width
834-
if (SrcVecTy->getPrimitiveSizeInBits() != DstVecTy->getPrimitiveSizeInBits())
835-
return false;
825+
assert(SrcVecTy->getPrimitiveSizeInBits() ==
826+
DstVecTy->getPrimitiveSizeInBits() &&
827+
"Bitcast should preserve total bit width");
828+
829+
// Cost Check :
830+
// OldCost = bitlogic + 2*bitcasts
831+
// NewCost = bitlogic + bitcast
832+
auto *BinOp = cast<BinaryOperator>(&I);
833+
InstructionCost OldCost =
834+
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
835+
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
836+
TTI::CastContextHint::None) +
837+
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
838+
TTI::CastContextHint::None);
839+
InstructionCost NewCost =
840+
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
841+
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
842+
TTI::CastContextHint::None);
836843

837-
// Cost check: prefer operations on narrower element types
838-
unsigned SrcEltBits = SrcVecTy->getScalarSizeInBits();
839-
unsigned DstEltBits = DstVecTy->getScalarSizeInBits();
844+
LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
845+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
846+
<< "\n");
840847

841-
// Prefer smaller element sizes (more elements, finer granularity)
842-
if (SrcEltBits > DstEltBits)
848+
if (NewCost > OldCost)
843849
return false;
844850

845851
// Create the operation on the source type
@@ -848,6 +854,8 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
848854
if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
849855
NewBinOp->copyIRFlags(BinOp);
850856

857+
Worklist.pushValue(NewOp);
858+
851859
// Bitcast the result back
852860
Value *Result = Builder.CreateBitCast(NewOp, I.getType());
853861
replaceValue(I, *Result);

llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@ define i32 @test_and(<16 x i32> %a, ptr %b) {
77
; CHECK-LABEL: @test_and(
88
; CHECK-NEXT: entry:
99
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
10-
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
11-
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
12-
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
10+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
11+
; CHECK-NEXT: [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A:%.*]]
1312
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
1413
; CHECK-NEXT: ret i32 [[TMP3]]
1514
;
@@ -26,9 +25,8 @@ define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
2625
; CHECK-NEXT: entry:
2726
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
2827
; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], splat (i32 16)
29-
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
30-
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
31-
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
28+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
29+
; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i32> [[TMP0]], [[A_MASKED]]
3230
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
3331
; CHECK-NEXT: ret i32 [[TMP3]]
3432
;
@@ -47,15 +45,13 @@ define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
4745
; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], splat (i32 255)
4846
; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], splat (i32 255)
4947
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
50-
; CHECK-NEXT: [[TMP0:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4)
51-
; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
52-
; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i8> [[TMP0]], [[TMP1]]
53-
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[TMP2]] to <16 x i32>
54-
; CHECK-NEXT: [[TMP4:%.*]] = and <16 x i8> [[WIDE_LOAD]], splat (i8 15)
55-
; CHECK-NEXT: [[TMP5:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
56-
; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP4]], [[TMP5]]
48+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
49+
; CHECK-NEXT: [[TMP6:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4)
5750
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
58-
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP7]]
51+
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i32> [[TMP7]], [[V_MASKED]]
52+
; CHECK-NEXT: [[TMP4:%.*]] = and <16 x i32> [[TMP0]], splat (i32 15)
53+
; CHECK-NEXT: [[TMP5:%.*]] = or <16 x i32> [[TMP4]], [[U_MASKED]]
54+
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP5]]
5955
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
6056
; CHECK-NEXT: ret i32 [[TMP9]]
6157
;
@@ -81,9 +77,8 @@ define i32 @phi_bug(<16 x i32> %a, ptr %b) {
8177
; CHECK: vector.body:
8278
; CHECK-NEXT: [[A_PHI:%.*]] = phi <16 x i32> [ [[A:%.*]], [[ENTRY:%.*]] ]
8379
; CHECK-NEXT: [[WIDE_LOAD_PHI:%.*]] = phi <16 x i8> [ [[WIDE_LOAD]], [[ENTRY]] ]
84-
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_PHI]] to <16 x i8>
85-
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD_PHI]], [[TMP0]]
86-
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
80+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD_PHI]] to <16 x i32>
81+
; CHECK-NEXT: [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A_PHI]]
8782
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
8883
; CHECK-NEXT: ret i32 [[TMP3]]
8984
;

0 commit comments

Comments
 (0)