diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index cbaf9b26cf32..ee29bd611cbd 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -5481,13 +5481,19 @@ def readWriteVectors(mat, vw, kernel): self.states.a.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprA"] else 1 * self.states.rpla if kernel["ProblemType"]["MXBlockA"]: self.states.mxsa.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprMXSA"] else 1 * self.states.rpla + self.states.mxsa.numVgprLocalWriteSwapAddr = 0 self.states.b.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprB"] else 1 * self.states.rpla if kernel["ProblemType"]["MXBlockB"]: self.states.mxsb.numVgprLocalWriteAddr = 0 if kernel["LocalWriteUseSgprMXSB"] else 1 * self.states.rpla + self.states.mxsb.numVgprLocalWriteSwapAddr = 0 self.states.m.numVgprLocalWriteAddr = 0 if kernel["ProblemType"]["Sparse"] and kernel["LocalWriteUseSgprMetadata"] else 1 * self.states.rpla self.states.a.numVgprLocalReadSwapAddr = 0 self.states.b.numVgprLocalReadSwapAddr = 0 self.states.m.numVgprLocalReadSwapAddr = 0 + if kernel["ProblemType"]["MXBlockA"]: + self.states.mxsa.numVgprLocalReadSwapAddr = 0 + if kernel["ProblemType"]["MXBlockB"]: + self.states.mxsb.numVgprLocalReadSwapAddr = 0 self.states.a.numVgprLocalWriteSwapAddr = 0 self.states.b.numVgprLocalWriteSwapAddr = 0 self.states.m.numVgprLocalWriteSwapAddr = 0 @@ -5559,8 +5565,14 @@ def readWriteVectors(mat, vw, kernel): if kernel["StoreSwapAddr"]: if self.states.a.numVgprLocalReadAddr > 0: self.states.a.numVgprLocalReadSwapAddr = 1 + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalReadAddr > 0: + self.states.mxsa.numVgprLocalReadSwapAddr = 1 if self.states.b.numVgprLocalReadAddr > 0: self.states.b.numVgprLocalReadSwapAddr = 1 + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalReadAddr > 0: + self.states.mxsb.numVgprLocalReadSwapAddr = 1 if self.states.m.numVgprLocalReadAddr > 0: self.states.m.numVgprLocalReadSwapAddr = 1 if not kernel["LocalWriteUseSgprA"] and self.states.a.numVgprLocalWriteAddr > 0: @@ -5569,6 +5581,12 @@ def readWriteVectors(mat, vw, kernel): self.states.b.numVgprLocalWriteSwapAddr = 1 if kernel["ProblemType"]["Sparse"] and not kernel["LocalWriteUseSgprMetadata"] and self.states.m.numVgprLocalWriteAddr > 0: self.states.m.numVgprLocalWriteSwapAddr = 1 + if kernel["ProblemType"]["MXBlockA"]: + if not kernel["LocalWriteUseSgprMXSA"] and self.states.mxsa.numVgprLocalWriteAddr > 0: + self.states.mxsa.numVgprLocalWriteSwapAddr = 1 + if kernel["ProblemType"]["MXBlockB"]: + if not kernel["LocalWriteUseSgprMXSB"] and self.states.mxsb.numVgprLocalWriteAddr > 0: + self.states.mxsb.numVgprLocalWriteSwapAddr = 1 # Note: MXSA/MXSB do not use swap addresses #################################### @@ -6147,6 +6165,14 @@ def GNLCOInit(tc): if self.states.b.numVgprLocalReadSwapAddr > 0: self.states.b.startVgprLocalReadSwapAddr = vgprIdx vgprIdx += 1 + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalReadSwapAddr > 0: + self.states.mxsa.startVgprLocalReadSwapAddr = vgprIdx + vgprIdx += 1 + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalReadSwapAddr > 0: + self.states.mxsb.startVgprLocalReadSwapAddr = vgprIdx + vgprIdx += 1 if self.states.a.numVgprLocalWriteSwapAddr > 0: self.states.a.startVgprLocalWriteSwapAddr = vgprIdx vgprIdx += 1 @@ -6156,6 +6182,14 @@ def GNLCOInit(tc): if self.states.b.numVgprLocalWriteSwapAddr > 0: self.states.b.startVgprLocalWriteSwapAddr = vgprIdx vgprIdx += 1 + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalWriteSwapAddr > 0: + self.states.mxsa.startVgprLocalWriteSwapAddr = vgprIdx + vgprIdx += 1 + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalWriteSwapAddr > 0: + self.states.mxsb.startVgprLocalWriteSwapAddr = vgprIdx + vgprIdx += 1 # X32F Emulation initializations # meaning of variables @@ -6583,6 +6617,10 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): self.defineSgpr("SwapA", 1) if kernel["LocalWriteUseSgprB"]: self.defineSgpr("SwapB", 1) + if kernel["ProblemType"]["MXBlockA"] and kernel["LocalWriteUseSgprMXSA"]: + self.defineSgpr("SwapMXSA", 1) + if kernel["ProblemType"]["MXBlockB"] and kernel["LocalWriteUseSgprMXSB"]: + self.defineSgpr("SwapMXSB", 1) if kernel["ProblemType"]["Sparse"] and kernel["LocalWriteUseSgprMetadata"]: self.defineSgpr("SwapMetadata", 1) diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index d9a0f4bbe4df..b879a082a5cc 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -1171,6 +1171,14 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: if self.states.m.numVgprLocalReadSwapAddr > 0: module.add(RegSet("v", "vgprLocalReadSwapAddrMetadata", \ self.states.m.startVgprLocalReadSwapAddr)) + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalReadSwapAddr > 0: + module.add(RegSet("v", "vgprLocalReadSwapAddrMXSA", \ + self.states.mxsa.startVgprLocalReadSwapAddr)) + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalReadSwapAddr > 0: + module.add(RegSet("v", "vgprLocalReadSwapAddrMXSB", \ + self.states.mxsb.startVgprLocalReadSwapAddr)) if self.states.a.numVgprLocalWriteSwapAddr > 0: module.add(RegSet("v", "vgprLocalWriteSwapAddrA", \ self.states.a.startVgprLocalWriteSwapAddr)) @@ -1180,6 +1188,14 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: if self.states.m.numVgprLocalWriteSwapAddr > 0: module.add(RegSet("v", "vgprLocalWriteSwapAddrMetadata", \ self.states.m.startVgprLocalWriteSwapAddr)) + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalWriteSwapAddr > 0: + module.add(RegSet("v", "vgprLocalWriteSwapAddrMXSA", \ + self.states.mxsa.startVgprLocalWriteSwapAddr)) + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalWriteSwapAddr > 0: + module.add(RegSet("v", "vgprLocalWriteSwapAddrMXSB", \ + self.states.mxsb.startVgprLocalWriteSwapAddr)) if kernel["ProblemType"]["OutputAmaxD"]: module.add(RegSet("v", "vgprAmaxOut", self.startVgprAmaxOut)) @@ -1771,10 +1787,16 @@ def localReadAddresses(self, kernel, tPA, tPB, tPM): if self.states.a.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPA)) module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPA, False, True)) + if kernel["ProblemType"]["MXBlockA"]: + module.add(self.lraSwapAddressesForDTLPad(kernel, tPA["MX"])) + module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPA["MX"], False, True)) if self.states.b.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPB)) module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPB, False, True)) + if kernel["ProblemType"]["MXBlockA"]: + module.add(self.lraSwapAddressesForDTLPad(kernel, tPB["MX"])) + module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPB["MX"], False, True)) if self.states.m.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPM)) @@ -4598,8 +4620,7 @@ def lwaFirstOffset(self, kernel, tP): self.vgprPool.checkIn(tmpv) self.vgprPool.checkIn(destVgpr) - # MXSA/MXSB do not use swap addresses (per mx_tony implementation) - if kernel["StoreSwapAddr"] and tc not in ("MXSA", "MXSB"): + if kernel["StoreSwapAddr"]: if kernel["LocalWriteUseSgpr%s"%tc]: # needed for the VReadfirstlaneB32 in the prior code block if self.states.archCaps["CrosslaneWait"]: @@ -4868,8 +4889,7 @@ def lraSwapAddressesForDTLPad(self, kernel, tP): module = Module("lraSwapAddressesForDTLPad") tc = tP["tensorChar"] - # MXSA/MXSB do not use swap addresses (per mx_tony implementation) - if kernel["StoreSwapAddr"] and tc not in ("MXSA", "MXSB"): + if kernel["StoreSwapAddr"]: module.add(VAddU32(dst=vgpr("LocalReadSwapAddr%s"%tc), src0=kernel["LdsOffsetA_Blk"], src1=vgpr("LocalReadAddr%s"%tc), \ comment="Calculate starting lds addr of second buffer" )) module.add(VXorB32(dst=vgpr("LocalReadSwapAddr%s"%tc), \ @@ -9808,6 +9828,18 @@ def localWriteAddRound(tc): module.add(SCmpEQU32(src0=sgpr("LDSBufferWriteInc"), src1="LdsBlockEndSize", comment="LDSBufferWriteInc == End ?")) module.add(SCMovB32(dst=sgpr("LDSBufferWriteInc"), src=0, comment="LDSBufferWriteInc loop back to 0")) + def getSrc0Val(tc): + src0Val = None + if kernel["StoreSwapAddr"]: + if kernel["LocalWriteUseSgpr%s"%tc]: + src0Val = sgpr("Swap%s"%tc) + else: + src0Val = vgpr("LocalWriteSwapAddr%s"%tc) + else: + # Using inlined constants + src0Val = hex(kernel["LdsOffsetA_Blk"]) + return src0Val + if needSwap: #fixme-iui need to use wrapping increment for double or triple buffering: @@ -9822,25 +9854,13 @@ def localWriteAddRound(tc): # (numLDSBlk>=3 is for DTL (and LocalWriteUseSgpr) only) localWriteAddRound(tc) else: - src0Val = None - if kernel["StoreSwapAddr"]: - if kernel["LocalWriteUseSgpr%s"%tc]: - src0Val = sgpr("Swap%s"%tc) - else: - src0Val = vgpr("LocalWriteSwapAddr%s"%tc) - else: - # Using inlined constants - TODO: always use LdsOffsetA_Blk for A/B/MXSA/MXSB? - src0Val = hex(kernel["LdsOffsetA_Blk"]) - + src0Val = getSrc0Val(tc) numLwa = self.states.a.numVgprLocalWriteAddr if tP["isA"] else self.states.b.numVgprLocalWriteAddr localWriteSwapXOR(tc, src0Val, numLwa) - #TODO: if "MX" in tP: tc = tP["MX"]["tensorChar"] - src0Val = hex(kernel["LdsOffsetA_Blk"]) - #TODO:check to use isMXSA instead of isA - #numLwa = self.states.mxsa.numVgprLocalWriteAddr if tP["isA"] else self.states.mxsb.numVgprLocalWriteAddr - numLwa = self.states.mxsa.numVgprLocalWriteAddr if tP["isMXSA"] else self.states.mxsb.numVgprLocalWriteAddr + src0Val = getSrc0Val(tc) + numLwa = self.states.mxsa.numVgprLocalWriteAddr if tP["MX"]["isMXSA"] else self.states.mxsb.numVgprLocalWriteAddr localWriteSwapXOR(tc, src0Val, numLwa) # This used to control where to store the metadata @@ -9855,14 +9875,7 @@ def localWriteAddRound(tc): tPM["localWriteSwapByteOffset"] = 0 if tPM["localWriteSwapByteOffset"] else kernel["LdsOffsetA_Blk"] module.addComment1("(EPS=1) local write swap internal offset -> %u" % tPM["localWriteSwapByteOffset"]) else: - if kernel["StoreSwapAddr"]: - if kernel["LocalWriteUseSgpr%s"%tc]: - src0Val = sgpr("Swap%s"%tc) - else: - src0Val = vgpr("LocalWriteSwapAddr%s"%tc) - else: - # Using inlined constants - src0Val = hex(kernel["LdsOffsetA_Blk"]) + src0Val = getSrc0Val(tc) numLwa = self.states.m.numVgprLocalWriteAddr localWriteSwapXOR(tc, src0Val, numLwa) return module @@ -10843,13 +10856,6 @@ def localReadSwapOffsets(self, kernel, internalPointerSwap, tP): if not kernel["StoreSwapAddr"]: tP["localReadSwapByteOffset"] = 0 if tP["localReadSwapByteOffset"] else kernel["LdsOffsetA_Blk"] module.addComment1("local read swap internal offset -> %u" % tP["localReadSwapByteOffset"]) - elif tc in ("MXSA", "MXSB"): - #TODO: MXSA/MXSB do not use swap addresses, use fixed offset instead - module.add(VXorB32( - dst=vgpr("LocalReadAddr%s"%tc), \ - src0=hex(kernel["LdsOffsetA_Blk"]), \ - src1=vgpr("LocalReadAddr%s"%tc), \ - comment="swap Red Blk")) else: module.add(VXorB32( dst=vgpr("LocalReadAddr%s"%tc), \ diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index 987e3fc5039e..bd93f33ef22d 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -2321,6 +2321,10 @@ def calSwizzlePackK(state, tc): curGRVW *= 2 if state["ProblemType"]["MXBlockA"]: state["GlobalReadVectorWidthMXSA"] = max(state["MacroTile0"] * state["_DepthUMXSA"] // state["NumThreads"], 1) + # workaround for DTL + # use 32bit load for DTL if GlobalReadVectorWidthMXS is 8 + if state["DirectToLdsMXSA"] and state["GlobalReadVectorWidthMXSA"] == 8: + state["GlobalReadVectorWidthMXSA"] = 4 else: # dot2 if state["UseDotInstruction"]: @@ -2365,6 +2369,10 @@ def calSwizzlePackK(state, tc): curGRVW *= 2 if state["ProblemType"]["MXBlockB"]: state["GlobalReadVectorWidthMXSB"] = max(state["MacroTile1"] * state["_DepthUMXSB"] // state["NumThreads"], 1) + # workaround for DTL + # use 32bit load for DTL if GlobalReadVectorWidthMXS is 8 + if state["DirectToLdsMXSB"] and state["GlobalReadVectorWidthMXSB"] == 8: + state["GlobalReadVectorWidthMXSB"] = 4 else: # dot2 if state["UseDotInstruction"]: @@ -3462,13 +3470,10 @@ def setLdsOffsets(offsetBlk, numLdsBlk, ldsNumBytesB): # TODO: # Disable StoreSwapAddr to ensure LdsOffsetA_Blk is always a power of 2 # This is consistent with referenc implementation which doesn't have StoreSwapAddr - if state["ProblemType"]["MXBlockA"]: - state["StoreSwapAddr"] = False - else: - # Original logic: - state["StoreSwapAddr"] = (state["PrefetchGlobalRead"] == 2) and \ - (state["1LDSBuffer"] == 0) and numLdsBlk == 2 and \ - (offsetBlk + roundupOffsetBlk) > state["MaxLDS"] + # Original logic: + state["StoreSwapAddr"] = (state["PrefetchGlobalRead"] == 2) and \ + (state["1LDSBuffer"] == 0) and numLdsBlk == 2 and \ + (offsetBlk + roundupOffsetBlk) > state["MaxLDS"] if offsetBlk > 0 and not state["StoreSwapAddr"] and numLdsBlk == 2: # Rounds offsetBlk to a power of two to enable inlining {s,v}_xor constants for swapping offsets