Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,5 @@ EXT(SPV_INTEL_shader_atomic_bfloat16)
EXT(SPV_EXT_float8)
EXT(SPV_INTEL_predicated_io)
EXT(SPV_INTEL_sigmoid)
EXT(SPV_INTEL_float4)
EXT(SPV_INTEL_fp_conversions)
142 changes: 112 additions & 30 deletions lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ enum FPEncodingWrap {
BF16 = FPEncoding::FPEncodingBFloat16KHR,
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
E2M1 = internal::FPEncodingFloat4E2M1INTEL,
};

// Structure describing non-trivial conversions (FP8 and int4)
Expand Down Expand Up @@ -1077,36 +1078,117 @@ typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;

// clang-format off
template <> inline void FPConvertToEncodingMap::init() {
// 8-bit conversions
add("ConvertE4M3ToFP16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE5M2ToFP16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE4M3ToBF16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
add("ConvertE5M2ToBF16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
add("ConvertFP16ToE4M3EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertFP16ToE5M2EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
add("ConvertBF16ToE4M3EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertBF16ToE5M2EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});

add("ConvertInt4ToE4M3INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
add("ConvertInt4ToE5M2INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
add("ConvertInt4ToFP16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
add("ConvertInt4ToBF16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
add("ConvertFP16ToInt4INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
add("ConvertBF16ToInt4INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
// 4-bit conversions
add("ConvertE2M1ToE4M3INTEL",
{FPEncodingWrap::E2M1, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertE2M1ToE5M2INTEL",
{FPEncodingWrap::E2M1, FPEncodingWrap::E5M2, OpFConvert});
add("ConvertE2M1ToFP16INTEL",
{FPEncodingWrap::E2M1, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE2M1ToBF16INTEL",
{FPEncodingWrap::E2M1, FPEncodingWrap::BF16, OpFConvert});

add("ConvertInt4ToE4M3INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
add("ConvertInt4ToE5M2INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
add("ConvertInt4ToFP16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
add("ConvertInt4ToBF16INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
add("ConvertInt4ToInt8INTEL",
{FPEncodingWrap::Integer, FPEncodingWrap::Integer, OpSConvert});

add("ConvertFP16ToE2M1INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1, OpFConvert});
add("ConvertBF16ToE2M1INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1, OpFConvert});
add("ConvertFP16ToInt4INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
add("ConvertBF16ToInt4INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});

// 8-bit conversions
add("ConvertE4M3ToFP16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE5M2ToFP16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
add("ConvertE4M3ToBF16EXT",
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
add("ConvertE5M2ToBF16EXT",
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
add("ConvertFP16ToE4M3EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertFP16ToE5M2EXT",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
add("ConvertBF16ToE4M3EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
add("ConvertBF16ToE5M2EXT",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});

// SPV_INTEL_fp_conversions
add("ClampConvertFP16ToE2M1INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
internal::OpClampConvertFToFINTEL});
add("ClampConvertBF16ToE2M1INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
internal::OpClampConvertFToFINTEL});
add("ClampConvertFP16ToE4M3INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
internal::OpClampConvertFToFINTEL});
add("ClampConvertBF16ToE4M3INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
internal::OpClampConvertFToFINTEL});
add("ClampConvertFP16ToE5M2INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
internal::OpClampConvertFToFINTEL});
add("ClampConvertBF16ToE5M2INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
internal::OpClampConvertFToFINTEL});
add("ClampConvertFP16ToInt4INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
internal::OpClampConvertFToSINTEL});
add("ClampConvertBF16ToInt4INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
internal::OpClampConvertFToSINTEL});

add("StochasticRoundFP16ToE5M2INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
internal::OpStochasticRoundFToFINTEL});
add("StochasticRoundFP16ToE4M3INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
internal::OpStochasticRoundFToFINTEL});
add("StochasticRoundBF16ToE5M2INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
internal::OpStochasticRoundFToFINTEL});
add("StochasticRoundBF16ToE4M3INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
internal::OpStochasticRoundFToFINTEL});
add("StochasticRoundFP16ToE2M1INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
internal::OpStochasticRoundFToFINTEL});
add("StochasticRoundBF16ToE2M1INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
internal::OpStochasticRoundFToFINTEL});
add("ClampStochasticRoundFP16ToInt4INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
internal::OpClampStochasticRoundFToSINTEL});
add("ClampStochasticRoundBF16ToInt4INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
internal::OpClampStochasticRoundFToSINTEL});

add("ClampStochasticRoundFP16ToE5M2INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
internal::OpClampStochasticRoundFToFINTEL});
add("ClampStochasticRoundFP16ToE4M3INTEL",
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
internal::OpClampStochasticRoundFToFINTEL});
add("ClampStochasticRoundBF16ToE5M2INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
internal::OpClampStochasticRoundFToFINTEL});
add("ClampStochasticRoundBF16ToE4M3INTEL",
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
internal::OpClampStochasticRoundFToFINTEL});
}

// clang-format on
Expand Down
68 changes: 61 additions & 7 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,11 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {

Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
switch (T->getFloatBitWidth()) {
case 4:
// No LLVM IR counter part for FP4 - map it on i4.
return Type::getIntNTy(*Context, 4);
case 8:
// No LLVM IR counter part for FP8 - map it on i8
// No LLVM IR counter part for FP8 - map it on i8.
return Type::getIntNTy(*Context, 8);
case 16:
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
Expand Down Expand Up @@ -1064,11 +1067,12 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
return FPEncodingWrap::IEEE754;
};

auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
auto IsFP4OrFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
return Encoding == FPEncodingWrap::E4M3 ||
Encoding == FPEncodingWrap::E5M2 || Encoding == FPEncodingWrap::E2M1;
};

switch (BC->getOpCode()) {
switch (static_cast<unsigned>(BC->getOpCode())) {
case OpPtrCastToGeneric:
case OpGenericCastToPtr:
case OpPtrCastToCrossWorkgroupINTEL:
Expand All @@ -1089,6 +1093,11 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
case OpUConvert:
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
break;
case internal::OpClampConvertFToFINTEL:
case internal::OpClampConvertFToSINTEL:
case internal::OpStochasticRoundFToFINTEL:
case internal::OpClampStochasticRoundFToFINTEL:
case internal::OpClampStochasticRoundFToSINTEL:
case OpConvertSToF:
case OpConvertFToS:
case OpConvertUToF:
Expand All @@ -1113,7 +1122,7 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,

FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
if (IsFP4OrFP8Encoding(SrcEnc) || IsFP4OrFP8Encoding(DstEnc) ||
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
Expand All @@ -1123,13 +1132,47 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
std::string BuiltinName =
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
BuiltinFuncMangleInfo Info;
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
std::string MangledName;
// Translate additional Ops for stochastic conversions.
if (OC == internal::OpStochasticRoundFToFINTEL ||
OC == internal::OpClampStochasticRoundFToFINTEL ||
OC == internal::OpClampStochasticRoundFToSINTEL) {
// Seed.
Ops.emplace_back(transValue(SPVOps[1], F, BB, true));
OpsTys.emplace_back(Ops[1]->getType());
constexpr unsigned MaxOpsSize = 3;
if (SPVOps.size() == MaxOpsSize) {
// New Seed.
Ops.emplace_back(transValue(SPVOps[2], F, BB, true));

// The following mess is needed to create a function with correct
// mangling.
SPIRVType *PtrTy = SPVOps[2]->getType();
const unsigned AS =
SPIRSPIRVAddrSpaceMap::rmap(PtrTy->getPointerStorageClass());
Type *ElementTy = transType(PtrTy->getPointerElementType());
OpsTys.emplace_back(TypedPointerType::get(ElementTy, AS));
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
// But to create function itself we need untyped pointer type.
OpsTys[2] = opaquifyType(OpsTys[2]);
}
}

if (MangledName.empty())
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);

FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
return CallInst::Create(Func, Ops, "", BB);
}
}
// These conversions can be done without __builtin_spirv prefixed functions
// as their operand and result types have native representation in LLVM IR.
if (OC == internal::OpClampConvertFToFINTEL ||
OC == internal::OpStochasticRoundFToFINTEL ||
OC == internal::OpClampStochasticRoundFToFINTEL)
return mapValue(BV, transSPIRVBuiltinFromInst(
static_cast<SPIRVInstruction *>(BV), BB));

if (OC == OpFConvert) {
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
Expand Down Expand Up @@ -3053,7 +3096,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
OutMatrixElementTy->isTypeFloat(
4, internal::FPEncodingFloat4E2M1INTEL) ||
InMatrixElementTy->isTypeFloat(4,
internal::FPEncodingFloat4E2M1INTEL))
Inst = transConvertInst(BV, F, BB);
else
Inst = transSPIRVBuiltinFromInst(BI, BB);
Expand All @@ -3062,6 +3109,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
}
return mapValue(BV, Inst);
}
if (isIntelCvtOpCode(OC))
return mapValue(BV, transConvertInst(BV, F, BB));
return mapValue(
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
}
Expand Down Expand Up @@ -3878,6 +3927,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
case internal::OpTaskSequenceCreateINTEL:
case internal::OpConvertHandleToImageINTEL:
case internal::OpConvertHandleToSampledImageINTEL:
case internal::OpClampConvertFToFINTEL:
case internal::OpClampConvertFToSINTEL:
case internal::OpStochasticRoundFToFINTEL:
case internal::OpClampStochasticRoundFToFINTEL:
case internal::OpClampStochasticRoundFToSINTEL:
AddRetTypePostfix = true;
break;
default: {
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/SPIRVToOCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ void SPIRVToOCLBase::visitCastInst(CastInst &Cast) {
DstVecTy->getScalarSizeInBits() == 1)
return;

// We don't have OpenCL builtins for 4-bit conversions.
if (DstVecTy->getScalarSizeInBits() == 4 || SrcTy->getScalarSizeInBits() == 4)
return;

// Assemble built-in name -> convert_gentypeN
std::string CastBuiltInName(kOCLBuiltinName::ConvertPrefix);
// Check if this is 'floating point -> unsigned integer' cast
Expand Down
Loading
Loading