From 054b7499184ef63239aafc198cd1c508ac7cdd5a Mon Sep 17 00:00:00 2001 From: Koji Nakajima Date: Thu, 19 Feb 2026 06:19:24 +0000 Subject: [PATCH 1/3] Enable DirectToLds for MXSA/B --- .../Tensile/Components/LocalRead.py | 2 +- .../tensilelite/Tensile/KernelWriter.py | 4 +++ .../Tensile/KernelWriterAssembly.py | 26 ++++++++++--------- .../Tensile/SolutionStructs/Solution.py | 24 ++++++++--------- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py index 3eeaf40b3af4..5c747e3d6935 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py @@ -901,7 +901,7 @@ def applyPad(offset_val): offset_val = offset_val + tP["localReadSwapByteOffset"] # TODO: Add NLC>1 offset calcs here? if (kernel["DirectToLds%s" % tc] and \ - kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]: + kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]: # another address conversion for DirectToLds + NumLoadsCoalesced > 1 dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val) diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 8c1058c99840..694202c913bf 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -5756,6 +5756,10 @@ def GNLCOInit(tc): self.defineSgpr("LocalWriteAddrA", 1) if kernel["LocalWriteUseSgprB"]: self.defineSgpr("LocalWriteAddrB", 1) + if kernel["ProblemType"]["MXBlockA"] and kernel["LocalWriteUseSgprMXSA"]: + self.defineSgpr("LocalWriteAddrMXSA", 1) + if kernel["ProblemType"]["MXBlockB"] and kernel["LocalWriteUseSgprMXSB"]: + self.defineSgpr("LocalWriteAddrMXSB", 1) # Allocate registers to swap between lds buffers if kernel["StoreSwapAddr"]: diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index fe309fffab98..64931a18a1e0 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -3287,7 +3287,7 @@ def graFinalOffsetsSingleLoopGNLC(self, kernel, tP, tc, margin = -1): self.vgprPool.checkIn(reMap1) stride = "Strides%s"%(tc) - module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%c"%tc]), src=vgpr(tmpv2))) + module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%s"%tc]), src=vgpr(tmpv2))) module.add(VMulLOU32(dst=vgpr(tmpv), src0=sgpr(stride), src1=vgpr(tmpv))) if kernel["EdgeType"] == "ShiftPtr" and kernel["ProblemType"]["TLU%s"%tc] == 1: @@ -3387,11 +3387,11 @@ def graFinalOffsetsSingleLoop(self, kernel, tP, tc, tmp, graIdx, perp, sPerp, pa module.add(SMovB32(dst=sgpr(tmpSgpr), src=self.buff_load_inst_offset_max)) module.add(VAddU32(dst=vgpr(groVgpr), src0=vgpr(groVgpr), src1=sgpr(tmpSgpr), comment="shift for UseInstOffsetForGRO")) - ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"] + ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"] if kernel["LdsBlockSizePerPad%s"%tc] != 0: ldsInc += (ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"] else: - padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr + padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr ldsInc += (ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"] # buffer_load only support 12 bit instruction offset @@ -3447,11 +3447,11 @@ def computeScalarGroImpl(scalarGro): # add room for instruction offset module.add(SAddU32(dst=sgpr(scalarGro), src0=sgpr(scalarGro), src1=self.buff_load_inst_offset_max, comment="shift for UseInstOffsetForGRO")) - ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"] + ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"] if kernel["LdsBlockSizePerPad%s"%tc] != 0: ldsInc += (ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"] else: - padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr + padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr ldsInc += (ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"] # buffer_load only support 12 bit instruction offset @@ -4260,7 +4260,7 @@ def lwaFirstOffset(self, kernel, tP): dst=sgpr("LocalWriteAddr%s"%tc), \ src=vgpr(tmpv), \ comment="Copy lds write address VGPR to SGPR")) - lwastride = int((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%c"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"]) + lwastride = int((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%s"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"]) module.add(SMulI32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=lwastride )) if tc == 'B': module.add(SAddU32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), \ @@ -8405,13 +8405,13 @@ def globalReadGuardKBody(tP, optParams = None): # need to increment ldsInc only once per each loopCnt # this is pre count up, so increment it at r == 0 if r == 0: - ldsInc = int((self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"]) + ldsInc = int((self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"]) else: ldsInc = 0 if kernel["LdsBlockSizePerPad%s"%tc] != 0: ldsInc += int((ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"]) else: - padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr + padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr ldsInc += int((ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"]) if kernel["UseInstOffsetForGRO"]: # buffer_load only support 12 bit instruction offset @@ -8889,8 +8889,6 @@ def directToLdsM0Update(self, kernel, mode, tP, skipWait = False): DtldsModule.addComment0("before DirectToLds load, ensure prior ds_reads have finished") DtldsModule.add(SWaitCnt(dscnt=0, comment="")) DtldsModule.add(SBarrier()) - if "MX" in tP: - imod.add(self.directToLdsM0Update(kernel, 0, tP["MX"])) return imod @@ -8980,6 +8978,10 @@ def globalReadBody(tP): else: destVgprPrefix = "G2L%s"%(tc) + # add m0 init code for MX here + if tc == "MXSA" or tc == "MXSB": + imod.middle.add(self.directToLdsM0Update(kernel, mode, tP, skipWait=True)) + loopCnt = -1 for perp in range(0, tP["nrp"]): for sPerp in range(0, tP["nrpv"]): @@ -9035,11 +9037,11 @@ def globalReadBody(tP): if kernel["DirectToLds%s"%tc]: # use bpe with GlobalReadVectorWidth - ldsInc = int((self.states.kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%c"%tc] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"] * kernel["GlobalReadVectorWidth%c"%tc]) * tP["bpeGR"]) + ldsInc = int((self.states.kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%s"%tc] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"] * kernel["GlobalReadVectorWidth%s"%tc]) * tP["bpeGR"]) if kernel["LdsBlockSizePerPad%s"%tc] != 0: ldsInc += int((ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"]) else: - padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr + padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr ldsInc += int((ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"]) if kernel["UseInstOffsetForGRO"]: diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index ad3bacd33db0..997d1345eb6b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -1240,11 +1240,10 @@ def assignDerivedParameters( # set True for DTL state["UseGeneralizedNLCOneA"] = state["DirectToLdsA"] state["UseGeneralizedNLCOneB"] = state["DirectToLdsB"] - # MX block does not use DTL, so set to False if state["ProblemType"]["MXBlockA"]: - state["UseGeneralizedNLCOneMXSA"] = False + state["UseGeneralizedNLCOneMXSA"] = False #state["DirectToLdsA"] if state["ProblemType"]["MXBlockB"]: - state["UseGeneralizedNLCOneMXSB"] = False + state["UseGeneralizedNLCOneMXSB"] = False #state["DirectToLdsB"] state["LocalWriteUseSgprA"] = False state["LocalWriteUseSgprB"] = False @@ -1464,8 +1463,8 @@ def assignDerivedParameters( state["WaveSeparateGlobalReadMXSA"] = state["WaveSeparateGlobalReadA"] state["NumLoadsCoalescedMXSA"] = state["NumLoadsCoalescedA"] Solution.checkAndAssignWaveSeparateGlobalRead(state, 'MXSA', printRejectionReason) - state["DirectToLdsMXSA"] = False - state["LocalWriteUseSgprMXSA"] = False + state["DirectToLdsMXSA"] = state["DirectToLdsA"] + state["LocalWriteUseSgprMXSA"] = state["DirectToLdsMXSA"] state["ProblemType"]["MirrorDimsMXSA"] = list(state["ProblemType"]["MirrorDimsA"]) state["VectorWidthMXSA"] = state["VectorWidthA"] state["MIWaveTileMXSA"] = state["MIWaveTileA"] @@ -1478,8 +1477,8 @@ def assignDerivedParameters( state["WaveSeparateGlobalReadMXSB"] = state["WaveSeparateGlobalReadB"] state["NumLoadsCoalescedMXSB"] = state["NumLoadsCoalescedB"] Solution.checkAndAssignWaveSeparateGlobalRead(state, 'MXSB', printRejectionReason) - state["DirectToLdsMXSB"] = False - state["LocalWriteUseSgprMXSB"] = False + state["DirectToLdsMXSB"] = state["DirectToLdsB"] + state["LocalWriteUseSgprMXSB"] = state["DirectToLdsMXSB"] state["ProblemType"]["MirrorDimsMXSB"] = list(state["ProblemType"]["MirrorDimsB"]) state["VectorWidthMXSB"] = state["VectorWidthB"] state["MIWaveTileMXSB"] = state["MIWaveTileB"] @@ -1758,10 +1757,10 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: else: ldsPadA = state["VectorWidthA"] else: - ldsPadA = max(state["GlobalReadVectorWidthA"],optPadA) - ## turn-off padding for directToLds if state["DirectToLdsA"]: - ldsPadA = 0 + ldsPadA = max(lrvwA, optPadA) if not state["ProblemType"]["TLUA"] else 0 + else: + ldsPadA = max(state["GlobalReadVectorWidthA"],optPadA) assert(ldsPadA >= 0) if ldsPadB == -1: @@ -1781,9 +1780,10 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: else: ldsPadB = state["VectorWidthB"] else: - ldsPadB = max(state["GlobalReadVectorWidthB"],optPadB) if state["DirectToLdsB"]: - ldsPadB = 0 + ldsPadB = max(lrvwB, optPadB) if not state["ProblemType"]["TLUB"] else 0 + else: + ldsPadB = max(state["GlobalReadVectorWidthB"],optPadB) assert(ldsPadB >= 0) if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]: From 6a5c7772b1f25c18d18423d0f39844092fceda86 Mon Sep 17 00:00:00 2001 From: Koji Nakajima Date: Fri, 20 Feb 2026 07:13:43 +0000 Subject: [PATCH 2/3] Added reject conditions for DirectToLdsMXSA/B --- .../Tensile/SolutionStructs/Solution.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index 997d1345eb6b..e963451b1389 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -829,16 +829,12 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): #TN # use for all precisions with TransposeLDS=1 - numBytesAB = state["ProblemType"]["DataType%s"%tc].numBytes() + if tc in ["MXSA", "MXSB"]: + numBytesAB = 1 + else: + numBytesAB = state["ProblemType"]["DataType%s"%tc].numBytes() numBytesPerLoad = state["GlobalReadVectorWidth%s"%tc] * numBytesAB - MT = state["MacroTile0"] if tc == 'A' else state["MacroTile1"] - - if (MT & (MT-1)) != 0 and not state["UseGeneralizedNLCOne%s"%tc]: # Check of MT not power of 2 - # so far, numBytesAB<=4 case, TLU=False only (continue with False) - if (numBytesAB <= 4) and state["ProblemType"]["TLU%c"%tc]: - return False - # x2 DTL is not supported if numBytesPerLoad == 8: printWarning("can't use DirectToLds with b64 buffer load, using non DirectToLds version instead") @@ -852,6 +848,17 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): reject(state, printRejectionReason, "DirectToLds not supported for loads less than 32bits") return False + if tc in ["MXSA", "MXSB"]: + # MXSA/B case, check numBytesPerLoad only + return True + + MT = state["MacroTile0"] if tc == 'A' else state["MacroTile1"] + + if (MT & (MT-1)) != 0 and not state["UseGeneralizedNLCOne%s"%tc]: # Check of MT not power of 2 + # so far, numBytesAB<=4 case, TLU=False only (continue with False) + if (numBytesAB <= 4) and state["ProblemType"]["TLU%c"%tc]: + return False + # so far MFMA only (TODO: enable non MFMA case) if not state["EnableMatrixInstruction"]: reject(state, printRejectionReason, "DirectToLds is for MatrixInstruction only for now (tentative)") @@ -2817,13 +2824,21 @@ def calSwizzlePackK(state, tc): # LDS (load size coalesced) * LSPA must load some multiple of 256 bytes. # No longer support loadX2/loadx4 . for tc in ['A', 'B']: + tcmx = "MXS%s"%tc if state["DirectToLds%s"%tc]: isDtlDoable = Solution.isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason) if (not state["DirectToVgpr%s"%tc]) and isDtlDoable: state["DirectToLds%s"%tc] = True state["LocalWriteUseSgpr%s"%tc] = True + # MX case + if state["ProblemType"]["MXBlock%s"%tc]: + isDtlMxDoable = Solution.isDirectToLdsDoable(state, tcmx, isaInfoMap, printRejectionReason) + state["DirectToLds%s"%tcmx] = isDtlMxDoable else: state["DirectToLds%s"%tc] = False + # MX case + if state["ProblemType"]["MXBlock%s"%tc]: + state["DirectToLds%s"%tcmx] = False if not isDtlDoable: if state["UseGeneralizedNLCOne%s"%tc]: reject(state, printRejectionReason, "DirectToLds%s not doable, but GNLC%s enabled, rejecting"%(tc, tc)) @@ -2840,6 +2855,10 @@ def calSwizzlePackK(state, tc): if state["1LDSBuffer"] == -1 and state["DirectToLds"]: #1LDS buffer must be 0 for DirectToLdsA state["1LDSBuffer"] = 0 + # MX case + if state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]: + if state["DirectToLdsA"] != state["DirectToLdsMXSA"] or state["DirectToLdsB"] != state["DirectToLdsMXSB"]: + reject(state, printRejectionReason, "DirectToLdsA/B and DirectToLdsMXSA/B should match") # does not work with UnrollLoopSwapGlobalReadOrder if (state["DirectToLds"] == 2 or state["DirectToLds"] == 3) and state["UnrollLoopSwapGlobalReadOrder"]: From bf76ee59097e5e0e8a5ba7e9e8f73c79dc4fa67e Mon Sep 17 00:00:00 2001 From: Koji Nakajima Date: Fri, 20 Feb 2026 20:15:14 +0000 Subject: [PATCH 3/3] Added test case --- .../Tests/common/gemm/gfx950/mx32f4_tn.yaml | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml index c9d7d2a3bdc6..bb0b55637dc3 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml @@ -145,6 +145,36 @@ BenchmarkProblems: - BiasTypeArgs: ['s'] - ActivationArgs: - [Enum: none] + - # BenchmarkProblemSizeGroup - DTL + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 8,8, 2,2] # 64x64 + - [16, 16, 128, 1, 1, 8,2, 2,2] # 128x64 + - [16, 16, 128, 1, 1, 2,8, 2,2] # 64x128 + - [32, 32, 64, 1, 1, 4,4, 2,2] # 128x128 + - DepthU: [128] + - AssertSummationElementMultiple: [32] + - LocalReadVectorWidth: [32] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - ScheduleIterAlg: [3] + - 1LDSBuffer: [0] + - DirectToLds: [1,2,3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [32, 16, 1, 1024] + - Exact: [256, 256, 1, 128] + - Exact: [1025, 513, 1, 2048] + - Exact: [127, 127, 1, 640] #special cleanup case + - Exact: [129, 129, 1, 640] + - Exact: [128, 128, 1, 512] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] # LibraryLogic: # ScheduleName: "aquavanjaram"