Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ff4360b
[hipBLASLt] Add initial changes for Navi SK support
stsokolo Jul 2, 2025
e943ae9
[hipBLASLt] Update gfx1100/gfx1201 logic yaml
stsokolo Jul 16, 2025
fe4b474
[hipBLASLt] Change Tensile version
stsokolo Jul 16, 2025
c9bb8eb
[hipBLASLt] Add cache scope modifier
stsokolo Jul 17, 2025
6a0d070
[hipBLASLt] Add simple hgemm and bgemm SK test configs
stsokolo Aug 8, 2025
d4c7cc6
[hipBLASLt] Fix SWaitCnt
stsokolo Aug 8, 2025
89c1f07
[hipBLASLt] Swap global store with buffer store instruction
stsokolo Aug 14, 2025
69c9868
[hipBLASLt] Fix typo
stsokolo Aug 14, 2025
2058396
[hipBLASLt] Minor change to MUBUFModifiers in rocisa
stsokolo Aug 15, 2025
03f64b2
[hipBLASLt] Add dlc modifier to GlobalWrite
stsokolo Sep 9, 2025
5a08f17
[hipBLASLt] Add off option to vgpr in rocisa
stsokolo Sep 9, 2025
c66710c
[hipBLASLt] Update container tests in rocisa
stsokolo Sep 11, 2025
8a4325c
[hipBLASLt] Remove gfx1100 and gfx1201 SK logic libraries
stsokolo Sep 11, 2025
caed531
[hipBLASLt] Update SK test configs for gfx1100 and gfx1201
stsokolo Sep 11, 2025
38491a3
[hipBLASLt] Fix gfx1100 SK test config
stsokolo Sep 11, 2025
1f50ddb
[hipBLASLt] Remove v_swmmac matrix instructions for gfx1201
stsokolo Sep 12, 2025
c044d4a
[hipBLASLt] Add s_store_b32 case in hardware caps
stsokolo Sep 16, 2025
e6bf83f
[hipBLASLt] Fix hardware caps s_store instruction
stsokolo Sep 23, 2025
a3f5f02
[hipBLASLt] Solution.py cleanup
stsokolo Oct 1, 2025
93a8907
[hipBLASLt] Fix modifiers for gfx942 and gfx950 in rocisa
stsokolo Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
################################################################################

from rocisa.enum import CacheScope
from rocisa.code import Module, Label
from rocisa.container import vgpr, sgpr, SMEMModifiers, replaceHolder, EXEC,\
from rocisa.container import vgpr, sgpr, SMEMModifiers, MUBUFModifiers, replaceHolder, EXEC,\
VOP3PModifiers, ContinuousRegister
from rocisa.instruction import SAddCU32, SAddI32, SAddU32, SAndB32, SBarrier, \
SBranch, SCBranchSCC0, SCBranchSCC1, SCMovB32, SCSelectB32, SCmpEQU32, SCmpEQU64, \
SCmpGtU32, SCmpLeU32, SCmpLtU32, SCmpGeU32, SLShiftLeftB32, SLShiftLeftB64, SLShiftRightB32, SLoadB32, \
SMinU32, SMovB32, SMovB64, SMulI32, SNop, SSleep, SStoreB32, SSubU32, \
SWaitCnt, VAddF32, VAddF64, VAddPKF16, VAddU32, VLShiftRightB32, VMovB32, \
VReadfirstlaneB32, VCvtBF16toFP32
VReadfirstlaneB32, VCvtBF16toFP32, BufferStoreB32
from rocisa.functions import scalarStaticDivideAndRemainder, sMagicDiv2, \
vectorStaticMultiply, BranchIfNotZero, scalarUInt32DivideAndRemainder

