Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5481,13 +5481,19 @@ def readWriteVectors(mat, vw, kernel):
self.states.a.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprA"] else 1 * self.states.rpla
if kernel["ProblemType"]["MXBlockA"]:
self.states.mxsa.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprMXSA"] else 1 * self.states.rpla
self.states.mxsa.numVgprLocalWriteSwapAddr = 0
self.states.b.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprB"] else 1 * self.states.rpla
if kernel["ProblemType"]["MXBlockB"]:
self.states.mxsb.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprMXSB"] else 1 * self.states.rpla
self.states.mxsb.numVgprLocalWriteSwapAddr = 0
self.states.m.numVgprLocalWriteAddr = 0 if kernel["ProblemType"]["Sparse"] and kernel["LocalWriteUseSgprMetadata"] else 1 * self.states.rpla
self.states.a.numVgprLocalReadSwapAddr = 0
self.states.b.numVgprLocalReadSwapAddr = 0
self.states.m.numVgprLocalReadSwapAddr = 0
if kernel["ProblemType"]["MXBlockA"]:
self.states.mxsa.numVgprLocalReadSwapAddr = 0
if kernel["ProblemType"]["MXBlockB"]:
self.states.mxsb.numVgprLocalReadSwapAddr = 0
self.states.a.numVgprLocalWriteSwapAddr = 0
self.states.b.numVgprLocalWriteSwapAddr = 0
self.states.m.numVgprLocalWriteSwapAddr = 0
Expand Down Expand Up @@ -5559,8 +5565,14 @@ def readWriteVectors(mat, vw, kernel):
if kernel["StoreSwapAddr"]:
if self.states.a.numVgprLocalReadAddr > 0:
self.states.a.numVgprLocalReadSwapAddr = 1
if kernel["ProblemType"]["MXBlockA"]:
if self.states.mxsa.numVgprLocalReadAddr > 0:
self.states.mxsa.numVgprLocalReadSwapAddr = 1
if self.states.b.numVgprLocalReadAddr > 0:
self.states.b.numVgprLocalReadSwapAddr = 1
if kernel["ProblemType"]["MXBlockB"]:
if self.states.mxsb.numVgprLocalReadAddr > 0:
self.states.mxsb.numVgprLocalReadSwapAddr = 1
if self.states.m.numVgprLocalReadAddr > 0:
self.states.m.numVgprLocalReadSwapAddr = 1
if not kernel["LocalWriteUseSgprA"] and self.states.a.numVgprLocalWriteAddr > 0:
Expand All @@ -5569,6 +5581,12 @@ def readWriteVectors(mat, vw, kernel):
self.states.b.numVgprLocalWriteSwapAddr = 1
if kernel["ProblemType"]["Sparse"] and not kernel["LocalWriteUseSgprMetadata"] and self.states.m.numVgprLocalWriteAddr > 0:
self.states.m.numVgprLocalWriteSwapAddr = 1
if kernel["ProblemType"]["MXBlockA"]:
if not kernel["LocalWriteUseSgprMXSA"] and self.states.mxsa.numVgprLocalWriteAddr > 0:
self.states.mxsa.numVgprLocalWriteSwapAddr = 1
if kernel["ProblemType"]["MXBlockB"]:
if not kernel["LocalWriteUseSgprMXSB"] and self.states.mxsb.numVgprLocalWriteAddr > 0:
self.states.mxsb.numVgprLocalWriteSwapAddr = 1
# Note: MXSA/MXSB do not use swap addresses

