diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py index e6af4a5abcf5..f98c80a20a9b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py @@ -23,14 +23,14 @@ ################################################################################ from rocisa.code import Module -from rocisa.container import DSModifiers, vgpr, sgpr, SDWAModifiers, VOP3PModifiers +from rocisa.container import DSModifiers, vgpr, sgpr, SDWAModifiers, VOP3PModifiers, ContinuousRegister from rocisa.enum import SelectBit from rocisa.instruction import SMovB32, SWaitCnt, VOrB32, VPermB32, VLShiftLeftOrB32, \ - VMovB32, VLShiftRightB32, VCvtPkFP8toF32, VCvtF32toF16, VCvtFP8toF32,VCvtScaleFP8toF16,VCvtScalePkFP8toF16, \ + VMovB32, VMovB64,VLShiftRightB32, VCvtPkFP8toF32, VCvtF32toF16, VCvtFP8toF32,VCvtScaleFP8toF16,VCvtScalePkFP8toF16, \ VCvtPkF32toBF16, VCvtBF16toFP32, PVCvtBF16toFP32, VDot2CF32BF16, SNop, VSubF32, VSwapB32 from ..Component import LocalRead - + from math import ceil class LocalReadVALU(LocalRead): @@ -149,15 +149,15 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): class LocalReadMFMA(LocalRead): kernel = {"EnableMatrixInstruction": True} - - # LDS size is increased on gfx950. const offset is still 16-bit. + + # LDS size is increased on gfx950. const offset is still 16-bit. # this function handles both LDS size < 64K and LDS size >= 64K def cal_offset_srcAddr(self, maxLDSConstOffset, tc, offset): num = offset // maxLDSConstOffset offset_val = offset - num * maxLDSConstOffset srcAddr = vgpr("LocalReadAddr%s+%u" %(tc, num)) return offset_val, srcAddr - + """ Local Read: Do It A/B @@ -245,7 +245,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): blockStride = elementsPerBlockSMFMA * threadGroups blockOffsetSMFMA = blockStride - elementsPerBlockSMFMA - maxLDSConstOffset = writer.states.regCaps["maxLDSConstOffset"] + maxLDSConstOffset = writer.states.regCaps["maxLDSConstOffset"] valufIdx = 0 if enableLDSTr: numberMTilesPerWave = kernel["MIWaveTile"][tile01] @@ -275,59 +275,80 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): localReadCode = imod.add(Module("LocalRead%s Valu%u"%(tc,valuiIdx))) if needPack or numSplitMetadata: packCode = pack.add(Module("packCode")) + + tmpvgprHI = [] for rIdx in range(0, numReadsPerUnroll): valuiIdx = int(valufIdx) baseLRVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx), numVgpr) destVgpr = baseLRVgpr highBitsForHalf = (blockWidth == 0.5) and ((rIdx % 2) == 1) # rIdx = 1 isHigh16Bits = (blockWidth == 0.25) and ( ((rIdx % 4) //2) == 1) # 2,3 - + if needPack or numSplitMetadata: if kernel["UseF32XEmulation"]: if valuiIdx % 4 == 0: - tmpvgpr01 = writer.vgprPool.checkOutAligned(2, 2) - tmpvgpr = writer.vgprPool.checkOut(1) + tmpvgprIDx = (valuiIdx % 8) // 4 + tmpvgprHI.append(writer.vgprPool.checkOutAligned(2, 2)) + numTmpForCVTSubTF32 = 4 + tmpvgpr =[] + for i in range(numTmpForCVTSubTF32): + tmpvgpr.append(writer.vgprPool.checkOut(1)) v0 = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx)) v1 = vgpr("Valu%s_X%u_I%u+%u+1"%(tc, bufferIdx, iui, valuiIdx)) v2 = vgpr("Valu%s_X%u_I%u+%u+2"%(tc, bufferIdx, iui, valuiIdx)) v3 = vgpr("Valu%s_X%u_I%u+%u+3"%(tc, bufferIdx, iui, valuiIdx)) - # For the TF32 emu we need 4x cvt+sub pairs to compute the "LO" parts. Here we use - # dot2 to replace one instance of cvt+sub pair. - packCode.add(VCvtPkF32toBF16(dst=vgpr(tmpvgpr01), src0=v0, src1=v2)) - packCode.add(VDot2CF32BF16(dst=v0, src0=hex(0x8000bf80), src1=vgpr(tmpvgpr01))) # v0_lo - packCode.add(VDot2CF32BF16(dst=v2, src0=hex(0xbf800000), src1=vgpr(tmpvgpr01))) # v2_lo - packCode.add(VCvtPkF32toBF16(dst=vgpr(tmpvgpr01+1), src0=v1, src1=v3)) + packCode.add(VCvtPkF32toBF16(dst=vgpr(tmpvgprHI[tmpvgprIDx]), src0=v0, src1=v1)) - packCode.add(VDot2CF32BF16(dst=v1, src0=hex(0x8000bf80), src1=vgpr(tmpvgpr01+1))) # v1_lo + packCode.add(PVCvtBF16toFP32(dst=vgpr(tmpvgpr[0]), src=vgpr(tmpvgprHI[tmpvgprIDx]))) + packCode.add(VSubF32(dst=v0, src0=v0, src1=vgpr(tmpvgpr[0]))) + packCode.add(VCvtBF16toFP32(dst=vgpr(tmpvgpr[1]), src=vgpr(tmpvgprHI[tmpvgprIDx]), vgprMask=None, vi=1)) + packCode.add(VSubF32(dst=v1, src0=v1, src1=vgpr(tmpvgpr[1]))) - # Note: v_dot2c_f32_bf16 needs 4 wait states for inst dependencies. Using 4 dot2 would require - # 2 extra dummy waits for the last dot2 (we don't have enough extra work to avoid dummy waits), - # so here we don't use dot2 and just use a cvt+sub pair. We use one extra instruction by - # not using dot2, but save on the 2 extra dummy waits - net savings of one instruction. - packCode.add(VCvtBF16toFP32(dst=vgpr(tmpvgpr), src=vgpr(tmpvgpr01+1), vgprMask=None, vi=1)) - packCode.add(VSubF32(dst=v3, src0=v3, src1=vgpr(tmpvgpr))) + packCode.add(VCvtPkF32toBF16(dst=vgpr(tmpvgprHI[tmpvgprIDx]+1), src0=v2, src1=v3)) + packCode.add(PVCvtBF16toFP32(dst=vgpr(tmpvgpr[2]), src=vgpr(tmpvgprHI[tmpvgprIDx]+1))) + packCode.add(VSubF32(dst=v2, src0=v2, src1=vgpr(tmpvgpr[2]))) + packCode.add(VCvtBF16toFP32(dst=vgpr(tmpvgpr[2]), src=vgpr(tmpvgprHI[tmpvgprIDx]+1), vgprMask=None, vi=1)) + packCode.add(VSubF32(dst=v3, src0=v3, src1=vgpr(tmpvgpr[2]))) - packCode.add(VCvtPkF32toBF16(dst=v2, src0=v0, src1=v2)) - packCode.add(VMovB32(dst=v0, src=vgpr(tmpvgpr01))) - packCode.add(VCvtPkF32toBF16(dst=v3, src0=v1, src1=v3)) - packCode.add(VMovB32(dst=v1, src=vgpr(tmpvgpr01+1))) + for i in range(numTmpForCVTSubTF32): + writer.vgprPool.checkIn(tmpvgpr[i]) - # layout: - # Val+0: bf16 high (0,2) - # Val+1: bf16 high (1,3) - # Val+2: bf16 low (0,2) - # Val+3: bf16 low (1,3) + if rIdx == numReadsPerUnroll - 1: # Last iteration + if not (kernel["MatrixInstM"] == 16 and kernel["MatrixInstK"] == 16): + v0 = vgpr("Valu%s_X%u_I%u+%u+0"%(tc, bufferIdx, iui, baseValuiIdx)) + v1 = vgpr("Valu%s_X%u_I%u+%u+1"%(tc, bufferIdx, iui, baseValuiIdx)) + v2 = vgpr("Valu%s_X%u_I%u+%u+2"%(tc, bufferIdx, iui, baseValuiIdx)) + v3 = vgpr("Valu%s_X%u_I%u+%u+3"%(tc, bufferIdx, iui, baseValuiIdx)) + v4 = vgpr("Valu%s_X%u_I%u+%u+4"%(tc, bufferIdx, iui, baseValuiIdx)) + v5 = vgpr("Valu%s_X%u_I%u+%u+5"%(tc, bufferIdx, iui, baseValuiIdx)) + v6 = vgpr("Valu%s_X%u_I%u+%u+6"%(tc, bufferIdx, iui, baseValuiIdx)) + v7 = vgpr("Valu%s_X%u_I%u+%u+7"%(tc, bufferIdx, iui, baseValuiIdx)) + packCode.add(VCvtPkF32toBF16(dst=v7, src0=v6, src1=v7)) + packCode.add(VCvtPkF32toBF16(dst=v6, src0=v4, src1=v5)) + packCode.add(VCvtPkF32toBF16(dst=v5, src0=v2, src1=v3)) + packCode.add(VCvtPkF32toBF16(dst=v4, src0=v0, src1=v1)) + tmpvgprHI064 = vgpr(tmpvgprHI[0], 2) + tmpvgprHI164 = vgpr(tmpvgprHI[1], 2) + valuvgprHI064 = vgpr("Valu%s_X%u_I%u+%u+0"%(tc, bufferIdx, iui, baseValuiIdx), 2) + valuvgprHI164 = vgpr("Valu%s_X%u_I%u+%u+2"%(tc, bufferIdx, iui, baseValuiIdx), 2) + packCode.add(VMovB64(dst=valuvgprHI064, src=tmpvgprHI064)) + packCode.add(VMovB64(dst=valuvgprHI164, src=tmpvgprHI164)) + for i in range(len(tmpvgprHI)): + writer.vgprPool.checkIn(tmpvgprHI[i]) - writer.vgprPool.checkIn(tmpvgpr01) - writer.vgprPool.checkIn(tmpvgpr) + # layout: + # Val+0: bf16 high (0,1) + # Val+1: bf16 high (2,3) + # Val+2: bf16 high (4,5) + # Val+3: bf16 high (6,7) + # Val+4: bf16 low (0,1) + # Val+5: bf16 low (2,3) + # Val+6: bf16 low (4,5) + # Val+7: bf16 low (6,7) - if rIdx == numReadsPerUnroll - 1: - if not (kernel["MatrixInstM"] == 16 and kernel["MatrixInstK"] == 16): - packCode.add(VSwapB32(dst=vgpr("Valu%s_X%u_I%u+%u+2"%(tc, bufferIdx, iui, baseValuiIdx)), src=vgpr("Valu%s_X%u_I%u+%u+4"%(tc, bufferIdx, iui, baseValuiIdx)))) - packCode.add(VSwapB32(dst=vgpr("Valu%s_X%u_I%u+%u+3"%(tc, bufferIdx, iui, baseValuiIdx)), src=vgpr("Valu%s_X%u_I%u+%u+5"%(tc, bufferIdx, iui, baseValuiIdx)))) if kernel["ConvertAfterDS"] and (tP["bpe"] != tP["bpeDS"]): highBitsForHalf = False isHigh16Bits = False @@ -381,7 +402,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) packCode.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) packCode.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - + if rIdx == numReadsPerUnroll-1: for i in range(0, numVgpr): vgprIdx = (vIdx * numVgpr + i) * tP["bpe"] * kernel["MIInputPerThread%s"%tc] // writer.states.bpr * min(writer.states.bpr // tP["bpe"], vectorWidth) @@ -410,7 +431,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) packCode.add(VCvtF32toF16(dst=cvtDestVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) packCode.add(VCvtF32toF16(dst=cvtDestVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - + if rIdx == numReadsPerUnroll-1: for i in range(0, numVgpr*2): vgprIdx = (2 * vIdx * numVgpr + i) * tP["bpe"] * kernel["MIInputPerThread%s"%tc] // writer.states.bpr * min(writer.states.bpr // tP["bpe"], vectorWidth) @@ -432,7 +453,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): cvtDestVgpr2 = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2*vIdx*numVgpr+2), 1) cvtDestVgpr3 = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2*vIdx*numVgpr+3), 1) packCode.add(VLShiftRightB32(dst=cvtDestVgpr3, shiftHex=16, src=cvtDestVgpr1, comment="shift 2 element to vgpr+3")) - + packCode.add(VMovB32(dst=cvtDestVgpr2, src=cvtDestVgpr1)) packCode.add(VLShiftRightB32(dst=cvtDestVgpr1, shiftHex=16, src=cvtDestVgpr0, comment="shift 2 element to vgpr+1")) if writer.states.asmCaps["Hascvtf16_fp8"]: @@ -453,7 +474,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr3, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) packCode.add(VCvtF32toF16(dst=cvtDestVgpr3, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) packCode.add(VCvtF32toF16(dst=cvtDestVgpr3, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - + if rIdx == numReadsPerUnroll-1: for i in range(0, numVgpr*2): vgprIdx = (2 * vIdx * numVgpr + i) * tP["bpe"] * kernel["MIInputPerThread%s"%tc] // writer.states.bpr * min(writer.states.bpr // tP["bpe"], vectorWidth) @@ -517,7 +538,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): destVgpr_ = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), vIdx*numVgpr + i)) bitShift = 0 for elementIdx in range(0, numSplitMetadata+1): - # go to next vgpr + # go to next vgpr if elementIdx >= writer.states.bpr: break comment_ = "another VGPR storing lshr %d-bit value %d %d" %(bitShift, vgprIdx, elementIdx) if bitShift != 0 else "" @@ -553,7 +574,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): comment="select K=%u%u for vector=%u"%(elementIdx*4+2, elementIdx*4+3, vectorIdx))) packCode.add(VLShiftLeftOrB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx + vgprOffset)), src0=vgpr("PackTemp"), shiftHex=16, src1=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx + vgprOffset)), comment="pack two half Vgpr to one Vgpr")) vgprOffset += 1 - + else: isHigh8Bits = (blockWidth == 0.25) and ( ((rIdx % 4) % 2) == 1) # 1,3 # pack for blockWidth 0.5 type @@ -603,18 +624,18 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): destVgpr = highVgpr if isHigh8Bits and isHigh16Bits: packCode.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=highVgpr, shiftHex=hex(0x8), src1=baseLRVgpr, comment="pack two int8x2 Vgpr to one Vgpr")) - + if kernel["ConvertAfterDS"] and kernel["UnrollMajorLDS%s"%tc]: valufIdx += blockWidth * (tP["bpe"] // tP["bpeDS"]) if (not tP["isM"]) else 1 else: valufIdx += blockWidth if (not tP["isM"]) else 1 - + # load read instrution paramList = [] - + for oIdx in range(0, numOffsets): offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride - + if kernel["ProblemType"]["Sparse"] != 0: if blocksPerTGroupSMFMA > 1: blockId = (rIdx * numElementPerRead) // elementsPerBlockSMFMA #block 0 or block 1 @@ -662,7 +683,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): offset_val = (incOffset + offset_val + tP["localReadOffset"]) * tP["bpeDS"] else: offset_val = (rIdx * numElementPerRead * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpeDS"] - + if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"] offset_val = offset_val + tP["localReadSwapByteOffset"] @@ -670,15 +691,15 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4): # another address conversion for DirectToLds + NumLoadsCoalesced > 1 dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val) - + paramList.append(int(offset_val)) - + comment = "L -> Reg lro=%d swapByteOffset=%u ti=%u vIdx=%u eIdx=%u rIdx=%u oIdx=%u buffer=%u iui=%u" \ % (tP["localReadOffset"], tP["localReadSwapByteOffset"], MIWaveGroupShape[tile01], vIdx, eIdx, rIdx, oIdx, bufferIdx, iui) - + highBits = 0 if writer.states.archCaps["DSLow16NotPreserve"] else highBitsForHalf or isHigh16Bits - - + + if(paramList[0] >=131072): paramList[0] = paramList[0] -131072 srcAddr=vgpr("LocalReadAddr%s+2"%tc) @@ -687,7 +708,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): srcAddr=vgpr("LocalReadAddr%s+1"%tc) else: srcAddr=vgpr("LocalReadAddr%s"%tc) - + if numOffsets == 1: ds = DSModifiers(na=1, offset=paramList[0]) else: @@ -698,7 +719,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): with writer.allocTmpSgpr(1) as tmpSgprInfo: tmpSgpr = tmpSgprInfo.idx if writer.db["CheckValue1%s"%tc] and not writer.inTailLoop: - + dbgVgpr = destVgpr dbgVgprList = destVgpr.split("v[") if len(dbgVgprList) == 1: # vIdx, no [] @@ -708,39 +729,37 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): # TODO: Handle vector, but need to take care the last one dbgVgprList = (dbgVgprList[1].split("]")[0]).split(':') dbgVgpr = "v[%s]"%dbgVgprList[0] - localReadCode.add(SWaitCnt(dscnt=0, vscnt=0, comment="CheckValue1 wait for LDS read")) - + if kernel["ProblemType"]["DataType"].isHalf(): hexValue = hex(0x3c003c00) # packed 1s if needPack: hexValue = hex(0x3c000000) if highBitsForHalf else hex(0x00003c00) localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: FP16")) localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) - + elif kernel["ProblemType"]["DataType"].isBFloat16(): hexValue = hex(0x3f803f80) # packed 1s if needPack: hexValue = hex(0x3f800000) if highBitsForHalf else hex(0x00003f80) localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: BF16")) localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) - + if kernel["ProblemType"]["DataType"].isInt8(): if needPack: hexValue = hex(0x00010000) if isHigh16Bits else hex(0x00000001) localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: INT8")) localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) - + # TODO - Check if this works. But need this? MFMA would use INT8 elif kernel["ProblemType"]["DataType"].isInt8x4(): localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(0x01010101), comment="CheckValue1: INT8x4")) localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) - + elif kernel["ProblemType"]["DataType"].isSingle(): localReadCode.add(writer.assert_eq( dbgVgpr, 1.0) ) - # DTV case, do not return local read code. Return pack code only. if (tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]: imod = Module("LocalReadDo%s_I%s (Empty)" % (tP["tensorChar"],iui)) - + return imod, pack diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 30e9399a2049..61ae7e4d6022 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -948,9 +948,9 @@ def _makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, point scheduleTF32Emu = kernel["UseF32XEmulation"] if scheduleTF32Emu: - # 24 is the instruction count for the TF32 emulation sequence in LocalRead.py - instPerPackA = 24#len(packAItems) - instPerPackB = 24#len(packBItems) + # 26 is the instruction count for the TF32 emulation sequence in LocalRead.py + instPerPackA = 26#len(packAItems) + instPerPackB = 26#len(packBItems) while packAItems or packBItems: for n in range(instPerPackA): if packAItems: