From 336cb90619b6d5dc8f866d51960bc7767d644c2e Mon Sep 17 00:00:00 2001 From: Dmitry Sidorov Date: Mon, 11 Aug 2025 13:07:43 +0200 Subject: [PATCH 1/4] Implement SPV_INTEL_bfloat16_arithmetic (#3290) The extension relaxes rules for bf16 type allowing to use it in some arithmetic operations. Spec is available here: https://github.com/intel/llvm/pull/18352 Co-authered by: Michael Aziz --------- Signed-off-by: Sidorov, Dmitry --- include/LLVMSPIRVExtensions.inc | 1 + lib/SPIRV/SPIRVUtil.cpp | 11 + lib/SPIRV/SPIRVWriter.cpp | 19 ++ lib/SPIRV/libSPIRV/SPIRVEntry.h | 2 + lib/SPIRV/libSPIRV/SPIRVEnum.h | 2 + lib/SPIRV/libSPIRV/SPIRVModule.cpp | 4 + lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 1 + lib/SPIRV/libSPIRV/spirv_internal.hpp | 3 + .../INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll | 271 ++++++++++++++++++ 9 files changed, 314 insertions(+) create mode 100644 test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll diff --git a/include/LLVMSPIRVExtensions.inc b/include/LLVMSPIRVExtensions.inc index 173c7b18be..d10287a0ab 100644 --- a/include/LLVMSPIRVExtensions.inc +++ b/include/LLVMSPIRVExtensions.inc @@ -75,6 +75,7 @@ EXT(SPV_INTEL_bindless_images) EXT(SPV_INTEL_2d_block_io) EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate) EXT(SPV_KHR_bfloat16) +EXT(SPV_INTEL_bfloat16_arithmetic) EXT(SPV_INTEL_ternary_bitwise_function) EXT(SPV_INTEL_int4) EXT(SPV_INTEL_function_variants) diff --git a/lib/SPIRV/SPIRVUtil.cpp b/lib/SPIRV/SPIRVUtil.cpp index 33811e62d2..f37e40164b 100644 --- a/lib/SPIRV/SPIRVUtil.cpp +++ b/lib/SPIRV/SPIRVUtil.cpp @@ -551,6 +551,11 @@ ParamType lastFuncParamType(StringRef MangledName) { char Mangled = Copy.back(); std::string Mangled2 = Copy.substr(Copy.size() - 2); + std::string Mangled6 = Copy.substr(Copy.size() - 6); + if (Mangled6 == "__bf16") { + return ParamType::FLOAT; + } + if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) { return ParamType::FLOAT; } else if (isMangledTypeUnsigned(Mangled)) { @@ -1966,6 +1971,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) { NumElems = VecTy->getNumElements(); Ty = VecTy->getElementType(); } + if (Ty->isBFloatTy() && + BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL)) + return true; if ((!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) || (!BM->hasCapability(CapabilityVectorAnyINTEL) && ((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) { @@ -1982,6 +1990,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) { NumElems = VecTy->getNumElements(); Ty = VecTy->getElementType(); } + if (Ty->isBFloatTy() && + BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL)) + return true; if ((!Ty->isIntegerTy()) || (!BM->hasCapability(CapabilityVectorAnyINTEL) && ((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) { diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index b6d478c7af..116f936334 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -3817,6 +3817,20 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II, // -spirv-allow-unknown-intrinsics work correctly. auto IID = II->getIntrinsicID(); switch (IID) { + case Intrinsic::fabs: + case Intrinsic::fma: + case Intrinsic::maxnum: + case Intrinsic::minnum: + case Intrinsic::fmuladd: { + Type *Ty = II->getType(); + if (Ty->isBFloatTy()) + BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL); + break; + } + default: + break; + } + switch (IID) { case Intrinsic::assume: { // llvm.assume translation is currently supported only within // SPV_KHR_expect_assume extension, ignore it otherwise, since it's @@ -4627,6 +4641,11 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI, SmallVector Dec; if (isBuiltinTransToExtInst(CI->getCalledFunction(), &ExtSetKind, &ExtOp, &Dec)) { + if (const auto *FirstArg = F->getArg(0)) { + const auto *Type = FirstArg->getType(); + if (Type->isBFloatTy()) + BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL); + } if (DemangledName.find("__spirv_ocl_printf") != StringRef::npos) { auto *FormatStrPtr = cast(CI->getArgOperand(0)->getType()); if (FormatStrPtr->getAddressSpace() != diff --git a/lib/SPIRV/libSPIRV/SPIRVEntry.h b/lib/SPIRV/libSPIRV/SPIRVEntry.h index 574a9f91eb..f2cc1c8902 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEntry.h +++ b/lib/SPIRV/libSPIRV/SPIRVEntry.h @@ -909,6 +909,8 @@ class SPIRVCapability : public SPIRVEntryNoId { case CapabilityFunctionVariantsINTEL: case CapabilitySpecConditionalINTEL: return ExtensionID::SPV_INTEL_function_variants; + case internal::CapabilityBFloat16ArithmeticINTEL: + return ExtensionID::SPV_INTEL_bfloat16_arithmetic; default: return {}; } diff --git a/lib/SPIRV/libSPIRV/SPIRVEnum.h b/lib/SPIRV/libSPIRV/SPIRVEnum.h index 6e0318dcfb..4a35ede709 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVEnum.h @@ -216,6 +216,8 @@ template <> inline void SPIRVMap::init() { {CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR}); ADD_VEC_INIT(CapabilityInt4CooperativeMatrixINTEL, {CapabilityInt4TypeINTEL, CapabilityCooperativeMatrixKHR}); + ADD_VEC_INIT(internal::CapabilityBFloat16ArithmeticINTEL, + {CapabilityBFloat16TypeKHR}); } template <> inline void SPIRVMap::init() { diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp index 24bdb01491..fbe19dceac 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -1698,6 +1698,8 @@ SPIRVInstruction *SPIRVModuleImpl::addBinaryInst(Op TheOpCode, SPIRVType *Type, SPIRVValue *Op1, SPIRVValue *Op2, SPIRVBasicBlock *BB) { + if (Type->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot) + addCapability(internal::CapabilityBFloat16ArithmeticINTEL); return addInstruction(SPIRVInstTemplateBase::create( TheOpCode, Type, getId(), getVec(Op1->getId(), Op2->getId()), BB, this), @@ -1721,6 +1723,8 @@ SPIRVInstruction *SPIRVModuleImpl::addUnaryInst(Op TheOpCode, SPIRVType *TheType, SPIRVValue *Op, SPIRVBasicBlock *BB) { + if (TheType->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot) + addCapability(internal::CapabilityBFloat16ArithmeticINTEL); return addInstruction( SPIRVInstTemplateBase::create(TheOpCode, TheType, getId(), getVec(Op->getId()), BB, this), diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index 70180f8cae..d820ed3b2e 100644 --- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -664,6 +664,7 @@ template <> inline void SPIRVMap::init() { add(CapabilityInt4CooperativeMatrixINTEL, "Int4CooperativeMatrixINTEL"); add(CapabilityFunctionVariantsINTEL, "FunctionVariantsINTEL"); add(CapabilitySpecConditionalINTEL, "SpecConditionalINTEL"); + add(internal::CapabilityBFloat16ArithmeticINTEL, "BFloat16ArithmeticINTEL"); } SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap) diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index a05f2a74d2..b0b364569c 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -102,6 +102,7 @@ enum InternalCapability { ICapabilityHWThreadQueryINTEL = 6134, ICapGlobalVariableDecorationsINTEL = 6146, ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192, + ICapabilityBFloat16ArithmeticINTEL = 6226, ICapabilityCooperativeMatrixPrefetchINTEL = 6411, ICapabilityComplexFloatMulDivINTEL = 6414, ICapabilityTensorFloat32RoundingINTEL = 6425, @@ -234,6 +235,8 @@ constexpr Capability CapabilityBfloat16ConversionINTEL = static_cast(ICapBfloat16ConversionINTEL); constexpr Capability CapabilityGlobalVariableDecorationsINTEL = static_cast(ICapGlobalVariableDecorationsINTEL); +constexpr Capability CapabilityBFloat16ArithmeticINTEL = + static_cast(ICapabilityBFloat16ArithmeticINTEL); } // namespace internal } // namespace spv diff --git a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll new file mode 100644 index 0000000000..b7ddcc750d --- /dev/null +++ b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll @@ -0,0 +1,271 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_bfloat16 --spirv-ext=+SPV_INTEL_bfloat16_arithmetic -o %t.spv +; RUN: llvm-spirv %t.spv -to-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM + +; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_bfloat16 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-ERROR +; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_bfloat16_arithmetic 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension: + +source_filename = "bfloat16.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spirv64-unknown-unknown" + +; CHECK-SPIRV: Capability BFloat16TypeKHR +; CHECK-SPIRV: Capability BFloat16ArithmeticINTEL +; CHECK-SPIRV: Extension "SPV_INTEL_bfloat16_arithmetic" +; CHECK-SPIRV: Extension "SPV_KHR_bfloat16" +; CHECK-SPIRV: 4 TypeFloat [[BFLOAT:[0-9]+]] 16 0 +; CHECK-SPIRV: 5 Function [[#]] [[#]] [[#]] [[#]] +; CHECK-SPIRV: 7 Phi [[BFLOAT]] [[#]] [[#]] [[#]] [[#]] [[#]] +; CHECK-SPIRV: 2 ReturnValue [[#]] +; CHECK-SPIRV: 4 Variable [[#]] [[ADDR1:[0-9]+]] +; CHECK-SPIRV: 4 Variable [[#]] [[ADDR2:[0-9]+]] +; CHECK-SPIRV: 4 Variable [[#]] [[ADDR3:[0-9]+]] +; CHECK-SPIRV: 6 Load [[BFLOAT]] [[DATA1:[0-9]+]] [[ADDR1]] +; CHECK-SPIRV: 6 Load [[BFLOAT]] [[DATA2:[0-9]+]] [[ADDR2]] +; CHECK-SPIRV: 6 Load [[BFLOAT]] [[DATA3:[0-9]+]] [[ADDR3]] +; Undef +; Constant +; ConstantComposite +; ConstantNull +; SpecConstant +; SpecConstantComposite +; CHECK-SPIRV: 4 ConvertFToU [[#]] [[#]] [[DATA1]] +; CHECK-SPIRV: 4 ConvertFToS [[#]] [[#]] [[DATA1]] +; CHECK-SPIRV: 4 ConvertSToF [[BFLOAT]] [[#]] [[#]] +; CHECK-SPIRV: 4 ConvertUToF [[BFLOAT]] [[#]] [[#]] +; Bitcast +; CHECK-SPIRV: 4 FNegate [[BFLOAT]] [[#]] [[DATA1]] +; CHECK-SPIRV: 5 FAdd [[BFLOAT]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FSub [[BFLOAT]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FMul [[BFLOAT]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FDiv [[BFLOAT]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FRem [[BFLOAT]] [[#]] [[DATA1]] [[DATA2]] +; FMod +; VectorTimesScalar +; CHECK-SPIRV: 4 IsNan [[#]] [[#]] [[DATA1]] +; CHECK-SPIRV: 4 IsInf [[#]] [[#]] [[DATA1]] +; IsFinite +; CHECK-SPIRV: 4 IsNormal [[#]] [[#]] [[DATA1]] +; CHECK-SPIRV: 5 Ordered [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 Unordered [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 6 Select [[BFLOAT]] [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FOrdEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FUnordEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FOrdNotEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FUnordNotEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FOrdLessThan [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FUnordLessThan [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FOrdGreaterThan [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FUnordGreaterThan [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FOrdLessThanEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FUnordLessThanEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FOrdGreaterThanEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 5 FUnordGreaterThanEqual [[#]] [[#]] [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] fabs [[DATA1]] +; CHECK-SPIRV: 8 ExtInst [[BFLOAT]] [[#]] [[#]] fclamp [[DATA1]] [[DATA2]] [[DATA3]] +; CHECK-SPIRV: 8 ExtInst [[BFLOAT]] [[#]] [[#]] fma [[DATA1]] [[DATA2]] [[DATA3]] +; CHECK-SPIRV: 7 ExtInst [[BFLOAT]] [[#]] [[#]] fmax [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 7 ExtInst [[BFLOAT]] [[#]] [[#]] fmin [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 8 ExtInst [[BFLOAT]] [[#]] [[#]] mad [[DATA1]] [[DATA2]] [[DATA3]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] nan [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_cos [[DATA1]] +; CHECK-SPIRV: 7 ExtInst [[BFLOAT]] [[#]] [[#]] native_divide [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_exp [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_exp10 [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_exp2 [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_log [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_log10 [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_log2 [[DATA1]] +; CHECK-SPIRV: 7 ExtInst [[BFLOAT]] [[#]] [[#]] native_powr [[DATA1]] [[DATA2]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_recip [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_rsqrt [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_sin [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_sqrt [[DATA1]] +; CHECK-SPIRV: 6 ExtInst [[BFLOAT]] [[#]] [[#]] native_tan [[DATA1]] + +; CHECK-LLVM: define spir_func void @OpPhi(bfloat %data1, bfloat %data2) +; CHECK-LLVM: %OpPhi = phi bfloat [ %data1, %blockA ], [ %data2, %blockB ] +; CHECK-LLVM: ret bfloat %OpReturnValue +; CHECK-LLVM: [[ADDR1:[%a-z0-9]+]] = alloca bfloat +; CHECK-LLVM: [[ADDR2:[%a-z0-9]+]] = alloca bfloat +; CHECK-LLVM: [[ADDR3:[%a-z0-9]+]] = alloca bfloat +; CHECK-LLVM: [[DATA1:[%a-z0-9]+]] = load bfloat, ptr [[ADDR1]] +; CHECK-LLVM: [[DATA2:[%a-z0-9]+]] = load bfloat, ptr [[ADDR2]] +; CHECK-LLVM: [[DATA3:[%a-z0-9]+]] = load bfloat, ptr [[ADDR3]] +; %OpUndef +; %OpConstant +; %OpConstantComposite +; %OpConstantNull +; %OpSpecConstant +; %OpSpecConstantComposite +; CHECK-LLVM: %OpConvertFToU = fptoui bfloat [[DATA1]] to i32 +; CHECK-LLVM: %OpConvertFToS = fptosi bfloat [[DATA1]] to i32 +; CHECK-LLVM: %OpConvertSToF = sitofp i32 0 to bfloat +; CHECK-LLVM: %OpConvertUToF = uitofp i32 0 to bfloat +; %OpBitcast +; CHECK-LLVM: %OpFNegate = fneg bfloat [[DATA1]] +; CHECK-LLVM: %OpFAdd = fadd bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFSub = fsub bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFMul = fmul bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFDiv = fdiv bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFRem = frem bfloat [[DATA1]], [[DATA2]] +; %OpFMod +; %OpVectorTimesScalar +; CHECK-LLVM: %[[#]] = call spir_func i32 @_Z5isnanu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %[[#]] = call spir_func i32 @_Z5isinfu6__bf16(bfloat [[DATA1]]) +; %OpIsFinite +; CHECK-LLVM: %[[#]] = call spir_func i32 @_Z8isnormalu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %OpOrdered = fcmp ord bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpUnordered = fcmp uno bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpSelect = select i1 true, bfloat [[DATA1]], bfloat [[DATA2]] +; CHECK-LLVM: %OpFOrdEqual = fcmp oeq bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFUnordEqual = fcmp ueq bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFOrdNotEqual = fcmp one bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFUnordNotEqual = fcmp une bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFOrdLessThan = fcmp olt bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFUnordLessThan = fcmp ult bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFOrdGreaterThan = fcmp ogt bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFUnordGreaterThan = fcmp ugt bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFOrdLessThanEqual = fcmp ole bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFUnordLessThanEqual = fcmp ule bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFOrdGreaterThanEqual = fcmp oge bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %OpFUnordGreaterThanEqual = fcmp uge bfloat [[DATA1]], [[DATA2]] +; CHECK-LLVM: %fabs = call spir_func bfloat @_Z4fabsu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %fclamp = call spir_func bfloat @_Z5clampu6__bf16u6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]]) +; CHECK-LLVM: %fma = call spir_func bfloat @_Z3fmau6__bf16u6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]]) +; CHECK-LLVM: %fmax = call spir_func bfloat @_Z4fmaxu6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]]) +; CHECK-LLVM: %fmin = call spir_func bfloat @_Z4fminu6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]]) +; CHECK-LLVM: %mad = call spir_func bfloat @_Z3madu6__bf16u6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]]) +; CHECK-LLVM: %nan = call spir_func bfloat @_Z3nanu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_cos = call spir_func bfloat @_Z10native_cosu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_divide = call spir_func bfloat @_Z13native_divideu6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]]) +; CHECK-LLVM: %native_exp = call spir_func bfloat @_Z10native_expu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_exp10 = call spir_func bfloat @_Z12native_exp10u6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_exp2 = call spir_func bfloat @_Z11native_exp2u6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_log = call spir_func bfloat @_Z10native_logu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_log10 = call spir_func bfloat @_Z12native_log10u6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_log2 = call spir_func bfloat @_Z11native_log2u6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_powr = call spir_func bfloat @_Z11native_powru6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]]) +; CHECK-LLVM: %native_recip = call spir_func bfloat @_Z12native_recipu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_rsqrt = call spir_func bfloat @_Z12native_rsqrtu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_sin = call spir_func bfloat @_Z10native_sinu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_sqrt = call spir_func bfloat @_Z11native_sqrtu6__bf16(bfloat [[DATA1]]) +; CHECK-LLVM: %native_tan = call spir_func bfloat @_Z10native_tanu6__bf16(bfloat [[DATA1]]) + +declare spir_func bfloat @_Z5clampu6__bf16u6__bf16u6__bf16(bfloat, bfloat, bfloat) +declare spir_func bfloat @_Z3nanu6__bf16(bfloat) +declare spir_func bfloat @_Z10native_cosu6__bf16(bfloat) +declare spir_func bfloat @_Z13native_divideu6__bf16u6__bf16(bfloat, bfloat) +declare spir_func bfloat @_Z10native_expu6__bf16(bfloat) +declare spir_func bfloat @_Z12native_exp10u6__bf16(bfloat) +declare spir_func bfloat @_Z11native_exp2u6__bf16(bfloat) +declare spir_func bfloat @_Z10native_logu6__bf16(bfloat) +declare spir_func bfloat @_Z12native_log10u6__bf16(bfloat) +declare spir_func bfloat @_Z11native_log2u6__bf16(bfloat) +declare spir_func bfloat @_Z11native_powru6__bf16u6__bf16(bfloat, bfloat) +declare spir_func bfloat @_Z12native_recipu6__bf16(bfloat) +declare spir_func bfloat @_Z12native_rsqrtu6__bf16(bfloat) +declare spir_func bfloat @_Z10native_sinu6__bf16(bfloat) +declare spir_func bfloat @_Z11native_sqrtu6__bf16(bfloat) +declare spir_func bfloat @_Z10native_tanu6__bf16(bfloat) + +define spir_func void @OpPhi(bfloat %data1, bfloat %data2) { + br label %blockA +blockA: + br label %phi +blockB: + br label %phi +phi: + %OpPhi = phi bfloat [ %data1, %blockA ], [ %data2, %blockB ] + ret void +} + +define spir_func bfloat @OpReturnValue(bfloat %OpReturnValue) { + ret bfloat %OpReturnValue +} + +define spir_kernel void @testMath() { +entry: + %addr1 = alloca bfloat + %addr2 = alloca bfloat + %addr3 = alloca bfloat + %data1 = load bfloat, ptr %addr1 + %data2 = load bfloat, ptr %addr2 + %data3 = load bfloat, ptr %addr3 + ; %OpUndef + ; %OpConstant + ; %OpConstantComposite + ; %OpConstantNull + ; %OpSpecConstant + ; %OpSpecConstantComposite + %OpConvertFToU = fptoui bfloat %data1 to i32 + %OpConvertFToS = fptosi bfloat %data1 to i32 + %OpConvertSToF = sitofp i32 0 to bfloat + %OpConvertUToF = uitofp i32 0 to bfloat + ; %OpBitcast + %OpFNegate = fneg bfloat %data1 + %OpFAdd = fadd bfloat %data1, %data2 + %OpFSub = fsub bfloat %data1, %data2 + %OpFMul = fmul bfloat %data1, %data2 + %OpFDiv = fdiv bfloat %data1, %data2 + %OpFRem = frem bfloat %data1, %data2 + ; %OpFMod + ; %OpVectorTimesScalar + %OpIsNan = call i1 @llvm.is.fpclass.bfloat(bfloat %data1, i32 3) + %OpIsInf = call i1 @llvm.is.fpclass.bfloat(bfloat %data1, i32 516) + ; %OpIsFinite + %OpIsNormal = call i1 @llvm.is.fpclass.bfloat(bfloat %data1, i32 264) + %OpOrdered = fcmp ord bfloat %data1, %data2 + %OpUnordered = fcmp uno bfloat %data1, %data2 + %OpSelect = select i1 true, bfloat %data1, bfloat %data2 + %OpFOrdEqual = fcmp oeq bfloat %data1, %data2 + %OpFUnordEqual = fcmp ueq bfloat %data1, %data2 + %OpFOrdNotEqual = fcmp one bfloat %data1, %data2 + %OpFUnordNotEqual = fcmp une bfloat %data1, %data2 + %OpFOrdLessThan = fcmp olt bfloat %data1, %data2 + %OpFUnordLessThan = fcmp ult bfloat %data1, %data2 + %OpFOrdGreaterThan = fcmp ogt bfloat %data1, %data2 + %OpFUnordGreaterThan = fcmp ugt bfloat %data1, %data2 + %OpFOrdLessThanEqual = fcmp ole bfloat %data1, %data2 + %OpFUnordLessThanEqual = fcmp ule bfloat %data1, %data2 + %OpFOrdGreaterThanEqual = fcmp oge bfloat %data1, %data2 + %OpFUnordGreaterThanEqual = fcmp uge bfloat %data1, %data2 + %fabs = call bfloat @llvm.fabs.bfloat(bfloat %data1) + %fclamp = call spir_func bfloat @_Z5clampu6__bf16u6__bf16u6__bf16(bfloat %data1, bfloat %data2, bfloat %data3) + %fma = call bfloat @llvm.fma.bfloat(bfloat %data1, bfloat %data2, bfloat %data3) + %fmax = call bfloat @llvm.maxnum.bfloat(bfloat %data1, bfloat %data2) + %fmin = call bfloat @llvm.minnum.bfloat(bfloat %data1, bfloat %data2) + %mad = call bfloat @llvm.fmuladd.bfloat(bfloat %data1, bfloat %data2, bfloat %data3) + %nan = call spir_func bfloat @_Z3nanu6__bf16(bfloat %data1) + %native_cos = call spir_func bfloat @_Z10native_cosu6__bf16(bfloat %data1) + %native_divide = call spir_func bfloat @_Z13native_divideu6__bf16u6__bf16(bfloat %data1, bfloat %data2) + %native_exp = call spir_func bfloat @_Z10native_expu6__bf16(bfloat %data1) + %native_exp10 = call spir_func bfloat @_Z12native_exp10u6__bf16(bfloat %data1) + %native_exp2 = call spir_func bfloat @_Z11native_exp2u6__bf16(bfloat %data1) + %native_log = call spir_func bfloat @_Z10native_logu6__bf16(bfloat %data1) + %native_log10 = call spir_func bfloat @_Z12native_log10u6__bf16(bfloat %data1) + %native_log2 = call spir_func bfloat @_Z11native_log2u6__bf16(bfloat %data1) + %native_powr = call spir_func bfloat @_Z11native_powru6__bf16u6__bf16(bfloat %data1, bfloat %data2) + %native_recip = call spir_func bfloat @_Z12native_recipu6__bf16(bfloat %data1) + %native_rsqrt = call spir_func bfloat @_Z12native_rsqrtu6__bf16(bfloat %data1) + %native_sin = call spir_func bfloat @_Z10native_sinu6__bf16(bfloat %data1) + %native_sqrt = call spir_func bfloat @_Z11native_sqrtu6__bf16(bfloat %data1) + %native_tan = call spir_func bfloat @_Z10native_tanu6__bf16(bfloat %data1) + ret void +} + +!opencl.enable.FP_CONTRACT = !{} +!opencl.spir.version = !{!0} +!opencl.ocl.version = !{!1} +!opencl.used.extensions = !{!2} +!opencl.used.optional.core.features = !{!3} +!opencl.compiler.options = !{!3} + +!0 = !{i32 1, i32 2} +!1 = !{i32 2, i32 0} +!2 = !{!"cl_khr_fp16"} +!3 = !{} From 4fb437ed82cb1d77f85428d7249946327d13fbb0 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Tue, 26 Aug 2025 06:15:06 -0700 Subject: [PATCH 2/4] adjust test --- .../INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll index b7ddcc750d..5eb885356f 100644 --- a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll +++ b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll @@ -258,6 +258,18 @@ entry: ret void } +declare i1 @llvm.is.fpclass.bfloat(bfloat, i32) + +declare bfloat @llvm.fabs.bfloat(bfloat) + +declare bfloat @llvm.fma.bfloat(bfloat, bfloat, bfloat) + +declare bfloat @llvm.maxnum.bfloat(bfloat, bfloat) + +declare bfloat @llvm.minnum.bfloat(bfloat, bfloat) + +declare bfloat @llvm.fmuladd.bfloat(bfloat, bfloat, bfloat) + !opencl.enable.FP_CONTRACT = !{} !opencl.spir.version = !{!0} !opencl.ocl.version = !{!1} From 6359e392511f0ee146cdc58b4b05df5ced749826 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Wed, 27 Aug 2025 05:14:13 -0700 Subject: [PATCH 3/4] remove fpclass --- test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll index 5eb885356f..09ce620825 100644 --- a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll +++ b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll @@ -46,10 +46,6 @@ target triple = "spirv64-unknown-unknown" ; CHECK-SPIRV: 5 FRem [[BFLOAT]] [[#]] [[DATA1]] [[DATA2]] ; FMod ; VectorTimesScalar -; CHECK-SPIRV: 4 IsNan [[#]] [[#]] [[DATA1]] -; CHECK-SPIRV: 4 IsInf [[#]] [[#]] [[DATA1]] -; IsFinite -; CHECK-SPIRV: 4 IsNormal [[#]] [[#]] [[DATA1]] ; CHECK-SPIRV: 5 Ordered [[#]] [[#]] [[DATA1]] [[DATA2]] ; CHECK-SPIRV: 5 Unordered [[#]] [[#]] [[DATA1]] [[DATA2]] ; CHECK-SPIRV: 6 Select [[BFLOAT]] [[#]] [[#]] [[DATA1]] [[DATA2]] @@ -215,10 +211,6 @@ entry: %OpFRem = frem bfloat %data1, %data2 ; %OpFMod ; %OpVectorTimesScalar - %OpIsNan = call i1 @llvm.is.fpclass.bfloat(bfloat %data1, i32 3) - %OpIsInf = call i1 @llvm.is.fpclass.bfloat(bfloat %data1, i32 516) - ; %OpIsFinite - %OpIsNormal = call i1 @llvm.is.fpclass.bfloat(bfloat %data1, i32 264) %OpOrdered = fcmp ord bfloat %data1, %data2 %OpUnordered = fcmp uno bfloat %data1, %data2 %OpSelect = select i1 true, bfloat %data1, bfloat %data2 From e6459bc17ff30404896e7bbd55c795c462a22266 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Wed, 27 Aug 2025 05:19:04 -0700 Subject: [PATCH 4/4] adjust test Signed-off-by: Sidorov, Dmitry --- test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll index 09ce620825..9ca5da2c11 100644 --- a/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll +++ b/test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll @@ -111,10 +111,6 @@ target triple = "spirv64-unknown-unknown" ; CHECK-LLVM: %OpFRem = frem bfloat [[DATA1]], [[DATA2]] ; %OpFMod ; %OpVectorTimesScalar -; CHECK-LLVM: %[[#]] = call spir_func i32 @_Z5isnanu6__bf16(bfloat [[DATA1]]) -; CHECK-LLVM: %[[#]] = call spir_func i32 @_Z5isinfu6__bf16(bfloat [[DATA1]]) -; %OpIsFinite -; CHECK-LLVM: %[[#]] = call spir_func i32 @_Z8isnormalu6__bf16(bfloat [[DATA1]]) ; CHECK-LLVM: %OpOrdered = fcmp ord bfloat [[DATA1]], [[DATA2]] ; CHECK-LLVM: %OpUnordered = fcmp uno bfloat [[DATA1]], [[DATA2]] ; CHECK-LLVM: %OpSelect = select i1 true, bfloat [[DATA1]], bfloat [[DATA2]]