####################################
Expand Down Expand Up @@ -6147,6 +6165,14 @@ def GNLCOInit(tc):
if self.states.b.numVgprLocalReadSwapAddr > 0:
self.states.b.startVgprLocalReadSwapAddr = vgprIdx
vgprIdx += 1
if kernel["ProblemType"]["MXBlockA"]:
if self.states.mxsa.numVgprLocalReadSwapAddr > 0:
self.states.mxsa.startVgprLocalReadSwapAddr = vgprIdx
vgprIdx += 1
if kernel["ProblemType"]["MXBlockB"]:
if self.states.mxsb.numVgprLocalReadSwapAddr > 0:
self.states.mxsb.startVgprLocalReadSwapAddr = vgprIdx
vgprIdx += 1
if self.states.a.numVgprLocalWriteSwapAddr > 0:
self.states.a.startVgprLocalWriteSwapAddr = vgprIdx
vgprIdx += 1
Expand All @@ -6156,6 +6182,14 @@ def GNLCOInit(tc):
if self.states.b.numVgprLocalWriteSwapAddr > 0:
self.states.b.startVgprLocalWriteSwapAddr = vgprIdx
vgprIdx += 1
if kernel["ProblemType"]["MXBlockA"]:
if self.states.mxsa.numVgprLocalWriteSwapAddr > 0:
self.states.mxsa.startVgprLocalWriteSwapAddr = vgprIdx
vgprIdx += 1
if kernel["ProblemType"]["MXBlockB"]:
if self.states.mxsb.numVgprLocalWriteSwapAddr > 0:
self.states.mxsb.startVgprLocalWriteSwapAddr = vgprIdx
vgprIdx += 1

# X32F Emulation initializations
# meaning of variables
Expand Down Expand Up @@ -6583,6 +6617,10 @@ def checkVregOverflowTF32Emu(vgprIdx, numV):
self.defineSgpr("SwapA", 1)
if kernel["LocalWriteUseSgprB"]:
self.defineSgpr("SwapB", 1)
if kernel["ProblemType"]["MXBlockA"] and kernel["LocalWriteUseSgprMXSA"]:
self.defineSgpr("SwapMXSA", 1)
if kernel["ProblemType"]["MXBlockB"] and kernel["LocalWriteUseSgprMXSB"]:
self.defineSgpr("SwapMXSB", 1)
if kernel["ProblemType"]["Sparse"] and kernel["LocalWriteUseSgprMetadata"]:
self.defineSgpr("SwapMetadata", 1)

Expand Down
56 changes: 39 additions & 17 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,14 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module:
if self.states.m.numVgprLocalReadSwapAddr > 0:
module.add(RegSet("v", "vgprLocalReadSwapAddrMetadata", \
self.states.m.startVgprLocalReadSwapAddr))
if kernel["ProblemType"]["MXBlockA"]:
if self.states.mxsa.numVgprLocalReadSwapAddr > 0:
module.add(RegSet("v", "vgprLocalReadSwapAddrMXSA", \
self.states.mxsa.startVgprLocalReadSwapAddr))
if kernel["ProblemType"]["MXBlockB"]:
if self.states.mxsb.numVgprLocalReadSwapAddr > 0:
module.add(RegSet("v", "vgprLocalReadSwapAddrMXSB", \
self.states.mxsb.startVgprLocalReadSwapAddr))
if self.states.a.numVgprLocalWriteSwapAddr > 0:
module.add(RegSet("v", "vgprLocalWriteSwapAddrA", \
self.states.a.startVgprLocalWriteSwapAddr))
Expand All @@ -1180,6 +1188,14 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module:
if self.states.m.numVgprLocalWriteSwapAddr > 0:
module.add(RegSet("v", "vgprLocalWriteSwapAddrMetadata", \
self.states.m.startVgprLocalWriteSwapAddr))
if kernel["ProblemType"]["MXBlockA"]:
if self.states.mxsa.numVgprLocalWriteSwapAddr > 0:
module.add(RegSet("v", "vgprLocalWriteSwapAddrMXSA", \
self.states.mxsa.startVgprLocalWriteSwapAddr))
if kernel["ProblemType"]["MXBlockB"]:
if self.states.mxsb.numVgprLocalWriteSwapAddr > 0:
module.add(RegSet("v", "vgprLocalWriteSwapAddrMXSB", \
self.states.mxsb.startVgprLocalWriteSwapAddr))

