diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py b/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py index 11ce15e4ff1f..6f89b38658d4 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py @@ -369,3 +369,28 @@ def ceilDivide(numerator, denominator): def roundUpToNearestMultiple(numerator, denominator): return ceilDivide(numerator,denominator)*int(denominator) + + +# Given a divisor, this routine computes the corresponding multiplicative constant +# and required post shifts. +# +# Algorithm based on: https://dl.acm.org/doi/pdf/10.1145/178243.178249 +# +# Inputs: +# d: divisor +# N: Number of bits integers are represented in +# p: precision in bits (usually N = P) +# +# Output: +# mhigh: multiplicative constant +# shPost: amount to right shift after multiplication +def choose_multiplier(d, N, p): + l = int(math.ceil(math.log(d, 2))) + shPost = l + mlow = 2**(N+l) // d + mhigh = (2**(N+l) + 2 ** (N + l - p )) // d + while ((mlow // 2) < (mhigh // 2)) and shPost > 0: + mlow //= 2 + mhigh //= 2 + shPost -=1 + return mhigh, shPost, l diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py b/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py index 65eb09acf855..25803e45345a 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py @@ -158,7 +158,7 @@ def addToStream(key, indexList, InstructionList): InstStreams = convOptToStream(opt1) macro = Macro("MAINLOOP", ["ID", "useGR=1", "usePLR=1", "useGRInc=1", "useLoop=1"]) - #module.add(SBarrier(comment="debug")) + #macro.add(SBarrier(comment="debug")) lastIter = numLoopIter - 1 @@ -472,43 +472,43 @@ def hasCustomSchedule(kernel): optSchedule = dict() syncCode = [] if isNN and useLDSTr and TLDS==1: + # TODO: This schedule can be improved when BC are resolved for MT192 # Note: A/B Global read orders are swapped # i.e. GRA contains GR for B + kernel["SwapGlobalReadOrder"] = True optSchedule = { - 'SYNC' : [[20,21,23,25,27,29,31,33,46,57,58,94], - [20,21,24,26,28,30,32,34,47,58,58,94]], - 'GRIncA' : [[0,1,2,3,4,5,6,7,8]], - 'GRIncB' : [[9,10,11,12,13,14,15,16,17]], - 'LRB0' : [[0,0,1,1,2,2,6,8], - [3,3,4,4,5,5,7,9]], - 'LRA0' : [[10,12,14,16,18,23,35,37,39,41,43,45], - [11,13,15,17,19,22,36,38,40,42,44,46]], - 'LWA' : [[23,25,27,29,31,33], - [24,26,28,30,32,34]], - 'GRA' : [[22,22,24,24,26,26,28,28,30,30, 42,42,43,43,45,45], - [23,23,25,25,27,27,29,29,31,31, 43,43,44,44,46,46]], - 'GRB' : [[54, 56, 58, 60, 62, 64], - [55, 57, 59, 61, 63, 65]], - 'LRSA' : [[47]], - 'LRSB' : [[37]], - 'LWSB' : [[47]], # For B - 'LWSA' : [[52]], # For A - 'LRB1' : [[59,59,61,61,63,63,65,67], - [60,60,62,62,64,64,66,68]], - 'LRA1' : [[69,71,73,75,77,79,81,83,85,85,87,87], - [70,72,74,76,78,80,82,84,86,86,88,88]], + 'SYNC' : [[12,13, 47,48,49,50,51, 52,53, 56,56, 94]], + 'GRIncB' : [[0,1,2,3,4,5,6,7,8]], + 'GRIncA' : [[9,10,11,12,13,14,15,16,17]], + 'LRB0' : [[0,0,1,1,2,2,6,8], + [3,3,4,4,5,5,7,9]], + # These local reads have BC + 'LRA0' : [[10, 15,17,19,21,23, 25,27,29,33,37,39], + [11, 14,16,18,20,22, 24,26,28,32,36,38]], + 'GRA' : [[14,14, 16,16, 18,18, 20,20, 22,22, 34,34,36,36,38,38], + [15,15, 17,17, 19,19, 21,21, 23,23, 35,35,37,37,39,39]], + 'GRB' : [[54,54, 56,56, 58,58, 60,60, 62,62, 64,64], + [55,55, 57,57, 59,59, 61,61, 63,63, 65,65]], + 'LRSA' : [[40]], + 'LRSB' : [[40]], + 'LWSB' : [[41]], # For B + 'LWSA' : [[66]], # For A + 'LRB1' : [[57,57,59,59,61,61,63,65], + [58,58,60,60,62,62,64,64]], + 'LRA1' : [[67,71,73,75,77,79,81,85,87,89,91,93], + [68,72,74,76,78,80,82,86,88,90,92,94]], 'LCC' : [[95, 95]], } - syncCode = [SWaitCnt(dscnt=5, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"), + syncCode = [SWaitCnt(dscnt=1, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"), SBarrier(comment=""), - SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"), - SWaitCnt(dscnt=-1, vlcnt=10, vscnt=-1, comment="Wait for LRB0 to complete"), + SWaitCnt(dscnt=10, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"), + SWaitCnt(dscnt=8, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"), + SWaitCnt(dscnt=6, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"), + SWaitCnt(dscnt=4, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"), + SWaitCnt(dscnt=2, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"), + SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"), + SBarrier(comment=""), + SWaitCnt(dscnt=-1, vlcnt=9, vscnt=-1, comment="Wait for LRB0 to complete"), SBarrier(comment=""), SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"),] diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py index dfa9102498ec..1dd82301c52b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py @@ -259,6 +259,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): numElementPerRead = 1 if kernel["ConvertAfterDS"] and not kernel["UseF32XEmulation"] else (int(blockWidth * bpr) // tP['bpe'] // lrvwTile) inputPerThread = kernel["LocalReadVectorWidth"] if not writer.states.inTailLoop else kernel["MIInputPerThread%s"%tc] + abmatrixinfo = writer.states.a if tc == 'A' else writer.states.b + perpStride = abmatrixinfo.gNLCPerpStride + # pack register if writer.states.archCaps["HasEccHalf"] or not writer.states.asmCaps["HasWMMA_V1"]: needPack = tP["bpeDS"] < 4 and not kernel["UnrollMajorLDS%s"%tc] and not tP["isM"] @@ -305,16 +308,27 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): LocalReadX = instruction.getInst(highBits) offset_val = (tP["localReadOffset"]+MIWaveGroupShape[tile01]*tIdx) * tP["bpeDS"] + tP["localReadSwapByteOffset"] - if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): - offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"] + + def applyPad(offset_val): + if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): + offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"] + return offset_val for oIdx in range(0,numOffsetsPerLoad): + offset, srcAddr = self.cal_offset_srcAddr(maxLDSConstOffset, tc, offset_val) + offset = applyPad(offset) ds = DSModifiers(na=1, offset=offset) destVgpr = vgpr("Valu%s_X%u_I%u+%u+%u"%(tc,bufferIdx,iui, 4*tIdx, oIdx * 2), 2) localReadCode = Module("LocalRead%s Valu%u"%(tc,valuiIdx)) localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) - offset_val += UnrollStride*inputPerThread + if perpStride == 1: + offset_val += UnrollStride*inputPerThread + else: + permBlock = kernel["MatrixInstK"] + perpStrideInv = permBlock // perpStride + inv4K = perpStrideInv * (4 % perpStride) + 4 // perpStride + offset_val += inv4K * kernel["MacroTile%s"%tc] * tP["bpeDS"] if ((subTileIdx == 0 and subIterLoadCount < totalLoads // numSubTiles) \ or (subTileIdx == 1 and subIterLoadCount >= totalLoads // numSubTiles) \ or numSubTiles == 1) or writer.states.inTailLoop: @@ -762,7 +776,12 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): paramList = [] for oIdx in range(0, numOffsets): - offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride + if perpStride > 1 and kernel["ProblemType"]["TLU%s"%tc] == 0: + permBlock = kernel["MatrixInstK"] if kernel["ProblemType"]["TLU%s"%tc] == 1 else kernel["VectorWidth%s"%tc] * kernel["MatrixInstM"] + perpStrideInv = permBlock // perpStride + offset_val = (eIdx * (perpStrideInv) + ((vIdx) * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride + else: + offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride if kernel["ProblemType"]["Sparse"] != 0: if blocksPerTGroupSMFMA > 1: @@ -815,8 +834,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"] offset_val = offset_val + tP["localReadSwapByteOffset"] + # TODO: Add NLC>1 offset calcs here? if (kernel["DirectToLds%s" % tc] and \ - kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4): + kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]: # another address conversion for DirectToLds + NumLoadsCoalesced > 1 dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val) diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py b/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py index 7fcf92228819..240ba6996042 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py @@ -22,14 +22,15 @@ # ################################################################################ -from rocisa.code import Module +from rocisa.code import Module, Label from rocisa.container import vgpr, ContinuousRegister -from rocisa.instruction import VAddU32 +from rocisa.instruction import VAddU32, VAndB32, VLShiftLeftB32, VLShiftRightB32 from rocisa.functions import vectorStaticRemainder, \ vectorStaticDivideAndRemainder, vectorStaticDivide, vectorStaticMultiply, \ vectorStaticMultiplyAdd from ..Component import LraTileAssignment, LraTileProperties +from ..Common import roundUp, log2, ceilDivide from dataclasses import dataclass @dataclass @@ -155,7 +156,6 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi mReg = writer.vgprPool.checkOut(1,"mReg") # remainder isWmma_v1 = writer.states.asmCaps["HasWMMA_V1"] - # get constant parameter tc = tP["tensorChar"] tile01 = tP["tile01Idx"] @@ -191,6 +191,9 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi dividedForWaveId = dividedForWaveId, \ vectorWidth=vectorWidth, \ maxKId=maxKId) + abmatrixinfo = writer.states.a if tc == 'A' else writer.states.b + perpStride = abmatrixinfo.gNLCPerpStride + permBlock = abmatrixinfo.gNLCPermBlock # strider for each type of index umlds = kernel["UnrollMajorLDS%s" % tc] @@ -204,6 +207,8 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi strideK = inputPerThread if umlds else (mt + LdsPad) * inputPerThread if enableLDSTr: + if kernel["UseGeneralizedNLCOne%s"%tc] and perpStride > 1: + strideK = 8 strideK1 = mt+LdsPad # FIXME SPARSE @@ -234,6 +239,25 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi if isDTVAB: strideTile = 1 # DTV case. Actual stride will be applied later. + def perpPerm(vgprReg): + reMap0 = writer.vgprPool.checkOut(1) + reMap1 = writer.vgprPool.checkOut(1) + perpStrideInv = permBlock // perpStride + + module.addComment0("Computing strided(%u) perp indicies"%perpStrideInv) + module.add(VAndB32(dst=vgpr(reMap0), src0=(permBlock // perpStrideInv - 1), src1=vgpr(vgprReg), comment="r0 = I %% (%u // %u)"%(permBlock, perpStrideInv))) + module.add(VLShiftLeftB32(dst=vgpr(reMap0), shiftHex=log2(perpStrideInv), src=vgpr(reMap0), comment="r0 = %u * r0"%(perpStrideInv))) + module.addComment0("Computing r1 = (I %% %u) // (%u // %u)"%(permBlock, permBlock, perpStrideInv)) + module.add(VAndB32(dst=vgpr(reMap1), src0=(permBlock - 1), src1=vgpr(vgprReg), comment="r1 = I %% (%u)"%(permBlock))) + module.add(VLShiftRightB32(dst=vgpr(reMap1), shiftHex=log2(permBlock // perpStrideInv), src=vgpr(reMap1), comment="r1 = (r1) // (%u // %u)"%(permBlock, perpStrideInv))) + module.add(VAddU32(dst=vgpr(reMap0), src0=vgpr(reMap0), src1=vgpr(reMap1), comment="r0 = r0 + r1" )) + + module.add(VLShiftRightB32(dst=vgpr(reMap1), shiftHex=log2(permBlock), src=vgpr(vgprReg), comment="r1 = I // %u"%(permBlock))) + module.add(vectorStaticMultiplyAdd(vgpr(vgprReg), vgpr(reMap1), permBlock, vgpr(reMap0), None)) + + module.addComment0("Done computing strided(%u) perp indices"%perpStrideInv) + writer.vgprPool.checkIn(reMap0) + writer.vgprPool.checkIn(reMap1) with writer.allocTmpSgpr(1) as tmpSgprInfo: # tile offset @@ -245,13 +269,21 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi module.add(vectorStaticRemainder(dummy, sReg, kReg, dividendForKId, tmpVgprRes, tmpSgprInfo, \ "1. N offset: nIdx = wtid %% MI_M(%d)"%dividendForKId)) module.add(vectorStaticDivide(sReg, sReg, 16, tmpVgprRes, \ - "1. thread id in wave: k1Idx = mtid // 4")) + "1. thread id in wave: k1Idx = mtid // 16")) module.add(vectorStaticMultiply(vgpr(sReg), vgpr(sReg), 16, tmpSgprInfo, \ "1. K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1))) else: module.add(vectorStaticRemainder(dummy, tReg, kReg, kernel["MatrixInstN"], tmpVgprRes, tmpSgprInfo, \ "1. N offset: nIdx = wtid %% MI_N(%u)" % kernel["MatrixInstN"])) + + applyVWCalcEarly = perpStride > 1 and kernel["ProblemType"]["TLU%s"%tc] == 0 + if applyVWCalcEarly: + # Apply vector width calc before we apply permutation to perp dim + module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), vectorWidth, tmpSgprInfo, \ + "1. apply VectorWidth: bnOffset = bnOffset * vw(%u)" % vectorWidth)) + perpPerm(tReg) + module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), strideTile, tmpSgprInfo, \ "1. N offset: nOffset = nIdx * nStride(%u)" % strideTile)) if enableLDSTr: @@ -268,8 +300,9 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi else: module.addComment0("Skip. 2. block offset: bnOffset = 0 when num1DBlocks = 1") - module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), vectorWidth, tmpSgprInfo, \ - "4. apply VectorWidth: bnOffset = bnOffset * vw(%u)" % vectorWidth)) + if not applyVWCalcEarly: + module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), vectorWidth, tmpSgprInfo, \ + "4. apply VectorWidth: bnOffset = bnOffset * vw(%u)" % vectorWidth)) # unroll offset #if isMfma and (dividendForKId != waveWidth): @@ -285,13 +318,24 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi module.add(vectorStaticDivide(kReg, kReg, dividendForKId, tmpVgprRes, \ "5. K offset: kIdx = wtid / (MIN(%u) * MIBB(%u))" % (kernel["MatrixInstN"], kernel["MatrixInstB"]))) if (dividendForKId != waveWidth) and (not isDTVAB): + if enableLDSTr: module.add(vectorStaticMultiply(vgpr(kReg), vgpr(kReg), strideK, tmpSgprInfo, \ "5. K offset: lrKOffset = kIdx * mStride(%u)" % (strideK))) - module.add(vectorStaticMultiply(vgpr(mReg), vgpr(mReg), strideK1, tmpSgprInfo, \ - "5.1 K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1))) - module.add(VAddU32(dst=vgpr(kReg), src0=vgpr(mReg), src1=vgpr(kReg), \ - comment="5.2 offset in wave: lrOffset = bnOffset + lrKOffset")) + + if perpStride == 1: + module.add(vectorStaticMultiply(vgpr(mReg), vgpr(mReg), strideK1, tmpSgprInfo, \ + "5.1 K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1))) + module.add(VAddU32(dst=vgpr(kReg), src0=vgpr(mReg), src1=vgpr(kReg), \ + comment="5.1 offset in wave: lrOffset = bnOffset + lrKOffset")) + else: + module.add(VAddU32(dst=vgpr(kReg), src0=vgpr(mReg), src1=vgpr(kReg), \ + comment="5.1 offset in wave: lrOffset = bnOffset + lrKOffset")) + # Apply permutation to perpendicular dim + if perpStride > 1: + perpPerm(kReg) + module.add(vectorStaticMultiply(vgpr(kReg), vgpr(kReg), strideK1, tmpSgprInfo, \ + "5.2 K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1))) module.add(VAddU32(dst=vgpr(tReg), src0=vgpr(kReg), src1=vgpr(tReg), \ comment="6. offset in wave: lrOffset = bnOffset + lrKOffset")) else: diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 34e5d4b877b6..9be7086fc306 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -102,6 +102,9 @@ class ABMatrixInfo(MatrixInfo): startVgprLocalWriteSwapAddr: int= -1 numSgprGlobalReadIncs: int = -1 + gNLCPermBlock: int = -1 + gNLCPerpStride: int = -1 + # States @dataclass class StateValues: @@ -1956,7 +1959,7 @@ def setupNewTile(self, kernel, tensorParametersA, tensorParametersB, isOptNLL=Fa # If this occurs we need to 'unshift' the C values (see shiftVectorComponents) # BufferLoad does support this shifting, but if GuaranteeNoPartial=1 then # it can be guaranteed that no shifting is required. - if not (kernel["BufferLoad"] and kernel["GuaranteeNoPartialA"]) and not forceNoTileCode: + if not (kernel["BufferLoad"] and kernel["GuaranteeNoPartialA"]) and not forceNoTileCode and not kernel["UseGeneralizedNLCOneA"]: module.addComment1("global read addresses: shift a") module.add(self.graShift(kernel, tensorParametersA)) if tensorParametersA["is_sparse"] and kernel["DirectToVgprSparseMetadata"]: @@ -1969,7 +1972,7 @@ def setupNewTile(self, kernel, tensorParametersA, tensorParametersB, isOptNLL=Fa # Using A's margin to instead Metadata's margin module.add(self.graShift(kernel, tPM, tPMRef["glvw"] if tPMRef["rtv"] else 1)) - if not (kernel["BufferLoad"] and kernel["GuaranteeNoPartialB"]) and not forceNoTileCode: + if not (kernel["BufferLoad"] and kernel["GuaranteeNoPartialB"]) and not forceNoTileCode and not kernel["UseGeneralizedNLCOneB"]: module.addComment1("global read addresses: shift b") module.add(self.graShift(kernel, tensorParametersB)) if tensorParametersB["is_sparse"] and kernel["DirectToVgprSparseMetadata"]: @@ -4371,6 +4374,35 @@ def readWriteVectors(mat, vw, kernel): else: numVgprGlobalReadIncsMetadata = 0 + + def GNLCOInit(tc): + abmatrixinfo = self.states.a if tc == 'A' else self.states.b + if kernel["DirectToLds%s"%tc] and kernel["UseGeneralizedNLCOne%s"%tc]: + ntpl = kernel["NumTotalPackedLoads%s"%tc] + # TODOBS: Determine logic to calculate best permStride.. + if kernel["ProblemType"]["TLU%s"%tc] == 1 and not kernel["enableLDSTr%s"%tc]: + usePerpPerm = False + elif kernel["ProblemType"]["TLU%s"%tc] == 1 and kernel["enableLDSTr%s"%tc]: + usePerpPerm = (ntpl & (ntpl-1)) == 0 + else: + # Currently only VW=1,2 is supported due to how the local read offset + # is currently computed. Supporting VW=1,2 only required small modifications + # to the offset calc. + # TODO: Add support for VW=4,8, this will require more changes in LR offset + # calculations + usePerpPerm = False if kernel["VectorWidth%s"%tc] > 2 else True + + permBlock = kernel["MatrixInstK"] if kernel["ProblemType"]["TLU%s"%tc] == 1 \ + else kernel["VectorWidth%s"%tc] * kernel["MatrixInstM"] + abmatrixinfo.gNLCPermBlock = permBlock + abmatrixinfo.gNLCPerpStride = min([8, 2**int(math.log(ntpl, 2)), permBlock]) if usePerpPerm else 1 + else: + abmatrixinfo.gNLCPerpStride = 1 + abmatrixinfo.gNLCPermBlock = 1 + + GNLCOInit('A') + GNLCOInit('B') + numVgprAddressDbg = self.states.rpga if self.debugConfig.debugKernel else 0 #################################### @@ -5340,6 +5372,7 @@ def getTensorParameters(self, tP, kernel, itP, cM): tP["nru"] = itP[cM].numReadsUnroll # number of reads along unroll dimension tP["nrc"] = kernel["NumLoadsCoalesced%s"%cM] # number of reads along coalesced dimension tP["nrcv"] = itP[cM].numReadsCoalVecComp # number of vector components along coalesced dimension + tP["ntpl"] = kernel["NumTotalPackedLoads%s"%cM] tP["nrp"] = kernel["NumLoadsPerpendicular%s"%cM] # number of reads along perpendicular dimension tP["nrpv"] = itP[cM].numReadsPerpVecComp # number of vector components along perpendicular dimension tP["nwcv"] = itP[cM].numWritesCoalVecComp # number of vector component writes along coalesced dimension diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index 9fd8cc1c7305..b79510694df3 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -65,7 +65,7 @@ VCvtF32toF16, VCvtFP8toF32, VCvtInstruction, VCvtPkF32toBF16, VCvtPkF32toBF8, \ VCvtPkF32toFP8, VCvtPkFP8toF32, VCvtSRF32toBF8, VCvtSRF32toFP8, VCvtScaleFP8toF16, \ VCvtScalePkF16toBF8, VCvtScalePkF16toFP8, VCvtScalePkFP8toF16, VLShiftLeftB32, \ - VLShiftLeftB64, VLShiftRightB32, VMadU32U24, VMaxF32, VMinI32, VMovB32, VMovB64, VMulF32, \ + VLShiftLeftB64, VLShiftRightB32, VLShiftRightB64, VMadU32U24, VMaxF32, VMinI32, VMovB32, VMovB64, VMulF32, \ VMulHIU32, VMulLOU32, VMulPKF32S, VMulU32U24, VNotB32, VOrB32, VPackF16toB32, \ VPrngB32, VReadfirstlaneB32, VSubF32, VSubI32, VSubU32, VXorB32, GlobalLoadTR8B64, GlobalLoadTR16B128 @@ -75,7 +75,7 @@ from .AsmStoreState import StoreState, VectorDataTypes from .Activation import ActivationType from .CustomKernels import isCustomKernelConfig -from .Common import roundUp, log2, ceilDivide +from .Common import roundUp, log2, ceilDivide, choose_multiplier from Tensile.Common import print2, printExit, printWarning, INDEX_CHARS, DebugConfig, DataDirection from Tensile.Common.DataType import DataType from Tensile.Common.RegisterPool import RegisterPool, allocTmpGpr, allocTmpGprList @@ -2913,13 +2913,19 @@ def graFinalOffsets(self, kernel, tP): module.add(singleModule) # DTVA/B always go this way, including swizzled elif (not swapPerpPara): - for perp in range(0, tP["nrp"]): - for sPerp in range(0, tP["nrpv"]): - for para in range(0, tP["nrc"]): - for sPara in range(0, tP["nrcv"]//tP["nrcvpi"]): - # single loop - singleModule, graIdx = self.graFinalOffsetsSingleLoop(kernel, tP, tc, tmp, graIdx, perp, sPerp, para, sPara) - module.add(singleModule) + #module.add(self.graFinalOffsetsSingleLoopGNLC(kernel, tP, tc)) + if kernel["UseGeneralizedNLCOne%s"%tc] and not self.states.inTailLoop: + module.add(self.graFinalOffsetsSingleLoopGNLC(kernel, tP, tc)) + else: + module.addComment0("=============================================================") + for perp in range(0, tP["nrp"]): + for sPerp in range(0, tP["nrpv"]): + for para in range(0, tP["nrc"]): + for sPara in range(0, tP["nrcv"]//tP["nrcvpi"]): + # single loop + singleModule, graIdx = self.graFinalOffsetsSingleLoop(kernel, tP, tc, tmp, graIdx, perp, sPerp, para, sPara) + module.add(singleModule) + module.addComment0("=============================================================") else: # swap para and perp for para in range(0, tP["nrc"]): @@ -2963,6 +2969,142 @@ def graFinalOffsets(self, kernel, tP): return Module("Global Read Addresses: Final Offsets A/B (Empty)") if self.dontAppendCode else module + ############################################################################## + # Global Read Addresses: Final Offsets A/B (single loop) + ############################################################################## + def graFinalOffsetsSingleLoopGNLC(self, kernel, tP, tc, margin = -1): + module = Module() + + if margin == -1: + margin = tP["glvw"] if tP["rtv"] else 1 + + module.addComment("Using GLNC for %s"%tc) + groVgpr0 = "GlobalReadOffset%s+%u" % (tc, 0) + parDimSize = kernel["MacroTile%s"%tc] if kernel["ProblemType"]["TLU%s"%tc] == 1 else kernel["DepthU"] + numThreadsCoalesced = (parDimSize // kernel["GlobalReadVectorWidth%s"%tc]) + + numThreadsPerMI = max((kernel["MatrixInstM"] if kernel["ProblemType"]["TLU%s"%tc] == 1 \ + else kernel["MatrixInstK"]) // kernel["GlobalReadVectorWidth%s"%tc],\ + 1) + numThreadGroupsPerParDim = numThreadsCoalesced // numThreadsPerMI + + module.addComment0("NumThreadsCoalesced%s = %u, %u total threads, %u thread groups"%( \ + tc, numThreadsCoalesced, kernel["NumThreads"], numThreadGroupsPerParDim)) + + + module.add(VMovB32(dst=vgpr(groVgpr0), src=vgpr("Serial"))) + for perp in range(1, tP["ntpl"]): + groVgpr = "GlobalReadOffset%s+%u" % (tc, perp) + groVgprPrev = "GlobalReadOffset%s+%u" % (tc, perp - 1) + strideLoad = kernel["NumThreads"] # stride between consecutive loads + module.add(VAddU32(dst=vgpr(groVgpr), src0=strideLoad, src1=vgpr(groVgprPrev), comment=" = vgprSerial + %u * %u"%(perp, strideLoad))) + + + # TODOBS: Only use perperm for pow2 cases for now.. + abMatrixInfo = self.states.a if tc == 'A' else self.states.b + perpStride = abMatrixInfo.gNLCPerpStride + permBlock = abMatrixInfo.gNLCPermBlock + usePerpPerm = perpStride > 1 + + tmpv = self.vgprPool.checkOutAligned(2,2) + tmpv2 = self.vgprPool.checkOut(1) + tmps = self.sgprPool.checkOut(1) + tmps2 = self.sgprPool.checkOut(1) + divsor = numThreadsCoalesced + + useMagicDiv = divsor > 1 and (divsor & (divsor-1)) != 0 + if useMagicDiv: + maxDividend = kernel["NumThreads"] * tP["ntpl"] + # Verified limits for magic div algo + assert divsor < 512 and maxDividend < 64 * 1024 + Nbits = 32 + + # If divsor is even, compute largest odd value that is multiple of divsor. + # divsor = 2^ee * divsor2 + divsor2, cc,ee = divsor, 1, 0 + while divsor2 % 2 == 0: + divsor2 //= 2 + cc *= 2 + ee += 1 + mm, shPost, l = choose_multiplier(divsor2, Nbits - ee, 16) + + module.add(SMovB32(dst=sgpr(tmps), src=mm, comment="Used in magic div algo, multiplicative constant for 1/%u"%divsor2)) + module.add(SMovB32(dst=sgpr(tmps2), src=divsor2)) + + if kernel["EdgeType"] == "ShiftPtr" and kernel["ProblemType"]["TLU%s"%tc] == 1: + tmpSgpr = self.sgprPool.checkOut(1) + module.add(SMulI32(dst=sgpr(tmpSgpr), src0=sgpr(tP["wg"]), src1=kernel[tP["mt"]], comment="WorkGroup[01] * MT")) + module.add(SSubU32(dst=sgpr(tmpSgpr), src0=self.sizeRef(tP["idx"]), src1=sgpr(tmpSgpr), \ + comment="edge = Size%s - WG*MT"%(tP["tileChar"]))) + module.add(SSubU32(dst=sgpr(tmpSgpr), src0=sgpr(tmpSgpr), src1=margin, comment="edge -= margin(%u)"%(margin))) + + for perp in range(0, tP["ntpl"]): + strideChar = 'L' if tc == 'A' else 'K' + grov = "GlobalReadOffset%s+%u" % (tc, perp) + # Compute division + if useMagicDiv: + if ee > 0: + module.add(VLShiftRightB32(dst=vgpr(tmpv2), shiftHex=ee, src=vgpr(grov), comment="division")) + + module.add(VMulLOU32(dst=vgpr(tmpv), src0=sgpr(tmps), src1=vgpr(tmpv2), \ + comment="division" )) + module.add(VMulHIU32(dst=vgpr(tmpv+1), src0=sgpr(tmps), src1=vgpr(tmpv2), \ + comment="division" )) + module.add(VLShiftRightB64(dst=vgpr(tmpv,2), shiftHex=(Nbits - ee + shPost), src=vgpr(tmpv,2), comment="division")) + else: + module.add(VLShiftRightB32(dst=vgpr(tmpv), shiftHex=log2(divsor), src=vgpr(grov), comment="division")) + + # Compute remainder + if useMagicDiv: + module.add(VMulLOU32(dst=vgpr(tmpv2), src0=sgpr(tmps2), src1=vgpr(tmpv))) + module.add(VLShiftLeftB32(dst=vgpr(tmpv2), shiftHex=ee, src=vgpr(tmpv2), comment="remainder")) + module.add(VSubU32(dst=vgpr(tmpv2), src0=vgpr(grov), src1=vgpr(tmpv2))) + else: + module.add(VAndB32(dst=vgpr(tmpv2), src0=hex(divsor - 1), src1=vgpr(grov))) + + # Permute logic perp dim + if usePerpPerm: + reMap0 = self.vgprPool.checkOut(1) + reMap1 = self.vgprPool.checkOut(1) + module.addComment0("Computing strided(%u) perp indicies"%perpStride) + module.add(VAndB32(dst=vgpr(reMap0), src0=(permBlock // perpStride - 1), src1=vgpr(tmpv), comment="r0 = I %% (%u // %u)"%(permBlock, perpStride))) + module.add(VLShiftLeftB32(dst=vgpr(reMap0), shiftHex=log2(perpStride), src=vgpr(reMap0), comment="r0 = %u * r0"%(perpStride))) + module.addComment0("Computing r1 = (I %% %u) // (%u // %u)"%(permBlock, permBlock, perpStride)) + module.add(VAndB32(dst=vgpr(reMap1), src0=(permBlock - 1), src1=vgpr(tmpv), comment="r1 = I %% (%u)"%(permBlock))) + module.add(VLShiftRightB32(dst=vgpr(reMap1), shiftHex=log2(permBlock // perpStride), src=vgpr(reMap1), comment="r1 = (r1) // (%u // %u)"%(permBlock, perpStride))) + module.add(VAddU32(dst=vgpr(reMap0), src0=vgpr(reMap0), src1=vgpr(reMap1), comment="r0 = r0 + r1" )) + + module.add(VLShiftRightB32(dst=vgpr(reMap1), shiftHex=log2(permBlock), src=vgpr(tmpv), comment="r1 = I // %u"%(permBlock))) + module.add(vectorStaticMultiplyAdd(vgpr(tmpv), vgpr(reMap1), permBlock, vgpr(reMap0), None)) + + module.addComment0("Done computing strided(%u) perp indices"%perpStride) + self.vgprPool.checkIn(reMap0) + self.vgprPool.checkIn(reMap1) + + stride = "Strides%s"%(tc) + module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%c"%tc]), src=vgpr(tmpv2))) + module.add(VMulLOU32(dst=vgpr(tmpv), src0=sgpr(stride), src1=vgpr(tmpv))) + + if kernel["EdgeType"] == "ShiftPtr" and kernel["ProblemType"]["TLU%s"%tc] == 1: + module.add(VMinI32(dst=vgpr(grov), src0=sgpr(tmpSgpr), src1=vgpr(grov), comment="")) + + module.add(VAddU32(dst=vgpr(grov), src0=vgpr(tmpv), src1=vgpr(grov), \ + comment="final" )) + module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(tP["bpeGR"]), src=vgpr(grov))) + module.add(VAddU32(dst=vgpr(grov), src0=self.states.srdShiftLeft[tc] * tP["bpeGR"] , src1=vgpr(grov), \ + comment="ptr-shift" )) + + self.vgprPool.checkIn(tmpv) + self.vgprPool.checkIn(tmpv2) + self.sgprPool.checkIn(tmps) + self.sgprPool.checkIn(tmps2) + + if kernel["EdgeType"] == "ShiftPtr" and kernel["ProblemType"]["TLU%s"%tc] == 1: + self.sgprPool.checkIn(tmpSgpr) + + return module + + ############################################################################## # Global Read Addresses: Final Offsets A/B (single loop) ############################################################################## @@ -3875,8 +4017,9 @@ def lwaFirstOffset(self, kernel, tP): validBytesPerLoad *= (kernel["MacroTile%s"%tc] // kernel["NumLoadsPerpendicular%s"%tc] // (kernel["NumThreads"] // kernel["WavefrontSize"])) isDTVAB = ((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) - assert (validBytesPerLoad <= maxBytesPerLoad) or isDTVAB - assert (kernel[tP["lsc"]] * kernel[tP["lsp"]] % tP["glvw"] == 0) or isDTVAB + # For GNLC we don't need to check these asserts since num threads coalesced may not divide num threads + assert (validBytesPerLoad <= maxBytesPerLoad) or isDTVAB or kernel["UseGeneralizedNLCOne%s"%tc] + assert (kernel[tP["lsc"]] * kernel[tP["lsp"]] % tP["glvw"] == 0) or isDTVAB or kernel["UseGeneralizedNLCOne%s"%tc] if validBytesPerLoad != maxBytesPerLoad: with self.allocTmpSgpr(1) as tmpSgprInfo: @@ -3898,13 +4041,29 @@ def lwaFirstOffset(self, kernel, tP): self.vgprPool.checkIn(tmpVgpr) if kernel["LocalWriteUseSgpr%s"%tc]: - # TODO: Can refactor code above to Compute this directly: - if self.states.archCaps["CrosslaneWait"]: - module.add(SNop(waitState=0, comment="1 wait states required before reading vgpr by lane")) - module.add(VReadfirstlaneB32( + if not kernel["UseGeneralizedNLCOne%s"%tc]: + # TODO: Can refactor code above to Compute this directly: + if self.states.archCaps["CrosslaneWait"]: + module.add(SNop(waitState=0, comment="1 wait states required before reading vgpr by lane")) + module.add(VReadfirstlaneB32( dst=sgpr("LocalWriteAddr%s"%tc), \ src=vgpr(destVgpr), \ comment="Copy lds write address VGPR to SGPR")) + else: + tmpv = self.vgprPool.checkOut(1) + module.add(VLShiftRightB32(dst=vgpr(tmpv), shiftHex=log2(kernel["WavefrontSize"]), src=vgpr("Serial"), comment="Compute waveID")) + if self.states.archCaps["CrosslaneWait"]: + module.add(SNop(waitState=0, comment="1 wait states required before reading vgpr by lane")) + module.add(VReadfirstlaneB32( + dst=sgpr("LocalWriteAddr%s"%tc), \ + src=vgpr(tmpv), \ + comment="Copy lds write address VGPR to SGPR")) + module.add(SMulI32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), \ + src1=((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%c"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"]) )) + if tc == 'B': + module.add(SAddU32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), \ + src1=kernel["LdsOffsetB"] )) + self.vgprPool.checkIn(tmpv) self.vgprPool.checkIn(destVgpr) if kernel["StoreSwapAddr"]: @@ -5213,11 +5372,6 @@ def generateReLoadLoop(tc): imod.addComment1("global read for tail done") imod.add(tailGlobalLoadEndLabel) - if (doA and kernel["DirectToLds%s"%tPA["tensorChar"]]) or \ - (doB and kernel["DirectToLds%s"%tPB["tensorChar"]]): - imod.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ - comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) - if doA or doB: self.vgprPool.checkIn(tmpVgpr) for s in singSgprList: @@ -6127,9 +6281,8 @@ def fixPreloadOffset(offset, sgpxIdxVec, numStoreSgprToLoad): #module.add(SNop(waitState=instCycles)) module.addComment1("Mapping of Acc register -> C Vgpr register") self.codes.accVgprRead = mapAcctoArchRegs(kernel, self.states.maxLimitAgprs, write=False) - if kernel["StreamK"] > 0 and kernel["StreamKAtomic"] == 0: - self.codes.accVgprWrite = mapAcctoArchRegs(kernel, self.states.maxLimitAgprs, write=True) - if kernel["GlobalSplitUAlgorithm"] == "MultipleBufferSingleKernel": + if (kernel["StreamK"] > 0 and kernel["StreamKAtomic"] == 0) or \ + (kernel["GlobalSplitUAlgorithm"] == "MultipleBufferSingleKernel"): self.codes.accVgprWrite = mapAcctoArchRegs(kernel, self.states.maxLimitAgprs, write=True) if kernel["MIArchVgpr"]: module.addComment1("Multiply MI out register with Alpha -> C Vgpr register") @@ -6798,8 +6951,8 @@ def findSparseOffset(isA:bool): numSubTiles = kernel["numSubTiles"] if numSubTiles > 1 and not self.states.inTailLoop: # iter (idxOuter_start, idxOuter_stop) (idxInner_start, idxInner_stop) MFMA - # 0 (0,4) (0,4) MFMA(A0,B0) - # 1 (0,4) (4,8) MFMA(A1,B0) + # 0 (0,4) (0,4) MFMA(A0,B0) + # 1 (0,4) (4,8) MFMA(A1,B0) # 2 (4,8) (0,4) MFMA(A0,B1) # 3 (4,8) (4,8) MFMA(A1,B1) outerBy2=(kernel["MIWaveTile"][outer]//numSubTiles) @@ -6808,8 +6961,8 @@ def findSparseOffset(isA:bool): innerMod2=(kernel["MIWaveTile"][inner]%numSubTiles) idxHalfO = u//numSubTiles idxHalfI = u % numSubTiles - idxOuter_start = (outerBy2 + outerMod2)*idxHalfO - idxInner_start = (innerBy2 + innerMod2)*idxHalfI + idxOuter_start = (outerBy2 + outerMod2)*idxHalfO + idxInner_start = (innerBy2 + innerMod2)*idxHalfI idxOuter_stop = kernel["MIWaveTile"][outer] - (1-idxHalfO)* outerBy2 idxInner_stop = kernel["MIWaveTile"][inner] - (1-idxHalfI)* innerBy2 @@ -8172,11 +8325,6 @@ def globalReadGuardKBody(tP, optParams = None): module.add(SWaitCnt(dscnt=0, vlcnt=0, vscnt=0, comment="")) module.add(SBarrier(comment="debug")) - # TODO - can remove one of these m0 restores if A and B both TLU - if kernel["DirectToLds%s"%tP["tensorChar"]]: - module.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ - comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) - if isTr: self.vgprPool.checkIn(maxGroVgpr) elif not kernel["BufferLoad"]: @@ -8617,17 +8765,6 @@ def globalReadBody(tP): imod.footer.add(SBarrier(comment="debug")) #module.add(self.getCmpAssert(self.asmAssert.lt, vgpr("Serial"), 64)) # examine second wavefront - # TODO - can remove one of these m0 restores if A and B both TLU - if kernel["DirectToLds%s"%tP["tensorChar"]] and not (mode == 1 and kernel["PrefetchGlobalRead"]==2): - dst = mgpr(0) - src = hex(kernel["LdsNumBytes"]) - comment = "Restore LDS clamp at %u bytes"%(kernel["LdsNumBytes"]) - # PGR=2 case, footer is located before global read. To avoid setting clamp before global read, store lds clamp code in middle - if kernel["PrefetchGlobalRead"] == 2: - imod.middle.add(SMovB32(dst=dst, src=src, comment=comment)) - else: - imod.footer.add(SMovB32(dst=dst, src=src, comment=comment)) - return imod @@ -8911,7 +9048,10 @@ def calculateLdsWriteOffset(self, perp, para, sPerp, sPara, kernel, tP): # print("2lspaOffset", lspaOffset) # print("2lscaOffset", lscaOffset) - offsetElements = (lspaOffset + lscaOffset) + if kernel["UseGeneralizedNLCOne%s"%tc]: + offsetElements = perp * kernel["GlobalReadVectorWidth%s"%tc] * kernel["WavefrontSize"] * kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] + else: + offsetElements = (lspaOffset + lscaOffset) # print("offsetElements", offsetElements) offsetBytes = offsetElements*tP["bpeDS"] @@ -9774,12 +9914,23 @@ def localReadInc(self, kernel, iui, tP): elif tc == "Metadata": inc //= 8 + padd = 0 + # Apply additional padding if needed when cumulative incs add upt to ldsblocksize if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): - inc = inc + (inc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"] + ldsBSPad = kernel["LdsBlockSizePerPad%s"%tc] + if ( (iui+1) * inc) % kernel["LdsBlockSizePerPad%s"%tc] == 0: + totalIncForNextLR = (iui+1) * inc + totalIncForCurrLR = iui * inc + # Calculate number padding needed for inc used in next LR, and subtract + # any padding added previously to avoid double counting. + padd = (totalIncForNextLR // ldsBSPad - totalIncForCurrLR // ldsBSPad ) \ + * kernel["LdsPad%s"%tc] * tP["bpeDS"] + module.addComment0("Adding additional %u pad since cumulative inc has reached %u"\ + %(padd, kernel["LdsBlockSizePerPad%s"%tc])) with self.allocTmpSgpr(1) as tmpSgprInfo: tmpSgpr = tmpSgprInfo.idx - module.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(inc), comment="inc")) + module.add(SMovB32(dst=sgpr(tmpSgpr), src=(inc + padd), comment="inc")) numLra = 0 if tP["isA"]: numLra = self.states.a.numVgprLocalReadAddr diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index f052e0386a8b..5b9c3e6ff193 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -533,6 +533,20 @@ def setGlobalLoadTileDimClassic(state, tc, numLoads, totalVectorsCoalesced, tota if (tc == "A" or tc == "B") and state["enableGLTr%s"%tc]: state["NumLoadsCoalesced%s"%tc], state["NumLoadsPerpendicular%s"%tc] = state["NumLoadsPerpendicular%s"%tc], state["NumLoadsCoalesced%s"%tc] + # Generalized nlc = 1 case + idxWG = 0 if tc == 'A' else 1 + totalLoadsNeeded = (state["MacroTile%s"%tc] * depthU) // (state["GlobalReadVectorWidth%s"%tc] * state["WavefrontSize"]) + MT = state["MacroTile%s"%tc] + numWaves = state["NumThreads"] // state["WavefrontSize"] + if state["UseGeneralizedNLCOne%s"%tc]: + if (totalLoadsNeeded % numWaves) != 0: + reject(state, printRejectionReason, "GNLC%s: totalLoadsNeeded (%u) %% numWaves (%u) != 0"%(tc, totalLoadsNeeded, numWaves)) + state["NumTotalPackedLoads%s"%tc] = totalLoadsNeeded // numWaves + state["NumLoadsPerpendicular%s"%tc] = state["NumTotalPackedLoads%s"%tc] + state["NumLoadsCoalesced%s"%tc] = 1 + else: + state["NumTotalPackedLoads%s"%tc] = -1 + if state["ProblemType"]["TLU%s"%tc]: state["LSC%s"%tc] = state["MacroTile%s"%tc] // state["NumLoadsCoalesced%s"%tc] state["LSP%s"%tc] = int(math.ceil(float(depthU) / state["NumLoadsPerpendicular%s"%tc])) @@ -763,7 +777,6 @@ def isDirectToVgprDoable(state, tc, printRejectionReason: bool, isaInfoMap: Dict # determine can we use DirectToLds @staticmethod def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): - numBytes = state["ProblemType"]["DataType"].numBytes() isa = state["ISA"] @@ -784,7 +797,7 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): MT = state["MacroTile0"] if tc == 'A' else state["MacroTile1"] - if (MT & (MT-1)) != 0: # Check of MT not power of 2 + if (MT & (MT-1)) != 0 and not state["UseGeneralizedNLCOne%s"%tc]: # Check of MT not power of 2 # so far, numBytesAB<4 case, TLU=False only (continue with False) if (numBytesAB < 4 or state["UseF32XEmulation"]) and state["ProblemType"]["TLU%c"%tc]: return False @@ -820,7 +833,7 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): return False if state["ProblemType"]["TLU%c"%tc] == state["UnrollMajorLDS%c" % tc]: - reject(state, printRejectionReason, "can't use DirectToLds for TLU%c == UnrollMajorLDS%c"%(tc, tc)) + printWarning("can't use DirectToLds for TLU%c == UnrollMajorLDS%c, using nonDirectToLds version instead"%(tc, tc)) return False # avoid picking x2&x4 for precisions < f32/f64 in [ProblemType][TLU] == TRUE @@ -833,7 +846,10 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): if state["LSC%c"%tc] * state["LSP%c"%tc] * numBytesAB != state["WavefrontSize"] * state["GlobalReadVectorWidth%c"%tc] * numBytesAB: reject(state, printRejectionReason, "can't use DirectToLds for LSC%c and LSP%c * bpe!= WavefrontSize * GlobalReadVectorWidth%c * bpe%c > 4"%(tc, tc, tc, tc)) return False - else: + if state["WaveSeparateGlobalRead%c" % tc] == 2 and state["TransposeLDS"] != 2: + reject(state, printRejectionReason, "can't use DirectToLds for WSGR%s = 2 and TLDS != 2"%(tc)) + return False + elif not state["UseGeneralizedNLCOne%s"%tc]: if state["LSC%c"%tc] * state["LSP%c"%tc] * numBytesAB != state["NumThreads"] * state["GlobalReadVectorWidth%c"%tc] * numBytesAB: reject(state, printRejectionReason, "can't use DirectToLds for LSC%c and LSP%c * bpe != NumThreads * GlobalReadVectorWidth%c * bpe%c > 4"%(tc, tc, tc, tc)) return False @@ -1196,7 +1212,7 @@ def assignDerivedParameters( if state["StreamKXCCMapping"] != 0: reject(state, printRejectionReason, "Cannot use auto WGMXCC with SKXCC.") return False - + if state["WorkGroupMapping"] == 0: if state["WorkGroupMappingXCC"] == -1: if state["StreamK"] == 0: @@ -1282,6 +1298,15 @@ def assignDerivedParameters( #print("PackedC0IdxChars", state["PackedC0IdxChars"]) #print("PackedC1IdxChars", state["PackedC1IdxChars"]) + # + # UnrollMajorLDS{A,B} + # 0: means M/N is contiguous in LDS + # 1: means K is contiguous in LDS + # + # TLU{A,B} + # 1: means M/N is contiguous in global memory + # 0: means K is contiguous in global memory + # if state["EnableMatrixInstruction"]: if state["TransposeLDS"] == -1: if state["ProblemType"]["TLUA"] and state["ProblemType"]["TLUB"]: @@ -1623,7 +1648,7 @@ def calcLdsPad(lrvw: int, isaInfoMap: Dict[str, IsaInfo]) -> int: ldsPadA = 16 // state["ProblemType"]["DataType"].numBytes() if state["DirectToLdsA"]: # TODO: Check if there are cases which benefit from padding, currently set to zero by default - ldsPadA = 0 + ldsPadA = state["MatrixInstM"] if state["enableLDSTrA"] else 0 else: # mac instruction if state["ProblemType"]["TLUA"]: ldsPadA = 0 @@ -1646,7 +1671,7 @@ def calcLdsPad(lrvw: int, isaInfoMap: Dict[str, IsaInfo]) -> int: ldsPadB = 16 // state["ProblemType"]["DataType"].numBytes() if state["DirectToLdsB"]: # TODO: Check if there are cases which benefit from padding, currently set to zero by default - ldsPadB = 0 + ldsPadB = state["MatrixInstM"] if state["enableLDSTrB"] else 0 else: if state["ProblemType"]["TLUB"]: ldsPadB = 0 @@ -1676,10 +1701,21 @@ def calcLdsPad(lrvw: int, isaInfoMap: Dict[str, IsaInfo]) -> int: ldsPadM = 0 assert(ldsPadM >= 0) - if state["DirectToLdsA"] and state["ProblemType"]["TLUA"]: - ldsPadA = 0 - if state["DirectToLdsB"] and state["ProblemType"]["TLUB"]: - ldsPadB = 0 + def removeLdsPadLogicForDTL(tc, ldsPad): + ret = ldsPad + miwt = state["MIWaveTile%s"%tc] + # If TLU = 1 and not using LDSTR, lds read is contiguous so no padding needed + if state["ProblemType"]["TLU%s"%tc] and (not state["enableLDSTr%s"%tc]): + ret = 0 + if state["ProblemType"]["TLU%s"%tc] and state["enableLDSTr%s"%tc] \ + and (miwt & (miwt-1)) != 0 and state["UseGeneralizedNLCOne%s"%tc]: + ret = 0 + return ret + + if state["DirectToLdsA"]: + ldsPadA = removeLdsPadLogicForDTL('A', ldsPadA) + if state["DirectToLdsB"]: + ldsPadB = removeLdsPadLogicForDTL('B', ldsPadB) # set ldsPadA,B=0 for DirectToVgpr if state["DirectToVgprA"]: ldsPadA = 0 @@ -1730,13 +1766,19 @@ def calcLdsBlockSizePerPad(lrvw: int) -> int: bpeA = state["ProblemType"]["DataTypeA"].numBytes() # For DTL lds padding must be a multiple of the instruction load size (in bytes) MinLdsBlockSizePerPadA = (state[f"GlobalReadVectorWidthA"] * bpeA) * state["WavefrontSize"] - LdsBlockSizePerPadA = roundUpToNearestMultiple(LdsBlockSizePerPadA, MinLdsBlockSizePerPadA) + if state["UseGeneralizedNLCOneA"]: + LdsBlockSizePerPadA = MinLdsBlockSizePerPadA + else: + LdsBlockSizePerPadA = roundUpToNearestMultiple(LdsBlockSizePerPadA, MinLdsBlockSizePerPadA) if state["DirectToLdsB"]: bpeB = state["ProblemType"]["DataTypeB"].numBytes() # For DTL lds padding must be a multiple of the instruction load size (in bytes) MinLdsBlockSizePerPadB = (state[f"GlobalReadVectorWidthB"] * bpeB) * state["WavefrontSize"] - LdsBlockSizePerPadB = roundUpToNearestMultiple(LdsBlockSizePerPadB, MinLdsBlockSizePerPadB) + if state["UseGeneralizedNLCOneB"]: + LdsBlockSizePerPadB = MinLdsBlockSizePerPadB + else: + LdsBlockSizePerPadB = roundUpToNearestMultiple(LdsBlockSizePerPadB, MinLdsBlockSizePerPadB) return LdsBlockSizePerPadA, LdsBlockSizePerPadB @@ -1974,6 +2016,28 @@ def calSwizzlePackK(state, tc): if (state["GlobalReadVectorWidthB"] * state["ProblemType"]["DataTypeB"].numBytes()) > 16 and not state["UseF32XEmulation"]: reject(state, printRejectionReason, "GRVWB * DataTypeB.numBytes() > 16") + disableGNLC = False # Set to true to disable GNLC if needed + isMixedPrec = (state["ProblemType"]["DataTypeA"].numBytes() != state["ProblemType"]["DataTypeB"].numBytes()) + if state["DirectToLds"] and state["LocalSplitU"] == 1 \ + and not isMixedPrec and not state["ProblemType"]["Sparse"] \ + and state["MatrixInstB"] == 1 \ + and not disableGNLC: + + for tc in ['A', 'B']: + # Check if we are requesting b64 loads for A/B - these are not compatible with DTL + grwidth = state["GlobalReadVectorWidth%s"%tc] * state["ProblemType"]["DataType%s"%tc].numBytes() + # Check that GR layout is the same as LDS layout for A/B + sameLayout = state["ProblemType"]["TLU%s"%tc] != state["UnrollMajorLDS%s"%tc] + state["UseGeneralizedNLCOne%s"%tc] = grwidth != 8 and sameLayout \ + and state["WaveSeparateGlobalRead%s"%tc] == 0 and not state["DirectToVgpr%s"%tc] + + state["UseGeneralizedNLCOneMetadata"] = False + state["_UseSgprForGRO"] = 0 + else: + state["UseGeneralizedNLCOneA"] = False + state["UseGeneralizedNLCOneB"] = False + state["UseGeneralizedNLCOneMetadata"] = False + ######################################## # Search DepthU # Inputs: @@ -2022,13 +2086,14 @@ def calSwizzlePackK(state, tc): tva = totalElementsA // state["GlobalReadVectorWidthA"] if not Solution.setGlobalReadVectorWidth(state, "A", tva, state["GlobalReadVectorWidthA"], printRejectionReason): validDepthU = False + tvb = totalElementsB // state["GlobalReadVectorWidthB"] if not Solution.setGlobalReadVectorWidth(state, "B", tvb, state["GlobalReadVectorWidthB"], printRejectionReason): validDepthU = False if state["EnableMatrixInstruction"] and state["GlobalReadVectorWidthA"]: partialA = state["ProblemType"]["TLUA"] and (state["AssertFree0ElementMultiple"] % state["GlobalReadVectorWidthA"] != 0) - if partialA: + if partialA and not state["UseGeneralizedNLCOneA"]: glvwAlimit = 16 // state["ProblemType"]["DataType"].numBytes() if state["SourceSwap"]: matrixInstM = (state["MatrixInstM"] * state["MatrixInstBM"]) if (state["MatrixInstM"] == 4) else state["MatrixInstM"] @@ -2047,7 +2112,7 @@ def calSwizzlePackK(state, tc): if state["EnableMatrixInstruction"] and state["GlobalReadVectorWidthB"]: partialB = state["ProblemType"]["TLUB"] and (state["AssertFree1ElementMultiple"] % state["GlobalReadVectorWidthB"] != 0) - if partialB: + if partialB and not state["UseGeneralizedNLCOneB"]: glvwBlimit = 16 // state["ProblemType"]["DataType"].numBytes() if state["SourceSwap"]: matrixInstM = (state["MatrixInstM"] * state["MatrixInstBM"]) if (state["MatrixInstM"] == 4) else state["MatrixInstM"] @@ -2448,15 +2513,15 @@ def calSwizzlePackK(state, tc): # No longer support loadX2/loadx4 . if state["DirectToLds"]: - if (not state["DirectToVgprA"]) and Solution.isDirectToLdsDoable(state, 'A', isaInfoMap, printRejectionReason): - state['tailLoopOptA'] = False - state["DirectToLdsA"] = True - state["LocalWriteUseSgprA"] = True - - if (not state["DirectToVgprB"]) and Solution.isDirectToLdsDoable(state, 'B', isaInfoMap, printRejectionReason): - state['tailLoopOptB'] = False - state["DirectToLdsB"] = True - state["LocalWriteUseSgprB"] = True + for tc in ['A', 'B']: + isDtlDoable = Solution.isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason) + if (not state["DirectToVgpr%s"%tc]) and isDtlDoable: + state['tailLoopOpt%s'%tc] = False + state["DirectToLds%s"%tc] = True + state["LocalWriteUseSgpr%s"%tc] = True + elif not isDtlDoable: + if state["UseGeneralizedNLCOne%s"%tc]: + reject(state, printRejectionReason, "DirectToLds%s not doable, but GNLC%s enabled, rejecting"%(tc, tc)) # Update parent variable so kernel display is accurate state["DirectToLds"] = state["DirectToLdsA"] or state["DirectToLdsB"] @@ -2490,7 +2555,7 @@ def calSwizzlePackK(state, tc): state["LdsBlockSizePerPadMetadata"] = state["LdsBlockSizePerPadA"] if state["EnableMatrixInstruction"]: - if state["LdsBlockSizePerPadA"]: + if state["LdsBlockSizePerPadA"] and not state["UseGeneralizedNLCOneA"]: if state["UnrollMajorLDSA"]: if state["LdsBlockSizePerPadA"] % (state["_DepthUA"] * state["ProblemType"]["DataTypeA"].numBytes()) != 0: reject(state, printRejectionReason, "reject: LdsBlockSizePerPadA %u %% depthU %u x bpeA != 0" % (state["LdsBlockSizePerPadA"],state["_DepthUA"])) @@ -2498,7 +2563,7 @@ def calSwizzlePackK(state, tc): state["LSPA"] % (state["LdsBlockSizePerPadA"] // (state["_DepthUA"] * state["ProblemType"]["DataType"].numBytes())) != 0: reject(state, printRejectionReason, "can't pad by addrVgpr or instOffset") - if state["LdsBlockSizePerPadB"]: + if state["LdsBlockSizePerPadB"] and not state["UseGeneralizedNLCOneB"]: if state["UnrollMajorLDSB"]: if state["LdsBlockSizePerPadB"] % state["_DepthUB"] * state["ProblemType"]["DataTypeB"].numBytes() != 0: reject(state, printRejectionReason, "reject: LdsBlockSizePerPadB %u %% depthU %u x bpeB != 0" % (state["LdsBlockSizePerPadB"],state["_DepthUB"])) @@ -2654,7 +2719,7 @@ def subCheckLdsBlockSizePerPad(tc, idx): idx = 0 if tc == "A" else 1 auto_LdsBlockSizePerPad_for_mix = auto_LdsBlockSizePerPadA_for_mix if tc == "A" else auto_LdsBlockSizePerPadB_for_mix - if not subCheckLdsBlockSizePerPad(tc, idx): + if not subCheckLdsBlockSizePerPad(tc, idx) and not state["UseGeneralizedNLCOne%s"%tc]: if auto_LdsBlockSizePerPad_for_mix: printWarning("Padded address is inconisstent, set LdsBlockSizePerPad%s=0."%tc) state["LdsBlockSizePerPad%s"%tc] = 0 diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/dtl.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/dtl.yaml index 65de6ddbb222..842b51e5def9 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/dtl.yaml +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/dtl.yaml @@ -109,6 +109,9 @@ BenchmarkProblems: - MatrixInstruction: - [16, 16, 4, 1, 1, 1, 1, 1, 1] - [16, 16, 4, 1, 1, 2, 2, 2, 2] + - [16, 16, 4, 1, 1, 3, 3, 2, 2] + - [16, 16, 4, 1, 1, 5, 3, 2, 2] + - [16, 16, 4, 1, 1, 7, 7, 2, 2] - DepthU: [ 32 ] - LdsPadA: [-1] - LdsPadB: [-1] @@ -219,10 +222,14 @@ BenchmarkProblems: - MatrixInstruction: - [16, 16, 128, 1, 1, 1, 1, 1,1 ] - [32, 32, 64, 1, 1, 1, 1, 1,1 ] + - [32, 32, 64, 1, 1, 5, 1, 1,1 ] + - [32, 32, 64, 1, 1, 3, 7, 4,1 ] - [16, 16, 128, 1, 1, 4, 4, 2,2 ] + - [16, 16, 128, 1, 1, 6, 5, 2,2 ] - [32, 32, 64, 1, 1, 2, 2, 2,2 ] - [32, 32, 64, 1, 1, 3, 2, 2,2 ] - [32, 32, 64, 1, 1, 2, 3, 2,2 ] + - [32, 32, 64, 1, 1, 3, 3, 2,2 ] - [16, 16, 128, 1, 1, 3, 2, 2,2 ] - [16, 16, 128, 1, 1, 2, 3, 2,2 ] - WorkGroup: @@ -273,6 +280,7 @@ BenchmarkProblems: - MatrixInstruction: - [16, 16, 128, 1, 1, 1, 1, 1,1 ] - [32, 32, 64, 1, 1, 1, 1, 1,1 ] + - [32, 32, 64, 1, 1, 3, 3, 1,1 ] - WorkGroup: - [16,16,1] - GlobalReadVectorWidthA: [16] @@ -446,8 +454,8 @@ BenchmarkProblems: - [16, 16, 128, 1, 1, 4, 4, 2,2 ] - [32, 32, 64, 1, 1, 2, 2, 2,2 ] - [32, 32, 64, 1, 1, 3, 2, 2,2 ] - - [32, 32, 64, 1, 1, 2, 3, 2,2 ] - - [16, 16, 128, 1, 1, 3, 2, 2,2 ] + - [32, 32, 64, 1, 1, 5, 3, 2,2 ] + - [16, 16, 128, 1, 1, 3, 7, 2,2 ] - [16, 16, 128, 1, 1, 2, 3, 2,2 ] - WorkGroup: - [16,16,1] @@ -516,6 +524,7 @@ BenchmarkProblems: - [16, 16, 32, 1, 1, 8, 8, 2, 2 ] - [16, 16, 32, 1, 1, 1, 1, 2, 2 ] - [16, 16, 32, 1, 1, 2, 3, 2, 2 ] + - [16, 16, 32, 1, 1, 7, 5, 2, 2 ] - [32, 32, 16, 1, 1, 3, 2, 2, 2 ] - PrefetchGlobalRead: [1, 2] - PrefetchLocalRead: [1, 2] @@ -633,6 +642,7 @@ BenchmarkProblems: - [16, 16, 32, 1, 1, 8, 8, 2, 2 ] - [16, 16, 32, 1, 1, 1, 1, 2, 2 ] - [16, 16, 32, 1, 1, 2, 3, 2, 2 ] + - [16, 16, 32, 1, 1, 6, 5, 2, 2 ] - [32, 32, 16, 1, 1, 3, 2, 2, 2 ] - PrefetchGlobalRead: [1, 2] - PrefetchLocalRead: [1, 2] @@ -649,6 +659,7 @@ BenchmarkProblems: - LdsPadB: [0, 8] - 1LDSBuffer: [-1] - GlobalSplitU: [1, 2] + - LDSTrInst: [0,1] - SourceSwap: [1] BenchmarkJoinParameters: BenchmarkFinalParameters: @@ -902,6 +913,7 @@ BenchmarkProblems: - MatrixInstruction: - [32, 32, 16, 1, 1, 4, 4, 2, 2 ] - [16, 16, 32, 1, 1, 8, 8, 2, 2 ] + - [16, 16, 32, 1, 1, 7, 7, 2, 2 ] - PrefetchGlobalRead: [1, 2] - PrefetchLocalRead: [1, 2] - DepthU: [64] @@ -916,6 +928,7 @@ BenchmarkProblems: - LdsPadB: [8] - 1LDSBuffer: [-1] - SourceSwap: [1] + - LDSTrInst: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: @@ -996,6 +1009,7 @@ BenchmarkProblems: - 1LDSBuffer: [-1] - GlobalSplitU: [1] - SourceSwap: [1] + - LDSTrInst: [0,1] BenchmarkJoinParameters: BenchmarkFinalParameters: - ProblemSizes: @@ -1051,3 +1065,58 @@ BenchmarkProblems: BenchmarkFinalParameters: - ProblemSizes: - Range: [[4096], [8192], [1], [64,64,192]] + + + ######################################## + # BBS TN - checks tail loop padding logic + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: b + DestDataType: b + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + ActivationFuncCall: True + - # BenchmarkProblemSizeGroup - Standard - All problem + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 2, 1, 4, 1 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [256] + - ScheduleIterAlg: [3] + - ExpandPointerSwap: [0] + - TransposeLDS: [2] #0,1 + - LocalReadVectorWidth: [8] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [2] + - ClusterLocalRead: [0]#[0,1] + - DirectToLds: [1] + - StreamK: [3] + - LdsPadA: [0] #[-1] + - LdsPadB: [8] #[-1] + - StaggerU: [0] + - WorkGroupMapping: [16] + - WorkGroupMappingXCC: [2] + - 1LDSBuffer: [0] + - NonTemporalD: [4] + - SourceSwap: [1] + - UseSgprForGRO: [0] + - UseCustomMainLoopSchedule: [0] + - LDSTrInst: [1] + - VectorWidthA: [2] + - VectorWidthB: [1] + - WaveSeparateGlobalReadA: [0] + - WaveSeparateGlobalReadB: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [128, 128, 1, 255] \ No newline at end of file