Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 134 additions & 93 deletions projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
tc = tP["tensorChar"]
if tc == "A":
writer.states.localReadDoCntA += 1
subTileIdx = writer.states.SubTileIdxA
elif tc == "Metadata":
writer.states.localReadDoCntMetadata += 1
else:
writer.states.localReadDoCntB += 1
subTileIdx = writer.states.SubTileIdxB
tile01 = tP["tile01Idx"]
instruction = tP["localReadInstruction"]
bpr = 4 # bytes/register
Expand All @@ -175,6 +177,8 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):

vectorWidth = kernel["VectorWidth%s"%tc]

numSubTiles = kernel["numSubTiles%s"%tc]

MIWaveGroupShape = [ kernel["MatrixInstM"] * kernel["MatrixInstBM"] * kernel["MIWaveGroup"][0] * kernel["VectorWidthA"], \
kernel["MatrixInstN"] * kernel["MatrixInstBN"] * kernel["MIWaveGroup"][1] * kernel["VectorWidthB"]]

Expand Down Expand Up @@ -257,20 +261,59 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
ds = DSModifiers(na=1, offset=paramList[1])
localReadCode.add(LocalReadX(dst=destVgpr, src=vgpr("LocalReadAddr%s"%tc), ds=ds, comment=comment))
else:
for vIdx in range(0, numVectorsPerTile):
for eIdx in range(0, numReadsPerVector):
eIdxCnt = numReadsPerVector
eIdxStart = 0
vIdxStart = 0
vIdxCnt = numVectorsPerTile
valufIdx = 0
if numVectorsPerTile == 1:
eIdxCnt = numReadsPerVector//numSubTiles
eIdxStart = subTileIdx * (numReadsPerVector//numSubTiles)
else:
vIdxStart = subTileIdx * (numVectorsPerTile//numSubTiles)
vIdxCnt = numVectorsPerTile//numSubTiles

# Calculate total number of local read instructions needed
totalLocalReads = numVectorsPerTile * numReadsPerVector
localReadsPerSubTile = totalLocalReads // numSubTiles
remainderReads = totalLocalReads % numSubTiles

# Adjust for this subTile
if subTileIdx < remainderReads:
localReadsForThisSubTile = localReadsPerSubTile + 1
else:
localReadsForThisSubTile = localReadsPerSubTile

# Calculate starting position for this subTile
startReadIdx = subTileIdx * localReadsPerSubTile + min(subTileIdx, remainderReads)

# Use while loop to generate the correct number of local read instructions
readCount = 0
while readCount < localReadsForThisSubTile:
# Calculate current vIdx and eIdx from the global read index
globalReadIdx = startReadIdx + readCount
vIdx = globalReadIdx // numReadsPerVector
eIdx = globalReadIdx % numReadsPerVector
# Calculate valufIdx based on current vIdx and eIdx
if numVectorsPerTile == 1:
valufIdx = eIdx * blockWidth * numReadsPerUnroll
else:
valufIdx = (vIdx * numReadsPerVector + eIdx) * blockWidth * numReadsPerUnroll

valuiIdx = int(valufIdx)
baseValuiIdx = valuiIdx
localReadCode = imod.add(Module("LocalRead%s Valu%u"%(tc,valuiIdx)))
readCount += 1
if needPack or numSplitMetadata:
packCode = pack.add(Module("packCode"))
for rIdx in range(0, numReadsPerUnroll):
valuiIdx = int(valufIdx)
localReadCode = imod.add(Module("LocalRead%s Valu%u"%(tc,valuiIdx)))
if needPack or numSplitMetadata:
packCode = pack.add(Module("packCode"))
for rIdx in range(0, numReadsPerUnroll):
valuiIdx = int(valufIdx)
baseLRVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx), numVgpr)
destVgpr = baseLRVgpr
highBitsForHalf = (blockWidth == 0.5) and ((rIdx % 2) == 1) # rIdx = 1
isHigh16Bits = (blockWidth == 0.25) and ( ((rIdx % 4) //2) == 1) # 2,3
baseLRVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx), numVgpr)
destVgpr = baseLRVgpr
highBitsForHalf = (blockWidth == 0.5) and ((rIdx % 2) == 1) # rIdx = 1
isHigh16Bits = (blockWidth == 0.25) and ( ((rIdx % 4) //2) == 1) # 2,3

if needPack or numSplitMetadata:
if needPack or numSplitMetadata:
if kernel["ConvertAfterDS"] and (tP["bpe"] != tP["bpeDS"]):
highBitsForHalf = False
isHigh16Bits = False
Expand Down Expand Up @@ -547,26 +590,26 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
if isHigh8Bits and isHigh16Bits:
packCode.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=highVgpr, shiftHex=hex(0x8), src1=baseLRVgpr, comment="pack two int8x2 Vgpr to one Vgpr"))

