Skip to content

Commit

Permalink
[VP] Correct lowering of predicated fma and faddmul to avoid strictfp…
Browse files Browse the repository at this point in the history
…. (#85272)

Correct missing cases in a switch that result in @llvm.vp.fma.v4f32
getting lowered to a constrained fma intrinsic. Vector predicated
lowering to contrained intrinsics is not supported currently, and
there's no consensus on the path forward. We certainly shouldn't be
introducing constrained intrinsics into a function that isn't strictfp.

Problem found with D146845.
  • Loading branch information
kpneal authored Apr 17, 2024
1 parent 915c84b commit 79726ef
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 11 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ namespace Intrinsic {
/// Map a MS builtin name to an intrinsic ID.
ID getIntrinsicForMSBuiltin(const char *Prefix, StringRef BuiltinName);

/// Returns true if the intrinsic ID is for one of the "Constrained
/// Floating-Point Intrinsics".
bool isConstrainedFPIntrinsic(ID QID);

/// This is a type descriptor which explains the type requirements of an
/// intrinsic. This is returned by getIntrinsicInfoTableEntries.
struct IITDescriptor {
Expand Down
12 changes: 10 additions & 2 deletions llvm/lib/CodeGen/ExpandVectorPredication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,21 @@ Value *CachingVPExpander::expandPredicationToFPCall(
replaceOperation(*NewOp, VPI);
return NewOp;
}
case Intrinsic::fma:
case Intrinsic::fmuladd:
case Intrinsic::experimental_constrained_fma:
case Intrinsic::experimental_constrained_fmuladd: {
Value *Op0 = VPI.getOperand(0);
Value *Op1 = VPI.getOperand(1);
Value *Op2 = VPI.getOperand(2);
Function *Fn = Intrinsic::getDeclaration(
VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
Value *NewOp =
Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName());
Value *NewOp;
if (Intrinsic::isConstrainedFPIntrinsic(UnpredicatedIntrinsicID))
NewOp =
Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName());
else
NewOp = Builder.CreateCall(Fn, {Op0, Op1, Op2}, VPI.getName());
replaceOperation(*NewOp, VPI);
return NewOp;
}
Expand Down Expand Up @@ -731,6 +737,8 @@ Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
case Intrinsic::vp_minnum:
case Intrinsic::vp_maximum:
case Intrinsic::vp_minimum:
case Intrinsic::vp_fma:
case Intrinsic::vp_fmuladd:
return expandPredicationToFPCall(Builder, VPI,
VPI.getFunctionalIntrinsicID().value());
case Intrinsic::vp_load:
Expand Down
22 changes: 13 additions & 9 deletions llvm/lib/IR/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,7 @@ static MutableArrayRef<Argument> makeArgArray(Argument *Args, size_t Count) {
}

bool Function::isConstrainedFPIntrinsic() const {
switch (getIntrinsicID()) {
#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC) \
case Intrinsic::INTRINSIC:
#include "llvm/IR/ConstrainedOps.def"
return true;
#undef INSTRUCTION
default:
return false;
}
return Intrinsic::isConstrainedFPIntrinsic(getIntrinsicID());
}

void Function::clearArguments() {
Expand Down Expand Up @@ -1486,6 +1478,18 @@ Function *Intrinsic::getDeclaration(Module *M, ID id, ArrayRef<Type*> Tys) {
#include "llvm/IR/IntrinsicImpl.inc"
#undef GET_LLVM_INTRINSIC_FOR_MS_BUILTIN

bool Intrinsic::isConstrainedFPIntrinsic(ID QID) {
switch (QID) {
#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC) \
case Intrinsic::INTRINSIC:
#include "llvm/IR/ConstrainedOps.def"
return true;
#undef INSTRUCTION
default:
return false;
}
}

using DeferredIntrinsicMatchPair =
std::pair<Type *, ArrayRef<Intrinsic::IITDescriptor>>;

Expand Down
176 changes: 176 additions & 0 deletions llvm/test/CodeGen/Generic/expand-vp-fp-intrinsics.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt -expandvp -S < %s | FileCheck %s

define void @vp_fadd_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_fadd_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[RES1:%.*]] = fadd <4 x float> [[A0]], [[A1]]
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)

define void @vp_fsub_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_fsub_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = fsub <4 x float> [[A0]], [[A1]]
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)

define void @vp_fmul_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_fmul_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = fmul <4 x float> [[A0]], [[A1]]
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fmul.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fmul.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)

