Skip to content

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
from ..Common import printWarning, roundUp, print2, DebugConfig, DataDirection, \
INDEX_CHARS, IsaVersion


from rocisa.code import Module, TextBlock, StructuredModule, KernelBody, Label
from rocisa.label import LabelManager

from rocisa.container import MUBUFModifiers, vgpr, sgpr, accvgpr, mgpr
from rocisa.enum import InstType, SelectBit, CacheScope
from rocisa.instruction import MFMAInstruction

import math
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, List, NamedTuple, Optional, Tuple, Type
from contextlib import contextmanager
from collections import deque
from rocisa import rocIsa, countInstruction, countGlobalRead, \
countLocalRead, countLocalWrite, countDSStoreB256, getMFMAs
from rocisa.asmpass import rocIsaPass, rocIsaPassOption
Expand Down Expand Up @@ -1245,25 +1259,36 @@ def globalReadScalePtrUpdates(tc, writer, kernel):
def emitSingleBufferLoad(tileInfo, kernel, sId0, sId1):
"""Emit buffer_load instructions for a single subtile (sId0, sId1).

When loadRatioGR > 1, multiple local subtiles share the same global read.
Only the first subtile in each group emits the load; others return empty.

Args:
tileInfo: TileInfo for the tensor component
sId0: Subtile row index
sId1: Subtile column index (K-dimension)
"""
module = Module()
tc = tileInfo.tc

subtileInfo = tileInfo.localSubtiles[tileInfo.getLocalSubtileLinearId(sId0, sId1)]
grBaseId = subtileInfo.globalReadMap[0]

# When loadRatioGR > 1, multiple subtiles share one global read.
# Only emit the load for the first subtile of each group.
if tileInfo.loadRatioGR > 1:
linearId = tileInfo.getLocalSubtileLinearId(sId0, sId1)
firstInGroup = int(grBaseId * tileInfo.loadRatioGR)
if linearId != firstInGroup:
return module

tc = tileInfo.tc
isGlc = bool(kernel["NonTemporal%s"%tc] & 0x1)
isSlc = bool(kernel["NonTemporal%s"%tc] & 0x2)
isNT = bool(kernel["NonTemporal%s"%tc] & 0x4)

subtileInfo = tileInfo.localSubtiles[tileInfo.getLocalSubtileLinearId(sId0, sId1)]
regList = tileInfo.localSubtilesRegister[subtileInfo.regListId]

offsetK = sId1 * int(tileInfo.mmaTileShape[1] * tileInfo.subtileShape[1] * tileInfo.bpe)
# TODO: grBaseId is probably not needed..
grBaseId = subtileInfo.globalReadMap[0]


subtileOffset = math.ceil(tileInfo.loadRatioGR*tileInfo.subtileSize)
WriteBaseAddr = "LocalWriteBaseAddr%s"%tc
# Emit number of buffer loads equal to number of loads needed to load a subtile
Expand Down Expand Up @@ -1297,17 +1322,10 @@ def globalReadDoSubtile(tc, writer, kernel):

tileInfo = writer.states.a.tileInfo if tc == 'A' else writer.states.b.tileInfo

grTracker = set()
for j in range(tileInfo.localSubtileGrid[1]):
for i in range(tileInfo.localSubtileGrid[0]):
grIds = tileInfo.localSubtiles[tileInfo.getLocalSubtileLinearId(i ,j)].globalReadMap
if not set(grIds).issubset(grTracker):
for grId in grIds:
grTracker.add(grId)
module.addComment0("Emit load for %s subtile: [%u, %u]"%(tc, i, j))
module.add(emitSubtileBufferLoad(tc, writer, kernel, [i, j]))
else:
module.addComment0("Emit load for %s subtile: [%u, %u] - already covered"%(tc, i, j))
module.addComment0("Emit load for %s subtile: [%u, %u]"%(tc, i, j))
module.add(emitSubtileBufferLoad(tc, writer, kernel, [i, j]))

