Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
################################################################################

from rocisa.enum import CacheScope
from rocisa.code import Module, Label
from rocisa.container import vgpr, sgpr, SMEMModifiers, replaceHolder, EXEC,\
from rocisa.container import vgpr, sgpr, SMEMModifiers, MUBUFModifiers, replaceHolder, EXEC,\
VOP3PModifiers, ContinuousRegister
from rocisa.instruction import SAddCU32, SAddI32, SAddU32, SAndB32, SBarrier, \
SBranch, SCBranchSCC0, SCBranchSCC1, SCMovB32, SCSelectB32, SCmpEQU32, SCmpEQU64, \
SCmpGtU32, SCmpLeU32, SCmpLtU32, SLShiftLeftB32, SLShiftLeftB64, SLShiftRightB32, SLoadB32, \
SMinU32, SMovB32, SMovB64, SMulI32, SNop, SSleep, SStoreB32, SSubU32, \
SWaitCnt, VAddF32, VAddF64, VAddPKF16, VAddU32, VLShiftRightB32, VMovB32, \
VReadfirstlaneB32, VCvtBF16toFP32
VReadfirstlaneB32, VCvtBF16toFP32, BufferStoreB32
from rocisa.functions import scalarStaticDivideAndRemainder, sMagicDiv2, \
vectorStaticMultiply, BranchIfNotZero, scalarUInt32DivideAndRemainder

Expand Down Expand Up @@ -370,7 +371,7 @@ def storeBranchesCommon(self, writer, kernel, skPartialsLabel, vectorWidths, ele

# Check flag
module.add(SLShiftLeftB32(dst=sgpr(tmpSgpr), src=sgpr(sCtaIdx), shiftHex=log2(4), comment="flag offset based on CTA index"))
module.add(SLoadB32(dst=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="get flag"))
module.add(SLoadB32(dst=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True, dlc=True, scope=CacheScope.SCOPE_DEV), comment="get flag"))

