From 9fd396c70497c8b051b9acdd1de1318fa1cb1a33 Mon Sep 17 00:00:00 2001 From: Aditya Date: Tue, 12 Aug 2025 17:53:31 -0500 Subject: [PATCH 1/2] Enable PLR when DepthU=MI_K Reduce clutter of whitespace changes enablePLR for even wave tile Scaled down local read for even tile tile sizes WIP --- .../Tensile/Components/LocalRead.py | 7 +- .../tensilelite/Tensile/Components/SIA.py | 8 +- .../tensilelite/Tensile/KernelWriter.py | 103 ++++++++++-- .../Tensile/KernelWriterAssembly.py | 27 ++- .../Tensile/SolutionStructs/Solution.py | 13 +- .../Tests/common/gemm/gfx950/f16_plr.yaml | 143 ++++++++++++++++ .../Tests/common/gemm/gfx950/f8_plr.yaml | 159 ++++++++++++++++++ 7 files changed, 437 insertions(+), 23 deletions(-) create mode 100644 projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml create mode 100644 projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py index aeaf18cabc0..a96fde2162b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py @@ -208,6 +208,7 @@ def pack4HiBits(self, kernel, tct, index, bufferIdx, baseValuiIdx, iui, writer, """ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): imod = Module("LocalReadDo%s_I%s" % (tP["tensorChar"],iui)) + subTileIdx = writer.states.SubTileIdx tc = tP["tensorChar"] if tc == "A": @@ -227,6 +228,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): vectorWidth = kernel["VectorWidth%s"%tc] + numSubTiles = kernel["numSubTiles"] MIWaveGroupShape = [ kernel["MatrixInstM"] * kernel["MatrixInstBM"] * kernel["MIWaveGroup"][0] * kernel["VectorWidthA"], \ kernel["MatrixInstN"] * kernel["MatrixInstBN"] * kernel["MIWaveGroup"][1] * kernel["VectorWidthB"]] @@ -290,6 +292,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): maxLDSConstOffset = writer.states.regCaps["maxLDSConstOffset"] valufIdx = 0 + eIdxCnt = numReadsPerVector//numSubTiles + eIdxStart = subTileIdx * (numReadsPerVector//numSubTiles) + valufIdx = eIdxStart * blockWidth *numReadsPerUnroll if enableLDSTr: numberMTilesPerWave = kernel["MIWaveTile"][tile01] highBits = 0 @@ -312,7 +317,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) else: for vIdx in range(0, numVectorsPerTile): - for eIdx in range(0, numReadsPerVector): + for eIdx in range(eIdxStart, (eIdxStart + eIdxCnt)): valuiIdx = int(valufIdx) baseValuiIdx = valuiIdx localReadCode = imod.add(Module("LocalRead%s Valu%u"%(tc,valuiIdx))) diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py index 07ac8a5818f..063126f6347 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py @@ -33,7 +33,6 @@ from copy import deepcopy from typing import Tuple - PRECISION = 100 class SIA3(SIA): kernel = {"ScheduleIterAlg": 3} @@ -258,6 +257,13 @@ def calculateLatencyLeft(numReads, localReadBlockWidth, localReadLatency): # final index definition writer.states.numMfmaForNextLoopLR = min(writer.states.numMfmaForNextLoopLR,numMfmaPerIter-1) writer.states.syncPlrMfmaIndex = numMfmaPerIter*(kernel["LoopIters"]-writer.states.numItersPLR+1) - writer.states.numMfmaForNextLoopLR - 1 if writer.states.numItersPLR else 0 + + if kernel["ForceUnrollSubIter"]: + if ( kernel["UseF32XEmulation"]) : + writer.states.syncPlrMfmaIndex = writer.states.syncPlrMfmaIndex *3 # TF32 + elif ( kernel["ProblemType"]["DataType"].isComplex()): + writer.states.syncPlrMfmaIndex = writer.states.syncPlrMfmaIndex *4 # Complex + numMfmaBetweenLWandBarrier = 2 if kernel["MatrixInstM"] == 32 else 3 writer.states.lwEndMfmaIndex = max(writer.states.syncPlrMfmaIndex - numMfmaBetweenLWandBarrier,0) if writer.states.numItersPLR else numMfmaPerIter*kernel["LoopIters"] - 1 if kernel["DirectToLds"] and kernel["PrefetchGlobalRead"] == 2: diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 0a79e9a0eec..37a9583ad86 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -52,6 +52,7 @@ from Tensile.SolutionStructs.Naming import getKernelNameMin from Tensile.Toolchain.Component import Assembler +import math import abc import sys import collections @@ -177,7 +178,7 @@ class StateValues: lrvwUnrollMetadata: int = 0 # For Sparse Metadat numMfmaPerIter: int = 0 - + SubTileIdx: int = 0 numReadsIterCoalescedA: int = 0 numReadsIterCoalescedB: int = 0 numReadsIterCoalescedMetadata: int = 0 @@ -1309,6 +1310,11 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length): skipLocalWriteWaitcnt += countLocalWrite(writeItem) + countDSStoreB256(writeItem) if not localReadItemsThisLoop: self.states.perIterLocalWriteCanSkip[iteration] += countLocalWrite(writeItem) + countDSStoreB256(writeItem) + if kernel["ForceUnrollSubIter"] and (writeItems and i == (numMfmaPerIter - 1)): + # if ForceUnrollSubIter, we need to schedule all localWrite in last mfma + while writeItems: + writeItem = writeItems.pop(0) + iterCode.add(writeItem) if mfmaIndex == self.states.lwEndMfmaIndex: while writeItems: localWriteCodeCounts.pop(0) @@ -2153,12 +2159,18 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is self.makeSchedule(kernel, tensorParametersA, tensorParametersB, localWriteEndIter, skipGlobalReadInc=False, lastLoop=NLLlast, isNGLL=isNGLL) module.add(self.codes.unrollLoopHeader) + if kernel["ForceUnrollSubIter"]: + self.states.lwStartMfmaIndex = (kernel["MIWaveTile"][1] + kernel["MIWaveTile"][0])//2 -1 + # which loop iteration to reset the LRO, # note if PLR=0, isResetLroIter is False for all u - isResetLroIter = (u == localWriteEndIter) + isResetLroIter = 1 if kernel["ForceUnrollSubIter"] else (u == localWriteEndIter) isSwapAndResetLwoIter = isResetLroIter isSwapLroIter = isResetLroIter if kernel["ScheduleIterAlg"] == 3: + if kernel["ForceUnrollSubIter"]: + isSwapAndResetLwoIter = 1 + else: isSwapAndResetLwoIter = (u == self.states.lwEndMfmaIndex//(self.states.numMfmaPerIter)) extraComment = "" @@ -2171,8 +2183,10 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is extraComment += " (swap and reset local write pointers iteration) " if isSwapLroIter: extraComment += " (swap local read pointers iteration) " - - module.addComment1("iter %u%s"%(u,extraComment)) + if kernel["ForceUnrollSubIter"]: + module.addComment1("subiter %u"%(u)) + else: + module.addComment1("iter %u%s"%(u,extraComment)) plrIdx = (u+pflr) % self.states.numVgprBuffer plrIdxDTV = (u+pflr) % kernel["LoopIters"] localReads = Module() @@ -2209,6 +2223,10 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is doReadA = (u < kernel["LoopIters"]/self.states.numIterPerCoalescedReadA - self.states.numItersPLR) doReadB = (u < kernel["LoopIters"]/self.states.numIterPerCoalescedReadB - self.states.numItersPLR) doReadM = (u < kernel["LoopIters"]/self.states.numIterPerCoalescedReadMetadata - self.states.numItersPLR) + if kernel["ForceUnrollSubIter"]: + doReadA = 1 if u == 0 else 0 + doReadB = 1 if u == 0 else 0 + doReadM = 1 if u == 0 else 0 # reads for next loop doReadA = doReadA or (hasLiveLdsData and u > localWriteEndIter) doReadB = doReadB or (hasLiveLdsData and u > localWriteEndIter) @@ -2228,6 +2246,9 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is if needNextBufLR: localReads.add(localReadCodeA) pack[plrIdx*self.states.numIterPerCoalescedReadA].add(packCodeA) + if kernel["ForceUnrollSubIter"]: + pack[1] = Module() + pack[1].add(packCodeA) if doReadM: localReads.addComment1("local read metadata") localReadCodeM, packCodeM = self.localReadDo(kernel, plrIdx*self.states.numIterPerCoalescedReadMetadata, iui*self.states.numReadsIterCoalescedMetadata, 0, tPM) @@ -2241,6 +2262,8 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is # DTV + pack or input conversion case, offset bufferIdx for local read packing instructions bufferIdx = plrIdxDTV*self.states.numIterPerCoalescedReadB + vregSetIdxLR * kernel["LoopIters"] localReadCodeB, packCodeB = self.localReadDo(kernel, bufferIdx, iui*self.states.numReadsIterCoalescedB, 0, tensorParametersB) + if kernel["ForceUnrollSubIter"]: + pack[1].add(packCodeB) if needNextBufLR: localReads.add(localReadCodeB) pack[plrIdx*self.states.numIterPerCoalescedReadB].add(packCodeB) @@ -2279,10 +2302,12 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is # Swap, reset, or increment the LRO: # force internalPointerSwap = False in NGLL case internalPointerSwap = expand and not isNGLL - pointerLRCode.addComment1("local read swap offsets a") - pointerLRCode.add(self.localReadSwapOffsets(kernel, internalPointerSwap, tensorParametersA)) - pointerLRCode.addComment1("local read swap offsets b") - pointerLRCode.add(self.localReadSwapOffsets(kernel, internalPointerSwap, tensorParametersB)) + if not kernel["ForceUnrollSubIter"] or (doReadA and (u localWriteEndIter) doReadB = doReadB or (hasLiveLdsData and u > localWriteEndIter) @@ -2645,6 +2689,9 @@ def _loopBody( self, kernel, tensorParametersA, tensorParametersB, pack, lc, loo if kernel["UseCustomMainLoopSchedule"]: LRCodeAAllIters[uIdx].add(localReadCodeA) PackCodeAAllIters[uIdx].add(packCodeA) + if kernel["ForceUnrollSubIter"]: + pack[1] = Module() + pack[1].add(packCodeA) if doReadM: localReads.addComment1("local read metadata") localReadCodeM, packCodeM = self.localReadDo(kernel, plrIdx*self.states.numIterPerCoalescedReadMetadata, iui*self.states.numReadsIterCoalescedMetadata, 0, tPM) @@ -2664,6 +2711,9 @@ def _loopBody( self, kernel, tensorParametersA, tensorParametersB, pack, lc, loo if kernel["UseCustomMainLoopSchedule"]: LRCodeBAllIters[uIdx].add(localReadCodeB) PackCodeBAllIters[uIdx].add(packCodeB) + if kernel["ForceUnrollSubIter"]: + pack[1].add(packCodeB) + # Don't increment the LRO if we are going to reset them below: if not isResetLroIter or iui != kernel["InnerUnroll"]-1: if doReadA: @@ -2712,8 +2762,9 @@ def _loopBody( self, kernel, tensorParametersA, tensorParametersB, pack, lc, loo if kernel["ExpertSchedulingMode"] > 0: pointerLRCode.add(SWaitAlu(vm_vsrc=0, comment="wait for local read to vgpr complete")) # Swap, reset, or increment the LRO: - pointerLRCode.addComment1("local read swap offsets a") - pointerLRCode.add(self.localReadSwapOffsets(kernel, expand, tensorParametersA)) + if not kernel["ForceUnrollSubIter"] or (doReadA and (u 1: + # iter (idxOuter_start, idxOuter_stop) (idxInner_start, idxInner_stop) MFMA + # 0 (0,4) (0,4) MFMA(A0,B0) + # 1 (0,4) (4,8) MFMA(A1,B0) + # 2 (4,8) (0,4) MFMA(A0,B1) + # 3 (4,8) (4,8) MFMA(A1,B1) + outerBy2=(kernel["MIWaveTile"][outer]//numSubTiles) + innerBy2=(kernel["MIWaveTile"][inner]//numSubTiles) + outerMod2=(kernel["MIWaveTile"][outer]%numSubTiles) + innerMod2=(kernel["MIWaveTile"][inner]%numSubTiles) + idxHalfO = u//numSubTiles + idxHalfI = u % numSubTiles + idxOuter_start = (outerBy2 + outerMod2)*idxHalfO + idxInner_start = (innerBy2 + innerMod2)*idxHalfI + idxOuter_stop = kernel["MIWaveTile"][outer] - (1-idxHalfO)* outerBy2 + idxInner_stop = kernel["MIWaveTile"][inner] - (1-idxHalfI)* innerBy2 + + for idxOuter in range(idxOuter_start, idxOuter_stop): + for idxInner in range(idxInner_start, idxInner_stop): idx0 = idxInner idx1 = idxOuter if loopSwap: diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index c4de5e4a6da..5d74b9cef97 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -1361,7 +1361,7 @@ def assignDerivedParameters( if (state["MIWaveTile"][1] % state["VectorWidthB"]) != 0: reject(state, printRejectionReason, "MIWaveTile0(%u) should be multiple of VectorWidthB(%u)" % (state["MIWaveTile"][1], state["VectorWidthB"])) return - + if len(problemType["IndicesSummation"]) > 1: # not supported with multiple summations, bug is maybe something with # how stagger iteration is wrapped when unroll loop exits @@ -1426,6 +1426,17 @@ def assignDerivedParameters( if "ValidDepthU" in state: del state["ValidDepthU"] + if ( + state["DepthU"] == state["MatrixInstK"] and state["PrefetchGlobalRead"] and not state["1LDSBuffer"] + and (state["MIWaveTile"][0] > 2 and state["MIWaveTile"][1] > 2) + and (state["MIWaveTile"][0] % 2 == 0 and state["MIWaveTile"][1] % 2 == 0) + ): + state["ForceUnrollSubIter"] = True + state["numSubTiles"] = 2 + else: + state["ForceUnrollSubIter"] = False + state["numSubTiles"] = 1 + # 0: Normal mode. Hardware applies all of the normal data dependency checks # 1: Full expert mode (not suppoeted yet). Disable hardware checks against: VA_VDST, VA_SDST, VA_SSRC, VA_VCC, VM_VSRC and SA_SDST. # 2: Disable only VA_VDST and VM_VSRC checks. diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml new file mode 100644 index 00000000000..2d1b6ed34f6 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml @@ -0,0 +1,143 @@ +GlobalParameters: + MergeFiles: False + NumElementsToValidate: -1 + + NumWarmups: 1000 + EnqueuesPerSync: 10000 + + NumBenchmarks: 1 + SyncsPerBenchmark: 1 + SleepPercent: 0 + DataInitTypeA: 12 + DataInitTypeB: 13 + # DataInitTypeC: 0 + # DataInitTypeD: 3 + DataInitTypeBeta: 0 + DataInitTypeAlpha: 1 + # DataInitTypeBias: 0 + DataInitTypeScaleAlphaVec: 1 + DataInitTypeScaleA: 1 + DataInitTypeScaleB: 1 + CSVExportWinner: 1 + CSVMergeSameProblemID: 1 + #Device: 0 + MinKForGSU: 1 + #MaxWorkspaceSize: 3355443200 + MaxFileName: 256 + KernelTime: True + #RotatingBufferSize: 512 + MaxLDS: 163840 + DeviceLDS: 163840 + #GenerateSourcesAndExit: True + + PrintSolutionRejectionReason: True + #Device: 3 + + RotatingBufferSize: 512 + KeepBuildTmp: True + + +BenchmarkProblems: + ######################################## + # NN - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + #DataTypeA: f8 + #DataTypeB: h + #UseScaleAB: True + + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + + #UseBias: True + #Activation: True + #UseScaleAlphaVec: True + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 2 , 2,2 ] + + + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [32] + - TransposeLDS: [1] + #- DirectToLds: [1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + #- GlobalSplitU: [3] #disable GSU + - SourceSwap: [1] + # - NonTemporalA: [4] + # - NonTemporalB: [0,1] + # - NonTemporalC: [3] + # - NonTemporalD: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [4096, 4096, 1, 16384] + #- Exact: [256, 256, 1, 384] + #- BiasTypeArgs: ['s'] + #- ActivationArgs: + # - [Enum: none] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml new file mode 100644 index 00000000000..6f654ada334 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml @@ -0,0 +1,159 @@ +GlobalParameters: + MergeFiles: False + NumElementsToValidate: -1 + + NumWarmups: 1000 + EnqueuesPerSync: 10000 + + NumBenchmarks: 1 + SyncsPerBenchmark: 1 + SleepPercent: 0 + DataInitTypeA: 12 + DataInitTypeB: 13 + # DataInitTypeC: 0 + # DataInitTypeD: 3 + DataInitTypeBeta: 0 + DataInitTypeAlpha: 1 + # DataInitTypeBias: 0 + DataInitTypeScaleAlphaVec: 1 + DataInitTypeScaleA: 1 + DataInitTypeScaleB: 1 + CSVExportWinner: 1 + CSVMergeSameProblemID: 1 + #Device: 0 + MinKForGSU: 1 + #MaxWorkspaceSize: 3355443200 + MaxFileName: 256 + KernelTime: True + #RotatingBufferSize: 512 + MaxLDS: 163840 + DeviceLDS: 163840 + #GenerateSourcesAndExit: True + + PrintSolutionRejectionReason: True + #Device: 3 + + RotatingBufferSize: 512 + KeepBuildTmp: True + + +BenchmarkProblems: + ######################################## + # NN - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + #DataTypeA: f8 + #DataTypeB: h + #UseScaleAB: True + + DataType: f8 + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + + #UseBias: True + #Activation: True + #UseScaleAlphaVec: True + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 7, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 5, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 3, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 5 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 3 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 1, 1 , 2,2 ] + + - AssertSummationElementMultiple: [128] + - AssertFree0ElementMultiple: [16] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [128] + - TransposeLDS: [1] + #- DirectToLds: [1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [16] + - LocalReadVectorWidth: [16] + #- GlobalSplitU: [3] #disable GSU + - SourceSwap: [1] + # - NonTemporalA: [4] + # - NonTemporalB: [0,1] + # - NonTemporalC: [3] + # - NonTemporalD: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [4096, 4096, 1, 16384] + #- Exact: [256, 256, 1, 384] + #- BiasTypeArgs: ['s'] + #- ActivationArgs: + # - [Enum: none] From e668710cc69abbe76841db760923c32260a5638a Mon Sep 17 00:00:00 2001 From: b-shi Date: Mon, 13 Oct 2025 13:48:00 -0500 Subject: [PATCH 2/2] Fix CI failures, add more tests, enable for ldsTR --- .../Tensile/Common/RequiredParameters.py | 1 + .../Tensile/Components/CustomSchedule.py | 40 +- .../Tensile/Components/LocalRead.py | 249 ++++--- .../tensilelite/Tensile/Components/SIA.py | 4 +- .../tensilelite/Tensile/KernelWriter.py | 56 +- .../Tensile/KernelWriterAssembly.py | 2 +- .../Tensile/SolutionStructs/Solution.py | 31 +- .../Tests/common/gemm/gfx950/f16_plr.yaml | 143 ---- .../Tests/common/gemm/gfx950/f8_plr.yaml | 159 ----- .../Tests/common/gemm/gfx950/plr_zero.yaml | 641 ++++++++++++++++++ .../Tests/common/gemm/gfx950/xfp32.yaml | 96 +++ 11 files changed, 954 insertions(+), 468 deletions(-) delete mode 100644 projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml delete mode 100644 projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml create mode 100644 projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/plr_zero.yaml diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py index 385ee4dd653..cd897ad31f0 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py @@ -60,6 +60,7 @@ def getRequiredParametersMin() -> set: 'LdsPadA', 'LdsPadB', 'LdsPadMetadata', + 'LDSTrInst', 'LocalReadVectorWidth', 'LocalWritePerMfma', 'MIArchVgpr', diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py b/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py index 767125bc531..dd404ca48f8 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py @@ -427,22 +427,24 @@ def hasCustomSchedule(kernel): optSchedule = dict() syncCode = [] + plr = 3 if kernel["ForceUnrollSubIter"] else 1 + if isTN and TLDS == 1: optSchedule = { - 'SYNC' : [[6,7, 20,21, 46,47, 61]], - 'GRIncA' : [[0,1,2,3,4,4,4,4,4]], - 'GRIncB' : [[5,5,5,5,5,6,6,6,6]], - 'LRA0' : [[0,0, 1,1, 2,2, 3,3]], - 'GRA' : [[8,8,9,9,10,10,11,11,12,12, 23,23,24,24,25,25]], - 'LRB0' : [[13,13,14,14,15,15,16,16]], - 'LRA1' : [[48,48,49,49,50,50,51,51]], - 'LRB1' : [[52,52,54,54,55,55,56,56]], - 'GRB' : [[26,26,27,27, 39,39,40,40,41,41,42,42,43,43, 53,53]], - 'LCC' : [[60, 60]], - 'LRSA' : [[17]], - 'LRSB' : [[17]], - 'LWSA' : [[57]], - 'LWSB' : [[57]], + 'SYNC' : [[6,7, 20,21, 46,47, 61]], + 'GRIncA' : [[0,1,2,3,4,4,4,4,4]], + 'GRIncB' : [[5,5,5,5,5,6,6,6,6]], + 'LRA0' : [[0,0, 1,1, 2,2, 3,3]], + 'GRA' : [[8,8,9,9,10,10,11,11,12,12, 23,23,24,24,25,25]], + 'LRB0' : [[13,13,14,14,15,15,16,16]], + 'LRA%u'%plr : [[48,48,49,49,50,50,51,51]], + 'LRB%u'%plr : [[52,52,54,54,55,55,56,56]], + 'GRB' : [[26,26,27,27, 39,39,40,40,41,41,42,42,43,43, 53,53]], + 'LCC' : [[60, 60]], + 'LRSA' : [[17]], + 'LRSB' : [[17]], + 'LWSA' : [[57]], + 'LWSB' : [[57]], } syncCode = [SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for LRA0/LRB0 to complete"), SBarrier(comment=""), @@ -456,10 +458,12 @@ def hasCustomSchedule(kernel): numMfma = 64 # B0A0, B0A1, B1A0, B1A1 - mfmaReorder = [0,1,2,3, 8,9,10,11, 16,17,18,19, 24,25,26,27, - 4,5,6,7, 12,13,14,15, 20,21,22,23, 28,29,30,31, - 32,33,34,35, 40,41,42,43, 48,49,50,51, 56,57,58,59, - 36,37,38,39, 44,45,46,47, 52,53,54,55, 60,61,62,63] + mfmaReorder = [] + if not kernel["ForceUnrollSubIter"]: + mfmaReorder = [0,1,2,3, 8,9,10,11, 16,17,18,19, 24,25,26,27, + 4,5,6,7, 12,13,14,15, 20,21,22,23, 28,29,30,31, + 32,33,34,35, 40,41,42,43, 48,49,50,51, 56,57,58,59, + 36,37,38,39, 44,45,46,47, 52,53,54,55, 60,61,62,63] opt1 = ScheduleInfo(1, numMfma, optSchedule, syncCode, mfmaReorder) return True, opt1 elif is192x256x64DTL and is16bit and not isMixed and ([GRVWA, GRVWB, LRVW] == [8, 8, 8]) and MI == [16,16,32,1] and MIWG == [2,2]: diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py index a96fde2162b..5adfc17db18 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py @@ -291,33 +291,39 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): blockOffsetSMFMA = blockStride - elementsPerBlockSMFMA maxLDSConstOffset = writer.states.regCaps["maxLDSConstOffset"] + + subIterLoadCount = 0 valufIdx = 0 - eIdxCnt = numReadsPerVector//numSubTiles - eIdxStart = subTileIdx * (numReadsPerVector//numSubTiles) - valufIdx = eIdxStart * blockWidth *numReadsPerUnroll if enableLDSTr: numberMTilesPerWave = kernel["MIWaveTile"][tile01] + numOffsetsPerLoad = 2 highBits = 0 + totalLoads = numberMTilesPerWave * numOffsetsPerLoad for tIdx in range(0, numberMTilesPerWave): + valuiIdx = int(valufIdx) + comment = "LDS Transpose" + LocalReadX = instruction.getInst(highBits) + offset_val = (tP["localReadOffset"]+MIWaveGroupShape[tile01]*tIdx) * tP["bpeDS"] + tP["localReadSwapByteOffset"] if (kernel["LdsBlockSizePerPad%s"%tc] != 0) and (kernel["LdsPad%s"%tc] != 0): offset_val = offset_val + (offset_val // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeDS"] - offset, srcAddr = self.cal_offset_srcAddr(maxLDSConstOffset, tc, offset_val) - ds = DSModifiers(na=1, offset=offset) - LocalReadX = instruction.getInst(highBits) - destVgpr = vgpr("Valu%s_X%u_I%u+%u+0"%(tc,bufferIdx,iui, 4*tIdx), 2) - comment = "LDS Transpose" - valuiIdx = int(valufIdx) - localReadCode = imod.add(Module("LocalRead%s Valu%u"%(tc,valuiIdx))) - localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) - destVgpr = vgpr("Valu%s_X%u_I%u+%u+2"%(tc,bufferIdx,iui,4*tIdx), 2) - offset_val += UnrollStride*inputPerThread; - offset, srcAddr = self.cal_offset_srcAddr(maxLDSConstOffset, tc, offset_val) - ds = DSModifiers(na=1, offset=offset) - localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) + + for oIdx in range(0,numOffsetsPerLoad): + offset, srcAddr = self.cal_offset_srcAddr(maxLDSConstOffset, tc, offset_val) + ds = DSModifiers(na=1, offset=offset) + destVgpr = vgpr("Valu%s_X%u_I%u+%u+%u"%(tc,bufferIdx,iui, 4*tIdx, oIdx * 2), 2) + localReadCode = Module("LocalRead%s Valu%u"%(tc,valuiIdx)) + localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) + offset_val += UnrollStride*inputPerThread + if ((subTileIdx == 0 and subIterLoadCount < totalLoads // numSubTiles) \ + or (subTileIdx == 1 and subIterLoadCount >= totalLoads // numSubTiles) \ + or numSubTiles == 1) or writer.states.inTailLoop: + imod.add(localReadCode) + subIterLoadCount += 1 else: + totalLoads = numVectorsPerTile * numReadsPerVector * numReadsPerUnroll for vIdx in range(0, numVectorsPerTile): - for eIdx in range(eIdxStart, (eIdxStart + eIdxCnt)): + for eIdx in range(0, numReadsPerVector): valuiIdx = int(valufIdx) baseValuiIdx = valuiIdx localReadCode = imod.add(Module("LocalRead%s Valu%u"%(tc,valuiIdx))) @@ -332,6 +338,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): highBitsForHalf = (blockWidth == 0.5) and ((rIdx % 2) == 1) # rIdx = 1 isHigh16Bits = (blockWidth == 0.25) and ( ((rIdx % 4) //2) == 1) # 2,3 + packCodeT = Module() # Allocate temporary module for pack code + localReadCodeT = Module() + if needPack or numSplitMetadata: if kernel["UseF32XEmulation"]: # Pack data 0-7 with layout: @@ -346,8 +355,8 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): # For every 8 read vgprs of fp32, pack high bits of bf16 into first 4 vgprs if valuiIdx % 8 == 0: - self.pack4HiBits(kernel, tc, 0, bufferIdx, baseValuiIdx, iui, writer, packCode, tmpvgprFP32) - self.pack4HiBits(kernel, tc, 4, bufferIdx, baseValuiIdx, iui, writer, packCode, tmpvgprFP32) + self.pack4HiBits(kernel, tc, 0, bufferIdx, baseValuiIdx, iui, writer, packCodeT, tmpvgprFP32) + self.pack4HiBits(kernel, tc, 4, bufferIdx, baseValuiIdx, iui, writer, packCodeT, tmpvgprFP32) if valuiIdx % 4 == 0: tmpvgpr = [] tmp = writer.vgprPool.checkOut(1) @@ -386,27 +395,27 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): # Compute low bits = fp32(highBF16(A/B)) - fp32(A/B) if kernel["UseDot2F32XEmulation"]: - packCode.add(VDot2CF32BF16(dst=v0t, src0=hex(0x8000bf80), src1=vHi0)) - packCode.add(VDot2CF32BF16(dst=v1t, src0=hex(0xbf800000), src1=vHi0)) + packCodeT.add(VDot2CF32BF16(dst=v0t, src0=hex(0x8000bf80), src1=vHi0)) + packCodeT.add(VDot2CF32BF16(dst=v1t, src0=hex(0xbf800000), src1=vHi0)) else: - packCode.add(PVCvtBF16toFP32(dst=vgpr(tmp), src=vHi0, comment="begin"+str(valuiIdx))) - packCode.add(VSubF32(dst=v0t, src0=v0t, src1=vgpr(tmp))) - packCode.add(VCvtBF16toFP32(dst=vgpr(tmp), src=vHi0, vgprMask=None, vi=1)) - packCode.add(VSubF32(dst=v1t, src0=v1t, src1=vgpr(tmp))) + packCodeT.add(PVCvtBF16toFP32(dst=vgpr(tmp), src=vHi0, comment="begin"+str(valuiIdx))) + packCodeT.add(VSubF32(dst=v0t, src0=v0t, src1=vgpr(tmp))) + packCodeT.add(VCvtBF16toFP32(dst=vgpr(tmp), src=vHi0, vgprMask=None, vi=1)) + packCodeT.add(VSubF32(dst=v1t, src0=v1t, src1=vgpr(tmp))) if kernel["UseDot2F32XEmulation"]: - packCode.add(VDot2CF32BF16(dst=v2t, src0=hex(0x8000bf80), src1=vHi1)) + packCodeT.add(VDot2CF32BF16(dst=v2t, src0=hex(0x8000bf80), src1=vHi1)) else: - packCode.add(PVCvtBF16toFP32(dst=vgpr(tmp), src=vHi1)) - packCode.add(VSubF32(dst=v2t, src0=v2t, src1=vgpr(tmp))) + packCodeT.add(PVCvtBF16toFP32(dst=vgpr(tmp), src=vHi1)) + packCodeT.add(VSubF32(dst=v2t, src0=v2t, src1=vgpr(tmp))) # We use cvt+sub pair since dot2 requires adding 4 wait states. - packCode.add(VCvtBF16toFP32(dst=vgpr(tmp), src=vHi1, vgprMask=None, vi=1)) - packCode.add(VSubF32(dst=v3t, src0=v3t, src1=vgpr(tmp), comment="end")) + packCodeT.add(VCvtBF16toFP32(dst=vgpr(tmp), src=vHi1, vgprMask=None, vi=1)) + packCodeT.add(VSubF32(dst=v3t, src0=v3t, src1=vgpr(tmp), comment="end")) if kernel["UseDot2F32XEmulation"]: - packCode.add(VMovB32(dst=vgpr(tmp), src=0)) - packCode.add(VMovB32(dst=vgpr(tmp), src=0)) + packCodeT.add(VMovB32(dst=vgpr(tmp), src=0)) + packCodeT.add(VMovB32(dst=vgpr(tmp), src=0)) for val in tmpvgpr: writer.vgprPool.checkIn(val) @@ -430,11 +439,11 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): v5 = vgpr("Valu%s_X%u_I%u+%u+5"%(tc, bufferIdx, iui, baseValuiIdx)) v6 = vgpr("Valu%s_X%u_I%u+%u+6"%(tc, bufferIdx, iui, baseValuiIdx)) v7 = vgpr("Valu%s_X%u_I%u+%u+7"%(tc, bufferIdx, iui, baseValuiIdx)) - packCode.add(VCvtPkF32toBF16(dst=v7, src0=v6, src1=v7, comment="pack tail begin")) - packCode.add(VCvtPkF32toBF16(dst=v6, src0=v4, src1=v5)) - packCode.add(VCvtPkF32toBF16(dst=v5, src0=v2, src1=v3)) + packCodeT.add(VCvtPkF32toBF16(dst=v7, src0=v6, src1=v7, comment="pack tail begin")) + packCodeT.add(VCvtPkF32toBF16(dst=v6, src0=v4, src1=v5)) + packCodeT.add(VCvtPkF32toBF16(dst=v5, src0=v2, src1=v3)) commentStr ="__TF32_2_" + tc + " pack tail end" - packCode.add(VCvtPkF32toBF16(dst=v4, src0=v0, src1=v1, comment=commentStr)) + packCodeT.add(VCvtPkF32toBF16(dst=v4, src0=v0, src1=v1, comment=commentStr)) if kernel["UseDirect32XEmulation"]: index = len(tmpvgprFP32) - 1 while index >= 0: @@ -477,24 +486,24 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): for i in range(0, cvtTimes): offset = cvtTimes - i - 1 if writer.states.asmCaps["Hascvtf16_fp8"]: - packCode.add(VCvtScalePkFP8toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+1+offset*2)),\ + packCodeT.add(VCvtScalePkFP8toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+1+offset*2)),\ src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+offset)), scale=0x3f800000,\ vop3=VOP3PModifiers(op_sel=[1,0,0,0]), comment="convert fp8 to f16")) - packCode.add(VCvtScalePkFP8toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+0+offset*2)),\ + packCodeT.add(VCvtScalePkFP8toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+0+offset*2)),\ src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+offset)), scale=0x3f800000,\ vop3=VOP3PModifiers(op_sel=[0,0,0,0]), comment="convert fp8 to f16")) else: - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+offset)),\ + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+offset)),\ sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_1), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+1+offset*2)), src=vgpr("CvtTemp+0"),\ + packCodeT.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+1+offset*2)), src=vgpr("CvtTemp+0"),\ sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+1+offset*2)), src=vgpr("CvtTemp+1"),\ + packCodeT.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+1+offset*2)), src=vgpr("CvtTemp+1"),\ sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+offset)),\ + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+offset)),\ sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+0+offset*2)), src=vgpr("CvtTemp+0"),\ + packCodeT.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+0+offset*2)), src=vgpr("CvtTemp+0"),\ sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+0+offset*2)), src=vgpr("CvtTemp+1"),\ + packCodeT.add(VCvtF32toF16(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valuiIdx+0+offset*2)), src=vgpr("CvtTemp+1"),\ sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) #Case B elif (writer.states.lrvwTileA == 1 and tc == 'A') or (writer.states.lrvwTileB == 1 and tc == 'B'): @@ -505,10 +514,10 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): destVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%4, valuiIdx), numVgpr) if writer.states.asmCaps["Hascvtf16_fp8"]: sel = [0,0,0,0] if (rIdx % 2 == 0) else [0,0,1,0] - packCode.add(VCvtScaleFP8toF16(dst=CvtDstVgpr, src=destVgpr, scale=0x3f800000, vop3=VOP3PModifiers(op_sel=sel), comment="convert fp8 to f16")) + packCodeT.add(VCvtScaleFP8toF16(dst=CvtDstVgpr, src=destVgpr, scale=0x3f800000, vop3=VOP3PModifiers(op_sel=sel), comment="convert fp8 to f16")) else: - packCode.add(VCvtFP8toF32(dst=destVgpr, src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.BYTE_0))) - packCode.add(VCvtF32toF16(dst=CvtDstVgpr, src=destVgpr, sdwa=sdwa, comment="Convert to FP16")) + packCodeT.add(VCvtFP8toF32(dst=destVgpr, src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.BYTE_0))) + packCodeT.add(VCvtF32toF16(dst=CvtDstVgpr, src=destVgpr, sdwa=sdwa, comment="Convert to FP16")) #Case C elif (writer.states.lrvwTileA == 2 and tc == 'A') or (writer.states.lrvwTileB == 2 and tc == 'B'): if needPack or numSplitMetadata: @@ -516,11 +525,11 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): for i in range(0, numVgpr): cvtDstVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), vIdx*numVgpr), numVgpr) if writer.states.asmCaps["Hascvtf16_fp8"]: - packCode.add(VCvtScalePkFP8toF16(dst=destVgpr, src=destVgpr,scale=0x3f800000,vop3=VOP3PModifiers(op_sel=[0,0,0,0]),comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=destVgpr, src=destVgpr,scale=0x3f800000,vop3=VOP3PModifiers(op_sel=[0,0,0,0]),comment="convert F8 to F16")) else: - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) if rIdx == numReadsPerUnroll-1: for i in range(0, numVgpr): @@ -528,7 +537,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): vgprOffset = 0 for vectorIdx in range(0, 2): for elementIdx in range(0, tP["bpe"]*kernel["MIInputPerThread%s"%tc]//writer.states.bpr): - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), \ src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*2+1, i+vIdx*numVgpr)), \ src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*2, i+vIdx*numVgpr)), \ src2=sgpr("PackKForV%u"%vectorIdx), \ @@ -539,17 +548,17 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if needPack or numSplitMetadata: destVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2 * vIdx * numVgpr), numVgpr) cvtDestVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2 * vIdx * numVgpr + 1), numVgpr) - packCode.add(VLShiftRightB32(dst=cvtDestVgpr, shiftHex=16, src=destVgpr, comment="shift 2 element to vgpr+1")) + packCodeT.add(VLShiftRightB32(dst=cvtDestVgpr, shiftHex=16, src=destVgpr, comment="shift 2 element to vgpr+1")) if writer.states.asmCaps["Hascvtf16_fp8"]: - packCode.add(VCvtScalePkFP8toF16(dst=destVgpr, src=destVgpr,scale=0x3f800000, comment="convert F8 to F16")) - packCode.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr, src=cvtDestVgpr,scale=0x3f800000, comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=destVgpr, src=destVgpr,scale=0x3f800000, comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr, src=cvtDestVgpr,scale=0x3f800000, comment="convert F8 to F16")) else: - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=destVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=destVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) if rIdx == numReadsPerUnroll-1: for i in range(0, numVgpr*2): @@ -557,7 +566,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): vgprOffset = 0 for vectorIdx in range(0, 2): for elementIdx in range(0, tP["bpe"]*kernel["MIInputPerThread%s"%tc]//writer.states.bpr): - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), \ src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*2+1, i+2*vIdx*numVgpr)), \ src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*2, i+2*vIdx*numVgpr)), \ src2=sgpr("PackKForV%u"%vectorIdx), \ @@ -571,28 +580,28 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): cvtDestVgpr1 = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2*vIdx*numVgpr+1), 1) cvtDestVgpr2 = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2*vIdx*numVgpr+2), 1) cvtDestVgpr3 = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%(kernel["MIInputPerThread%s"%tc]), 2*vIdx*numVgpr+3), 1) - packCode.add(VLShiftRightB32(dst=cvtDestVgpr3, shiftHex=16, src=cvtDestVgpr1, comment="shift 2 element to vgpr+3")) + packCodeT.add(VLShiftRightB32(dst=cvtDestVgpr3, shiftHex=16, src=cvtDestVgpr1, comment="shift 2 element to vgpr+3")) - packCode.add(VMovB32(dst=cvtDestVgpr2, src=cvtDestVgpr1)) - packCode.add(VLShiftRightB32(dst=cvtDestVgpr1, shiftHex=16, src=cvtDestVgpr0, comment="shift 2 element to vgpr+1")) + packCodeT.add(VMovB32(dst=cvtDestVgpr2, src=cvtDestVgpr1)) + packCodeT.add(VLShiftRightB32(dst=cvtDestVgpr1, shiftHex=16, src=cvtDestVgpr0, comment="shift 2 element to vgpr+1")) if writer.states.asmCaps["Hascvtf16_fp8"]: - packCode.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr0, src=cvtDestVgpr0,scale=0x3f800000, comment="convert F8 to F16")) - packCode.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr1, src=cvtDestVgpr1,scale=0x3f800000, comment="convert F8 to F16")) - packCode.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr2, src=cvtDestVgpr2,scale=0x3f800000, comment="convert F8 to F16")) - packCode.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr3, src=cvtDestVgpr3,scale=0x3f800000, comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr0, src=cvtDestVgpr0,scale=0x3f800000, comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr1, src=cvtDestVgpr1,scale=0x3f800000, comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr2, src=cvtDestVgpr2,scale=0x3f800000, comment="convert F8 to F16")) + packCodeT.add(VCvtScalePkFP8toF16(dst=cvtDestVgpr3, src=cvtDestVgpr3,scale=0x3f800000, comment="convert F8 to F16")) else: - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr0, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr0, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr0, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr1, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr1, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr1, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr2, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr2, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr2, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) - packCode.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr3, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr3, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) - packCode.add(VCvtF32toF16(dst=cvtDestVgpr3, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr0, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr0, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr0, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr1, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr1, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr1, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr2, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr2, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr2, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) + packCodeT.add(VCvtPkFP8toF32(dst=vgpr("CvtTemp", 2), src=cvtDestVgpr3, sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr3, src=vgpr("CvtTemp+0"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_0), comment="Convert to FP16")) + packCodeT.add(VCvtF32toF16(dst=cvtDestVgpr3, src=vgpr("CvtTemp+1"), sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1), comment="Convert to FP16")) if rIdx == numReadsPerUnroll-1: for i in range(0, numVgpr*2): @@ -600,7 +609,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): vgprOffset = 0 for vectorIdx in range(0, 2): for elementIdx in range(0, tP["bpe"]*kernel["MIInputPerThread%s"%tc]//writer.states.bpr): - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), \ src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*2+1, i+2*vIdx*numVgpr)), \ src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*2, i+2*vIdx*numVgpr)), \ src2=sgpr("PackKForV%u"%vectorIdx), \ @@ -638,11 +647,11 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if elementIdx >= writer.states.bpr: break # since the number of input thread is 4, so will alwasy be D0, D1, D2, D3 - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 0, i+vIdx*numVgpr)), src2=sgpr("PackKFor%sV%u"%(tPackM, vgprOffset)), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 0, i+vIdx*numVgpr)), src2=sgpr("PackKFor%sV%u"%(tPackM, vgprOffset)), \ comment="1 select K=%u%u for vector=%u"%(0, 1, vgprOffset))) - packCode.add(VPermB32(dst=vgpr("PackTemp"), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 3, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 2, i+vIdx*numVgpr)), src2=sgpr("PackKFor%sV%u"%(tPackM, vgprOffset)), \ + packCodeT.add(VPermB32(dst=vgpr("PackTemp"), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 3, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 2, i+vIdx*numVgpr)), src2=sgpr("PackKFor%sV%u"%(tPackM, vgprOffset)), \ comment="1 select K=%u%u for vector=%u"%(2, 3, vgprOffset))) - packCode.add(VLShiftLeftOrB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src0=vgpr("PackTemp"), shiftHex=16, src1=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), comment="pack two half Vgpr to one Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src0=vgpr("PackTemp"), shiftHex=16, src1=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), comment="pack two half Vgpr to one Vgpr")) vgprOffset += 1 elif kernel["MIInputPerThread%s"%tc] == 2: vgprOffset = 0 @@ -650,7 +659,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if elementIdx >= writer.states.bpr: break # since the number of input thread is 2, so will alwasy be D0 and D1 - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 0, i+vIdx*numVgpr)), src2=sgpr("PackKFor%sV%u"%(tPackM, vgprOffset)), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, 0, i+vIdx*numVgpr)), src2=sgpr("PackKFor%sV%u"%(tPackM, vgprOffset)), \ comment="select K=%u%u for vector=%u"%(0, 1, vgprOffset))) vgprOffset += 1 elif kernel["MIInputPerThread%s"%tc] == 1: @@ -661,23 +670,23 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if elementIdx >= writer.states.bpr: break comment_ = "another VGPR storing lshr %d-bit value %d %d" %(bitShift, vgprIdx, elementIdx) if bitShift != 0 else "" - packCode.add(VMovB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src=destVgpr_, comment=comment_)) + packCodeT.add(VMovB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), src=destVgpr_, comment=comment_)) if bitShift != 0: - packCode.add(VLShiftRightB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), shiftHex=hex(bitShift), src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), comment="ValuMetadata Vpgr >> %d" % bitShift)) + packCodeT.add(VLShiftRightB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), shiftHex=hex(bitShift), src=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx)), comment="ValuMetadata Vpgr >> %d" % bitShift)) bitShift += 8 else: assert False elif tP["isM"]: vgprOffset = 0 for elementIdx in range(0, kernel["MIInputPerThread%s"%tc]): - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx+vIdx*2)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, vgprOffset*2 + 1 , i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, vgprOffset*2, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%elementIdx), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+elementIdx+vIdx*2)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, vgprOffset*2 + 1 , i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, vgprOffset*2, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%elementIdx), \ comment="select K=%u%u for vector=%u"%(vgprOffset*2+1, vgprOffset*2, elementIdx))) vgprOffset += (1 if elementIdx % 2 == 1 else 0) elif kernel["ProblemType"]["DataType"].isHalf() or kernel["MFMA_BF16_1K"] or kernel["ProblemType"]["DataType"].isBFloat16(): vgprOffset = 0 for vectorIdx in range(0, numElementPerReg): for elementIdx in range(0, tP["bpe"]*kernel["MIInputPerThread%s"%tc]//writer.states.bpr): - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%vectorIdx), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%vectorIdx), \ comment="select K=%u%u for vector=%u"%(elementIdx*numElementPerReg, elementIdx*numElementPerReg+1, vectorIdx))) vgprOffset += 1 elif kernel["ProblemType"]["DataType"].isInt8() or kernel["ProblemType"]["DataType"].is8bitFloat(): @@ -687,11 +696,11 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if vectorWidth <= 2 and vectorIdx > 1: break for elementIdx in range(0, tP["bpe"]*kernel["MIInputPerThread%s"%tc]//writer.states.bpr): - packCode.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%vectorIdx), \ + packCodeT.add(VPermB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx+vgprOffset)), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+1, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%vectorIdx), \ comment="select K=%u%u for vector=%u"%(elementIdx*4, elementIdx*4+1, vectorIdx))) - packCode.add(VPermB32(dst=vgpr("PackTemp"), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+3, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+2, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%vectorIdx), \ + packCodeT.add(VPermB32(dst=vgpr("PackTemp"), src0=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+3, i+vIdx*numVgpr)), src1=vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, elementIdx*numElementPerReg+2, i+vIdx*numVgpr)), src2=sgpr("PackKForV%u"%vectorIdx), \ comment="select K=%u%u for vector=%u"%(elementIdx*4+2, elementIdx*4+3, vectorIdx))) - packCode.add(VLShiftLeftOrB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx + vgprOffset)), src0=vgpr("PackTemp"), shiftHex=16, src1=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx + vgprOffset)), comment="pack two half Vgpr to one Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx + vgprOffset)), src0=vgpr("PackTemp"), shiftHex=16, src1=vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, vgprIdx + vgprOffset)), comment="pack two half Vgpr to one Vgpr")) vgprOffset += 1 else: @@ -704,13 +713,13 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): dstVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valufIdx/2), numVgpr) lowVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valufIdx - 1), numVgpr) highVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valufIdx), numVgpr) - packCode.add(VLShiftLeftOrB32(dst=dstVgpr, src0=highVgpr, shiftHex=8, src1=lowVgpr, comment="pack two int8 Vgpr to one half Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=dstVgpr, src0=highVgpr, shiftHex=8, src1=lowVgpr, comment="pack two int8 Vgpr to one half Vgpr")) if isHigh16Bits: # every 4 metadatas will be packed into one vgpr, so divide 4 to let dstVgrp be 0,1,2,... dstVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valufIdx/4), numVgpr) lowVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valufIdx/2 - 1), numVgpr) highVgpr = vgpr("Valu%s_X%u_I%u+%u"%(tc, bufferIdx, iui, valufIdx/2), numVgpr) - packCode.add(VLShiftLeftOrB32(dst=dstVgpr, src0=highVgpr, shiftHex=hex(0x10), src1=lowVgpr, comment="pack two int8x2 Vgpr to one Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=dstVgpr, src0=highVgpr, shiftHex=hex(0x10), src1=lowVgpr, comment="pack two int8x2 Vgpr to one Vgpr")) # Metadata only use one vgpr in current SMFMA instructions, so doesn't need these two flags at localread (gfx94x, gfx95x). isHigh16Bits = False isHigh8Bits = False @@ -718,9 +727,9 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if highBitsForHalf: highVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%2, valuiIdx), numVgpr) if writer.states.archCaps["DSLow16NotPreserve"]: - packCode.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=highVgpr, shiftHex=hex(0x10), src1=baseLRVgpr, comment="pack two half Vgpr to one Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=highVgpr, shiftHex=hex(0x10), src1=baseLRVgpr, comment="pack two half Vgpr to one Vgpr")) else: - packCode.add(VOrB32(dst=baseLRVgpr, src0=baseLRVgpr, src1=highVgpr, comment="pack two half Vgpr to one Vgpr")) + packCodeT.add(VOrB32(dst=baseLRVgpr, src0=baseLRVgpr, src1=highVgpr, comment="pack two half Vgpr to one Vgpr")) destVgpr = highVgpr # pack for blockWidth 0.25 type if rIdx != 0: @@ -729,12 +738,12 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): destVgpr = highVgpr if isHigh8Bits: lowVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, (rIdx%4)-1, valuiIdx), numVgpr) if isHigh16Bits else baseLRVgpr - packCode.add(VLShiftLeftOrB32(dst=lowVgpr, src0=highVgpr, shiftHex=8, src1=lowVgpr, comment="pack two int8 Vgpr to one half Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=lowVgpr, src0=highVgpr, shiftHex=8, src1=lowVgpr, comment="pack two int8 Vgpr to one half Vgpr")) if isHigh16Bits: if writer.states.archCaps["DSLow16NotPreserve"]: - packCode.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=lowVgpr, shiftHex=hex(0x10), src1=baseLRVgpr, comment="pack two half Vgpr to one Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=lowVgpr, shiftHex=hex(0x10), src1=baseLRVgpr, comment="pack two half Vgpr to one Vgpr")) else: - packCode.add(VOrB32(dst=baseLRVgpr, src0=baseLRVgpr, src1=lowVgpr, comment="pack two half Vgpr to one Vgpr")) + packCodeT.add(VOrB32(dst=baseLRVgpr, src0=baseLRVgpr, src1=lowVgpr, comment="pack two half Vgpr to one Vgpr")) else: # no ECC pack # pack for No ECC blockwidth 0.25 type if rIdx != 0: @@ -742,7 +751,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): highVgpr = vgpr("Valu%s_X%u_I%u_D%u+%u"%(tc, bufferIdx, iui, rIdx%2, valuiIdx), numVgpr) destVgpr = highVgpr if isHigh8Bits and isHigh16Bits: - packCode.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=highVgpr, shiftHex=hex(0x8), src1=baseLRVgpr, comment="pack two int8x2 Vgpr to one Vgpr")) + packCodeT.add(VLShiftLeftOrB32(dst=baseLRVgpr, src0=highVgpr, shiftHex=hex(0x8), src1=baseLRVgpr, comment="pack two int8x2 Vgpr to one Vgpr")) if kernel["ConvertAfterDS"] and kernel["UnrollMajorLDS%s"%tc]: valufIdx += blockWidth * (tP["bpe"] // tP["bpeDS"]) if (not tP["isM"]) else 1 @@ -836,7 +845,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): if kernel["UseDirect32XEmulation"] and (valuiIdx % 8) < 4: index = baseValuiIdx // 2 + rIdx destVgpr = vgpr("Valu%s_T%u_I%u+%u"%(tc, bufferIdx, iui, index), blockWidth) - localReadCode.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) + localReadCodeT.add(LocalReadX(dst=destVgpr, src=srcAddr, ds=ds, comment=comment)) # TODO - handle vector-load with writer.allocTmpSgpr(1) as tmpSgprInfo: tmpSgpr = tmpSgprInfo.idx @@ -851,37 +860,53 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): # TODO: Handle vector, but need to take care the last one dbgVgprList = (dbgVgprList[1].split("]")[0]).split(':') dbgVgpr = "v[%s]"%dbgVgprList[0] - localReadCode.add(SWaitCnt(dscnt=0, vscnt=0, comment="CheckValue1 wait for LDS read")) + localReadCodeT.add(SWaitCnt(dscnt=0, vscnt=0, comment="CheckValue1 wait for LDS read")) if kernel["ProblemType"]["DataType"].isHalf(): hexValue = hex(0x3c003c00) # packed 1s if needPack: hexValue = hex(0x3c000000) if highBitsForHalf else hex(0x00003c00) - localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: FP16")) - localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) + localReadCodeT.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: FP16")) + localReadCodeT.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) elif kernel["ProblemType"]["DataType"].isBFloat16(): hexValue = hex(0x3f803f80) # packed 1s if needPack: hexValue = hex(0x3f800000) if highBitsForHalf else hex(0x00003f80) - localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: BF16")) - localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) + localReadCodeT.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: BF16")) + localReadCodeT.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) if kernel["ProblemType"]["DataType"].isInt8(): if needPack: hexValue = hex(0x00010000) if isHigh16Bits else hex(0x00000001) - localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: INT8")) - localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) + localReadCodeT.add(SMovB32(dst=sgpr(tmpSgpr), src=hexValue, comment="CheckValue1: INT8")) + localReadCodeT.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) # TODO - Check if this works. But need this? MFMA would use INT8 elif kernel["ProblemType"]["DataType"].isInt8x4(): - localReadCode.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(0x01010101), comment="CheckValue1: INT8x4")) - localReadCode.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) + localReadCodeT.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(0x01010101), comment="CheckValue1: INT8x4")) + localReadCodeT.add(writer.assert_eq( dbgVgpr, sgpr(tmpSgpr))) elif kernel["ProblemType"]["DataType"].isSingle(): - localReadCode.add(writer.assert_eq( dbgVgpr, 1.0) ) + localReadCodeT.add(writer.assert_eq( dbgVgpr, 1.0) ) + + addPackLR = False + if ((subTileIdx == 0 and subIterLoadCount < totalLoads // numSubTiles) \ + or (subTileIdx == 1 and subIterLoadCount >= totalLoads // numSubTiles) \ + or numSubTiles == 1) or writer.states.inTailLoop: + addPackLR = True + + if addPackLR: + if needPack or numSplitMetadata: + packCode.add(packCodeT) + localReadCode.add(localReadCodeT) + + subIterLoadCount += 1 + # End of loop3 if needPack: writer.states.numPackCvt = len(packCode.flatitems()) + # End of loop2 + # End of loop1 # DTV case, do not return local read code. if (tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]: imod = Module("LocalReadDo%s_I%s (Empty)" % (tP["tensorChar"],iui)) diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py index 063126f6347..d020f5bb7df 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py @@ -259,9 +259,7 @@ def calculateLatencyLeft(numReads, localReadBlockWidth, localReadLatency): writer.states.syncPlrMfmaIndex = numMfmaPerIter*(kernel["LoopIters"]-writer.states.numItersPLR+1) - writer.states.numMfmaForNextLoopLR - 1 if writer.states.numItersPLR else 0 if kernel["ForceUnrollSubIter"]: - if ( kernel["UseF32XEmulation"]) : - writer.states.syncPlrMfmaIndex = writer.states.syncPlrMfmaIndex *3 # TF32 - elif ( kernel["ProblemType"]["DataType"].isComplex()): + if ( kernel["ProblemType"]["DataType"].isComplex()): writer.states.syncPlrMfmaIndex = writer.states.syncPlrMfmaIndex *4 # Complex numMfmaBetweenLWandBarrier = 2 if kernel["MatrixInstM"] == 32 else 3 diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 37a9583ad86..ecc358d4f63 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -1263,7 +1263,8 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length): loadModule = globalReadCode.popFirstItem() iterCode.add(loadModule) # schedule remaining globalReadIncInst - if i == numMfmaPerIter - 1 and globalReadCode.itemsSize(): + if (i == numMfmaPerIter - 1 and globalReadCode.itemsSize()) or \ + (i == 0 and globalReadCode.itemsSize() and (iteration == 1 and kernel["ForceUnrollSubIter"])): loadModules = globalReadCode.popFirstNItems(globalReadCode.itemsSize()) iterCode.addItems(loadModules) @@ -1312,7 +1313,7 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length): self.states.perIterLocalWriteCanSkip[iteration] += countLocalWrite(writeItem) + countDSStoreB256(writeItem) if kernel["ForceUnrollSubIter"] and (writeItems and i == (numMfmaPerIter - 1)): # if ForceUnrollSubIter, we need to schedule all localWrite in last mfma - while writeItems: + while writeItems: writeItem = writeItems.pop(0) iterCode.add(writeItem) if mfmaIndex == self.states.lwEndMfmaIndex: @@ -1785,7 +1786,8 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length): localReads += localReadsA + localReadsB # some of localReads is interleaved after waitcnt in SIA3 if kernel["ScheduleIterAlg"] == 3 and self.states.numItersPLR and\ - (iteration < numReadsIterA or iteration < numReadsIterB or numPrefetchIter): + (iteration < numReadsIterA or iteration < numReadsIterB or numPrefetchIter) and\ + not kernel["ForceUnrollSubIter"]: if ((iteration < numReadsIterA and not dataAtIterA < max(dataAtIterA,dataAtIterB)) or numPrefetchIter) and (not kernel["DirectToVgprA"]): localReads -= self.states.numReadsPerIterA if ((iteration < numReadsIterB and not dataAtIterB < max(dataAtIterA,dataAtIterB)) or numPrefetchIter) and (not kernel["DirectToVgprB"]): @@ -1799,6 +1801,8 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length): localReadsNotWaited = self.states.numReadsPerIterB//kernel["InnerUnroll"] - self.states.numReadsPerUnrollB if localReadsNotWaited > 0: localReads += localReadsNotWaited + elif kernel["ForceUnrollSubIter"]: + localReads = 0 dscnt += localReads iterCode.addComment0("numPrefetchIter=%u" % numPrefetchIter) @@ -2168,10 +2172,7 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is isSwapAndResetLwoIter = isResetLroIter isSwapLroIter = isResetLroIter if kernel["ScheduleIterAlg"] == 3: - if kernel["ForceUnrollSubIter"]: - isSwapAndResetLwoIter = 1 - else: - isSwapAndResetLwoIter = (u == self.states.lwEndMfmaIndex//(self.states.numMfmaPerIter)) + isSwapAndResetLwoIter = (u == self.states.lwEndMfmaIndex//(self.states.numMfmaPerIter)) extraComment = "" if isLastLoop: @@ -2236,6 +2237,8 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is doReadA = doReadA and iui*self.states.numReadsIterCoalescedA < kernel["InnerUnroll"] doReadB = doReadB and iui*self.states.numReadsIterCoalescedB < kernel["InnerUnroll"] doReadM = doReadM and iui*self.states.numReadsIterCoalescedMetadata < kernel["InnerUnroll"] + if (doReadA or doReadB or doReadM) and kernel["ForceUnrollSubIter"]: + pack[1] = Module() if doReadA: localReads.addComment1("local read a") bufferIdx = plrIdx*self.states.numIterPerCoalescedReadA @@ -2247,7 +2250,6 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is localReads.add(localReadCodeA) pack[plrIdx*self.states.numIterPerCoalescedReadA].add(packCodeA) if kernel["ForceUnrollSubIter"]: - pack[1] = Module() pack[1].add(packCodeA) if doReadM: localReads.addComment1("local read metadata") @@ -2255,6 +2257,8 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is if needNextBufLR: localReads.add(localReadCodeM) pack[plrIdx*self.states.numIterPerCoalescedReadMetadata].add(packCodeM) + if kernel["ForceUnrollSubIter"]: + pack[1].add(packCodeM) if doReadB: localReads.addComment1("local read b") bufferIdx = plrIdx*self.states.numIterPerCoalescedReadB @@ -2308,7 +2312,8 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is if not kernel["ForceUnrollSubIter"] or (doReadB and (u 1: + if numSubTiles > 1 and not self.states.inTailLoop: # iter (idxOuter_start, idxOuter_stop) (idxInner_start, idxInner_stop) MFMA # 0 (0,4) (0,4) MFMA(A0,B0) # 1 (0,4) (4,8) MFMA(A1,B0) diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index 5d74b9cef97..27534781f3e 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -1361,7 +1361,7 @@ def assignDerivedParameters( if (state["MIWaveTile"][1] % state["VectorWidthB"]) != 0: reject(state, printRejectionReason, "MIWaveTile0(%u) should be multiple of VectorWidthB(%u)" % (state["MIWaveTile"][1], state["VectorWidthB"])) return - + if len(problemType["IndicesSummation"]) > 1: # not supported with multiple summations, bug is maybe something with # how stagger iteration is wrapped when unroll loop exits @@ -1426,16 +1426,36 @@ def assignDerivedParameters( if "ValidDepthU" in state: del state["ValidDepthU"] + ################################################################# + # ForceUnrollSubIter requirements + # - Needs PGR > 0, double buffer + # - MIWaveTile must be even and larger than 2 + # - TLU{A,B} cases only supported if using LdsTR or if VPerm not needed (size{A,B} >= 4) + # + # - Not supported for mixed precision cases currently + sizeDataTypeA = state["ProblemType"]["DataTypeA"].numBytes() + sizeDataTypeB = state["ProblemType"]["DataTypeB"].numBytes() + sizeDataType = state["ProblemType"]["DataType"].numBytes() + TLUA = state["ProblemType"]["TLUA"] + TLUB = state["ProblemType"]["TLUB"] if ( + state["EnableMatrixInstruction"] and not state["ExpandPointerSwap"] and state["DepthU"] == state["MatrixInstK"] and state["PrefetchGlobalRead"] and not state["1LDSBuffer"] and (state["MIWaveTile"][0] > 2 and state["MIWaveTile"][1] > 2) and (state["MIWaveTile"][0] % 2 == 0 and state["MIWaveTile"][1] % 2 == 0) + and (sizeDataTypeA == sizeDataType) and (sizeDataTypeB == sizeDataType) + and ((TLUA == False or state["enableLDSTrA"] or sizeDataTypeA >= 4) and (TLUB == False or state["enableLDSTrB"] or sizeDataTypeB >= 4) ) ): state["ForceUnrollSubIter"] = True - state["numSubTiles"] = 2 + state["numSubTiles"] = 2 + state["PrefetchLocalRead"] = 0 if state["ClusterLocalRead"] == 0 else state["PrefetchLocalRead"] else: state["ForceUnrollSubIter"] = False - state["numSubTiles"] = 1 + state["numSubTiles"] = 1 + + # Check if CMS is available for this solution + hasCMS,_ = hasCustomSchedule(state) + state["UseCustomMainLoopSchedule"] = hasCMS # 0: Normal mode. Hardware applies all of the normal data dependency checks # 1: Full expert mode (not suppoeted yet). Disable hardware checks against: VA_VDST, VA_SDST, VA_SSRC, VA_VCC, VM_VSRC and SA_SDST. @@ -3318,11 +3338,6 @@ def calcEpilogueTurns(factorDims: List) -> int: #print("Force to Disable PreloadKernArgs since this hipcc version doesn't support",) state["PreloadKernArgs"] = 0 - hasCMS,_ = hasCustomSchedule(state) - state["UseCustomMainLoopSchedule"] = hasCMS - - state["InternalSupportParams"]["UseSFC"] = (len(state["SpaceFillingAlgo"]) > 0) - ######################################## @ staticmethod def getParametersIndented(state, indent): diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml deleted file mode 100644 index 2d1b6ed34f6..00000000000 --- a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f16_plr.yaml +++ /dev/null @@ -1,143 +0,0 @@ -GlobalParameters: - MergeFiles: False - NumElementsToValidate: -1 - - NumWarmups: 1000 - EnqueuesPerSync: 10000 - - NumBenchmarks: 1 - SyncsPerBenchmark: 1 - SleepPercent: 0 - DataInitTypeA: 12 - DataInitTypeB: 13 - # DataInitTypeC: 0 - # DataInitTypeD: 3 - DataInitTypeBeta: 0 - DataInitTypeAlpha: 1 - # DataInitTypeBias: 0 - DataInitTypeScaleAlphaVec: 1 - DataInitTypeScaleA: 1 - DataInitTypeScaleB: 1 - CSVExportWinner: 1 - CSVMergeSameProblemID: 1 - #Device: 0 - MinKForGSU: 1 - #MaxWorkspaceSize: 3355443200 - MaxFileName: 256 - KernelTime: True - #RotatingBufferSize: 512 - MaxLDS: 163840 - DeviceLDS: 163840 - #GenerateSourcesAndExit: True - - PrintSolutionRejectionReason: True - #Device: 3 - - RotatingBufferSize: 512 - KeepBuildTmp: True - - -BenchmarkProblems: - ######################################## - # NN - standard - ######################################## - - - - # ProblemType - OperationType: GEMM - #DataTypeA: f8 - #DataTypeB: h - #UseScaleAB: True - - DataType: h - DestDataType: h - ComputeDataType: s - HighPrecisionAccumulate: True - TransposeA: 1 - TransposeB: 0 - UseBeta: True - Batched: True - - #UseBias: True - #Activation: True - #UseScaleAlphaVec: True - - - # BenchmarkProblemSizeGroup - Standard - InitialSolutionParameters: - BenchmarkCommonParameters: - - KernelLanguage: ["Assembly"] - ForkParameters: - - MatrixInstruction: - - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 8, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 8, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 8, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 8, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 8, 2 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 3 , 2,2 ] - - [16, 16, 32, 1, 1, 7, 2 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 3 , 2,2 ] - - [16, 16, 32, 1, 1, 6, 2 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 3 , 2,2 ] - - [16, 16, 32, 1, 1, 5, 2 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 3 , 2,2 ] - - [16, 16, 32, 1, 1, 4, 2 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 3 , 2,2 ] - - [16, 16, 32, 1, 1, 3, 2 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 8 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 7 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 6 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 5 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 4 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 3 , 2,2 ] - - [16, 16, 32, 1, 1, 2, 2 , 2,2 ] - - - - PrefetchGlobalRead: [2] - - PrefetchLocalRead: [1] - - DepthU: [32] - - TransposeLDS: [1] - #- DirectToLds: [1] - - StaggerU: [0] - - 1LDSBuffer: [0] - - GlobalReadVectorWidthA: [8] - - GlobalReadVectorWidthB: [8] - - LocalReadVectorWidth: [8] - #- GlobalSplitU: [3] #disable GSU - - SourceSwap: [1] - # - NonTemporalA: [4] - # - NonTemporalB: [0,1] - # - NonTemporalC: [3] - # - NonTemporalD: [0] - BenchmarkJoinParameters: - BenchmarkFinalParameters: - - ProblemSizes: - - Exact: [4096, 4096, 1, 16384] - #- Exact: [256, 256, 1, 384] - #- BiasTypeArgs: ['s'] - #- ActivationArgs: - # - [Enum: none] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml deleted file mode 100644 index 6f654ada334..00000000000 --- a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/f8_plr.yaml +++ /dev/null @@ -1,159 +0,0 @@ -GlobalParameters: - MergeFiles: False - NumElementsToValidate: -1 - - NumWarmups: 1000 - EnqueuesPerSync: 10000 - - NumBenchmarks: 1 - SyncsPerBenchmark: 1 - SleepPercent: 0 - DataInitTypeA: 12 - DataInitTypeB: 13 - # DataInitTypeC: 0 - # DataInitTypeD: 3 - DataInitTypeBeta: 0 - DataInitTypeAlpha: 1 - # DataInitTypeBias: 0 - DataInitTypeScaleAlphaVec: 1 - DataInitTypeScaleA: 1 - DataInitTypeScaleB: 1 - CSVExportWinner: 1 - CSVMergeSameProblemID: 1 - #Device: 0 - MinKForGSU: 1 - #MaxWorkspaceSize: 3355443200 - MaxFileName: 256 - KernelTime: True - #RotatingBufferSize: 512 - MaxLDS: 163840 - DeviceLDS: 163840 - #GenerateSourcesAndExit: True - - PrintSolutionRejectionReason: True - #Device: 3 - - RotatingBufferSize: 512 - KeepBuildTmp: True - - -BenchmarkProblems: - ######################################## - # NN - standard - ######################################## - - - - # ProblemType - OperationType: GEMM - #DataTypeA: f8 - #DataTypeB: h - #UseScaleAB: True - - DataType: f8 - DestDataType: h - ComputeDataType: s - HighPrecisionAccumulate: True - TransposeA: 1 - TransposeB: 0 - UseBeta: True - Batched: True - - #UseBias: True - #Activation: True - #UseScaleAlphaVec: True - - - # BenchmarkProblemSizeGroup - Standard - InitialSolutionParameters: - BenchmarkCommonParameters: - - KernelLanguage: ["Assembly"] - ForkParameters: - - MatrixInstruction: - - [16, 16, 128, 1, 1, 8, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 8, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 8, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 8, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 8, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 8, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 8, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 7, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 6, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 5, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 4, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 3, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 2, 1 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 8 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 7 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 6 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 5 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 4 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 3 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 2 , 2,2 ] - - [16, 16, 128, 1, 1, 1, 1 , 2,2 ] - - - AssertSummationElementMultiple: [128] - - AssertFree0ElementMultiple: [16] - - PrefetchGlobalRead: [2] - - PrefetchLocalRead: [1] - - DepthU: [128] - - TransposeLDS: [1] - #- DirectToLds: [1] - - StaggerU: [0] - - 1LDSBuffer: [0] - - GlobalReadVectorWidthA: [16] - - GlobalReadVectorWidthB: [16] - - LocalReadVectorWidth: [16] - #- GlobalSplitU: [3] #disable GSU - - SourceSwap: [1] - # - NonTemporalA: [4] - # - NonTemporalB: [0,1] - # - NonTemporalC: [3] - # - NonTemporalD: [0] - BenchmarkJoinParameters: - BenchmarkFinalParameters: - - ProblemSizes: - - Exact: [4096, 4096, 1, 16384] - #- Exact: [256, 256, 1, 384] - #- BiasTypeArgs: ['s'] - #- ActivationArgs: - # - [Enum: none] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/plr_zero.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/plr_zero.yaml new file mode 100644 index 00000000000..96a826dd225 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/plr_zero.yaml @@ -0,0 +1,641 @@ +TestParameters: + marks: [skip-gfx942, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] # not supported by arch + +GlobalParameters: + MergeFiles: False + NumElementsToValidate: -1 + + NumWarmups: 1 + EnqueuesPerSync: 1 + + NumBenchmarks: 1 + SyncsPerBenchmark: 1 + SleepPercent: 0 + DataInitTypeA: 12 + DataInitTypeB: 13 + DataInitTypeBeta: 0 + DataInitTypeAlpha: 1 + DataInitTypeScaleAlphaVec: 1 + DataInitTypeScaleA: 1 + DataInitTypeScaleB: 1 + CSVExportWinner: 1 + CSVMergeSameProblemID: 1 + MinKForGSU: 1 + MaxFileName: 256 + KernelTime: True + MaxLDS: 163840 + DeviceLDS: 163840 + + #PrintSolutionRejectionReason: True + #Device: 3 + + RotatingBufferSize: 512 + KeepBuildTmp: True + + +BenchmarkProblems: + ######################################## + # HHS TN - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 7, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 5, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 3, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 7 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 5 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 3 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 2 , 2,2 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [32, 32, 16, 1, 1, 4, 4 , 2,2 ] + - [32, 32, 16, 1, 1, 4, 3 , 2,2 ] + - [32, 32, 16, 1, 1, 4, 2 , 2,2 ] + - [32, 32, 16, 1, 1, 3, 4 , 2,2 ] + - [32, 32, 16, 1, 1, 3, 3 , 2,2 ] + - [32, 32, 16, 1, 1, 3, 2 , 2,2 ] + - [32, 32, 16, 1, 1, 2, 4 , 2,2 ] + - [32, 32, 16, 1, 1, 2, 3 , 2,2 ] + - [32, 32, 16, 1, 1, 2, 2 , 2,2 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [16] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 1,4 ] + - [16, 16, 32, 1, 1, 4, 4 , 4,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [8192, 8192, 1, 8192] + + ######################################## + # HHS NN - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 0 + TransposeB: 0 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 2 , 2,2 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + #- Exact: [4096, 4096, 1, 16384] + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [32, 32, 16, 1, 1, 4, 4 , 2,2 ] + - [32, 32, 16, 1, 1, 4, 6 , 1,1 ] + - [32, 32, 16, 1, 1, 6, 4 , 1,1 ] + - [32, 32, 16, 1, 1, 6, 6 , 1,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [16] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 1,4 ] + - [16, 16, 32, 1, 1, 4, 4 , 4,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [8192, 8192, 1, 8192] + + ######################################## + # HHS TT - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 1 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 4 , 2,2 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + #- Exact: [4096, 4096, 1, 16384] + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [32, 32, 16, 1, 1, 4, 4 , 2,2 ] + - [32, 32, 16, 1, 1, 4, 6 , 1,1 ] + - [32, 32, 16, 1, 1, 6, 4 , 1,1 ] + - [32, 32, 16, 1, 1, 6, 6 , 1,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [16] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 1,4 ] + - [16, 16, 32, 1, 1, 4, 4 , 4,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [8192, 8192, 1, 8192] + + ######################################## + # HHS NT - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 0 + TransposeB: 1 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 8, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 32, 1, 1, 2, 6 , 2,2 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [0] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + #- Exact: [4096, 4096, 1, 16384] + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [32, 32, 16, 1, 1, 4, 4 , 2,2 ] + - [32, 32, 16, 1, 1, 4, 6 , 1,1 ] + - [32, 32, 16, 1, 1, 6, 4 , 1,1 ] + - [32, 32, 16, 1, 1, 6, 6 , 1,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [16] + - TransposeLDS: [0] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 4, 4 , 1,4 ] + - [16, 16, 32, 1, 1, 4, 4 , 4,1 ] + - PrefetchGlobalRead: [1,2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [0] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,8] + - GlobalReadVectorWidthB: [1,2,8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + - LDSTrInst: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [1, 37, 128]] + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 8 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1,2] + - DepthU: [32] + - TransposeLDS: [0] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [8192, 8192, 1, 8192] + + ######################################## + # F8HS TN - standard + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: f8 + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 8, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 7 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 8, 1 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 6, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 4, 2 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 8 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 6 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 4 , 2,2 ] + - [16, 16, 128, 1, 1, 2, 2 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1,2] + - DepthU: [128] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [4,8,16] + - GlobalReadVectorWidthB: [4,8,16] + - LocalReadVectorWidth: [16] + - SourceSwap: [1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [2, 37, 512]] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/xfp32.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/xfp32.yaml index 0601a31692e..474f7e62e89 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/xfp32.yaml +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/xfp32.yaml @@ -95,6 +95,30 @@ BenchmarkProblems: BenchmarkFinalParameters: - ProblemSizes: - Range: [[1,67,517], [1,93,709], [1], [1,76,865]] + - # BenchmarkProblemSizeGroup - PLR0 tests + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [32] + - TransposeLDS: [0] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,4] + - GlobalReadVectorWidthB: [1,2,4] + - LocalReadVectorWidth: [4] + - SourceSwap: [1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [2, 37, 512]] ######################################## # F32X TN @@ -169,6 +193,30 @@ BenchmarkProblems: BenchmarkFinalParameters: - ProblemSizes: - Range: [[1,67,517], [1,93,709], [1], [1,76,865]] + - # BenchmarkProblemSizeGroup - PLR0 tests + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,4] + - GlobalReadVectorWidthB: [1,2,4] + - LocalReadVectorWidth: [4] + - SourceSwap: [1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [2, 37, 512]] ######################################## # F32X NN @@ -205,6 +253,30 @@ BenchmarkProblems: - Exact: [128, 128, 1, 128] - Exact: [128, 128, 1, 127] - Exact: [2048, 2048, 1, 256] + - # BenchmarkProblemSizeGroup - PLR0 tests + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,4] + - GlobalReadVectorWidthB: [1,2,4] + - LocalReadVectorWidth: [4] + - SourceSwap: [1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [2, 37, 512]] ######################################## # F32X TT @@ -239,3 +311,27 @@ BenchmarkProblems: - Exact: [128, 128, 1, 128] - Exact: [128, 128, 1, 127] - Exact: [2048, 2048, 1, 256] + - # BenchmarkProblemSizeGroup - PLR0 tests + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 32, 1, 1, 4, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 4 , 2,2 ] + - [16, 16, 32, 1, 1, 6, 6 , 2,2 ] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DepthU: [32] + - TransposeLDS: [1] + - DirectToLds: [0, 1] + - StaggerU: [0] + - 1LDSBuffer: [0] + - GlobalReadVectorWidthA: [1,2,4] + - GlobalReadVectorWidthB: [1,2,4] + - LocalReadVectorWidth: [4] + - SourceSwap: [1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[255, 1, 257], [255, 1, 257], [1], [2, 37, 512]]