Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3249,10 +3249,11 @@ class LLVM_ABI TargetLoweringBase {
/// Return true on success. Currently only supports
/// llvm.vector.deinterleave{2,3,5,7}
///
/// \p LI is the accompanying load instruction.
/// \p Load is the accompanying load instruction. Can be either a plain load
/// instruction or a vp.load intrinsic.
/// \p DeinterleaveValues contains the deinterleaved values.
virtual bool
lowerDeinterleaveIntrinsicToLoad(LoadInst *LI,
lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask,
ArrayRef<Value *> DeinterleaveValues) const {
return false;
}
Expand Down
19 changes: 7 additions & 12 deletions llvm/lib/CodeGen/InterleavedAccessPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,37 +634,32 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
if (!LastFactor)
return false;

Value *Mask = nullptr;
if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadedVal)) {
if (VPLoad->getIntrinsicID() != Intrinsic::vp_load)
return false;
// Check mask operand. Handle both all-true/false and interleaved mask.
Value *WideMask = VPLoad->getOperand(1);
Value *Mask =
getMask(WideMask, Factor, cast<VectorType>(LastFactor->getType()));
Mask = getMask(WideMask, Factor, cast<VectorType>(LastFactor->getType()));
if (!Mask)
return false;

LLVM_DEBUG(dbgs() << "IA: Found a vp.load with deinterleave intrinsic "
<< *DI << " and factor = " << Factor << "\n");

// Since lowerInterleaveLoad expects Shuffles and LoadInst, use special
// TLI function to emit target-specific interleaved instruction.
if (!TLI->lowerInterleavedVPLoad(VPLoad, Mask, DeinterleaveValues))
return false;

} else {
auto *LI = cast<LoadInst>(LoadedVal);
if (!LI->isSimple())
return false;

LLVM_DEBUG(dbgs() << "IA: Found a load with deinterleave intrinsic " << *DI
<< " and factor = " << Factor << "\n");

// Try and match this with target specific intrinsics.
if (!TLI->lowerDeinterleaveIntrinsicToLoad(LI, DeinterleaveValues))
return false;
}

// Try and match this with target specific intrinsics.
if (!TLI->lowerDeinterleaveIntrinsicToLoad(cast<Instruction>(LoadedVal), Mask,
DeinterleaveValues))
return false;

for (Value *V : DeinterleaveValues)
if (V)
DeadInsts.insert(cast<Instruction>(V));
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17476,12 +17476,17 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
}

bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
LoadInst *LI, ArrayRef<Value *> DeinterleavedValues) const {
Instruction *Load, Value *Mask,
ArrayRef<Value *> DeinterleavedValues) const {
unsigned Factor = DeinterleavedValues.size();
if (Factor != 2 && Factor != 4) {
LLVM_DEBUG(dbgs() << "Matching ld2 and ld4 patterns failed\n");
return false;
}
auto *LI = dyn_cast<LoadInst>(Load);
if (!LI)
return false;
assert(!Mask && "Unexpected mask on a load\n");

Value *FirstActive = *llvm::find_if(DeinterleavedValues,
[](Value *V) { return V != nullptr; });
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ class AArch64TargetLowering : public TargetLowering {
unsigned Factor) const override;

bool lowerDeinterleaveIntrinsicToLoad(
LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const override;
Instruction *Load, Value *Mask,
ArrayRef<Value *> DeinterleaveValues) const override;

bool lowerInterleaveIntrinsicToStore(
StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override;
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ class RISCVTargetLowering : public TargetLowering {
unsigned Factor) const override;

bool lowerDeinterleaveIntrinsicToLoad(
LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const override;
Instruction *Load, Value *Mask,
ArrayRef<Value *> DeinterleaveValues) const override;

bool lowerInterleaveIntrinsicToStore(
StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override;
Expand Down
110 changes: 68 additions & 42 deletions llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,53 +234,100 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
return true;
}

static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
assert(N);
if (N == 1)
return true;

using namespace PatternMatch;
// Right now we're only recognizing the simplest pattern.
uint64_t C;
if (match(V, m_CombineOr(m_ConstantInt(C),
m_c_Mul(m_Value(), m_ConstantInt(C)))) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this patch since this was just moved from another location. Do we need to use m_c_Mul instead of m_MuL here? Constants should have been canonicalized to the right hand side by this point.

C && C % N == 0)
return true;

if (isPowerOf2_32(N)) {
KnownBits KB = llvm::computeKnownBits(V, DL);
return KB.countMinTrailingZeros() >= Log2_32(N);
}

return false;
}

bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const {
Instruction *Load, Value *Mask,
ArrayRef<Value *> DeinterleaveValues) const {
const unsigned Factor = DeinterleaveValues.size();
if (Factor > 8)
return false;

assert(LI->isSimple());
IRBuilder<> Builder(LI);
IRBuilder<> Builder(Load);

Value *FirstActive =
*llvm::find_if(DeinterleaveValues, [](Value *V) { return V != nullptr; });
VectorType *ResVTy = cast<VectorType>(FirstActive->getType());

const DataLayout &DL = LI->getDataLayout();
const DataLayout &DL = Load->getDataLayout();
auto *XLenTy = Type::getIntNTy(Load->getContext(), Subtarget.getXLen());

if (!isLegalInterleavedAccessType(ResVTy, Factor, LI->getAlign(),
LI->getPointerAddressSpace(), DL))
Value *Ptr, *VL;
Align Alignment;
if (auto *LI = dyn_cast<LoadInst>(Load)) {
assert(LI->isSimple());
Ptr = LI->getPointerOperand();
Alignment = LI->getAlign();
assert(!Mask && "Unexpected mask on a load\n");
Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
VL = isa<FixedVectorType>(ResVTy)
? Builder.CreateElementCount(XLenTy, ResVTy->getElementCount())
: Constant::getAllOnesValue(XLenTy);
} else {
auto *VPLoad = cast<VPIntrinsic>(Load);
assert(VPLoad->getIntrinsicID() == Intrinsic::vp_load &&
"Unexpected intrinsic");
Ptr = VPLoad->getMemoryPointerParam();
Alignment = VPLoad->getPointerAlignment().value_or(
DL.getABITypeAlign(ResVTy->getElementType()));

assert(Mask && "vp.load needs a mask!");

Value *WideEVL = VPLoad->getVectorLengthParam();
// Conservatively check if EVL is a multiple of factor, otherwise some
// (trailing) elements might be lost after the transformation.
if (!isMultipleOfN(WideEVL, Load->getDataLayout(), Factor))
return false;

VL = Builder.CreateZExt(
Builder.CreateUDiv(WideEVL,
ConstantInt::get(WideEVL->getType(), Factor)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an exact udiv?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably should be, but doing this revealed a deeper problem in the EVL recognition. We're not checking for overflow in the multiply check, and I believe that to be unsound. I am going to return to this, but want to get a few things off my queue first.

XLenTy);
}

Type *PtrTy = Ptr->getType();
unsigned AS = PtrTy->getPointerAddressSpace();
if (!isLegalInterleavedAccessType(ResVTy, Factor, Alignment, AS, DL))
return false;

Value *Return;
Type *PtrTy = LI->getPointerOperandType();
Type *XLenTy = Type::getIntNTy(LI->getContext(), Subtarget.getXLen());

if (isa<FixedVectorType>(ResVTy)) {
Value *VL = Builder.CreateElementCount(XLenTy, ResVTy->getElementCount());
Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
Return = Builder.CreateIntrinsic(FixedVlsegIntrIds[Factor - 2],
{ResVTy, PtrTy, XLenTy},
{LI->getPointerOperand(), Mask, VL});
{ResVTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
} else {
unsigned SEW = DL.getTypeSizeInBits(ResVTy->getElementType());
unsigned NumElts = ResVTy->getElementCount().getKnownMinValue();
Type *VecTupTy = TargetExtType::get(
LI->getContext(), "riscv.vector.tuple",
ScalableVectorType::get(Type::getInt8Ty(LI->getContext()),
Load->getContext(), "riscv.vector.tuple",
ScalableVectorType::get(Type::getInt8Ty(Load->getContext()),
NumElts * SEW / 8),
Factor);
Value *VL = Constant::getAllOnesValue(XLenTy);
Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount());

Function *VlsegNFunc = Intrinsic::getOrInsertDeclaration(
LI->getModule(), ScalableVlsegIntrIds[Factor - 2],
Load->getModule(), ScalableVlsegIntrIds[Factor - 2],
{VecTupTy, PtrTy, Mask->getType(), VL->getType()});

Value *Operands[] = {
PoisonValue::get(VecTupTy),
LI->getPointerOperand(),
Ptr,
Mask,
VL,
ConstantInt::get(XLenTy,
Expand All @@ -290,7 +337,7 @@ bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
CallInst *Vlseg = Builder.CreateCall(VlsegNFunc, Operands);

SmallVector<Type *, 2> AggrTypes{Factor, ResVTy};
Return = PoisonValue::get(StructType::get(LI->getContext(), AggrTypes));
Return = PoisonValue::get(StructType::get(Load->getContext(), AggrTypes));
for (unsigned i = 0; i < Factor; ++i) {
Value *VecExtract = Builder.CreateIntrinsic(
Intrinsic::riscv_tuple_extract, {ResVTy, VecTupTy},
Expand Down Expand Up @@ -370,27 +417,6 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(
return true;
}

static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
assert(N);
if (N == 1)
return true;

using namespace PatternMatch;
// Right now we're only recognizing the simplest pattern.
uint64_t C;
if (match(V, m_CombineOr(m_ConstantInt(C),
m_c_Mul(m_Value(), m_ConstantInt(C)))) &&
C && C % N == 0)
return true;

if (isPowerOf2_32(N)) {
KnownBits KB = llvm::computeKnownBits(V, DL);
return KB.countMinTrailingZeros() >= Log2_32(N);
}

return false;
}

/// Lower an interleaved vp.load into a vlsegN intrinsic.
///
/// E.g. Lower an interleaved vp.load (Factor = 2):
Expand Down
Loading