if kernel["ProblemType"]["OutputAmaxD"]:
module.add(RegSet("v", "vgprAmaxOut", self.startVgprAmaxOut))
Expand Down Expand Up @@ -1771,10 +1787,16 @@ def localReadAddresses(self, kernel, tPA, tPB, tPM):
if self.states.a.numVgprLocalReadAddr > 0:
module.add(self.lraSwapAddressesForDTLPad(kernel, tPA))
module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPA, False, True))
if kernel["ProblemType"]["MXBlockA"]:
module.add(self.lraSwapAddressesForDTLPad(kernel, tPA["MX"]))
module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPA["MX"], False, True))

if self.states.b.numVgprLocalReadAddr > 0:
module.add(self.lraSwapAddressesForDTLPad(kernel, tPB))
module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPB, False, True))
if kernel["ProblemType"]["MXBlockA"]:
module.add(self.lraSwapAddressesForDTLPad(kernel, tPB["MX"]))
module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPB["MX"], False, True))
if self.states.m.numVgprLocalReadAddr > 0:
module.add(self.lraSwapAddressesForDTLPad(kernel, tPM))

Expand Down Expand Up @@ -4598,8 +4620,7 @@ def lwaFirstOffset(self, kernel, tP):
self.vgprPool.checkIn(tmpv)
self.vgprPool.checkIn(destVgpr)

# MXSA/MXSB do not use swap addresses (per mx_tony implementation)
if kernel["StoreSwapAddr"] and tc not in ("MXSA", "MXSB"):
if kernel["StoreSwapAddr"]:
if kernel["LocalWriteUseSgpr%s"%tc]:
# needed for the VReadfirstlaneB32 in the prior code block
if self.states.archCaps["CrosslaneWait"]:
Expand Down Expand Up @@ -4868,8 +4889,7 @@ def lraSwapAddressesForDTLPad(self, kernel, tP):
module = Module("lraSwapAddressesForDTLPad")

