Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ class StateValues:
doPackPreSchedulingThisLoop: bool = False
doPackPreSchedulingNextLoop: bool = False
doFullPackCodePrefetch: bool = False
useCommonSgprSwap: bool = False

# Epilogue states
preloadScaleA = False
Expand Down Expand Up @@ -4921,6 +4922,15 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB):
self.states.oneBufferScheduling = (kernel["1LDSBuffer"]) or \
((kernel["DirectToLdsA"] or kernel["DirectToLdsB"]) and \
self.states.numLDSBlk == kernel["PrefetchGlobalRead"])
# common sgprSwap
# enable it for the following case.
# DTLA+B + numLDSBlk==2 + (not EPS) + (not sparse) + (MX or (StoreSwapAddr and (not CMS))
self.states.useCommonSgprSwap = False
if kernel["DirectToLds"] == 1 and self.states.numLDSBlk == 2 and \
(not kernel["ExpandPointerSwap"]) and (not kernel["ProblemType"]["Sparse"]) and \
((kernel["ProblemType"]["MXBlockA"] or kernel["ProblemType"]["MXBlockB"]) or \
kernel["StoreSwapAddr"] and not kernel["UseCustomMainLoopSchedule"]):
self.states.useCommonSgprSwap = True

# NamedTuple is immutable
class intermediateTPValues(NamedTuple):
Expand Down Expand Up @@ -6655,7 +6665,9 @@ def checkVregOverflowTF32Emu(vgprIdx, numV):
self.defineSgpr("LocalWriteAddrMXSB", 1)

# Allocate registers to swap between lds buffers
if kernel["StoreSwapAddr"]:
if self.states.useCommonSgprSwap:
self.defineSgpr("SwapCommon", 1)
elif kernel["StoreSwapAddr"]:
if kernel["LocalWriteUseSgprA"]:
self.defineSgpr("SwapA", 1)
if kernel["LocalWriteUseSgprB"]:
Expand Down
49 changes: 38 additions & 11 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module:
module.add(ValueSet("MTOffset", reductionOffsetLow32, format=1))
module.add(ValueSet("MTOffsetH32", reductionOffsetHigh32, format=1))

if self.states.IncLdsBufSwitch:
if self.states.IncLdsBufSwitch or self.states.useCommonSgprSwap:
module.addComment0("%d LDS Blocks for PGR %d"%(self.states.numLDSBlk, kernel["PrefetchGlobalRead"]))
module.add(ValueSet("LdsOneBlockSize", kernel["LdsOffsetA_Blk"]))
module.add(ValueSet("LdsBlockEndSize", kernel["LdsOffsetA_Blk"] * self.states.numLDSBlk))
Expand Down Expand Up @@ -4620,13 +4620,18 @@ def lwaFirstOffset(self, kernel, tP):
self.vgprPool.checkIn(tmpv)
self.vgprPool.checkIn(destVgpr)

