Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def applyPad(offset_val):
offset_val = offset_val + tP["localReadSwapByteOffset"]
# TODO: Add NLC>1 offset calcs here?
if (kernel["DirectToLds%s" % tc] and \
kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]:
kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]:
# another address conversion for DirectToLds + NumLoadsCoalesced > 1
dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val)

Expand Down
4 changes: 4 additions & 0 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5756,6 +5756,10 @@ def GNLCOInit(tc):
self.defineSgpr("LocalWriteAddrA", 1)
if kernel["LocalWriteUseSgprB"]:
self.defineSgpr("LocalWriteAddrB", 1)
if kernel["ProblemType"]["MXBlockA"] and kernel["LocalWriteUseSgprMXSA"]:
self.defineSgpr("LocalWriteAddrMXSA", 1)
if kernel["ProblemType"]["MXBlockB"] and kernel["LocalWriteUseSgprMXSB"]:
self.defineSgpr("LocalWriteAddrMXSB", 1)

# Allocate registers to swap between lds buffers
if kernel["StoreSwapAddr"]:
Expand Down
26 changes: 14 additions & 12 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3287,7 +3287,7 @@ def graFinalOffsetsSingleLoopGNLC(self, kernel, tP, tc, margin = -1):
self.vgprPool.checkIn(reMap1)

stride = "Strides%s"%(tc)
module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%c"%tc]), src=vgpr(tmpv2)))
module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%s"%tc]), src=vgpr(tmpv2)))
module.add(VMulLOU32(dst=vgpr(tmpv), src0=sgpr(stride), src1=vgpr(tmpv)))