return module

Expand Down Expand Up @@ -1775,15 +1793,16 @@ def mainLoop(writer, kernel):

# new path for PGR=2 pipelining with SubtileBasedScheduler
if pgr == 2:
from Tensile.Components.SubtileBasedScheduler import SubtileBasedScheduler, SchedulerConfig, PrefetchMode, VGPRTileReUseStrategy
from Tensile.Components.SubtileBasedScheduler import SubtileBasedScheduler, SchedulerConfig, PrefetchMode
tiA = writer.states.a.tileInfo
tiB = writer.states.b.tileInfo
scaleTiA = writer.states.mxsa.tileInfo if kernel["ProblemType"].get("MXBlockA", 0) else None
scaleTiB = writer.states.mxsb.tileInfo if kernel["ProblemType"].get("MXBlockB", 0) else None
# For 320x256, Use 5x1 parition grid.
# cfg = SchedulerConfig(tiA.localSubtileGrid[0]//5, tiB.localSubtileGrid[0],
# Use a single partition for now. TODO
cfg = SchedulerConfig(tiA.localSubtileGrid[0], tiB.localSubtileGrid[0],
#cfg = SchedulerConfig(tiA.localSubtileGrid[0]//2, tiB.localSubtileGrid[0]//2,
PrefetchMode.HALF_PREFETCH, VGPRTileReUseStrategy.ACROSS_SUBGROUP)
PrefetchMode.HALF_PREFETCH)
scheduler = SubtileBasedScheduler(tiA, tiB, cfg,
scaleTileInfoA=scaleTiA, scaleTileInfoB=scaleTiB)
# scheduler.printSchedule()
Expand All @@ -1804,10 +1823,12 @@ def mainLoop(writer, kernel):
module.addComment0("MAINLOOP")
numPartitions = len(scheduler.partitions)

# With scale double buffering, the scale set rotates per partition inside _emitLoop.
# After one iteration (N partitions), the set flips if N is odd → need 2x unrolling.
# If N is even, the set returns to starting position → no unrolling needed.
needsScaleUnroll = scheduler.hasScale and (numPartitions % 2 == 1)
# With scale double buffering, the scale set rotates inside _emitLoop:
# once per partition end + once per subtileK boundary = numSubtileK flips per partition.
# After one iteration (N partitions), total flips = N * numSubtileK.
# If odd → need 2x unrolling. If even → sets return to start, no unrolling needed.
scaleFlipsPerIter = numPartitions * scheduler.numSubtileK
needsScaleUnroll = scheduler.hasScale and (scaleFlipsPerIter % 2 == 1)

if needsScaleUnroll:
# 2x unrolled mainloop for odd partition count.
Expand Down Expand Up @@ -1849,7 +1870,7 @@ def mainLoop(writer, kernel):
module.add(Label("SkipToNGLL", ""))
if scheduler.hasScale:
endLabel = Label("SkipToEnd", "")
nllSet = 1 if numPartitions % 2 == 1 else 0
nllSet = 1 if scaleFlipsPerIter % 2 == 1 else 0

# Even path (or only path when no unrolling): mainloop ended at scaleSet=0.
module.add(scheduler._emitLoop(writer, kernel, "NGLL", scheduler.ngllSteps,
Expand All @@ -1866,7 +1887,7 @@ def mainLoop(writer, kernel):
module.add(scheduler._emitLoop(writer, kernel, "NGLL_odd", scheduler.ngllSteps,
scaleSet=1))
module.addComment0("NLL (odd)")
nllSetOdd = 0 if numPartitions % 2 == 1 else 1
nllSetOdd = 0 if scaleFlipsPerIter % 2 == 1 else 1
module.add(scheduler._emitLoop(writer, kernel, "NLL_odd", scheduler.nllSteps, scaleSet=nllSetOdd))

# NLLEarly: reached when counterL<=1 (preloop skip, no NGLL).
Expand Down
Loading