if kernel["StoreSwapAddr"]:
if kernel["StoreSwapAddr"] or self.states.useCommonSgprSwap:
if kernel["LocalWriteUseSgpr%s"%tc]:
# needed for the VReadfirstlaneB32 in the prior code block
if self.states.archCaps["CrosslaneWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
module.add(SAddU32(dst=sgpr("Swap%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=kernel["LdsOffsetA_Blk"], comment="Calculate starting lds addr of second buffer"))
module.add(SXorB32(dst=sgpr("Swap%s"%tc), src0=sgpr("Swap%s"%tc), src1=sgpr("LocalWriteAddr%s"%tc), comment="xor both lds buffer offsets to enable swapping"))
if self.states.useCommonSgprSwap:
# Need only once. Generate the code for "A" only
if tc == "A":
module.add(SMovB32(dst=sgpr("SwapCommon"), src=0, comment="Initialize SwapCommon"))
else:
# needed for the VReadfirstlaneB32 in the prior code block
if self.states.archCaps["CrosslaneWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
module.add(SAddU32(dst=sgpr("Swap%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=kernel["LdsOffsetA_Blk"], comment="Calculate starting lds addr of second buffer"))
module.add(SXorB32(dst=sgpr("Swap%s"%tc), src0=sgpr("Swap%s"%tc), src1=sgpr("LocalWriteAddr%s"%tc), comment="xor both lds buffer offsets to enable swapping"))
else:
module.add(VAddU32(dst=vgpr("LocalWriteSwapAddr%s"%tc), src0=kernel["LdsOffsetA_Blk"], src1=vgpr("LocalWriteAddr%s"%tc), \
comment="starting lds addr of second buffer" ))
Expand Down Expand Up @@ -9392,6 +9397,9 @@ def directToLdsM0Update(self, kernel, mode, tP, skipWait = False):
if self.states.IncLdsBufSwitch:
DtldsModule.add(SAddU32(dst=mgpr(0), src0=sgpr("LocalWriteAddr%s"%tc), \
src1=sgpr("LDSBufferWriteInc"), comment="m0 <- LDS write address (base + inc)"))
elif self.states.useCommonSgprSwap:
DtldsModule.add(SAddU32(dst=mgpr(0), src0=sgpr("LocalWriteAddr%s"%tc), \
src1=sgpr("SwapCommon"), comment="m0 <- LDS write address (base + inc)"))
elif kernel["ExpandPointerSwap"]:
DtldsModule.add(SAddU32(dst=mgpr(0), src0=sgpr("LocalWriteAddr%s"%tc), \
src1=tP["localWriteSwapByteOffset"], comment="m0 <- LDS write address"))
Expand Down Expand Up @@ -9828,6 +9836,15 @@ def localWriteAddRound(tc):
module.add(SCmpEQU32(src0=sgpr("LDSBufferWriteInc"), src1="LdsBlockEndSize", comment="LDSBufferWriteInc == End ?"))
module.add(SCMovB32(dst=sgpr("LDSBufferWriteInc"), src=0, comment="LDSBufferWriteInc loop back to 0"))

def localWriteSwapCommon(tc):
is1st = tc == "A" # so far, A is always first
if is1st:
module.add(SXorB32(
dst=sgpr("SwapCommon"), \
src0="LdsOneBlockSize", \
src1=sgpr("SwapCommon"), \
comment="xor LDS block size"))

def getSrc0Val(tc):
src0Val = None
if kernel["StoreSwapAddr"]:
Expand All @@ -9853,6 +9870,10 @@ def getSrc0Val(tc):
# 3 or more LDS block case, we do not use xor. Instead, use add and max check for round back
# (numLDSBlk>=3 is for DTL (and LocalWriteUseSgpr) only)
localWriteAddRound(tc)
elif self.states.useCommonSgprSwap:
# commonSwap case, need only 1 swap
# (generate at "A" only)
localWriteSwapCommon(tc)
else:
src0Val = getSrc0Val(tc)
numLwa = self.states.a.numVgprLocalWriteAddr if tP["isA"] else self.states.b.numVgprLocalWriteAddr
Expand Down Expand Up @@ -9902,8 +9923,7 @@ def localWriteResetOffsets(self, kernel, internalPointerSwap, tP):
module = Module("localWriteResetOffsets")
if needReset:
resetMask = hex(kernel["LdsOffsetA_Blk"]-1 | self.consts.ldsOOB)
# MXSA/MXSB do not use swap addresses, use else branch instead
useSwapAddr = (internalPointerSwap or kernel["StoreSwapAddr"]) and tc not in ("MXSA", "MXSB")
useSwapAddr = (internalPointerSwap or (kernel["StoreSwapAddr"] and not self.states.useCommonSgprSwap))
if useSwapAddr:
if internalPointerSwap:
tP["localWriteSwapByteOffset"] = 0
Expand Down Expand Up @@ -9938,6 +9958,14 @@ def localWriteResetOffsets(self, kernel, internalPointerSwap, tP):
dst=sgpr("LDSBufferWriteInc"), \
src=0, \
comment="reset incSgpr"))
elif self.states.useCommonSgprSwap:
# commonSwap case, back to 0
# (generate at "A" only)
if tc == "A":
module.add(SMovB32(
dst=sgpr("SwapCommon"), \
src=0, \
comment="reset swapCommon"))
else:
if kernel["LocalWriteUseSgpr%s"%tc]:
module.add(SAndB32(
Expand Down Expand Up @@ -10919,8 +10947,7 @@ def localReadResetOffsets(self, kernel, tP):
dst=vgpr("LocalReadAddr%s"%(tc)), \
src=vgpr("LocalReadAddrOrig%s"%(tc)), \
comment="set LocalReadAddrOrig to LocalReadAddr"))
# MXSA/MXSB do not use swap addresses
elif kernel["StoreSwapAddr"] and tc not in ("MXSA", "MXSB"):
elif kernel["StoreSwapAddr"]:
# Reset offset, by picking smaller of the two
tmpvgpr = self.vgprPool.checkOut(1) # contains other offsets
module.add(VXorB32(
Expand Down
Loading