Skip to content

[SPIR-V] Implement SPV_KHR_fma extension#173057

Merged
michalpaszkowski merged 4 commits intollvm:mainfrom
vmaksimo:SPV_KHR_fma
Jan 30, 2026
Merged

[SPIR-V] Implement SPV_KHR_fma extension#173057
michalpaszkowski merged 4 commits intollvm:mainfrom
vmaksimo:SPV_KHR_fma

Conversation

@vmaksimo
Copy link
Contributor

The extension adds support for the OpFmaKHR instruction, which provides a native SPIR-V instruction for fused multiply-add operations as an alternative to using OpenCL.std::Fma extended instruction.

Translate both LLVM fma intrinsics as well as OCL builtins to OpFmaKHR if the extension is available.

Specification:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html

The extension adds support for the `OpFmaKHR` instruction, which
provides a native SPIR-V instruction for fused multiply-add operations
as an alternative to using OpenCL.std::Fma extended instruction.

Translate both LLVM fma intrinsics as well as OCL builtins to `OpFmaKHR`
if the extension is available.

Specification:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html
@vmaksimo vmaksimo marked this pull request as ready for review January 9, 2026 16:42
@llvmbot
Copy link
Member

llvmbot commented Jan 9, 2026

@llvm/pr-subscribers-backend-spir-v

Author: Viktoria Maximova (vmaksimo)

Changes

The extension adds support for the OpFmaKHR instruction, which provides a native SPIR-V instruction for fused multiply-add operations as an alternative to using OpenCL.std::Fma extended instruction.

Translate both LLVM fma intrinsics as well as OCL builtins to OpFmaKHR if the extension is available.

Specification:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html


Full diff: https://github.com/llvm/llvm-project/pull/173057.diff

7 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+16-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+13-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+2)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll (+42)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b3fa1ccd4510..e617164f15453 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1253,13 +1253,22 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
   SmallVector<Register> Arguments =
       getBuiltinCallArguments(Call, Number, MIRBuilder, GR);
 
-  // Build extended instruction.
-  auto MIB =
-      MIRBuilder.buildInstr(SPIRV::OpExtInst)
-          .addDef(Call->ReturnRegister)
-          .addUse(ReturnTypeId)
-          .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
-          .addImm(Number);
+  MachineInstrBuilder MIB;
+  if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma) &&
+      Number == SPIRV::OpenCLExtInst::fma) {
+    // Use the SPIR-V fma instruction instead of the OpenCL extended
+    // instruction if the extension is available.
+    MIB = MIRBuilder.buildInstr(SPIRV::OpFmaKHR)
+              .addDef(Call->ReturnRegister)
+              .addUse(ReturnTypeId);
+  } else {
+    // Build extended instruction.
+    MIB = MIRBuilder.buildInstr(SPIRV::OpExtInst)
+              .addDef(Call->ReturnRegister)
+              .addUse(ReturnTypeId)
+              .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
+              .addImm(Number);
+  }
 
   for (Register Argument : Arguments)
     MIB.addUse(Argument);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 74bef26984089..ddf5fa0b40faf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -114,6 +114,7 @@ static const std::map<StringRef, SPIRV::Extension::Extension>
          SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
         {"SPV_KHR_linkonce_odr",
          SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
+        {"SPV_KHR_fma", SPIRV::Extension::Extension::SPV_KHR_fma},
         {"SPV_INTEL_inline_assembly",
          SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
         {"SPV_INTEL_bindless_images",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 815d2d7ed854b..c41f8bccc0890 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -527,6 +527,9 @@ def OpUDotAccSat: Op<4454, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2,
 def OpSUDotAccSat: Op<4455, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops),
                   "$res = OpSUDotAccSat $type $vec1 $vec2 $acc">;
 
+def OpFmaKHR: Op<6034, (outs ID:$res), (ins TYPE:$type, ID:$a, ID:$b, ID:$c),
+                  "$res = OpFmaKHR $type $a $b $c">;
+
 // 3.42.14 Bit Instructions
 
 defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f991938c14dfe..a2a4fb8f276b9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -946,8 +946,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
   case TargetOpcode::G_STRICT_FMA:
-  case TargetOpcode::G_FMA:
+  case TargetOpcode::G_FMA: {
+    if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
+      MachineBasicBlock &BB = *I.getParent();
+      auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpFmaKHR))
+                     .addDef(ResVReg)
+                     .addUse(GR.getSPIRVTypeID(ResType))
+                     .addUse(I.getOperand(1).getReg())
+                     .addUse(I.getOperand(2).getReg())
+                     .addUse(I.getOperand(3).getReg())
+                     .setMIFlags(I.getFlags());
+      return MIB.constrainAllUses(TII, TRI, RBI);
+    }
     return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