if kernel["ConvertAfterDS"] and kernel["UnrollMajorLDS%s"%tc]:
if kernel["ConvertAfterDS"] and kernel["UnrollMajorLDS%s"%tc]:
valufIdx += blockWidth * (tP["bpe"] // tP["bpeDS"]) if (not tP["isM"]) else 1
else:
valufIdx += blockWidth if (not tP["isM"]) else 1
else:
valufIdx += blockWidth if (not tP["isM"]) else 1

# load read instrution
paramList = []
# load read instrution
paramList = []

for oIdx in range(0, numOffsets):
offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride
for oIdx in range(0, numOffsets):
offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride

if kernel["ProblemType"]["Sparse"] != 0:
if blocksPerTGroupSMFMA > 1:
blockId = (rIdx * numElementPerRead) // elementsPerBlockSMFMA #block 0 or block 1
if kernel["UnrollMajorLDS%s"%(tc)]:
offset_val = offset_val + (blockOffsetSMFMA * blockId)
else:
offset_val = offset_val + (blockOffsetSMFMA * blockId) * UnrollStride
offset_val = (rIdx * numElementPerRead * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpeDS"]
elif kernel["ProblemType"]["DataType"].is8bitFloat() and kernel["MatrixInstK"] > 32:
if kernel["ProblemType"]["Sparse"] != 0:
if blocksPerTGroupSMFMA > 1:
blockId = (rIdx * numElementPerRead) // elementsPerBlockSMFMA #block 0 or block 1
if kernel["UnrollMajorLDS%s"%(tc)]:
offset_val = offset_val + (blockOffsetSMFMA * blockId)
else:
offset_val = offset_val + (blockOffsetSMFMA * blockId) * UnrollStride
offset_val = (rIdx * numElementPerRead * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpeDS"]
elif kernel["ProblemType"]["DataType"].is8bitFloat() and kernel["MatrixInstK"] > 32:
incOffset = 0
midIdx = numReadsPerUnroll // 2
if rIdx >= midIdx:
Expand All @@ -583,84 +626,82 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP):
incOffset = 48
incOffset = rIdx * numElementPerRead * UnrollStride + incOffset
offset_val = (incOffset + offset_val + tP["localReadOffset"]) * tP["bpeDS"]
else:
offset_val = (rIdx * numElementPerRead * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpeDS"]

if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0):
offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"]
offset_val = offset_val + tP["localReadSwapByteOffset"]
if (kernel["DirectToLds%s" % tc] and \
kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4):
# another address conversion for DirectToLds + NumLoadsCoalesced > 1
dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val)
else:
offset_val = (rIdx * numElementPerRead * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpeDS"]

paramList.append(int(offset_val))
if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0):
offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"]
offset_val = offset_val + tP["localReadSwapByteOffset"]
if (kernel["DirectToLds%s" % tc] and \
kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4):
# another address conversion for DirectToLds + NumLoadsCoalesced > 1
dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val)

comment = "L -> Reg lro=%d swapByteOffset=%u ti=%u vIdx=%u eIdx=%u rIdx=%u oIdx=%u buffer=%u iui=%u" \
% (tP["localReadOffset"], tP["localReadSwapByteOffset"], MIWaveGroupShape[tile01], vIdx, eIdx, rIdx, oIdx, bufferIdx, iui)
paramList.append(int(offset_val))
comment = "L -> Reg lro=%d swapByteOffset=%u ti=%u vIdx=%u eIdx=%u rIdx=%u oIdx=%u buffer=%u iui=%u" \
% (tP["localReadOffset"], tP["localReadSwapByteOffset"], MIWaveGroupShape[tile01], vIdx, eIdx, rIdx, oIdx, bufferIdx, iui)

highBits = 0 if writer.states.archCaps["DSLow16NotPreserve"] else highBitsForHalf or isHigh16Bits
highBits = 0 if writer.states.archCaps["DSLow16NotPreserve"] else highBitsForHalf or isHigh16Bits

if(paramList[0] >=131072):
paramList[0] = paramList[0] -131072
srcAddr=vgpr("LocalReadAddr%s+2"%tc)
elif (paramList[0] >=65536):
paramList[0] = paramList[0] -65536
srcAddr=vgpr("LocalReadAddr%s+1"%tc)
else:
srcAddr=vgpr("LocalReadAddr%s"%tc)