tc = tP["tensorChar"]
# MXSA/MXSB do not use swap addresses (per mx_tony implementation)
if kernel["StoreSwapAddr"] and tc not in ("MXSA", "MXSB"):
if kernel["StoreSwapAddr"]:
module.add(VAddU32(dst=vgpr("LocalReadSwapAddr%s"%tc), src0=kernel["LdsOffsetA_Blk"], src1=vgpr("LocalReadAddr%s"%tc), \
comment="Calculate starting lds addr of second buffer" ))
module.add(VXorB32(dst=vgpr("LocalReadSwapAddr%s"%tc), \
Expand Down Expand Up @@ -9808,6 +9828,18 @@ def localWriteAddRound(tc):
module.add(SCmpEQU32(src0=sgpr("LDSBufferWriteInc"), src1="LdsBlockEndSize", comment="LDSBufferWriteInc == End ?"))
module.add(SCMovB32(dst=sgpr("LDSBufferWriteInc"), src=0, comment="LDSBufferWriteInc loop back to 0"))

def getSrc0Val(tc):
src0Val = None
if kernel["StoreSwapAddr"]:
if kernel["LocalWriteUseSgpr%s"%tc]:
src0Val = sgpr("Swap%s"%tc)
else:
src0Val = vgpr("LocalWriteSwapAddr%s"%tc)
else:
# Using inlined constants
src0Val = hex(kernel["LdsOffsetA_Blk"])
return src0Val

if needSwap:
#fixme-iui need to use wrapping increment for double or triple buffering:

Expand All @@ -9822,7 +9854,7 @@ def localWriteAddRound(tc):
# (numLDSBlk>=3 is for DTL (and LocalWriteUseSgpr) only)
localWriteAddRound(tc)
else:
src0Val = None
src0Val = getSrc0Val(tc)
if kernel["StoreSwapAddr"]:
if kernel["LocalWriteUseSgpr%s"%tc]:
src0Val = sgpr("Swap%s"%tc)
Expand All @@ -9834,13 +9866,10 @@ def localWriteAddRound(tc):

numLwa = self.states.a.numVgprLocalWriteAddr if tP["isA"] else self.states.b.numVgprLocalWriteAddr
localWriteSwapXOR(tc, src0Val, numLwa)
#TODO:
if "MX" in tP:
tc = tP["MX"]["tensorChar"]
src0Val = hex(kernel["LdsOffsetA_Blk"])
#TODO:check to use isMXSA instead of isA
#numLwa = self.states.mxsa.numVgprLocalWriteAddr if tP["isA"] else self.states.mxsb.numVgprLocalWriteAddr
numLwa = self.states.mxsa.numVgprLocalWriteAddr if tP["isMXSA"] else self.states.mxsb.numVgprLocalWriteAddr
src0Val = getSrc0Val(tc)
numLwa = self.states.mxsa.numVgprLocalWriteAddr if tP["MX"]["isMXSA"] else self.states.mxsb.numVgprLocalWriteAddr
localWriteSwapXOR(tc, src0Val, numLwa)

# This used to control where to store the metadata
Expand Down Expand Up @@ -10843,13 +10872,6 @@ def localReadSwapOffsets(self, kernel, internalPointerSwap, tP):
if not kernel["StoreSwapAddr"]:
tP["localReadSwapByteOffset"] = 0 if tP["localReadSwapByteOffset"] else kernel["LdsOffsetA_Blk"]
module.addComment1("local read swap internal offset -> %u" % tP["localReadSwapByteOffset"])
elif tc in ("MXSA", "MXSB"):
#TODO: MXSA/MXSB do not use swap addresses, use fixed offset instead
module.add(VXorB32(
dst=vgpr("LocalReadAddr%s"%tc), \
src0=hex(kernel["LdsOffsetA_Blk"]), \
src1=vgpr("LocalReadAddr%s"%tc), \
comment="swap Red Blk"))
else:
module.add(VXorB32(
dst=vgpr("LocalReadAddr%s"%tc), \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2321,6 +2321,10 @@ def calSwizzlePackK(state, tc):
curGRVW *= 2
if state["ProblemType"]["MXBlockA"]:
state["GlobalReadVectorWidthMXSA"] = max(state["MacroTile0"] * state["_DepthUMXSA"] // state["NumThreads"], 1)
# workaround for DTL
# use 32bit load for DTL if GlobalReadVectorWidthMXS is 8
if state["DirectToLdsMXSA"] and state["GlobalReadVectorWidthMXSA"] == 8:
Comment thread
msujon-AMD marked this conversation as resolved.
state["GlobalReadVectorWidthMXSA"] = 4
else:
# dot2
if state["UseDotInstruction"]:
Expand Down Expand Up @@ -2365,6 +2369,10 @@ def calSwizzlePackK(state, tc):
curGRVW *= 2
if state["ProblemType"]["MXBlockB"]:
state["GlobalReadVectorWidthMXSB"] = max(state["MacroTile1"] * state["_DepthUMXSB"] // state["NumThreads"], 1)
# workaround for DTL
# use 32bit load for DTL if GlobalReadVectorWidthMXS is 8
if state["DirectToLdsMXSB"] and state["GlobalReadVectorWidthMXSB"] == 8:
state["GlobalReadVectorWidthMXSB"] = 4
else:
# dot2
if state["UseDotInstruction"]:
Expand Down Expand Up @@ -3462,13 +3470,10 @@ def setLdsOffsets(offsetBlk, numLdsBlk, ldsNumBytesB):
# TODO:
# Disable StoreSwapAddr to ensure LdsOffsetA_Blk is always a power of 2
# This is consistent with referenc implementation which doesn't have StoreSwapAddr
if state["ProblemType"]["MXBlockA"]:
state["StoreSwapAddr"] = False
else:
# Original logic:
state["StoreSwapAddr"] = (state["PrefetchGlobalRead"] == 2) and \
(state["1LDSBuffer"] == 0) and numLdsBlk == 2 and \
(offsetBlk + roundupOffsetBlk) > state["MaxLDS"]
# Original logic:
state["StoreSwapAddr"] = (state["PrefetchGlobalRead"] == 2) and \
(state["1LDSBuffer"] == 0) and numLdsBlk == 2 and \
(offsetBlk + roundupOffsetBlk) > state["MaxLDS"]

if offsetBlk > 0 and not state["StoreSwapAddr"] and numLdsBlk == 2:
# Rounds offsetBlk to a power of two to enable inlining {s,v}_xor constants for swapping offsets
Expand Down
Loading