+  }
 
   case TargetOpcode::G_STRICT_FLDEXP:
     return selectExtInst(ResVReg, ResType, I, CL::ldexp);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index babce22d4f583..6f53f15d91773 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1768,6 +1768,12 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
     }
     break;
+  case SPIRV::OpFmaKHR:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_fma);
+      Reqs.addCapability(SPIRV::Capability::FmaKHR);
+    }
+    break;
   case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
   case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7335882d7af4f..d3e84797a3fa2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -394,6 +394,7 @@ defm SPV_ALTERA_arbitrary_precision_fixed_point : ExtensionOperand<131, [EnvOpen
 defm SPV_NV_shader_atomic_fp16_vector
     : ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>;
 defm SPV_EXT_image_raw10_raw12 :ExtensionOperand<133, [EnvOpenCL, EnvVulkan]>;
+defm SPV_KHR_fma : ExtensionOperand<134, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +571,7 @@ defm FloatControls2
     : CapabilityOperand<6029, 0x10200, 0, [SPV_KHR_float_controls2], []>;
 defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
 defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
+defm FmaKHR : CapabilityOperand<6035, 0, 0, [SPV_KHR_fma], []>;
 defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
 defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
 defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll
new file mode 100644
index 0000000000000..684e3104c05d8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll
@@ -0,0 +1,42 @@
+; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_fma %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown < %s | FileCheck --check-prefix=CHECK-NO-EXT %s
+; TODO: Add spirv-val validation once the extension is supported.
+
+; CHECK: OpCapability FmaKHR
+; CHECK: OpExtension "SPV_KHR_fma"
+; CHECK: %[[#TYPE_FLOAT:]] = OpTypeFloat 32
+; CHECK: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4
+; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]]
+; CHECK: OpFmaKHR %[[#TYPE_VEC]] %[[#]]
+; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]]
+
+; CHECK-NO-EXT-NOT: OpCapability FmaKHR
+; CHECK-NO-EXT-NOT: OpExtension "SPV_KHR_fma"
+; CHECK-NO-EXT: %[[#TYPE_FLOAT:]] = OpTypeFloat 32
+; CHECK-NO-EXT: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4
+; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma
+; CHECK-NO-EXT: OpExtInst %[[#TYPE_VEC]] %[[#]] fma
+; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma
+
+define spir_func float @test_fma_scalar(float %a, float %b, float %c) {
+entry:
+  %result = call float @llvm.fma.f32(float %a, float %b, float %c)
+  ret float %result
+}
+
+define spir_func <4 x float> @test_fma_vector(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+entry:
+  %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c)
+  ret <4 x float> %result
+}
+
+; Case to test fma translation via OCL builtins.
+define spir_func float @test_fma_ocl_scalar(float %a, float %b, float %c) {
+entry:
+  %result = call spir_func float @_Z15__spirv_ocl_fmafff(float %a, float %b, float %c)
+  ret float %result
+}
+
+declare float @llvm.fma.f32(float, float, float)
+declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
+declare spir_func float @_Z15__spirv_ocl_fmafff(float, float, float)

@vmaksimo
Copy link
Contributor Author

vmaksimo commented Jan 9, 2026

@MrSidims @michalpaszkowski could you please take a look? Thanks!

SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
{"SPV_KHR_linkonce_odr",
SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
{"SPV_KHR_fma", SPIRV::Extension::Extension::SPV_KHR_fma},
Copy link
Member

@michalpaszkowski michalpaszkowski Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also update the SPIRVUsage file. Otherwise, the change looks good to me.

@michalpaszkowski michalpaszkowski merged commit 8f8dfbf into llvm:main Jan 30, 2026
13 checks passed
honeygoyal pushed a commit to honeygoyal/llvm-project that referenced this pull request Jan 30, 2026
The extension adds support for the `OpFmaKHR` instruction, which
provides a native SPIR-V instruction for fused multiply-add operations
as an alternative to using OpenCL.std::Fma extended instruction.

Translate both LLVM fma intrinsics as well as OCL builtins to `OpFmaKHR`
if the extension is available.

Specification:

https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html
sshrestha-aa pushed a commit to sshrestha-aa/llvm-project that referenced this pull request Feb 4, 2026
The extension adds support for the `OpFmaKHR` instruction, which
provides a native SPIR-V instruction for fused multiply-add operations
as an alternative to using OpenCL.std::Fma extended instruction.

Translate both LLVM fma intrinsics as well as OCL builtins to `OpFmaKHR`
if the extension is available.

Specification:

https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_fma.html
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants