Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions tensilelite/Tensile/Components/GlobalWriteBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ class GlobalWriteBatchComponent(GlobalWriteComponents):
def __call__(self, kernel: Solution, tPA, tPB, activation: ActivationModule, ss: StoreState, \
batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleDVec, addrScaleAlphaVec, biasLocalBarrierInit: bool, \
tmpVgpr, bf16CVTVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
packdata, parentWriter) -> Module:
return GlobalWriteBatchWriter(kernel, tPA, tPB, activation, ss, batchIdx, applyAlpha, \
beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleDVec, addrScaleAlphaVec, biasLocalBarrierInit, \
tmpVgpr, bf16CVTVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
packdata, parentWriter).emit()

class GlobalWriteBatchWriter:
def __init__(self, kernel: Solution, tPA, tPB, activation: ActivationModule, ss: StoreState, \
batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleDVec, addrScaleAlphaVec, biasLocalBarrierInit: bool, \
tmpVgpr, bf16CVTVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, \
packdata, parentWriter):
self.kernel = kernel
self.tPA = tPA
Expand All @@ -77,7 +77,7 @@ def __init__(self, kernel: Solution, tPA, tPB, activation: ActivationModule, ss:
self.activationSetPCStruct = activationSetPCStruct
self.activationTypeStr = activationTypeStr
self.tmpVgpr = tmpVgpr
self.bf16CVTVgprStruct = bf16CVTVgprStruct
self.cvtVgprStruct = cvtVgprStruct
self.batchElementSgprs = batchElementSgprs
self.tmpSgpr = tmpSgpr
self.codeAccVgprRead = codeAccVgprRead
Expand Down Expand Up @@ -582,9 +582,13 @@ def _emitNonatomicAdd(self, module: Module):
activationCDataType = self.kernel["ProblemType"]["ActivationComputeDataType"]

if self.kernel["ProblemType"]["DestDataType"].isBFloat16() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.bf16CVTVgprStruct.vgprBf16Mask), "0xffff0000", "mask for pack two bfloat16 element to 32bit" ))
module.add(VMovB32(vgpr(self.bf16CVTVgprStruct.vgprFp32Nan), "0x7fff0000", "fp32 Nan" ))
module.add(VMovB32(vgpr(self.bf16CVTVgprStruct.vgprBf16Inc), "0x7fff", "rounding bias for bfloat16" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBf16Mask), "0xffff0000", "mask for pack two bfloat16 element to 32bit" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp32Nan), "0x7fff0000", "fp32 Nan" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBf16Inc), "0x7fff", "rounding bias for bfloat16" ))
elif self.kernel["ProblemType"]["DestDataType"].isFloat8() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8NanInf), "0x207", "Nan and +/- inf" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Max), "0x43700000", "Max 240" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Min), "0xc3700000", "Min -240" ))

storeCode = Module("GroupLoadStore")
waitCnter = [self.loadsBetaIssued + self.loadsEIssued + self.loadsScaleDVecIssued + self.loadsScaleAlphaVecIssued, self.localLoadsBiasIssued]
Expand Down Expand Up @@ -749,7 +753,7 @@ def _emitNonatomicAdd(self, module: Module):
module.add(scaleAlphaVecModule)

if self.beta:
module.add(self._addSumAlphaWithCBeta(self.kernel, self.ss, self.gwvw, elementIdx, vc0, self.tmpVgpr, self.bf16CVTVgprStruct))
module.add(self._addSumAlphaWithCBeta(self.kernel, self.ss, self.gwvw, elementIdx, vc0, self.tmpVgpr, self.cvtVgprStruct))
elif ((self.parentWriter.states.useBias == DataDirection.READ) or self.kernel["ActivationFuncCall"]) and not self.applyAlpha: # case of alpha=1 and beta=0
if (self.kernel["ProblemType"]["DestDataType"].isInt8() or self.kernel["ProblemType"]["DestDataType"].isInt32() or (self.kernel["ProblemType"]["DataType"].isInt8() and self.kernel["ProblemType"]["DestDataType"].isHalf())) and self.kernel["ProblemType"]["ComputeDataType"].isSingle():
module.add(convertData(self.gwvw, self.ss.elementSumIdx[elementIdx], cvtType=CvtType.CVT_I32_to_F32, \
Expand Down Expand Up @@ -865,7 +869,10 @@ def _emitNonatomicAdd(self, module: Module):
if self.kernel["ProblemType"]["DestDataType"].isHalf():
packModule = self.packdata(self.gwvw, destIdx, self.ss.elementSumIdx[elementIdx], inputPrefix="ValuC+", prefixOffset=self.parentWriter.states.c.startVgprValu)
elif self.kernel["ProblemType"]["DestDataType"].isBFloat16():
packModule = self.packdata(self.gwvw, destIdx, self.ss.elementSumIdx[elementIdx], bf16CVTVgprStruct=self.bf16CVTVgprStruct,
packModule = self.packdata(self.gwvw, destIdx, self.ss.elementSumIdx[elementIdx], bf16CVTVgprStruct=self.cvtVgprStruct,
tmpS01=self.tmpS01, laneSGPRC=self.laneSGPRC, inputPrefix="ValuC+", prefixOffset=self.parentWriter.states.c.startVgprValu)
elif self.kernel["ProblemType"]["DestDataType"].isFloat8():
packModule = self.packdata(self.gwvw, destIdx, self.ss.elementSumIdx[elementIdx], fp8CVTVgprStruct=self.cvtVgprStruct, \
tmpS01=self.tmpS01, laneSGPRC=self.laneSGPRC, inputPrefix="ValuC+", prefixOffset=self.parentWriter.states.c.startVgprValu)
elif self.kernel["ProblemType"]["DestDataType"].isInt32():
if self.kernel["ProblemType"]["ComputeDataType"].isSingle() and ((self.parentWriter.states.useBias == DataDirection.READ) or self.kernel["ActivationFuncCall"] or self.applyAlpha or self.beta):
Expand Down Expand Up @@ -1398,7 +1405,7 @@ def _applyAlpha(self, kernel, gwvw, elementSumIdx, elementIdx, tmpS01):
self.parentWriter.vgprPool.checkIn(vtmp2)
return module

def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, bf16CVTVgprStruct):
def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtVgprStruct):
module = Module("addSumAlphaWithCBeta #elementIdx%u, vc0 %u"%(elementIdx, vc0))
for vi in range(0, gwvw):
dataV = ss.elementData[elementIdx] + int(vi*ss.cfg.numVgprsPerDataPerVI)
Expand Down Expand Up @@ -1447,7 +1454,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, bf16
# src2 = sumIdxV = f32 = opsel 00
dataCExternal = ss.elementData[elementIdx] + vi//2
if (vi%2) == 1:
module.add(VAndB32(dst=vgpr(tmpVgpr), src0=vgpr(dataCExternal), src1=vgpr(bf16CVTVgprStruct.vgprBf16Mask), comment="convert bf16 to fp32"))
module.add(VAndB32(dst=vgpr(tmpVgpr), src0=vgpr(dataCExternal), src1=vgpr(cvtVgprStruct.vgprBf16Mask), comment="convert bf16 to fp32"))
else:
module.add(VLShiftLeftB32(dst=vgpr(tmpVgpr), shiftHex=16, src=vgpr(dataCExternal), comment="convert bf16 to fp32" ))
newSumIdxV = sumIdxV - self.parentWriter.states.c.startVgprValu
Expand Down Expand Up @@ -1509,6 +1516,29 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, bf16
module.add(VFmaF64(dst=vgpr("ValuC+%u"%(newSumIdxV+0),2), src0=vgpr(dataV+2,2), src1=sgpr("Beta+2",2), src2=vgpr("ValuC+%u"%(newSumIdxV+0),2), comment="c.real -= a.imag * b.imag"))
module.add(VFmaF64(dst=vgpr("ValuC+%u"%(newSumIdxV+2),2), src0=vgpr(dataV+0,2), src1=sgpr("Beta+2",2), src2=vgpr("ValuC+%u"%(newSumIdxV+2),2), comment="c.imag += a.real * b.imag"))
module.add(VFmaF64(dst=vgpr("ValuC+%u"%(newSumIdxV+2),2), src0=vgpr(dataV+2,2), src1=sgpr("Beta+0",2), src2=vgpr("ValuC+%u"%(newSumIdxV+2),2), comment="c.imag += a.imag * b.real"))

# float8 precision
elif kernel["ProblemType"]["DestDataType"].isFloat8():
if kernel["ProblemType"]["HighPrecisionAccumulate"]:
newSumIdxV = sumIdxV - self.parentWriter.states.c.startVgprValu
# Generate single f32 code if edge is detected.
isPK = False
if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1):
sb = SelectBit.BYTE_0 if self.gwvw == 1 else SelectBit.BYTE_2
module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb)))
# Original packed route
elif vi%2 == 1:
continue
else:
isPK = True
sb = SelectBit.WORD_0 if vi == 0 else SelectBit.WORD_1
module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb)))
module.add(SNop(waitState=0))
if kernel["ProblemType"]["ComputeDataType"].isSingle():
module.add(VMacF32(dst=vgpr("ValuC+%u"%newSumIdxV), src0=vgpr(tmpVgpr), src1=sgpr("Beta"), comment="finalSum = sum*alpha + C*beta"))
if isPK:
module.add(VMacF32(dst=vgpr("ValuC+%u"%(newSumIdxV+1)), src0=vgpr(tmpVgpr+1), src1=sgpr("Beta"), comment="finalSum = sum*alpha + C*beta (PK)"))

return module

def copyData(computeDataType, elementSumIdx, gwvw, vgprStart, direction=0):
Expand Down
39 changes: 34 additions & 5 deletions tensilelite/Tensile/Components/PackData.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

from ..TensileInstructions import Module, SDWAModifiers, SelectBit, UnusedBit, \
SaturateCastType, VSaturateCastInt, \
VAdd3U32, VCvtF32toF16, VCvtF32toI32, VLShiftRightB32, \
VCmpUF32, VCndMaskB32, \
VOrB32, VPackF16toB32, \
VAndOrB32, VBfeU32, VLShiftLeftB16, \
VRndneF32, SNop, SMovkI32, VMovB32, VMed3I32, \
VAdd3U32, VCvtF32toF16, VLShiftRightB32, \
VCmpUF32, VCndMaskB32, VCvtPkF32toFP8, VOP3PModifiers, \
VCmpClassF32, VOrB32, VPackF16toB32, \
VAndOrB32, VBfeU32, VLShiftLeftB16, SNop, VMed3F32, \
vgpr, sgpr, DataType, TensileInstructions
from ..Component import PackData
from ..Common import globalParameters
Expand Down Expand Up @@ -90,6 +89,36 @@ def __call__(self, gwvw, destIdx, elementSumIdx, bf16CVTVgprStruct, tmpS01, lane
module.add(VAndOrB32(dst=vgpr(d), src0=vgpr(formatVgpr), src1=vgpr(vgprBf16Mask), src2=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), comment="pack two bf16 to dword"))
return module

class PackData_FLOAT8(PackData):
kernel = {"ProblemType": {"ComputeDataType": DataType(DataType.single), "DestDataType": DataType(DataType.float8)}}
def __call__(self, gwvw, destIdx, elementSumIdx, fp8CVTVgprStruct, tmpS01, laneSGPRC, inputPrefix="", prefixOffset=0):
vgprFp8NanInf = fp8CVTVgprStruct.vgprFp8NanInf
vgprFp8Temp = fp8CVTVgprStruct.vgprFp8Temp
vgprFp8Min = fp8CVTVgprStruct.vgprFp8Min
vgprFp8Max = fp8CVTVgprStruct.vgprFp8Max

module = Module("PackData float8")
pos = 0
for vi in range(0, gwvw):
sumIdxV = elementSumIdx + vi
formatVgpr = formatting(sumIdxV, inputPrefix, prefixOffset)
d = destIdx + vi//4
if (vi + 1 >= gwvw) and (gwvw % 2 == 1):
module.add(VCmpClassF32(dst=sgpr(tmpS01,laneSGPRC), src0=vgpr(formatVgpr), src1=vgpr(vgprFp8NanInf), comment="Nan and +/- inf"))
module.add(VMed3F32(dst=vgpr(vgprFp8Temp), src0=vgpr(formatVgpr), src1= vgpr(vgprFp8Min),src2=vgpr(vgprFp8Max)))
module.add(VCndMaskB32(dst=vgpr(formatVgpr), src0=vgpr(vgprFp8Temp), src1=vgpr(formatVgpr), src2=sgpr(tmpS01,laneSGPRC)))
module.add(VCvtPkF32toFP8(dst=vgpr(d), src0=vgpr(formatVgpr), src1=vgpr(formatVgpr), vop3=VOP3PModifiers(op_sel=[0,0,0])))
if vi%2 == 1:
module.add(VCmpClassF32(dst=sgpr(tmpS01,laneSGPRC), src0=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), src1=vgpr(vgprFp8NanInf), comment="Nan and +/- inf"))
module.add(VMed3F32(dst=vgpr(vgprFp8Temp), src0=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), src1= vgpr(vgprFp8Min),src2=vgpr(vgprFp8Max)))
module.add(VCndMaskB32(dst=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), src0=vgpr(vgprFp8Temp), src1=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), src2=sgpr(tmpS01,laneSGPRC)))
module.add(VCmpClassF32(dst=sgpr(tmpS01,laneSGPRC), src0=vgpr(formatVgpr), src1=vgpr(vgprFp8NanInf), comment="Nan and +/- inf"))
module.add(VMed3F32(dst=vgpr(vgprFp8Temp), src0=vgpr(formatVgpr), src1= vgpr(vgprFp8Min),src2=vgpr(vgprFp8Max)))
module.add(VCndMaskB32(dst=vgpr(formatVgpr), src0=vgpr(vgprFp8Temp), src1=vgpr(formatVgpr), src2=sgpr(tmpS01,laneSGPRC)))
module.add(VCvtPkF32toFP8(dst=vgpr(d), src0=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), src1=vgpr(formatVgpr), vop3=VOP3PModifiers(op_sel=[0,0,pos])))
pos = ~pos
return module

class PackData_INT8(PackData):
kernel = {"ProblemType": {"ComputeDataType": DataType(DataType.int32), "DestDataType": DataType(DataType.int8)}}
def __call__(self, gwvw, destIdx, elementSumIdx, tmpVgpr, tmpS01, SaturateTypeInt8 = SaturateCastType.NORMAL, inputPrefix="", prefixOffset=0):
Expand Down
33 changes: 21 additions & 12 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -7792,6 +7792,12 @@ class BF16CVTVgprStruct(NamedTuple): # class for bf16 vgprs
vgprFp32Nan: int = -1
vgprBf16Inc: int = -1

class FP8CVTVgprStruct(NamedTuple):
vgprFp8NanInf: int = -1
vgprFp8Temp: int = -1
vgprFp8Min: int = -1
vgprFp8Max: int = -1

class ActivationSetPCStruct(NamedTuple):
sgprOffsetActivation: int = -1
sgprOffsetBack: int = -1
Expand Down Expand Up @@ -8061,14 +8067,17 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths, elements,
numTmpVgpr = max(numTmpVgpr, actPCMaxTempVgpr + actPCGwvwVgpr)
tmpVgpr = self.vgprPool.checkOutAligned(numTmpVgpr, maxAlign, "store tmps")

bf16CVTVgprStruct = None
bf16CVTVgpr = None
cvtVgprStruct = None
cvtVgpr = None
if kernel["ProblemType"]["DestDataType"].isBFloat16() and kernel["ProblemType"]["HighPrecisionAccumulate"]:
bf16CVTVgpr = self.vgprPool.checkOut(4)
bf16CVTVgprStruct = self.BF16CVTVgprStruct(vgprBf16Temp=bf16CVTVgpr, vgprBf16Mask=(bf16CVTVgpr+1), \
vgprFp32Nan=(bf16CVTVgpr+2), vgprBf16Inc=(bf16CVTVgpr+3))

if kernel["ProblemType"]["DestDataType"].isFloat8() or kernel["ProblemType"]["DestDataType"].isBFloat8():
cvtVgpr = self.vgprPool.checkOut(4)
cvtVgprStruct = self.BF16CVTVgprStruct(vgprBf16Temp=cvtVgpr, vgprBf16Mask=(cvtVgpr+1), \
vgprFp32Nan=(cvtVgpr+2), vgprBf16Inc=(cvtVgpr+3))
elif kernel["ProblemType"]["DestDataType"].isFloat8() and kernel["ProblemType"]["HighPrecisionAccumulate"]:
cvtVgpr = self.vgprPool.checkOut(4)
cvtVgprStruct = self.FP8CVTVgprStruct(vgprFp8Temp=cvtVgpr, vgprFp8NanInf=(cvtVgpr+1), \
vgprFp8Min=(cvtVgpr+2), vgprFp8Max=(cvtVgpr+3))
elif kernel["ProblemType"]["DestDataType"].isBFloat8():
assert(0) #TODO