Expand Down Expand Up @@ -228,7 +229,7 @@ def computeStoreSrdStartCommon(self, writer, kernel):
tmpSgpr0 = tmpSgprInfo.idx+1
tmpSgpr1 = tmpSgprInfo.idx+2
tmpSgpr2 = tmpSgprInfo.idx+0
tmpSgpr3 = tmpSgprInfo.idx+3
tmpSgpr3 = tmpSgprInfo.idx+3
module.addComment("Split Output Buffer offset: Free0 + (Free1-1)*StrideC1J + (Free2-1)*StrideCK * SplitIdx * bpe%s")
# PartialIdx was saved in sgprBeta for re-use
module.addModuleAsFlatItems(writer.s_mul_u64_u32(sgpr(tmpSgpr0), sgpr(tmpSgpr1), sgpr("SizesFree+0"), sgpr("SkPartialIdx"), comment="Free0"))
Expand Down Expand Up @@ -427,7 +428,7 @@ def storeBranchesCommon(self, writer, kernel, skPartialsLabel, vectorWidths, ele
module.add(SLShiftLeftB32(dst=sgpr(tmpSgpr), src=sgpr(sFlagIdx), shiftHex=log2(4), comment="flag offset based on wg index"))

module.add(skFixupWaitForFlag) # loop to wait for flag
module.add(SLoadB32(dst=sgpr(tmpSgpr+1), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="get flag"))
module.add(SLoadB32(dst=sgpr(tmpSgpr+1), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True, dlc=True, scope=CacheScope.SCOPE_DEV), comment="get flag"))

module.add(SWaitCnt(kmcnt=0, comment="wait for flag load"))
if kernel["DebugStreamK"] & 2 == 0: # Don't wait for partials if not being written
Expand Down Expand Up @@ -489,7 +490,7 @@ def storeBranchesCommon(self, writer, kernel, skPartialsLabel, vectorWidths, ele

# Check flag
module.add(SLShiftLeftB32(dst=sgpr(tmpSgpr), src=sgpr(sCtaIdx), shiftHex=log2(4), comment="flag offset based on CTA index"))
module.add(SLoadB32(dst=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="get flag"))
module.add(SLoadB32(dst=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True, dlc=True, scope=CacheScope.SCOPE_DEV), comment="get flag"))

module.add(SWaitCnt(kmcnt=0, comment="wait for flag load"))
if kernel["DebugStreamK"] & 2 == 0:
Expand All @@ -502,8 +503,12 @@ def storeBranchesCommon(self, writer, kernel, skPartialsLabel, vectorWidths, ele
module.add(VReadfirstlaneB32(dst=sgpr(tmpSgpr+2), src=vgpr("Serial"), comment="Wave 0 updates flags"))
module.add(SCmpEQU32(src0=sgpr(tmpSgpr+2), src1=0, comment="Check for wave 0"))
module.add(SCBranchSCC0(labelName=skipFlagReset.getLabelName(), comment="Skip flag reset"))
# (tmpSgpr+2) contains a vlue of 0, use it to reset the flag
module.add(SStoreB32(src=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="reset flag"))
if writer.states.asmCaps["HasScalarStore"]:
# (tmpSgpr+2) contains a vlue of 0, use it to reset the flag
module.add(SStoreB32(src=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="reset flag"))
else:
module.add(VMovB32(dst=vgpr(tmpVgpr), src=0, comment="move 0 to tmpVgpr"))
Comment thread
daineAMD marked this conversation as resolved.
module.add(self.setFlagValue(writer, src=vgpr(tmpVgpr), soffset=sgpr(tmpSgpr), comment="reset flag"))
module.add(skipFlagReset)
writer.sgprPool.checkIn(tmpSgpr)

Expand Down Expand Up @@ -806,8 +811,12 @@ def partialsWriteProcedure(self, writer, kernel, vectorWidths, elements, alpha,
module.add(VReadfirstlaneB32(dst=sgpr(flagSgpr), src=vgpr("Serial"), comment="Wave 0 updates flags"))
module.add(SCmpEQU32(src0=sgpr(flagSgpr), src1=0, comment="Check for wave 0"))
module.add(SCBranchSCC0(labelName=skipFlagSet.getLabelName(), comment="Skip flag set"))
module.add(SMovB32(dst=sgpr(flagSgpr), src=1, comment="flag data"))
module.add(SStoreB32(src=sgpr(flagSgpr), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="set flag"))
if writer.states.asmCaps["HasScalarStore"]:
module.add(SMovB32(dst=sgpr(flagSgpr), src=1, comment="flag data"))
module.add(SStoreB32(src=sgpr(flagSgpr), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="set flag"))
else:
module.add(VMovB32(dst=vgpr(tmpVgpr), src=1, comment="move 1 to tmpVgpr"))
module.add(self.setFlagValue(writer, src=vgpr(tmpVgpr), soffset=sgpr(tmpSgpr), comment="set flag"))
module.add(skipFlagSet)
module.add(SWaitCnt(kmcnt=0, comment="wait for flag")) # TODO just for testing

Expand All @@ -818,6 +827,20 @@ def partialsWriteProcedure(self, writer, kernel, vectorWidths, elements, alpha,

return module

def setFlagValue(self, writer, src, soffset, comment=""):
module = Module("Buffer Store Flag Value")
tmpSgprBuffer = writer.sgprPool.checkOutAligned(4, 4, preventOverflow=False)
module.add(SMovB64(dst=sgpr(tmpSgprBuffer, 2), src=sgpr("AddressFlags", 2)))
module.add(SMovB32(dst=sgpr(tmpSgprBuffer+2), src="BufferOOB"))
module.add(SMovB32(dst=sgpr(tmpSgprBuffer+3), src="Srd127_96"))
module.add(BufferStoreB32(src=src, vaddr=vgpr("off", isOff=True), saddr=sgpr(tmpSgprBuffer, 4), soffset=soffset, \
mubuf=MUBUFModifiers(glc=True, dlc=True, scope=CacheScope.SCOPE_DEV), \
comment=comment))
module.add(SWaitCnt(vscnt=0, comment="wait for data store")) #TODO: See if this wait is necessery
writer.sgprPool.checkIn(tmpSgprBuffer)

return module

def partialsWriteBatch(self, writer, kernel, ss, batchIdx, applyAlpha, beta, edge, gwvw, atomicW, \
batchElements, addrD, addrC, \
tmpVgpr, cvtVgprStruct, batchElementSgprs, tmpSgpr, codeAccVgprRead):
Expand Down Expand Up @@ -1753,6 +1776,12 @@ def preLoop(self, writer, kernel):
xccMapping = Component.XCCMapping.find(writer)
module.add(xccMapping(writer, kernel))

# Workaround for gfx12
if writer.states.archCaps["WorkGroupIdFromTTM"]:
module.add(SMovB32(dst=sgpr("WorkGroup0"), src="ttmp9", comment="workaround"))
module.add(SAndB32(dst=sgpr("WorkGroup1"), src0=hex(0xFFFF), src1="ttmp7", comment="workaround"))
module.add(SLShiftRightB32(dst=sgpr("WorkGroup2"), shiftHex=hex(0x10), src="ttmp7", comment="workaround"))

module.add(SMovB32(dst=sgpr("StreamKIdx"), src=sgpr("WorkGroup0"), comment="Save original StreamK index"))
# Basic SK
module.add(SMulI32(dst=sgpr("StreamKIter"), src0=sgpr("StreamKIdx"), src1=sgpr("SKItersPerWG"), comment="StreamK starting iteration"))
Expand Down Expand Up @@ -1831,6 +1860,12 @@ def preLoop(self, writer, kernel):
xccMapping = Component.XCCMapping.find(writer)
module.add(xccMapping(writer, kernel))

# Workaround for gfx12
if writer.states.archCaps["WorkGroupIdFromTTM"]:
module.add(SMovB32(dst=sgpr("WorkGroup0"), src="ttmp9", comment="workaround"))
module.add(SAndB32(dst=sgpr("WorkGroup1"), src0=hex(0xFFFF), src1="ttmp7", comment="workaround"))
module.add(SLShiftRightB32(dst=sgpr("WorkGroup2"), shiftHex=hex(0x10), src="ttmp7", comment="workaround"))

module.add(SMovB32(dst=sgpr("StreamKIdx"), src=sgpr("WorkGroup0"), comment="Save original StreamK index"))
# Two-tile SK (SK first)
# iter count after all extra iters have been distributed
Expand Down Expand Up @@ -1950,6 +1985,12 @@ def preLoop(self, writer, kernel):
xccMapping = Component.XCCMapping.find(writer)
module.add(xccMapping(writer, kernel))

# Workaround for gfx12
if writer.states.archCaps["WorkGroupIdFromTTM"]:
module.add(SMovB32(dst=sgpr("WorkGroup0"), src="ttmp9", comment="workaround"))
module.add(SAndB32(dst=sgpr("WorkGroup1"), src0=hex(0xFFFF), src1="ttmp7", comment="workaround"))
module.add(SLShiftRightB32(dst=sgpr("WorkGroup2"), shiftHex=hex(0x10), src="ttmp7", comment="workaround"))

module.add(SMovB32(dst=sgpr("StreamKIdx"), src=sgpr("WorkGroup0"), comment="Save original StreamK index"))
# Two-tile SK (DP first)
# Do DP tiles before SK
Expand Down Expand Up @@ -1979,13 +2020,13 @@ def preLoop(self, writer, kernel):
# if (partialIdx < extraIters) then (skIter = partialIdx * (itersPerWG + 1)) else (skIter = partialIdx * itersPerWG + extraIters)
skHasExtraLabel = Label("SK_HasExtra", "")
skDoneExtraLabel = Label("SK_DoneExtra", "")

# PartialIdx = itersPerTile % skSplit (skSplit is passed as SkSplit)
# extraIters = ItersPerTile - SkSplit * skItersPerWG
sSkExtraIters = writer.sgprPool.checkOut(1, "extraIters")
module.add(SMulI32(dst=sgpr(sSkExtraIters), src0=sgpr("SkSplit"), src1=sgpr("SKItersPerWG")))
module.add(SSubU32(dst=sgpr(sSkExtraIters), src0=sgpr("ItersPerTile"), src1=sgpr(sSkExtraIters), comment="extraIters = itersPerTile - SkSplit * skItersPerWG"))

module.add(SMulI32(dst=sgpr("StreamKIter"), src0=sgpr(stmpPartialIdx), src1=sgpr("SKItersPerWG"), comment="StreamK starting iteration (case: after extra iters)"))
module.add(SCmpLtU32(src0=sgpr(stmpPartialIdx), src1=sgpr(sSkExtraIters), comment="Check if WG gets an extra iteration"))
module.add(SCBranchSCC1(labelName=skHasExtraLabel.getLabelName(), comment="Has extra iter"))
Expand All @@ -2007,7 +2048,7 @@ def preLoop(self, writer, kernel):
module.add(SMovB32(dst=sgpr("SkPartialIdx"), src=sgpr(stmpPartialIdx), comment="Save partial idx for SrdD calculation"))
# Done init
module.add(SBranch(labelName=skInitDone.getLabelName(), comment="Done init for parallel reduction"))

# # Save PratialIdx for later, skExtraIters is unused for partial reduction
# module.add(SMovB32(dst=sgpr("skExtraIters"), src=sgpr(stmpPartialIdx), comment="Save partial idx for SrdD calculation"))
# # StreamKIter = tile * itersPerTile + itersPerWG * partialIndex
Expand Down
32 changes: 19 additions & 13 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
scalarStaticRemainder, scalarUInt32DivideAndRemainder, sMagicDiv, vectorStaticMultiply, \
vectorStaticMultiplyAdd, scalarStaticMultiply64, BranchIfZero, BranchIfNotZero, DSInit, \
ArgumentLoader
from rocisa.enum import InstType, SelectBit
from rocisa.enum import InstType, SelectBit, CacheScope
from rocisa.macro import MacroVMagicDiv, PseudoRandomGenerator
from . import CUSTOM_KERNEL_PATH
from rocisa.instruction import BranchInstruction, BufferLoadB128, BufferLoadB32, \
Expand Down Expand Up @@ -12295,7 +12295,9 @@ def bufferLoadImpl(soffset):
##############################################################################
def chooseGlobalWrite(self, useBuffer, bps, srcVgpr, rpv, \
addr0, addr1, offset, soffset=0, \
glc=False, slc=False, nt=False, hi16=0, comment="store"):
glc=False, slc=False, nt=False, dlc=False, \
scope=CacheScope.SCOPE_NONE, \
hi16=0, comment="store"):
"""
create the store instruction for requested vector width and other parms
rpv = regs per vector
Expand Down Expand Up @@ -12338,15 +12340,15 @@ def bufferStoreImpl(tmpSgpr, mubuf):
if offset2 >= 4096:
module.add(SMovB32(dst=tmpSgpr, src=offset2, comment="large offset"))
offset2 = 0
mubuf2 = MUBUFModifiers(offen=True, offset12=offset2, glc=glc, slc=slc, nt=nt, isStore=True)
mubuf2 = MUBUFModifiers(offen=True, offset12=offset2, glc=glc, slc=slc, dlc=dlc, scope=scope, nt=nt, isStore=True)
vgprOff = int(srcVgpr + shiftRpv * i) if isinstance(srcVgpr, int) else f"{srcVgpr}+{int(shiftRpv * i)}"
module.add(BufferStoreB128(src=vgpr(vgprOff, shiftRpv), vaddr=addr0, \
saddr=addr1, soffset=tmpSgpr, mubuf=mubuf2, comment=comment))
else:
assert 0, "bad bps"

if useBuffer:
mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, nt=nt, isStore=True)
mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, dlc=dlc, scope=scope, nt=nt, isStore=True)
if soffset != 0:
assert offset < 4096, "sgpr offset provided with large const offset"
# buffer_load offset field is 12-bit.
Expand All @@ -12357,13 +12359,13 @@ def bufferStoreImpl(tmpSgpr, mubuf):
tmpSgpr = sgpr(tmpSgprInfo.idx)
if offset >= 4096:
module.add(SMovB32(dst=tmpSgpr, src=offset, comment="large offset"))
mubuf = MUBUFModifiers(offen=True, offset12=0, glc=glc, slc=slc, nt=nt, isStore=True)
mubuf = MUBUFModifiers(offen=True, offset12=0, glc=glc, slc=slc, dlc=dlc, scope=scope, nt=nt, isStore=True)
bufferStoreImpl(tmpSgpr, mubuf)
else:
bufferStoreImpl(soffset, mubuf)

else:
flat = FLATModifiers(glc=glc, slc=slc, isStore=True)
flat = FLATModifiers(glc=glc, slc=slc, dlc=dlc, scope=scope, isStore=True)
if bps==2 and hi16:
module.add(FlatStoreD16HIB16(vaddr=addr0, src=vgpr(srcVgpr*2), flat=flat, comment=comment))
elif bps==2 and not hi16:
Expand Down Expand Up @@ -12493,6 +12495,8 @@ def addStore(self, kernel, ss, tc: str, addrCalc, sumIdx, tmpS01, edge, wsOffset
isGlc = False
isSlc = False
isNT = False
scope = CacheScope.SCOPE_NONE
isDlc = False

if tc == 'D':
isGlc = bool(kernel["NonTemporalD"] & 0x1)
Expand Down Expand Up @@ -12533,6 +12537,8 @@ def addStore(self, kernel, ss, tc: str, addrCalc, sumIdx, tmpS01, edge, wsOffset
isGlc = True
isSlc = True
isNT = bool(kernel["NonTemporalD"] & 0x4)
isDlc = True
scope = CacheScope.SCOPE_DEV

bps = self.states.bpeCinternal * ss.cfg.gwvw
rpv = self.states.bpeCinternal * ss.cfg.gwvw / self.states.bpr
Expand Down Expand Up @@ -12582,36 +12588,36 @@ def addStore(self, kernel, ss, tc: str, addrCalc, sumIdx, tmpS01, edge, wsOffset
if self.states.asmCaps["HasWMMA_V1"] and kernel["EnableMatrixInstruction"]:
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, hi16=0, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, hi16=0, comment=comment))
else:
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx//2, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, hi16=sumIdx%2, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, hi16=sumIdx%2, comment=comment))
else:
# (B,B,B,B,S,S), internal S
# (H,H,H,H,H,H), internal S
# (H,H,H,H,S,S), internal S
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, hi16=0, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, hi16=0, comment=comment))
elif dataType.isInt32() or dataType.isSingle():
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
elif dataType.isDouble() or dataType.isSingleComplex():
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx*2, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
elif dataType.isDoubleComplex():
rps = dataType.numRegisters()
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx*rps, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
elif dataType.isInt8() or dataType.isAnyFloat8() or dataType.isAnyBFloat8() or dataType.isAnyFloat8BFloat8() or dataType.isAnyBFloat8Float8():
if kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
return module

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,6 @@ def assignDerivedParameters(
state["GlobalSplitUAlgorithm"] = "MultipleBuffer" # Set default Algorithm
if not state["EnableMatrixInstruction"]:
reject(state, printRejectionReason, "Stream-K requires MatrixInstruction")
if isaInfoMap[isa].asmCaps["HasWMMA"]:
reject(state, printRejectionReason, "Stream-K untested with WMMA")
# if state["PersistentKernel"]:
# reject(state, printRejectionReason, "Cannot enable both Stream-K and PersistentKernel")
if not state["ProblemType"]["StridedBatched"]:
Expand Down
Loading