if kernel["EdgeType"] == "ShiftPtr" and kernel["ProblemType"]["TLU%s"%tc] == 1:
Expand Down Expand Up @@ -3387,11 +3387,11 @@ def graFinalOffsetsSingleLoop(self, kernel, tP, tc, tmp, graIdx, perp, sPerp, pa
module.add(SMovB32(dst=sgpr(tmpSgpr), src=self.buff_load_inst_offset_max))
module.add(VAddU32(dst=vgpr(groVgpr), src0=vgpr(groVgpr), src1=sgpr(tmpSgpr), comment="shift for UseInstOffsetForGRO"))

ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"]
ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"]
if kernel["LdsBlockSizePerPad%s"%tc] != 0:
ldsInc += (ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"]
else:
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr
ldsInc += (ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"]

# buffer_load only support 12 bit instruction offset
Expand Down Expand Up @@ -3447,11 +3447,11 @@ def computeScalarGroImpl(scalarGro):
# add room for instruction offset
module.add(SAddU32(dst=sgpr(scalarGro), src0=sgpr(scalarGro), src1=self.buff_load_inst_offset_max, comment="shift for UseInstOffsetForGRO"))

ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"]
ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"]
if kernel["LdsBlockSizePerPad%s"%tc] != 0:
ldsInc += (ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"]
else:
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr
ldsInc += (ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"]

# buffer_load only support 12 bit instruction offset
Expand Down Expand Up @@ -4260,7 +4260,7 @@ def lwaFirstOffset(self, kernel, tP):
dst=sgpr("LocalWriteAddr%s"%tc), \
src=vgpr(tmpv), \
comment="Copy lds write address VGPR to SGPR"))
lwastride = int((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%c"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"])
lwastride = int((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%s"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"])
module.add(SMulI32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=lwastride ))
if tc == 'B':
module.add(SAddU32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), \
Expand Down Expand Up @@ -8405,13 +8405,13 @@ def globalReadGuardKBody(tP, optParams = None):
# need to increment ldsInc only once per each loopCnt
# this is pre count up, so increment it at r == 0
if r == 0:
ldsInc = int((self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"])
ldsInc = int((self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"])
else:
ldsInc = 0
if kernel["LdsBlockSizePerPad%s"%tc] != 0:
ldsInc += int((ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"])
else:
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr
ldsInc += int((ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"])
if kernel["UseInstOffsetForGRO"]:
# buffer_load only support 12 bit instruction offset
Expand Down Expand Up @@ -8889,8 +8889,6 @@ def directToLdsM0Update(self, kernel, mode, tP, skipWait = False):
DtldsModule.addComment0("before DirectToLds load, ensure prior ds_reads have finished")
DtldsModule.add(SWaitCnt(dscnt=0, comment=""))
DtldsModule.add(SBarrier())
if "MX" in tP:
imod.add(self.directToLdsM0Update(kernel, 0, tP["MX"]))

return imod

Expand Down Expand Up @@ -8980,6 +8978,10 @@ def globalReadBody(tP):
else:
destVgprPrefix = "G2L%s"%(tc)

# add m0 init code for MX here
if tc == "MXSA" or tc == "MXSB":
imod.middle.add(self.directToLdsM0Update(kernel, mode, tP, skipWait=True))

loopCnt = -1
for perp in range(0, tP["nrp"]):
for sPerp in range(0, tP["nrpv"]):
Expand Down Expand Up @@ -9035,11 +9037,11 @@ def globalReadBody(tP):

if kernel["DirectToLds%s"%tc]:
# use bpe with GlobalReadVectorWidth
ldsInc = int((self.states.kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%c"%tc] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"] * kernel["GlobalReadVectorWidth%c"%tc]) * tP["bpeGR"])
ldsInc = int((self.states.kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%s"%tc] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"] * kernel["GlobalReadVectorWidth%s"%tc]) * tP["bpeGR"])
if kernel["LdsBlockSizePerPad%s"%tc] != 0:
ldsInc += int((ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"])
else:
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr
padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr
ldsInc += int((ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"])

if kernel["UseInstOffsetForGRO"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,16 +829,12 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool):
#TN
# use for all precisions with TransposeLDS=1

numBytesAB = state["ProblemType"]["DataType%s"%tc].numBytes()
if tc in ["MXSA", "MXSB"]:
numBytesAB = 1
else:
numBytesAB = state["ProblemType"]["DataType%s"%tc].numBytes()
numBytesPerLoad = state["GlobalReadVectorWidth%s"%tc] * numBytesAB

MT = state["MacroTile0"] if tc == 'A' else state["MacroTile1"]

if (MT & (MT-1)) != 0 and not state["UseGeneralizedNLCOne%s"%tc]: # Check of MT not power of 2
# so far, numBytesAB<=4 case, TLU=False only (continue with False)
if (numBytesAB <= 4) and state["ProblemType"]["TLU%c"%tc]:
return False

# x2 DTL is not supported
if numBytesPerLoad == 8:
printWarning("can't use DirectToLds with b64 buffer load, using non DirectToLds version instead")
Expand All @@ -852,6 +848,17 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool):
reject(state, printRejectionReason, "DirectToLds not supported for loads less than 32bits")
return False

if tc in ["MXSA", "MXSB"]:
# MXSA/B case, check numBytesPerLoad only
return True

MT = state["MacroTile0"] if tc == 'A' else state["MacroTile1"]

if (MT & (MT-1)) != 0 and not state["UseGeneralizedNLCOne%s"%tc]: # Check of MT not power of 2
# so far, numBytesAB<=4 case, TLU=False only (continue with False)
if (numBytesAB <= 4) and state["ProblemType"]["TLU%c"%tc]:
return False

# so far MFMA only (TODO: enable non MFMA case)
if not state["EnableMatrixInstruction"]:
reject(state, printRejectionReason, "DirectToLds is for MatrixInstruction only for now (tentative)")
Expand Down Expand Up @@ -1240,11 +1247,10 @@ def assignDerivedParameters(
# set True for DTL
state["UseGeneralizedNLCOneA"] = state["DirectToLdsA"]
state["UseGeneralizedNLCOneB"] = state["DirectToLdsB"]
# MX block does not use DTL, so set to False
if state["ProblemType"]["MXBlockA"]:
state["UseGeneralizedNLCOneMXSA"] = False
state["UseGeneralizedNLCOneMXSA"] = False #state["DirectToLdsA"]
if state["ProblemType"]["MXBlockB"]:
state["UseGeneralizedNLCOneMXSB"] = False
state["UseGeneralizedNLCOneMXSB"] = False #state["DirectToLdsB"]

state["LocalWriteUseSgprA"] = False
state["LocalWriteUseSgprB"] = False
Expand Down Expand Up @@ -1464,8 +1470,8 @@ def assignDerivedParameters(
state["WaveSeparateGlobalReadMXSA"] = state["WaveSeparateGlobalReadA"]
state["NumLoadsCoalescedMXSA"] = state["NumLoadsCoalescedA"]
Solution.checkAndAssignWaveSeparateGlobalRead(state, 'MXSA', printRejectionReason)
state["DirectToLdsMXSA"] = False
state["LocalWriteUseSgprMXSA"] = False
state["DirectToLdsMXSA"] = state["DirectToLdsA"]
state["LocalWriteUseSgprMXSA"] = state["DirectToLdsMXSA"]
state["ProblemType"]["MirrorDimsMXSA"] = list(state["ProblemType"]["MirrorDimsA"])
state["VectorWidthMXSA"] = state["VectorWidthA"]
state["MIWaveTileMXSA"] = state["MIWaveTileA"]
Expand All @@ -1478,8 +1484,8 @@ def assignDerivedParameters(
state["WaveSeparateGlobalReadMXSB"] = state["WaveSeparateGlobalReadB"]
state["NumLoadsCoalescedMXSB"] = state["NumLoadsCoalescedB"]
Solution.checkAndAssignWaveSeparateGlobalRead(state, 'MXSB', printRejectionReason)
state["DirectToLdsMXSB"] = False
state["LocalWriteUseSgprMXSB"] = False
state["DirectToLdsMXSB"] = state["DirectToLdsB"]
state["LocalWriteUseSgprMXSB"] = state["DirectToLdsMXSB"]
state["ProblemType"]["MirrorDimsMXSB"] = list(state["ProblemType"]["MirrorDimsB"])
state["VectorWidthMXSB"] = state["VectorWidthB"]
state["MIWaveTileMXSB"] = state["MIWaveTileB"]
Expand Down Expand Up @@ -1758,10 +1764,10 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int:
else:
ldsPadA = state["VectorWidthA"]
else:
ldsPadA = max(state["GlobalReadVectorWidthA"],optPadA)
## turn-off padding for directToLds
if state["DirectToLdsA"]:
ldsPadA = 0
ldsPadA = max(lrvwA, optPadA) if not state["ProblemType"]["TLUA"] else 0
else:
ldsPadA = max(state["GlobalReadVectorWidthA"],optPadA)
assert(ldsPadA >= 0)

if ldsPadB == -1:
Expand All @@ -1781,9 +1787,10 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int:
else:
ldsPadB = state["VectorWidthB"]
else:
ldsPadB = max(state["GlobalReadVectorWidthB"],optPadB)
if state["DirectToLdsB"]:
ldsPadB = 0
ldsPadB = max(lrvwB, optPadB) if not state["ProblemType"]["TLUB"] else 0
else:
ldsPadB = max(state["GlobalReadVectorWidthB"],optPadB)
assert(ldsPadB >= 0)

if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]:
Expand Down Expand Up @@ -2817,13 +2824,21 @@ def calSwizzlePackK(state, tc):
# LDS (load size coalesced) * LSPA must load some multiple of 256 bytes.
# No longer support loadX2/loadx4 .
for tc in ['A', 'B']:
tcmx = "MXS%s"%tc
if state["DirectToLds%s"%tc]:
isDtlDoable = Solution.isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason)
if (not state["DirectToVgpr%s"%tc]) and isDtlDoable:
state["DirectToLds%s"%tc] = True
state["LocalWriteUseSgpr%s"%tc] = True
# MX case
if state["ProblemType"]["MXBlock%s"%tc]:
isDtlMxDoable = Solution.isDirectToLdsDoable(state, tcmx, isaInfoMap, printRejectionReason)
state["DirectToLds%s"%tcmx] = isDtlMxDoable
else:
state["DirectToLds%s"%tc] = False
# MX case
if state["ProblemType"]["MXBlock%s"%tc]:
state["DirectToLds%s"%tcmx] = False
if not isDtlDoable:
if state["UseGeneralizedNLCOne%s"%tc]:
reject(state, printRejectionReason, "DirectToLds%s not doable, but GNLC%s enabled, rejecting"%(tc, tc))
Expand All @@ -2840,6 +2855,10 @@ def calSwizzlePackK(state, tc):
if state["1LDSBuffer"] == -1 and state["DirectToLds"]:
#1LDS buffer must be 0 for DirectToLdsA
state["1LDSBuffer"] = 0
# MX case
if state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]:
if state["DirectToLdsA"] != state["DirectToLdsMXSA"] or state["DirectToLdsB"] != state["DirectToLdsMXSB"]:
reject(state, printRejectionReason, "DirectToLdsA/B and DirectToLdsMXSA/B should match")

# does not work with UnrollLoopSwapGlobalReadOrder
if (state["DirectToLds"] == 2 or state["DirectToLds"] == 3) and state["UnrollLoopSwapGlobalReadOrder"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ BenchmarkProblems:
- BiasTypeArgs: ['s']
- ActivationArgs:
- [Enum: none]
- # BenchmarkProblemSizeGroup - DTL
InitialSolutionParameters:
BenchmarkCommonParameters:
- KernelLanguage: ["Assembly"]
ForkParameters:
- MatrixInstruction:
- [16, 16, 128, 1, 1, 8,8, 2,2] # 64x64
- [16, 16, 128, 1, 1, 8,2, 2,2] # 128x64
- [16, 16, 128, 1, 1, 2,8, 2,2] # 64x128
- [32, 32, 64, 1, 1, 4,4, 2,2] # 128x128
- DepthU: [128]
- AssertSummationElementMultiple: [32]
- LocalReadVectorWidth: [32]
- PrefetchGlobalRead: [2]
- PrefetchLocalRead: [1]
- ScheduleIterAlg: [3]
- 1LDSBuffer: [0]
- DirectToLds: [1,2,3]
BenchmarkJoinParameters:
BenchmarkFinalParameters:
- ProblemSizes:
- Exact: [32, 16, 1, 1024]
- Exact: [256, 256, 1, 128]
- Exact: [1025, 513, 1, 2048]
- Exact: [127, 127, 1, 640] #special cleanup case
- Exact: [129, 129, 1, 640]
- Exact: [128, 128, 1, 512]
- BiasTypeArgs: ['s']
- ActivationArgs:
- [Enum: none]

# LibraryLogic:
# ScheduleName: "aquavanjaram"
Expand Down
Loading