define void @vp_fdiv_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_fdiv_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = fdiv <4 x float> [[A0]], [[A1]]
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fdiv.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fdiv.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)

define void @vp_frem_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_frem_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = frem <4 x float> [[A0]], [[A1]]
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.frem.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.frem.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)

define void @vp_fabs_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_fabs_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> [[A0]])
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fabs.v4f32(<4 x float> %a0, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fabs.v4f32(<4 x float>, <4 x i1>, i32)

define void @vp_sqrt_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_sqrt_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[A0]])
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.sqrt.v4f32(<4 x float> %a0, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.sqrt.v4f32(<4 x float>, <4 x i1>, i32)

define void @vp_fneg_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
; CHECK-LABEL: define void @vp_fneg_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = fneg <4 x float> [[A0]]
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fneg.v4f32(<4 x float> %a0, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fneg.v4f32(<4 x float>, <4 x i1>, i32)

define void @vp_fma_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i4 %a5) nounwind {
; CHECK-LABEL: define void @vp_fma_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i4 [[A5:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A0]], <4 x float> [[A1]], <4 x float> [[A1]])
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fma.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 4)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fma.v4f32(<4 x float>, <4 x float>, <4 x float>, <4 x i1>, i32)

define void @vp_fmuladd_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i4 %a5) nounwind {
; CHECK-LABEL: define void @vp_fmuladd_v4f32(
; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i4 [[A5:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[RES1:%.*]] = call <4 x float> @llvm.fmuladd.v4f32(<4 x float> [[A0]], <4 x float> [[A1]], <4 x float> [[A1]])
; CHECK-NEXT: store <4 x float> [[RES1]], ptr [[OUT]], align 16
; CHECK-NEXT: ret void
;
%res = call <4 x float> @llvm.vp.fmuladd.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 4)
store <4 x float> %res, ptr %out
ret void
}
declare <4 x float> @llvm.vp.fmuladd.v4f32(<4 x float>, <4 x float>, <4 x float>, <4 x i1>, i32)

declare <4 x float> @llvm.vp.maxnum.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
define <4 x float> @vfmax_vv_v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: define <4 x float> @vfmax_vv_v4f32(
; CHECK-SAME: <4 x float> [[VA:%.*]], <4 x float> [[VB:%.*]], <4 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
; CHECK-NEXT: [[V1:%.*]] = call <4 x float> @llvm.maxnum.v4f32(<4 x float> [[VA]], <4 x float> [[VB]])
; CHECK-NEXT: ret <4 x float> [[V1]]
;
%v = call <4 x float> @llvm.vp.maxnum.v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 %evl)
ret <4 x float> %v
}

declare <8 x float> @llvm.vp.maxnum.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32)
define <8 x float> @vfmax_vv_v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: define <8 x float> @vfmax_vv_v8f32(
; CHECK-SAME: <8 x float> [[VA:%.*]], <8 x float> [[VB:%.*]], <8 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
; CHECK-NEXT: [[V1:%.*]] = call <8 x float> @llvm.maxnum.v8f32(<8 x float> [[VA]], <8 x float> [[VB]])
; CHECK-NEXT: ret <8 x float> [[V1]]
;
%v = call <8 x float> @llvm.vp.maxnum.v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 %evl)
ret <8 x float> %v
}

declare <4 x float> @llvm.vp.minnum.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
define <4 x float> @vfmin_vv_v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: define <4 x float> @vfmin_vv_v4f32(
; CHECK-SAME: <4 x float> [[VA:%.*]], <4 x float> [[VB:%.*]], <4 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
; CHECK-NEXT: [[V1:%.*]] = call <4 x float> @llvm.minnum.v4f32(<4 x float> [[VA]], <4 x float> [[VB]])
; CHECK-NEXT: ret <4 x float> [[V1]]
;
%v = call <4 x float> @llvm.vp.minnum.v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 %evl)
ret <4 x float> %v
}

declare <8 x float> @llvm.vp.minnum.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32)
define <8 x float> @vfmin_vv_v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: define <8 x float> @vfmin_vv_v8f32(
; CHECK-SAME: <8 x float> [[VA:%.*]], <8 x float> [[VB:%.*]], <8 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
; CHECK-NEXT: [[V1:%.*]] = call <8 x float> @llvm.minnum.v8f32(<8 x float> [[VA]], <8 x float> [[VB]])
; CHECK-NEXT: ret <8 x float> [[V1]]
;
%v = call <8 x float> @llvm.vp.minnum.v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 %evl)
ret <8 x float> %v
}

0 comments on commit 79726ef

Please sign in to comment.