diff --git a/llvm/include/llvm/IR/VectorTypeUtils.h b/llvm/include/llvm/IR/VectorTypeUtils.h index e3d7fadad6089..3db5c4a2b6576 100644 --- a/llvm/include/llvm/IR/VectorTypeUtils.h +++ b/llvm/include/llvm/IR/VectorTypeUtils.h @@ -14,6 +14,11 @@ namespace llvm { +/// Returns true if \p IID is a vector intrinsic that returns a struct with a +/// scalar element at index \p EleIdx. +LLVM_ABI bool isVectorIntrinsicWithStructReturnScalarAtField(unsigned IID, + unsigned EleIdx); + /// A helper function for converting Scalar types to vector types. If /// the incoming type is void, we return void. If the EC represents a /// scalar, we return the scalar type. @@ -31,7 +36,11 @@ inline Type *toVectorTy(Type *Scalar, unsigned VF) { /// Note: /// - If \p EC is scalar, \p StructTy is returned unchanged /// - Only unpacked literal struct types are supported -LLVM_ABI Type *toVectorizedStructTy(StructType *StructTy, ElementCount EC); +/// vector types. +/// - If IID (Intrinsic ID) is provided, only fields that are vector types +/// are widened. +LLVM_ABI Type *toVectorizedStructTy(StructType *StructTy, ElementCount EC, + unsigned IID = 0); /// A helper for converting structs of vector types to structs of scalar types. /// Note: Only unpacked literal struct types are supported. @@ -52,9 +61,9 @@ LLVM_ABI bool canVectorizeStructTy(StructType *StructTy); /// - If the incoming type is void, we return void /// - If \p EC is scalar, \p Ty is returned unchanged /// - Only unpacked literal struct types are supported -inline Type *toVectorizedTy(Type *Ty, ElementCount EC) { +inline Type *toVectorizedTy(Type *Ty, ElementCount EC, unsigned IID = 0) { if (StructType *StructTy = dyn_cast(Ty)) - return toVectorizedStructTy(StructTy, EC); + return toVectorizedStructTy(StructTy, EC, IID); return toVectorTy(Ty, EC); } diff --git a/llvm/lib/IR/VectorTypeUtils.cpp b/llvm/lib/IR/VectorTypeUtils.cpp index 62e39aab90079..5b0b62dac1473 100644 --- a/llvm/lib/IR/VectorTypeUtils.cpp +++ b/llvm/lib/IR/VectorTypeUtils.cpp @@ -8,12 +8,21 @@ #include "llvm/IR/VectorTypeUtils.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/IR/Intrinsics.h" using namespace llvm; +bool llvm::isVectorIntrinsicWithStructReturnScalarAtField(unsigned IID, + unsigned EleIdx) { + if (IID == Intrinsic::vp_load_ff) + return EleIdx == 1; + return false; +} + /// A helper for converting structs of scalar types to structs of vector types. /// Note: Only unpacked literal struct types are supported. -Type *llvm::toVectorizedStructTy(StructType *StructTy, ElementCount EC) { +Type *llvm::toVectorizedStructTy(StructType *StructTy, ElementCount EC, + unsigned IID) { if (EC.isScalar()) return StructTy; assert(isUnpackedStructLiteral(StructTy) && @@ -22,7 +31,10 @@ Type *llvm::toVectorizedStructTy(StructType *StructTy, ElementCount EC) { "expected all element types to be valid vector element types"); return StructType::get( StructTy->getContext(), - map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * { + map_to_vector(enumerate(StructTy->elements()), [&](auto It) -> Type * { + Type *ElTy = It.value(); + if (isVectorIntrinsicWithStructReturnScalarAtField(IID, It.index())) + return ElTy; return VectorType::get(ElTy, EC); })); } diff --git a/llvm/unittests/IR/VectorTypeUtilsTest.cpp b/llvm/unittests/IR/VectorTypeUtilsTest.cpp index c77f183e921de..5d4d30afa8fb0 100644 --- a/llvm/unittests/IR/VectorTypeUtilsTest.cpp +++ b/llvm/unittests/IR/VectorTypeUtilsTest.cpp @@ -8,6 +8,7 @@ #include "llvm/IR/VectorTypeUtils.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "gtest/gtest.h" @@ -24,6 +25,7 @@ TEST(VectorTypeUtilsTest, TestToVectorizedTy) { Type *FTy = Type::getFloatTy(C); Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy); Type *MixedStructTy = StructType::get(FTy, ITy); + Type *FFLoadRetTy = StructType::get(ITy, ITy); Type *VoidTy = Type::getVoidTy(C); for (ElementCount VF : @@ -54,6 +56,11 @@ TEST(VectorTypeUtilsTest, TestToVectorizedTy) { VectorType::get(ITy, VF)); EXPECT_EQ(toVectorizedTy(VoidTy, VF), VoidTy); + Type *WidenFFLoadRetTy = + toVectorizedTy(FFLoadRetTy, VF, Intrinsic::vp_load_ff); + EXPECT_EQ(cast(WidenFFLoadRetTy)->getElementType(0), + VectorType::get(ITy, VF)); + EXPECT_EQ(cast(WidenFFLoadRetTy)->getElementType(1), ITy); } ElementCount ScalarVF = ElementCount::getFixed(1);