[SPIR-V] Implement SPV_KHR_fma extension#173057
Merged
michalpaszkowski merged 4 commits intollvm:mainfrom Jan 30, 2026
Merged
Conversation
The extension adds support for the `OpFmaKHR` instruction, which provides a native SPIR-V instruction for fused multiply-add operations as an alternative to using OpenCL.std::Fma extended instruction. Translate both LLVM fma intrinsics as well as OCL builtins to `OpFmaKHR` if the extension is available. Specification: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html
Member
|
@llvm/pr-subscribers-backend-spir-v Author: Viktoria Maximova (vmaksimo) ChangesThe extension adds support for the Translate both LLVM fma intrinsics as well as OCL builtins to Specification: Full diff: https://github.com/llvm/llvm-project/pull/173057.diff 7 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b3fa1ccd4510..e617164f15453 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1253,13 +1253,22 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
SmallVector<Register> Arguments =
getBuiltinCallArguments(Call, Number, MIRBuilder, GR);
- // Build extended instruction.
- auto MIB =
- MIRBuilder.buildInstr(SPIRV::OpExtInst)
- .addDef(Call->ReturnRegister)
- .addUse(ReturnTypeId)
- .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
- .addImm(Number);
+ MachineInstrBuilder MIB;
+ if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma) &&
+ Number == SPIRV::OpenCLExtInst::fma) {
+ // Use the SPIR-V fma instruction instead of the OpenCL extended
+ // instruction if the extension is available.
+ MIB = MIRBuilder.buildInstr(SPIRV::OpFmaKHR)
+ .addDef(Call->ReturnRegister)
+ .addUse(ReturnTypeId);
+ } else {
+ // Build extended instruction.
+ MIB = MIRBuilder.buildInstr(SPIRV::OpExtInst)
+ .addDef(Call->ReturnRegister)
+ .addUse(ReturnTypeId)
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
+ .addImm(Number);
+ }
for (Register Argument : Arguments)
MIB.addUse(Argument);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 74bef26984089..ddf5fa0b40faf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -114,6 +114,7 @@ static const std::map<StringRef, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
{"SPV_KHR_linkonce_odr",
SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
+ {"SPV_KHR_fma", SPIRV::Extension::Extension::SPV_KHR_fma},
{"SPV_INTEL_inline_assembly",
SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
{"SPV_INTEL_bindless_images",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 815d2d7ed854b..c41f8bccc0890 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -527,6 +527,9 @@ def OpUDotAccSat: Op<4454, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2,
def OpSUDotAccSat: Op<4455, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops),
"$res = OpSUDotAccSat $type $vec1 $vec2 $acc">;
+def OpFmaKHR: Op<6034, (outs ID:$res), (ins TYPE:$type, ID:$a, ID:$b, ID:$c),
+ "$res = OpFmaKHR $type $a $b $c">;
+
// 3.42.14 Bit Instructions
defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f991938c14dfe..a2a4fb8f276b9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -946,8 +946,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
return MIB.constrainAllUses(TII, TRI, RBI);
}
case TargetOpcode::G_STRICT_FMA:
- case TargetOpcode::G_FMA:
+ case TargetOpcode::G_FMA: {
+ if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
+ MachineBasicBlock &BB = *I.getParent();
+ auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpFmaKHR))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(1).getReg())
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .setMIFlags(I.getFlags());
+ return MIB.constrainAllUses(TII, TRI, RBI);
+ }
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
+ }
case TargetOpcode::G_STRICT_FLDEXP:
return selectExtInst(ResVReg, ResType, I, CL::ldexp);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index babce22d4f583..6f53f15d91773 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1768,6 +1768,12 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
}
break;
+ case SPIRV::OpFmaKHR:
+ if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_fma);
+ Reqs.addCapability(SPIRV::Capability::FmaKHR);
+ }
+ break;
case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7335882d7af4f..d3e84797a3fa2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -394,6 +394,7 @@ defm SPV_ALTERA_arbitrary_precision_fixed_point : ExtensionOperand<131, [EnvOpen
defm SPV_NV_shader_atomic_fp16_vector
: ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>;
defm SPV_EXT_image_raw10_raw12 :ExtensionOperand<133, [EnvOpenCL, EnvVulkan]>;
+defm SPV_KHR_fma : ExtensionOperand<134, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +571,7 @@ defm FloatControls2
: CapabilityOperand<6029, 0x10200, 0, [SPV_KHR_float_controls2], []>;
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
+defm FmaKHR : CapabilityOperand<6035, 0, 0, [SPV_KHR_fma], []>;
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll
new file mode 100644
index 0000000000000..684e3104c05d8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll
@@ -0,0 +1,42 @@
+; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_fma %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown < %s | FileCheck --check-prefix=CHECK-NO-EXT %s
+; TODO: Add spirv-val validation once the extension is supported.
+
+; CHECK: OpCapability FmaKHR
+; CHECK: OpExtension "SPV_KHR_fma"
+; CHECK: %[[#TYPE_FLOAT:]] = OpTypeFloat 32
+; CHECK: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4
+; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]]
+; CHECK: OpFmaKHR %[[#TYPE_VEC]] %[[#]]
+; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]]
+
+; CHECK-NO-EXT-NOT: OpCapability FmaKHR
+; CHECK-NO-EXT-NOT: OpExtension "SPV_KHR_fma"
+; CHECK-NO-EXT: %[[#TYPE_FLOAT:]] = OpTypeFloat 32
+; CHECK-NO-EXT: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4
+; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma
+; CHECK-NO-EXT: OpExtInst %[[#TYPE_VEC]] %[[#]] fma
+; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma
+
+define spir_func float @test_fma_scalar(float %a, float %b, float %c) {
+entry:
+ %result = call float @llvm.fma.f32(float %a, float %b, float %c)
+ ret float %result
+}
+
+define spir_func <4 x float> @test_fma_vector(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+entry:
+ %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c)
+ ret <4 x float> %result
+}
+
+; Case to test fma translation via OCL builtins.
+define spir_func float @test_fma_ocl_scalar(float %a, float %b, float %c) {
+entry:
+ %result = call spir_func float @_Z15__spirv_ocl_fmafff(float %a, float %b, float %c)
+ ret float %result
+}
+
+declare float @llvm.fma.f32(float, float, float)
+declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
+declare spir_func float @_Z15__spirv_ocl_fmafff(float, float, float)
|
Contributor
Author
|
@MrSidims @michalpaszkowski could you please take a look? Thanks! |
| SPIRV::Extension::Extension::SPV_KHR_integer_dot_product}, | ||
| {"SPV_KHR_linkonce_odr", | ||
| SPIRV::Extension::Extension::SPV_KHR_linkonce_odr}, | ||
| {"SPV_KHR_fma", SPIRV::Extension::Extension::SPV_KHR_fma}, |
Member
There was a problem hiding this comment.
Please also update the SPIRVUsage file. Otherwise, the change looks good to me.
michalpaszkowski
approved these changes
Jan 28, 2026
MrSidims
approved these changes
Jan 30, 2026
honeygoyal
pushed a commit
to honeygoyal/llvm-project
that referenced
this pull request
Jan 30, 2026
The extension adds support for the `OpFmaKHR` instruction, which provides a native SPIR-V instruction for fused multiply-add operations as an alternative to using OpenCL.std::Fma extended instruction. Translate both LLVM fma intrinsics as well as OCL builtins to `OpFmaKHR` if the extension is available. Specification: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html
sshrestha-aa
pushed a commit
to sshrestha-aa/llvm-project
that referenced
this pull request
Feb 4, 2026
The extension adds support for the `OpFmaKHR` instruction, which provides a native SPIR-V instruction for fused multiply-add operations as an alternative to using OpenCL.std::Fma extended instruction. Translate both LLVM fma intrinsics as well as OCL builtins to `OpFmaKHR` if the extension is available. Specification: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The extension adds support for the
OpFmaKHRinstruction, which provides a native SPIR-V instruction for fused multiply-add operations as an alternative to using OpenCL.std::Fma extended instruction.Translate both LLVM fma intrinsics as well as OCL builtins to
OpFmaKHRif the extension is available.Specification:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html