diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst index d31a5d89a7a30..1159b4b908078 100644 --- a/llvm/docs/SPIRVUsage.rst +++ b/llvm/docs/SPIRVUsage.rst @@ -253,6 +253,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e - Adds Image Channel Data Type definitions for RAW10 and RAW12 image formats. * - ``SPV_ALTERA_arbitrary_precision_floating_point`` - Adds instructions for arbitrary precision floating-point arithmetic. The extension works without SPV_ALTERA_arbitrary_precision_integers, but together they allow greater flexibility in representing arbitrary precision data types. + * - ``SPV_KHR_fma`` + - Adds a core fused-multiply-add (fma) instruction to replace the different variants that have existed in extended instruction sets. SPIR-V representation in LLVM IR diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 6af62328d9ddc..f54fab3f381ee 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1253,13 +1253,22 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call, SmallVector Arguments = getBuiltinCallArguments(Call, Number, MIRBuilder, GR); - // Build extended instruction. - auto MIB = - MIRBuilder.buildInstr(SPIRV::OpExtInst) - .addDef(Call->ReturnRegister) - .addUse(ReturnTypeId) - .addImm(static_cast(SPIRV::InstructionSet::OpenCL_std)) - .addImm(Number); + MachineInstrBuilder MIB; + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma) && + Number == SPIRV::OpenCLExtInst::fma) { + // Use the SPIR-V fma instruction instead of the OpenCL extended + // instruction if the extension is available. + MIB = MIRBuilder.buildInstr(SPIRV::OpFmaKHR) + .addDef(Call->ReturnRegister) + .addUse(ReturnTypeId); + } else { + // Build extended instruction. + MIB = MIRBuilder.buildInstr(SPIRV::OpExtInst) + .addDef(Call->ReturnRegister) + .addUse(ReturnTypeId) + .addImm(static_cast(SPIRV::InstructionSet::OpenCL_std)) + .addImm(Number); + } for (Register Argument : Arguments) MIB.addUse(Argument); diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 5a70c95bc6fd3..8de1c7188a80f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -117,6 +117,7 @@ static const std::map SPIRV::Extension::Extension::SPV_KHR_integer_dot_product}, {"SPV_KHR_linkonce_odr", SPIRV::Extension::Extension::SPV_KHR_linkonce_odr}, + {"SPV_KHR_fma", SPIRV::Extension::Extension::SPV_KHR_fma}, {"SPV_INTEL_inline_assembly", SPIRV::Extension::Extension::SPV_INTEL_inline_assembly}, {"SPV_INTEL_bindless_images", diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 7ca1384abc950..811a1273ed6e0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -527,6 +527,9 @@ def OpUDotAccSat: Op<4454, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, def OpSUDotAccSat: Op<4455, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops), "$res = OpSUDotAccSat $type $vec1 $vec2 $acc">; +def OpFmaKHR: Op<6034, (outs ID:$res), (ins TYPE:$type, ID:$a, ID:$b, ID:$c), + "$res = OpFmaKHR $type $a $b $c">; + // 3.42.14 Bit Instructions defm OpShiftRightLogical: BinOpTypedGen<"OpShiftRightLogical", 194, srl, 0, 1>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 98b5bfd678135..7bdb8d8beae82 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -946,8 +946,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, return MIB.constrainAllUses(TII, TRI, RBI); } case TargetOpcode::G_STRICT_FMA: - case TargetOpcode::G_FMA: + case TargetOpcode::G_FMA: { + if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpFmaKHR)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()) + .setMIFlags(I.getFlags()); + return MIB.constrainAllUses(TII, TRI, RBI); + } return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma); + } case TargetOpcode::G_STRICT_FLDEXP: return selectExtInst(ResVReg, ResType, I, CL::ldexp); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index d944a018ba60d..dec5ed23e304c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1768,6 +1768,12 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); } break; + case SPIRV::OpFmaKHR: + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) { + Reqs.addExtension(SPIRV::Extension::SPV_KHR_fma); + Reqs.addCapability(SPIRV::Capability::FmaKHR); + } + break; case SPIRV::OpPtrCastToCrossWorkgroupINTEL: case SPIRV::OpCrossWorkgroupCastToPtrINTEL: if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 6f50f6a6421e1..1a2d3d140ba39 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -395,6 +395,8 @@ defm SPV_NV_shader_atomic_fp16_vector : ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>; defm SPV_EXT_image_raw10_raw12 :ExtensionOperand<133, [EnvOpenCL, EnvVulkan]>; defm SPV_ALTERA_arbitrary_precision_floating_point: ExtensionOperand<134, [EnvOpenCL]>; +defm SPV_KHR_fma : ExtensionOperand<135, [EnvVulkan, EnvOpenCL]>; + //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time // SymbolicOperand entries with string mnemonics, versioning, extensions, and @@ -570,6 +572,7 @@ defm FloatControls2 : CapabilityOperand<6029, 0x10200, 0, [SPV_KHR_float_controls2], []>; defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>; defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>; +defm FmaKHR : CapabilityOperand<6035, 0, 0, [SPV_KHR_fma], []>; defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>; defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>; defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll new file mode 100644 index 0000000000000..684e3104c05d8 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_fma/fma.ll @@ -0,0 +1,42 @@ +; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_fma %s -o - | FileCheck %s +; RUN: llc -verify-machineinstrs -mtriple=spirv64-unknown-unknown < %s | FileCheck --check-prefix=CHECK-NO-EXT %s +; TODO: Add spirv-val validation once the extension is supported. + +; CHECK: OpCapability FmaKHR +; CHECK: OpExtension "SPV_KHR_fma" +; CHECK: %[[#TYPE_FLOAT:]] = OpTypeFloat 32 +; CHECK: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4 +; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]] +; CHECK: OpFmaKHR %[[#TYPE_VEC]] %[[#]] +; CHECK: OpFmaKHR %[[#TYPE_FLOAT]] %[[#]] + +; CHECK-NO-EXT-NOT: OpCapability FmaKHR +; CHECK-NO-EXT-NOT: OpExtension "SPV_KHR_fma" +; CHECK-NO-EXT: %[[#TYPE_FLOAT:]] = OpTypeFloat 32 +; CHECK-NO-EXT: %[[#TYPE_VEC:]] = OpTypeVector %[[#TYPE_FLOAT]] 4 +; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma +; CHECK-NO-EXT: OpExtInst %[[#TYPE_VEC]] %[[#]] fma +; CHECK-NO-EXT: OpExtInst %[[#TYPE_FLOAT]] %[[#]] fma + +define spir_func float @test_fma_scalar(float %a, float %b, float %c) { +entry: + %result = call float @llvm.fma.f32(float %a, float %b, float %c) + ret float %result +} + +define spir_func <4 x float> @test_fma_vector(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +entry: + %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) + ret <4 x float> %result +} + +; Case to test fma translation via OCL builtins. +define spir_func float @test_fma_ocl_scalar(float %a, float %b, float %c) { +entry: + %result = call spir_func float @_Z15__spirv_ocl_fmafff(float %a, float %b, float %c) + ret float %result +} + +declare float @llvm.fma.f32(float, float, float) +declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) +declare spir_func float @_Z15__spirv_ocl_fmafff(float, float, float)