activationSetPCStruct = None
Expand Down Expand Up @@ -8335,7 +8344,7 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths, elements,
actLoopModule.add(self.globalWriteBatch(kernel, tPA, tPB, activation, ss, batchIdx, \
applyAlpha, beta, edge, atomic, gwvw, atomicW, \
elementsThisBatch, self.vgprs.addrE, self.vgprs.addrD, self.vgprs.addrC, self.vgprs.addrBias, self.vgprs.addrScaleDVec, self.vgprs.addrScaleAlphaVec, \
biasLocalBarrierInit, tmpVgpr, bf16CVTVgprStruct, activationSetPCStruct, \
biasLocalBarrierInit, tmpVgpr, cvtVgprStruct, activationSetPCStruct, \
activationTypeStr, elementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha))
biasLocalBarrierInit = True

Expand Down Expand Up @@ -8441,8 +8450,8 @@ def findInstCount(module, targetItem, count):
# End label
module.add(endLabel)
self.vgprPool.checkIn(tmpVgpr)
if bf16CVTVgpr is not None:
self.vgprPool.checkIn(bf16CVTVgpr)
if cvtVgpr is not None:
self.vgprPool.checkIn(cvtVgpr)
return module

##############################################################################
Expand Down Expand Up @@ -8890,14 +8899,14 @@ def readInput(self, kernel, ss, tc: str, dataType, addrCalc, vc0, data, gwvw, ad
def globalWriteBatch(self, kernel, tPA, tPB, activation, ss: StoreState, batchIdx, \
applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleDVec, addrScaleAlphaVec, biasLocalBarrierInit: bool, \
tmpVgpr, bf16CVTVgprStruct, activationSetPCStruct, activationTypeStr, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, \
batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha) -> Module:
packdata = Component.PackData.find(self)
gwriter = Component.GlobalWriteComponents.find(self)
return gwriter(kernel, tPA, tPB, activation, ss, \
batchIdx, applyAlpha, beta, edge, atomic, gwvw, atomicW, \
batchElements, addrE, addrD, addrC, addrBias, addrScaleDVec, addrScaleAlphaVec, biasLocalBarrierInit, \
tmpVgpr, bf16CVTVgprStruct, activationSetPCStruct, activationTypeStr, \
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationTypeStr, \
batchElementSgprs, tmpSgpr, codeAccVgprRead, codeMulAlpha, packdata, self)

##############################################################################
Expand Down
3 changes: 3 additions & 0 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,9 @@ def assignDerivedParameters(state):

# GlobalSplitU doesn't work with some other things:
if state["GlobalSplitU"] > 1:
if state["ProblemType"]["DestDataType"].isFloat8():
reject(state, "GlobalSplitU currently does not support GSU > 1.")
return
# added GSU support for DGEMM
supported = \
(state["ProblemType"]["DataType"].isSingle()) or \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,13 @@ namespace Tensile
return std::isinf(static_cast<float>(value));
}

template <>
inline bool DataInitialization::isBadOutput<Float8>(Float8 value)
{
return std::isinf(static_cast<float>(value));
}


template <>
inline bool DataInitialization::isBadOutput<BFloat16>(BFloat16 value)
{
Expand Down
21 changes: 21 additions & 0 deletions tensilelite/Tensile/Source/client/include/Reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ namespace Tensile
return absDiff / (absA + absB + 1) < 0.01;
}

template <>
inline bool AlmostEqual(Float8 a, Float8 b)
{
Float8 absA = (a > 0) ? a : (a * -1.f);
Float8 absB = (b > 0) ? b : (b * -1.f);
// this avoids NaN when inf is compared against inf in the alternative code
// path
if(static_cast<float>(absA) == std::numeric_limits<float>::infinity()
|| // numeric_limits is yet to
// support _Float16 type
// properly;
static_cast<float>(absB)
== std::numeric_limits<float>::infinity()) // however promoting it to
// float works just as fine
{
return a == b;
}
Float8 absDiff = (a - b > 0) ? a - b : b - a;
return absDiff / (absA + absB + 1.f) < 0.01f;
}

template <>
inline bool AlmostEqual(BFloat16 a, BFloat16 b)
{
Expand Down
Loading