if(paramList[0] >=131072):
paramList[0] = paramList[0] -131072
srcAddr=vgpr("LocalReadAddr%s+2"%tc)
elif (paramList[0] >=65536):
paramList[0] = paramList[0] -65536
srcAddr=vgpr("LocalReadAddr%s+1"%tc)
else:
srcAddr=vgpr("LocalReadAddr%s"%tc)
if numOffsets == 1:
ds = DSModifiers(na=1, offset=paramList[0])
else:
ds = DSModifiers(na=2, offset0=paramList[0], offset1=paramList[1])
LocalReadX = instruction.getInst(highBits)
localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment))
# TODO - handle vector-load
with writer.allocTmpSgpr(1) as tmpSgprInfo:
tmpSgpr = tmpSgprInfo.idx
if writer.db["CheckValue1%s"%tc] and not writer.inTailLoop:

if numOffsets == 1:
ds = DSModifiers(na=1, offset=paramList[0])
else:
ds = DSModifiers(na=2, offset0=paramList[0], offset1=paramList[1])
LocalReadX = instruction.getInst(highBits)
localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment))
# TODO - handle vector-load
with writer.allocTmpSgpr(1) as tmpSgprInfo:
tmpSgpr = tmpSgprInfo.idx
if writer.db["CheckValue1%s"%tc] and not writer.inTailLoop:
dbgVgpr = destVgpr
dbgVgprList = destVgpr.split("v[")
if len(dbgVgprList) == 1: # vIdx, no []
dbgVgpr = dbgVgprList[0]
else:
# We only check the first one now
# TODO: Handle vector, but need to take care the last one
dbgVgprList = (dbgVgprList[1].split("]")[0]).split(':')
dbgVgpr = "v[%s]"%dbgVgprList[0]

dbgVgpr = destVgpr
dbgVgprList = destVgpr.split("v[")
if len(dbgVgprList) == 1: # vIdx, no []
dbgVgpr = dbgVgprList[0]
else:
# We only check the first one now
# TODO: Handle vector, but need to take care the last one
dbgVgprList = (dbgVgprList[1].split("]")[0]).split(':')
dbgVgpr = "v[%s]"%dbgVgprList[0]
localReadCode.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="CheckValue1 wait for LDS read"))

localReadCode.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="CheckValue1 wait for LDS read"))
if kernel["ProblemType"]["DataType"].isHalf():
hexValue = hex(0x3c003c00) # packed 1s
if needPack:
hexValue = hex(0x3c000000) if highBitsForHalf else hex(0x00003c00)
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: FP16"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))

if kernel["ProblemType"]["DataType"].isHalf():
hexValue = hex(0x3c003c00) # packed 1s
if needPack:
hexValue = hex(0x3c000000) if highBitsForHalf else hex(0x00003c00)
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: FP16"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))
elif kernel["ProblemType"]["DataType"].isBFloat16():
hexValue = hex(0x3f803f80) # packed 1s
if needPack:
hexValue = hex(0x3f800000) if highBitsForHalf else hex(0x00003f80)
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: BF16"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))

elif kernel["ProblemType"]["DataType"].isBFloat16():
hexValue = hex(0x3f803f80) # packed 1s
if needPack:
hexValue = hex(0x3f800000) if highBitsForHalf else hex(0x00003f80)
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: BF16"))
if kernel["ProblemType"]["DataType"].isInt8():
if needPack:
hexValue = hex(0x00010000) if isHigh16Bits else hex(0x00000001)
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: INT8"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))

if kernel["ProblemType"]["DataType"].isInt8():
if needPack:
hexValue = hex(0x00010000) if isHigh16Bits else hex(0x00000001)
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: INT8"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))

# TODO - Check if this works. But need this? MFMA would use INT8
elif kernel["ProblemType"]["DataType"].isInt8x4():
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(0x01010101), comment="CheckValue1: INT8x4"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))
# TODO - Check if this works. But need this? MFMA would use INT8
elif kernel["ProblemType"]["DataType"].isInt8x4():
localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(0x01010101), comment="CheckValue1: INT8x4"))
localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr)))

elif kernel["ProblemType"]["DataType"].isSingle():
localReadCode.add(writer.assert_eq( dbgVgpr, 1.0) )
elif kernel["ProblemType"]["DataType"].isSingle():
localReadCode.add(writer.assert_eq( dbgVgpr, 1.0) )

# DTV case, do not return local read code. Return pack code only.
if (tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]:
Expand Down
Loading
Loading