module.add(SWaitCnt(kmcnt=0, comment="wait for flag load"))
if kernel["DebugStreamK"] & 2 == 0:
Expand All @@ -383,8 +384,12 @@ def storeBranchesCommon(self, writer, kernel, skPartialsLabel, vectorWidths, ele
module.add(VReadfirstlaneB32(dst=sgpr(tmpSgpr+2), src=vgpr("Serial"), comment="Wave 0 updates flags"))
module.add(SCmpEQU32(src0=sgpr(tmpSgpr+2), src1=0, comment="Check for wave 0"))
module.add(SCBranchSCC0(labelName=skipFlagReset.getLabelName(), comment="Skip flag reset"))
# (tmpSgpr+2) contains a vlue of 0, use it to reset the flag
module.add(SStoreB32(src=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="reset flag"))
if writer.states.asmCaps["HasScalarStore"]:
# (tmpSgpr+2) contains a vlue of 0, use it to reset the flag
module.add(SStoreB32(src=sgpr(tmpSgpr+2), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="reset flag"))
else:
module.add(VMovB32(dst=vgpr(tmpVgpr), src=0, comment="move 0 to tmpVgpr"))
module.add(self.setFlagValue(writer, src=vgpr(tmpVgpr), soffset=sgpr(tmpSgpr), comment="reset flag"))
module.add(skipFlagReset)
writer.sgprPool.checkIn(tmpSgpr)

Expand Down Expand Up @@ -684,8 +689,12 @@ def partialsWriteProcedure(self, writer, kernel, vectorWidths, elements, alpha,
module.add(VReadfirstlaneB32(dst=sgpr(flagSgpr), src=vgpr("Serial"), comment="Wave 0 updates flags"))
module.add(SCmpEQU32(src0=sgpr(flagSgpr), src1=0, comment="Check for wave 0"))
module.add(SCBranchSCC0(labelName=skipFlagSet.getLabelName(), comment="Skip flag set"))
module.add(SMovB32(dst=sgpr(flagSgpr), src=1, comment="flag data"))
module.add(SStoreB32(src=sgpr(flagSgpr), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="set flag"))
if writer.states.asmCaps["HasScalarStore"]:
module.add(SMovB32(dst=sgpr(flagSgpr), src=1, comment="flag data"))
module.add(SStoreB32(src=sgpr(flagSgpr), base=sgpr("AddressFlags", 2), soffset=sgpr(tmpSgpr), smem=SMEMModifiers(glc=True), comment="set flag"))
else:
module.add(VMovB32(dst=vgpr(tmpVgpr), src=1, comment="move 1 to tmpVgpr"))
module.add(self.setFlagValue(writer, src=vgpr(tmpVgpr), soffset=sgpr(tmpSgpr), comment="set flag"))
module.add(skipFlagSet)
module.add(SWaitCnt(kmcnt=0, comment="wait for flag")) # TODO just for testing

Expand All @@ -696,6 +705,20 @@ def partialsWriteProcedure(self, writer, kernel, vectorWidths, elements, alpha,

return module

def setFlagValue(self, writer, src, soffset, comment=""):
module = Module("Buffer Store Flag Value")
tmpSgprBuffer = writer.sgprPool.checkOutAligned(4, 4, preventOverflow=False)
module.add(SMovB64(dst=sgpr(tmpSgprBuffer, 2), src=sgpr("AddressFlags", 2)))
module.add(SMovB32(dst=sgpr(tmpSgprBuffer+2), src="BufferOOB"))
module.add(SMovB32(dst=sgpr(tmpSgprBuffer+3), src="Srd127_96"))
module.add(BufferStoreB32(src=src, vaddr=vgpr("off", isOff=True), saddr=sgpr(tmpSgprBuffer, 4), soffset=soffset, \
mubuf=MUBUFModifiers(glc=True, dlc=True, scope=CacheScope.SCOPE_DEV), \
comment=comment))
module.add(SWaitCnt(vscnt=0, comment="wait for data store")) #TODO: See if this wait is necessery
writer.sgprPool.checkIn(tmpSgprBuffer)

return module

def partialsWriteBatch(self, writer, kernel, ss, batchIdx, applyAlpha, beta, edge, gwvw, atomicW, \
batchElements, addrD, addrC, \
tmpVgpr, cvtVgprStruct, batchElementSgprs, tmpSgpr, codeAccVgprRead):
Expand Down Expand Up @@ -1631,6 +1654,12 @@ def preLoop(self, writer, kernel):
xccMapping = Component.XCCMapping.find(writer)
module.add(xccMapping(writer, kernel))

# Workaround for gfx12
if writer.states.archCaps["WorkGroupIdFromTTM"]:
module.add(SMovB32(dst=sgpr("WorkGroup0"), src="ttmp9", comment="workaround"))
module.add(SAndB32(dst=sgpr("WorkGroup1"), src0=hex(0xFFFF), src1="ttmp7", comment="workaround"))
module.add(SLShiftRightB32(dst=sgpr("WorkGroup2"), shiftHex=hex(0x10), src="ttmp7", comment="workaround"))

module.add(SMovB32(dst=sgpr("StreamKIdx"), src=sgpr("WorkGroup0"), comment="Save original StreamK index"))
# Basic SK
module.add(SMulI32(dst=sgpr("StreamKIter"), src0=sgpr("StreamKIdx"), src1=sgpr("SKItersPerWG"), comment="StreamK starting iteration"))
Expand Down Expand Up @@ -1709,6 +1738,12 @@ def preLoop(self, writer, kernel):
xccMapping = Component.XCCMapping.find(writer)
module.add(xccMapping(writer, kernel))

# Workaround for gfx12
if writer.states.archCaps["WorkGroupIdFromTTM"]:
module.add(SMovB32(dst=sgpr("WorkGroup0"), src="ttmp9", comment="workaround"))
module.add(SAndB32(dst=sgpr("WorkGroup1"), src0=hex(0xFFFF), src1="ttmp7", comment="workaround"))
module.add(SLShiftRightB32(dst=sgpr("WorkGroup2"), shiftHex=hex(0x10), src="ttmp7", comment="workaround"))

module.add(SMovB32(dst=sgpr("StreamKIdx"), src=sgpr("WorkGroup0"), comment="Save original StreamK index"))
# Two-tile SK (SK first)
# iter count after all extra iters have been distributed
Expand Down Expand Up @@ -1825,6 +1860,12 @@ def preLoop(self, writer, kernel):
xccMapping = Component.XCCMapping.find(writer)
module.add(xccMapping(writer, kernel))

# Workaround for gfx12
if writer.states.archCaps["WorkGroupIdFromTTM"]:
module.add(SMovB32(dst=sgpr("WorkGroup0"), src="ttmp9", comment="workaround"))
module.add(SAndB32(dst=sgpr("WorkGroup1"), src0=hex(0xFFFF), src1="ttmp7", comment="workaround"))
module.add(SLShiftRightB32(dst=sgpr("WorkGroup2"), shiftHex=hex(0x10), src="ttmp7", comment="workaround"))

module.add(SMovB32(dst=sgpr("StreamKIdx"), src=sgpr("WorkGroup0"), comment="Save original StreamK index"))
# Two-tile SK (DP first)
# Do DP tiles before SK
Expand Down
32 changes: 19 additions & 13 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
scalarStaticRemainder, scalarUInt32DivideAndRemainder, sMagicDiv, vectorStaticMultiply, \
vectorStaticMultiplyAdd, scalarStaticMultiply64, BranchIfZero, BranchIfNotZero, DSInit, \
ArgumentLoader
from rocisa.enum import InstType, SelectBit
from rocisa.enum import InstType, SelectBit, CacheScope
from rocisa.macro import MacroVMagicDiv, PseudoRandomGenerator
from . import CUSTOM_KERNEL_PATH
from rocisa.instruction import BranchInstruction, BufferLoadB128, BufferLoadB32, \
Expand Down Expand Up @@ -12123,7 +12123,9 @@ def bufferLoadImpl(soffset):
##############################################################################
def chooseGlobalWrite(self, useBuffer, bps, srcVgpr, rpv, \
addr0, addr1, offset, soffset=0, \
glc=False, slc=False, nt=False, hi16=0, comment="store"):
glc=False, slc=False, nt=False, dlc=False, \
scope=CacheScope.SCOPE_NONE, \
hi16=0, comment="store"):
"""
create the store instruction for requested vector width and other parms
rpv = regs per vector
Expand Down Expand Up @@ -12166,15 +12168,15 @@ def bufferStoreImpl(tmpSgpr, mubuf):
if offset2 >= 4096:
module.add(SMovB32(dst=tmpSgpr, src=offset2, comment="large offset"))
offset2 = 0
mubuf2 = MUBUFModifiers(offen=True, offset12=offset2, glc=glc, slc=slc, nt=nt, isStore=True)
mubuf2 = MUBUFModifiers(offen=True, offset12=offset2, glc=glc, slc=slc, dlc=dlc, scope=scope, nt=nt, isStore=True)
vgprOff = int(srcVgpr + shiftRpv * i) if isinstance(srcVgpr, int) else f"{srcVgpr}+{int(shiftRpv * i)}"
module.add(BufferStoreB128(src=vgpr(vgprOff, shiftRpv), vaddr=addr0, \
saddr=addr1, soffset=tmpSgpr, mubuf=mubuf2, comment=comment))
else:
assert 0, "bad bps"

if useBuffer:
mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, nt=nt, isStore=True)
mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, dlc=dlc, scope=scope, nt=nt, isStore=True)
if soffset != 0:
assert offset < 4096, "sgpr offset provided with large const offset"
# buffer_load offset field is 12-bit.
Expand All @@ -12185,13 +12187,13 @@ def bufferStoreImpl(tmpSgpr, mubuf):
tmpSgpr = sgpr(tmpSgprInfo.idx)
if offset >= 4096:
module.add(SMovB32(dst=tmpSgpr, src=offset, comment="large offset"))
mubuf = MUBUFModifiers(offen=True, offset12=0, glc=glc, slc=slc, nt=nt, isStore=True)
mubuf = MUBUFModifiers(offen=True, offset12=0, glc=glc, slc=slc, dlc=dlc, scope=scope, nt=nt, isStore=True)
bufferStoreImpl(tmpSgpr, mubuf)
else:
bufferStoreImpl(soffset, mubuf)

else:
flat = FLATModifiers(glc=glc, slc=slc, isStore=True)
flat = FLATModifiers(glc=glc, slc=slc, dlc=dlc, scope=scope, isStore=True)
if bps==2 and hi16:
module.add(FlatStoreD16HIB16(vaddr=addr0, src=vgpr(srcVgpr*2), flat=flat, comment=comment))
elif bps==2 and not hi16:
Expand Down Expand Up @@ -12321,6 +12323,8 @@ def addStore(self, kernel, ss, tc: str, addrCalc, sumIdx, tmpS01, edge, wsOffset
isGlc = False
isSlc = False
isNT = False
scope = CacheScope.SCOPE_NONE
isDlc = False

if tc == 'D':
isGlc = bool(kernel["NonTemporalD"] & 0x1)
Expand Down Expand Up @@ -12361,6 +12365,8 @@ def addStore(self, kernel, ss, tc: str, addrCalc, sumIdx, tmpS01, edge, wsOffset
isGlc = True
isSlc = True
isNT = bool(kernel["NonTemporalD"] & 0x4)
isDlc = True
scope = CacheScope.SCOPE_DEV

bps = self.states.bpeCinternal * ss.cfg.gwvw
rpv = self.states.bpeCinternal * ss.cfg.gwvw / self.states.bpr
Expand Down Expand Up @@ -12410,36 +12416,36 @@ def addStore(self, kernel, ss, tc: str, addrCalc, sumIdx, tmpS01, edge, wsOffset
if self.states.asmCaps["HasWMMA_V1"] and kernel["EnableMatrixInstruction"]:
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, hi16=0, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, hi16=0, comment=comment))
else:
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx//2, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, hi16=sumIdx%2, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, hi16=sumIdx%2, comment=comment))
else:
# (B,B,B,B,S,S), internal S
# (H,H,H,H,H,H), internal S
# (H,H,H,H,S,S), internal S
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, hi16=0, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, hi16=0, comment=comment))
elif dataType.isInt32() or dataType.isSingle():
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
elif dataType.isDouble() or dataType.isSingleComplex():
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx*2, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
elif dataType.isDoubleComplex():
rps = dataType.numRegisters()
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx*rps, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
elif dataType.isInt8() or dataType.isAnyFloat8() or dataType.isAnyBFloat8() or dataType.isAnyFloat8BFloat8() or dataType.isAnyBFloat8Float8():
if kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(self.chooseGlobalWrite(useBuffer, bps, sumIdx, rpv, \
addr0, addr1, globalOffset, soffset=wsOffset, \
glc=isGlc, slc=isSlc, nt=isNT, comment=comment))
glc=isGlc, slc=isSlc, nt=isNT, dlc=isDlc, scope=scope, comment=comment))
return module

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,8 +954,8 @@ def assignDerivedParameters(
state["GlobalSplitUAlgorithm"] = "MultipleBuffer" # Set default Algorithm
if not state["EnableMatrixInstruction"]:
reject(state, printRejectionReason, "Stream-K requires MatrixInstruction")
if isaInfoMap[isa].asmCaps["HasWMMA"]:
reject(state, printRejectionReason, "Stream-K untested with WMMA")
# if isaInfoMap[isa].asmCaps["HasWMMA"]:
# reject(state, printRejectionReason, "Stream-K untested with WMMA")
# if state["PersistentKernel"]:
# reject(state, printRejectionReason, "Cannot enable both Stream-K and PersistentKernel")
if not state["ProblemType"]["StridedBatched"]:
Expand Down
Loading