Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,28 @@ def ceilDivide(numerator, denominator):

def roundUpToNearestMultiple(numerator, denominator):
return ceilDivide(numerator,denominator)*int(denominator)


# Given a divisor, this routine computes the corresponding multiplicative constant
# and required post shifts.
#
# Algorithm based on: https://dl.acm.org/doi/pdf/10.1145/178243.178249
#
# Inputs:
# d: divisor
# N: Number of bits integers are represented in
# p: precision in bits (usually N = P)
#
# Output:
# mhigh: multiplicative constant
# shPost: amount to right shift after multiplication
def choose_multiplier(d, N, p):
l = int(math.ceil(math.log(d, 2)))
shPost = l
mlow = 2**(N+l) // d
mhigh = (2**(N+l) + 2 ** (N + l - p )) // d
while ((mlow // 2) < (mhigh // 2)) and shPost > 0:
mlow //= 2
mhigh //= 2
shPost -=1
return mhigh, shPost, l
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def addToStream(key, indexList, InstructionList):
InstStreams = convOptToStream(opt1)

macro = Macro("MAINLOOP", ["ID", "useGR=1", "usePLR=1", "useGRInc=1", "useLoop=1"])
#module.add(SBarrier(comment="debug"))
#macro.add(SBarrier(comment="debug"))

lastIter = numLoopIter - 1

Expand Down Expand Up @@ -472,43 +472,43 @@ def hasCustomSchedule(kernel):
optSchedule = dict()
syncCode = []
if isNN and useLDSTr and TLDS==1:
# TODO: This schedule can be improved when BC are resolved for MT192
# Note: A/B Global read orders are swapped
# i.e. GRA contains GR for B
kernel["SwapGlobalReadOrder"] = True
optSchedule = {
'SYNC' : [[20,21,23,25,27,29,31,33,46,57,58,94],
[20,21,24,26,28,30,32,34,47,58,58,94]],
'GRIncA' : [[0,1,2,3,4,5,6,7,8]],
'GRIncB' : [[9,10,11,12,13,14,15,16,17]],
'LRB0' : [[0,0,1,1,2,2,6,8],
[3,3,4,4,5,5,7,9]],
'LRA0' : [[10,12,14,16,18,23,35,37,39,41,43,45],
[11,13,15,17,19,22,36,38,40,42,44,46]],
'LWA' : [[23,25,27,29,31,33],
[24,26,28,30,32,34]],
'GRA' : [[22,22,24,24,26,26,28,28,30,30, 42,42,43,43,45,45],
[23,23,25,25,27,27,29,29,31,31, 43,43,44,44,46,46]],
'GRB' : [[54, 56, 58, 60, 62, 64],
[55, 57, 59, 61, 63, 65]],
'LRSA' : [[47]],
'LRSB' : [[37]],
'LWSB' : [[47]], # For B
'LWSA' : [[52]], # For A
'LRB1' : [[59,59,61,61,63,63,65,67],
[60,60,62,62,64,64,66,68]],
'LRA1' : [[69,71,73,75,77,79,81,83,85,85,87,87],
[70,72,74,76,78,80,82,84,86,86,88,88]],
'SYNC' : [[12,13, 47,48,49,50,51, 52,53, 56,56, 94]],
'GRIncB' : [[0,1,2,3,4,5,6,7,8]],
'GRIncA' : [[9,10,11,12,13,14,15,16,17]],
'LRB0' : [[0,0,1,1,2,2,6,8],
[3,3,4,4,5,5,7,9]],
# These local reads have BC
'LRA0' : [[10, 15,17,19,21,23, 25,27,29,33,37,39],
[11, 14,16,18,20,22, 24,26,28,32,36,38]],
'GRA' : [[14,14, 16,16, 18,18, 20,20, 22,22, 34,34,36,36,38,38],
[15,15, 17,17, 19,19, 21,21, 23,23, 35,35,37,37,39,39]],
'GRB' : [[54,54, 56,56, 58,58, 60,60, 62,62, 64,64],
[55,55, 57,57, 59,59, 61,61, 63,63, 65,65]],
'LRSA' : [[40]],
'LRSB' : [[40]],
'LWSB' : [[41]], # For B
'LWSA' : [[66]], # For A
'LRB1' : [[57,57,59,59,61,61,63,65],
[58,58,60,60,62,62,64,64]],
'LRA1' : [[67,71,73,75,77,79,81,85,87,89,91,93],
[68,72,74,76,78,80,82,86,88,90,92,94]],
'LCC' : [[95, 95]],
}
syncCode = [SWaitCnt(dscnt=5, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"),
syncCode = [SWaitCnt(dscnt=1, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"),
SBarrier(comment=""),
SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=-1, vlcnt=5, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=-1, vlcnt=10, vscnt=-1, comment="Wait for LRB0 to complete"),
SWaitCnt(dscnt=10, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"),
SWaitCnt(dscnt=8, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"),
SWaitCnt(dscnt=6, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"),
SWaitCnt(dscnt=4, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"),
SWaitCnt(dscnt=2, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"),
SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRA0 to complete"),
SBarrier(comment=""),
SWaitCnt(dscnt=-1, vlcnt=9, vscnt=-1, comment="Wait for LRB0 to complete"),
SBarrier(comment=""),
SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRB0 to complete"),]

Expand Down
30 changes: 25 additions & 5 deletions projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
numElementPerRead = 1 if kernel["ConvertAfterDS"] and not kernel["UseF32XEmulation"] else (int(blockWidth * bpr) // tP['bpe'] // lrvwTile)
inputPerThread = kernel["LocalReadVectorWidth"] if not writer.states.inTailLoop else kernel["MIInputPerThread%s"%tc]

abmatrixinfo = writer.states.a if tc == 'A' else writer.states.b
perpStride = abmatrixinfo.gNLCPerpStride

# pack register
if writer.states.archCaps["HasEccHalf"] or not writer.states.asmCaps["HasWMMA_V1"]:
needPack = tP["bpeDS"] < 4 and not kernel["UnrollMajorLDS%s"%tc] and not tP["isM"]
Expand Down Expand Up @@ -305,16 +308,27 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
LocalReadX = instruction.getInst(highBits)

offset_val = (tP["localReadOffset"]+MIWaveGroupShape[tile01]*tIdx) * tP["bpeDS"] + tP["localReadSwapByteOffset"]
if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0):
offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"]

def applyPad(offset_val):
if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0):
offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"]
return offset_val

for oIdx in range(0,numOffsetsPerLoad):

offset, srcAddr = self.cal_offset_srcAddr(maxLDSConstOffset, tc, offset_val)
offset = applyPad(offset)
ds = DSModifiers(na=1, offset=offset)
destVgpr = vgpr("Valu%s_X%u_I%u+%u+%u"%(tc,bufferIdx,iui, 4*tIdx, oIdx * 2), 2)
localReadCode = Module("LocalRead%s Valu%u"%(tc,valuiIdx))
localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment))
offset_val += UnrollStride*inputPerThread
if perpStride == 1:
offset_val += UnrollStride*inputPerThread
else:
permBlock = kernel["MatrixInstK"]
perpStrideInv = permBlock // perpStride
inv4K = perpStrideInv * (4 % perpStride) + 4 // perpStride
offset_val += inv4K * kernel["MacroTile%s"%tc] * tP["bpeDS"]
if ((subTileIdx == 0 and subIterLoadCount < totalLoads // numSubTiles) \
or (subTileIdx == 1 and subIterLoadCount >= totalLoads // numSubTiles) \
or numSubTiles == 1) or writer.states.inTailLoop:
Expand Down Expand Up @@ -762,7 +776,12 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
paramList = []

for oIdx in range(0, numOffsets):
offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride
if perpStride > 1 and kernel["ProblemType"]["TLU%s"%tc] == 0:
permBlock = kernel["MatrixInstK"] if kernel["ProblemType"]["TLU%s"%tc] == 1 else kernel["VectorWidth%s"%tc] * kernel["MatrixInstM"]
perpStrideInv = permBlock // perpStride
offset_val = (eIdx * (perpStrideInv) + ((vIdx) * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride
else:
offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride

if kernel["ProblemType"]["Sparse"] != 0:
if blocksPerTGroupSMFMA > 1:
Expand Down Expand Up @@ -815,8 +834,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0):
offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"]
offset_val = offset_val + tP["localReadSwapByteOffset"]
# TODO: Add NLC>1 offset calcs here?
if (kernel["DirectToLds%s" % tc] and \
kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4):
kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]:
# another address conversion for DirectToLds + NumLoadsCoalesced > 1
dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
#
################################################################################

from rocisa.code import Module
from rocisa.code import Module, Label
from rocisa.container import vgpr, ContinuousRegister
from rocisa.instruction import VAddU32
from rocisa.instruction import VAddU32, VAndB32, VLShiftLeftB32, VLShiftRightB32
from rocisa.functions import vectorStaticRemainder, \
vectorStaticDivideAndRemainder, vectorStaticDivide, vectorStaticMultiply, \
vectorStaticMultiplyAdd

from ..Component import LraTileAssignment, LraTileProperties
from ..Common import roundUp, log2, ceilDivide
from dataclasses import dataclass

@dataclass
Expand Down Expand Up @@ -155,7 +156,6 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
mReg = writer.vgprPool.checkOut(1,"mReg") # remainder

isWmma_v1 = writer.states.asmCaps["HasWMMA_V1"]

# get constant parameter
tc = tP["tensorChar"]
tile01 = tP["tile01Idx"]
Expand Down Expand Up @@ -191,6 +191,9 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
dividedForWaveId = dividedForWaveId, \
vectorWidth=vectorWidth, \
maxKId=maxKId)
abmatrixinfo = writer.states.a if tc == 'A' else writer.states.b
perpStride = abmatrixinfo.gNLCPerpStride
permBlock = abmatrixinfo.gNLCPermBlock

# strider for each type of index
umlds = kernel["UnrollMajorLDS%s" % tc]
Expand All @@ -204,6 +207,8 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi

strideK = inputPerThread if umlds else (mt + LdsPad) * inputPerThread
if enableLDSTr:
if kernel["UseGeneralizedNLCOne%s"%tc] and perpStride > 1:
strideK = 8
strideK1 = mt+LdsPad

# FIXME SPARSE
Expand Down Expand Up @@ -234,6 +239,25 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
if isDTVAB:
strideTile = 1 # DTV case. Actual stride will be applied later.

def perpPerm(vgprReg):
reMap0 = writer.vgprPool.checkOut(1)
reMap1 = writer.vgprPool.checkOut(1)
perpStrideInv = permBlock // perpStride

module.addComment0("Computing strided(%u) perp indicies"%perpStrideInv)
module.add(VAndB32(dst=vgpr(reMap0), src0=(permBlock // perpStrideInv - 1), src1=vgpr(vgprReg), comment="r0 = I %% (%u // %u)"%(permBlock, perpStrideInv)))
module.add(VLShiftLeftB32(dst=vgpr(reMap0), shiftHex=log2(perpStrideInv), src=vgpr(reMap0), comment="r0 = %u * r0"%(perpStrideInv)))
module.addComment0("Computing r1 = (I %% %u) // (%u // %u)"%(permBlock, permBlock, perpStrideInv))
module.add(VAndB32(dst=vgpr(reMap1), src0=(permBlock - 1), src1=vgpr(vgprReg), comment="r1 = I %% (%u)"%(permBlock)))
module.add(VLShiftRightB32(dst=vgpr(reMap1), shiftHex=log2(permBlock // perpStrideInv), src=vgpr(reMap1), comment="r1 = (r1) // (%u // %u)"%(permBlock, perpStrideInv)))
module.add(VAddU32(dst=vgpr(reMap0), src0=vgpr(reMap0), src1=vgpr(reMap1), comment="r0 = r0 + r1" ))

module.add(VLShiftRightB32(dst=vgpr(reMap1), shiftHex=log2(permBlock), src=vgpr(vgprReg), comment="r1 = I // %u"%(permBlock)))
module.add(vectorStaticMultiplyAdd(vgpr(vgprReg), vgpr(reMap1), permBlock, vgpr(reMap0), None))

module.addComment0("Done computing strided(%u) perp indices"%perpStrideInv)
writer.vgprPool.checkIn(reMap0)
writer.vgprPool.checkIn(reMap1)

with writer.allocTmpSgpr(1) as tmpSgprInfo:
# tile offset
Expand All @@ -245,13 +269,21 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
module.add(vectorStaticRemainder(dummy, sReg, kReg, dividendForKId, tmpVgprRes, tmpSgprInfo, \
"1. N offset: nIdx = wtid %% MI_M(%d)"%dividendForKId))
module.add(vectorStaticDivide(sReg, sReg, 16, tmpVgprRes, \
"1. thread id in wave: k1Idx = mtid // 4"))
"1. thread id in wave: k1Idx = mtid // 16"))
module.add(vectorStaticMultiply(vgpr(sReg), vgpr(sReg), 16, tmpSgprInfo, \
"1. K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1)))

else:
module.add(vectorStaticRemainder(dummy, tReg, kReg, kernel["MatrixInstN"], tmpVgprRes, tmpSgprInfo, \
"1. N offset: nIdx = wtid %% MI_N(%u)" % kernel["MatrixInstN"]))

applyVWCalcEarly = perpStride > 1 and kernel["ProblemType"]["TLU%s"%tc] == 0
if applyVWCalcEarly:
# Apply vector width calc before we apply permutation to perp dim
module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), vectorWidth, tmpSgprInfo, \
"1. apply VectorWidth: bnOffset = bnOffset * vw(%u)" % vectorWidth))
perpPerm(tReg)

module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), strideTile, tmpSgprInfo, \
"1. N offset: nOffset = nIdx * nStride(%u)" % strideTile))
if enableLDSTr:
Expand All @@ -268,8 +300,9 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
else:
module.addComment0("Skip. 2. block offset: bnOffset = 0 when num1DBlocks = 1")

module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), vectorWidth, tmpSgprInfo, \
"4. apply VectorWidth: bnOffset = bnOffset * vw(%u)" % vectorWidth))
if not applyVWCalcEarly:
module.add(vectorStaticMultiply(vgpr(tReg), vgpr(tReg), vectorWidth, tmpSgprInfo, \
"4. apply VectorWidth: bnOffset = bnOffset * vw(%u)" % vectorWidth))

# unroll offset
#if isMfma and (dividendForKId != waveWidth):
Expand All @@ -285,13 +318,24 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
module.add(vectorStaticDivide(kReg, kReg, dividendForKId, tmpVgprRes, \
"5. K offset: kIdx = wtid / (MIN(%u) * MIBB(%u))" % (kernel["MatrixInstN"], kernel["MatrixInstB"])))
if (dividendForKId != waveWidth) and (not isDTVAB):

if enableLDSTr:
module.add(vectorStaticMultiply(vgpr(kReg), vgpr(kReg), strideK, tmpSgprInfo, \
"5. K offset: lrKOffset = kIdx * mStride(%u)" % (strideK)))
module.add(vectorStaticMultiply(vgpr(mReg), vgpr(mReg), strideK1, tmpSgprInfo, \
"5.1 K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1)))
module.add(VAddU32(dst=vgpr(kReg), src0=vgpr(mReg), src1=vgpr(kReg), \
comment="5.2 offset in wave: lrOffset = bnOffset + lrKOffset"))

if perpStride == 1:
module.add(vectorStaticMultiply(vgpr(mReg), vgpr(mReg), strideK1, tmpSgprInfo, \
"5.1 K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1)))
module.add(VAddU32(dst=vgpr(kReg), src0=vgpr(mReg), src1=vgpr(kReg), \
comment="5.1 offset in wave: lrOffset = bnOffset + lrKOffset"))
else:
module.add(VAddU32(dst=vgpr(kReg), src0=vgpr(mReg), src1=vgpr(kReg), \
comment="5.1 offset in wave: lrOffset = bnOffset + lrKOffset"))
# Apply permutation to perpendicular dim
if perpStride > 1:
perpPerm(kReg)
module.add(vectorStaticMultiply(vgpr(kReg), vgpr(kReg), strideK1, tmpSgprInfo, \
"5.2 K1 offset: lrK1Offset = k1Idx * mStride(%u)" % (strideK1)))
module.add(VAddU32(dst=vgpr(tReg), src0=vgpr(kReg), src1=vgpr(tReg), \
comment="6. offset in wave: lrOffset = bnOffset + lrKOffset"))
else:
Expand Down
Loading
Loading