Skip to content
25 changes: 23 additions & 2 deletions llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
// and checks that all vector operands match the previously found EC.
SmallVector<Type *, 8> ScalarArgTypes;
SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
std::string ScalarName;
Function *FuncToReplace = nullptr;
if (auto *CI = dyn_cast<CallInst>(&I)) {
auto *CI = dyn_cast<CallInst>(&I);
if (CI) {
FuncToReplace = CI->getCalledFunction();
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
OrigArgTypes.push_back(ArgTy);
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
Expand Down Expand Up @@ -168,12 +170,31 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
if (!OptInfo)
return false;

// There is no guarantee that the vectorized instructions followed the VFABI
// specification when being created, this is why we need to add extra check to
// make sure that the operands of the vector function obtained via VFABI match
// the operands of the original vector instruction.
if (CI) {
for (auto VFParam : OptInfo->Shape.Parameters) {
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
continue;
Type *OrigTy = OrigArgTypes[VFParam.ParamPos];
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
<< ": Will not replace: wrong type at index: "
<< VFParam.ParamPos << ": " << *OrigTy << "\n");
return false;
}
}
}

FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
if (!VectorFTy)
return false;

Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);

replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ set(ANALYSIS_TEST_SOURCES
PluginInlineAdvisorAnalysisTest.cpp
PluginInlineOrderAnalysisTest.cpp
ProfileSummaryInfoTest.cpp
ReplaceWithVecLibTest.cpp
ScalarEvolutionTest.cpp
VectorFunctionABITest.cpp
SparsePropagation.cpp
Expand Down
89 changes: 89 additions & 0 deletions llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/ReplaceWithVeclib.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"

using namespace llvm;

namespace {

static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
if (!Mod)
Err.print("ReplaceWithVecLibTest", errs());
return Mod;
}

/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
/// allows checking that the pass won't crash when the function to replace (from
/// the input IR) does not match the replacement function (derived from the
/// VecDesc mapping).
class ReplaceWithVecLibTest : public ::testing::Test {
protected:
LLVMContext Ctx;

/// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
/// pass. The pass should not crash even when the replacement function
/// (derived from the \p VD mapping) does not match the function to be
/// replaced (from the input \p IR).
bool run(const VecDesc &VD, const char *IR) {
// Create TLII and register it with FAM so it's preserved when
// ReplaceWithVeclib pass runs.
TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
TLII.addVectorizableFunctions({VD});
FunctionAnalysisManager FAM;
FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });

// Register and run the pass on the 'foo' function from the input IR.
FunctionPassManager FPM;
FPM.addPass(ReplaceWithVeclib());
std::unique_ptr<Module> M = parseIR(Ctx, IR);
PassBuilder PB;
PB.registerFunctionAnalyses(FAM);
FPM.run(*M->getFunction("foo"), FAM);

return true;
}
};

} // end anonymous namespace

static const char *IR = R"IR(
define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
%call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
ret <vscale x 4 x float> %call
}
declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
)IR";

// The VFABI prefix in TLI describes signature which is matching the powi
// intrinsic declaration.
TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
ElementCount::getScalable(4), /*Masked*/ true,
"_ZGVsMxvu"};
EXPECT_TRUE(run(CorrectVD, IR));
}

// The VFABI prefix in TLI describes signature which is not matching the powi
// intrinsic declaration.
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
ElementCount::getScalable(4), /*Masked*/ true,
"_ZGVsMxvv"};
EXPECT_TRUE(run(IncorrectVD, IR));
}