Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ EXT(SPV_INTEL_bindless_images)
EXT(SPV_INTEL_2d_block_io)
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
EXT(SPV_KHR_bfloat16)
EXT(SPV_INTEL_bfloat16_arithmetic)
EXT(SPV_INTEL_ternary_bitwise_function)
EXT(SPV_INTEL_int4)
EXT(SPV_INTEL_function_variants)
11 changes: 11 additions & 0 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,11 @@ ParamType lastFuncParamType(StringRef MangledName) {
char Mangled = Copy.back();
std::string Mangled2 = Copy.substr(Copy.size() - 2);

std::string Mangled6 = Copy.substr(Copy.size() - 6);
if (Mangled6 == "__bf16") {
return ParamType::FLOAT;
}

if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) {
return ParamType::FLOAT;
} else if (isMangledTypeUnsigned(Mangled)) {
Expand Down Expand Up @@ -1966,6 +1971,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
NumElems = VecTy->getNumElements();
Ty = VecTy->getElementType();
}
if (Ty->isBFloatTy() &&
BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL))
return true;
if ((!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) ||
(!BM->hasCapability(CapabilityVectorAnyINTEL) &&
((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) {
Expand All @@ -1982,6 +1990,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
NumElems = VecTy->getNumElements();
Ty = VecTy->getElementType();
}
if (Ty->isBFloatTy() &&
BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL))
return true;
if ((!Ty->isIntegerTy()) ||
(!BM->hasCapability(CapabilityVectorAnyINTEL) &&
((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) {
Expand Down
19 changes: 19 additions & 0 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3817,6 +3817,20 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
// -spirv-allow-unknown-intrinsics work correctly.
auto IID = II->getIntrinsicID();
switch (IID) {
case Intrinsic::fabs:
case Intrinsic::fma:
case Intrinsic::maxnum:
case Intrinsic::minnum:
case Intrinsic::fmuladd: {
Type *Ty = II->getType();
if (Ty->isBFloatTy())
BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
break;
}
default:
break;
}
switch (IID) {
case Intrinsic::assume: {
// llvm.assume translation is currently supported only within
// SPV_KHR_expect_assume extension, ignore it otherwise, since it's
Expand Down Expand Up @@ -4627,6 +4641,11 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
SmallVector<std::string, 2> Dec;
if (isBuiltinTransToExtInst(CI->getCalledFunction(), &ExtSetKind, &ExtOp,
&Dec)) {
if (const auto *FirstArg = F->getArg(0)) {
const auto *Type = FirstArg->getType();
if (Type->isBFloatTy())
BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
}
if (DemangledName.find("__spirv_ocl_printf") != StringRef::npos) {
auto *FormatStrPtr = cast<PointerType>(CI->getArgOperand(0)->getType());
if (FormatStrPtr->getAddressSpace() !=
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,8 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
case CapabilityFunctionVariantsINTEL:
case CapabilitySpecConditionalINTEL:
return ExtensionID::SPV_INTEL_function_variants;
case internal::CapabilityBFloat16ArithmeticINTEL:
return ExtensionID::SPV_INTEL_bfloat16_arithmetic;
default:
return {};
}
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(CapabilityInt4CooperativeMatrixINTEL,
{CapabilityInt4TypeINTEL, CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityBFloat16ArithmeticINTEL,
{CapabilityBFloat16TypeKHR});
}

template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,8 @@ SPIRVInstruction *SPIRVModuleImpl::addBinaryInst(Op TheOpCode, SPIRVType *Type,
SPIRVValue *Op1,
SPIRVValue *Op2,
SPIRVBasicBlock *BB) {
if (Type->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
return addInstruction(SPIRVInstTemplateBase::create(
TheOpCode, Type, getId(),
getVec(Op1->getId(), Op2->getId()), BB, this),
Expand All @@ -1721,6 +1723,8 @@ SPIRVInstruction *SPIRVModuleImpl::addUnaryInst(Op TheOpCode,
SPIRVType *TheType,
SPIRVValue *Op,
SPIRVBasicBlock *BB) {
if (TheType->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
return addInstruction(
SPIRVInstTemplateBase::create(TheOpCode, TheType, getId(),
getVec(Op->getId()), BB, this),
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityInt4CooperativeMatrixINTEL, "Int4CooperativeMatrixINTEL");
add(CapabilityFunctionVariantsINTEL, "FunctionVariantsINTEL");
add(CapabilitySpecConditionalINTEL, "SpecConditionalINTEL");
add(internal::CapabilityBFloat16ArithmeticINTEL, "BFloat16ArithmeticINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ enum InternalCapability {
ICapabilityHWThreadQueryINTEL = 6134,
ICapGlobalVariableDecorationsINTEL = 6146,
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
ICapabilityBFloat16ArithmeticINTEL = 6226,
ICapabilityCooperativeMatrixPrefetchINTEL = 6411,
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32RoundingINTEL = 6425,
Expand Down Expand Up @@ -234,6 +235,8 @@ constexpr Capability CapabilityBfloat16ConversionINTEL =
static_cast<Capability>(ICapBfloat16ConversionINTEL);
constexpr Capability CapabilityGlobalVariableDecorationsINTEL =
static_cast<Capability>(ICapGlobalVariableDecorationsINTEL);
constexpr Capability CapabilityBFloat16ArithmeticINTEL =
static_cast<Capability>(ICapabilityBFloat16ArithmeticINTEL);

} // namespace internal
} // namespace spv
Expand Down
Loading
Loading