Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
- Adds Image Channel Data Type definitions for RAW10 and RAW12 image formats.
* - ``SPV_ALTERA_arbitrary_precision_floating_point``
- Adds instructions for arbitrary precision floating-point arithmetic. The extension works without SPV_ALTERA_arbitrary_precision_integers, but together they allow greater flexibility in representing arbitrary precision data types.
* - ``SPV_KHR_fma``
- Adds a core fused-multiply-add (fma) instruction to replace the different variants that have existed in extended instruction sets.


SPIR-V representation in LLVM IR
Expand Down
23 changes: 16 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,13 +1253,22 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
SmallVector<Register> Arguments =
getBuiltinCallArguments(Call, Number, MIRBuilder, GR);

// Build extended instruction.
auto MIB =
MIRBuilder.buildInstr(SPIRV::OpExtInst)
.addDef(Call->ReturnRegister)
.addUse(ReturnTypeId)
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
.addImm(Number);
MachineInstrBuilder MIB;
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma) &&
Number == SPIRV::OpenCLExtInst::fma) {
// Use the SPIR-V fma instruction instead of the OpenCL extended
// instruction if the extension is available.
MIB = MIRBuilder.buildInstr(SPIRV::OpFmaKHR)
.addDef(Call->ReturnRegister)
.addUse(ReturnTypeId);
} else {
// Build extended instruction.
MIB = MIRBuilder.buildInstr(SPIRV::OpExtInst)
.addDef(Call->ReturnRegister)
.addUse(ReturnTypeId)
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
.addImm(Number);
}

for (Register Argument : Arguments)
MIB.addUse(Argument);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ static const std::map<StringRef, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_KHR_integer_dot_product},
{"SPV_KHR_linkonce_odr",
SPIRV::Extension::Extension::SPV_KHR_linkonce_odr},
{"SPV_KHR_fma", SPIRV::Extension::Extension::SPV_KHR_fma},
Copy link
Member

@michalpaszkowski michalpaszkowski Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also update the SPIRVUsage file. Otherwise, the change looks good to me.

{"SPV_INTEL_inline_assembly",
SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
{"SPV_INTEL_bindless_images",
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,9 @@ def OpUDotAccSat: Op<4454, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2,
def OpSUDotAccSat: Op<4455, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops),
"$res = OpSUDotAccSat $type $vec1 $vec2 $acc">;

def OpFmaKHR: Op<6034, (outs ID:$res), (ins TYPE:$type, ID:$a, ID:$b, ID:$c),
"$res = OpFmaKHR $type $a $b $c">;

// 3.42.14 Bit Instructions

defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>;
Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
return MIB.constrainAllUses(TII, TRI, RBI);
}
case TargetOpcode::G_STRICT_FMA:
case TargetOpcode::G_FMA:
case TargetOpcode::G_FMA: {
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpFmaKHR))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(1).getReg())
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.setMIFlags(I.getFlags());
return MIB.constrainAllUses(TII, TRI, RBI);
}
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
}

case TargetOpcode::G_STRICT_FLDEXP:
return selectExtInst(ResVReg, ResType, I, CL::ldexp);
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,12 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
}
break;
case SPIRV::OpFmaKHR:
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_fma);
Reqs.addCapability(SPIRV::Capability::FmaKHR);
}
break;
case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ defm SPV_NV_shader_atomic_fp16_vector
: ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>;
defm SPV_EXT_image_raw10_raw12 :ExtensionOperand<133, [EnvOpenCL, EnvVulkan]>;
defm SPV_ALTERA_arbitrary_precision_floating_point: ExtensionOperand<134, [EnvOpenCL]>;
defm SPV_KHR_fma : ExtensionOperand<135, [EnvVulkan, EnvOpenCL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
// SymbolicOperand entries with string mnemonics, versioning, extensions, and
Expand Down Expand Up @@ -570,6 +572,7 @@ defm FloatControls2
: CapabilityOperand<6029, 0x10200, 0, [SPV_KHR_float_controls2], []>;
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm FmaKHR : CapabilityOperand<6035, 0, 0, [SPV_KHR_fma], []>;
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_fma %s -o - | FileCheck %s
; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown < %s | FileCheck --check-prefix=CHECK-NO-EXT %s
; TODO: Add spirv-val validation once the extension is supported.

; CHECK: OpCapability FmaKHR
; CHECK: OpExtension "SPV_KHR_fma"
; CHECK: %[[#TYPE_FLOAT:]] = OpTypeFloat 32
; CHECK: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4
; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]]
; CHECK: OpFmaKHR %[[#TYPE_VEC]] %[[#]]
; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]]

; CHECK-NO-EXT-NOT: OpCapability FmaKHR
; CHECK-NO-EXT-NOT: OpExtension "SPV_KHR_fma"
; CHECK-NO-EXT: %[[#TYPE_FLOAT:]] = OpTypeFloat 32
; CHECK-NO-EXT: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4
; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma
; CHECK-NO-EXT: OpExtInst %[[#TYPE_VEC]] %[[#]] fma
; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma

define spir_func float @test_fma_scalar(float %a, float %b, float %c) {
entry:
%result = call float @llvm.fma.f32(float %a, float %b, float %c)
ret float %result
}

define spir_func <4 x float> @test_fma_vector(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
entry:
%result = call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c)
ret <4 x float> %result
}

; Case to test fma translation via OCL builtins.
define spir_func float @test_fma_ocl_scalar(float %a, float %b, float %c) {
entry:
%result = call spir_func float @_Z15__spirv_ocl_fmafff(float %a, float %b, float %c)
ret float %result
}

declare float @llvm.fma.f32(float, float, float)
declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
declare spir_func float @_Z15__spirv_ocl_fmafff(float, float, float)