diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py b/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py index d9c3f23e717..973e80431c4 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py @@ -124,7 +124,11 @@ def skTileIndex(self, writer, kernel, sTmp, tPA, tPB): # Always reset pointers to handle odd-exit case which moves LRO to the upper bank if kernel["PrefetchGlobalRead"]: # not self.prefetchAcrossPersistent module.add(writer.localReadResetOffsets(kernel, tPA)) + if kernel["ProblemType"]["MXBlockA"]: + module.add(writer.localReadResetOffsets(kernel, tPA["MX"])) module.add(writer.localReadResetOffsets(kernel, tPB)) + if kernel["ProblemType"]["MXBlockB"]: + module.add(writer.localReadResetOffsets(kernel, tPB["MX"])) module.addComment0("StreamK calculate tile idx and map to WG") diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 8e238c6d8b5..ab39b02f8aa 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -5919,8 +5919,14 @@ def GNLCOInit(tc): # Need backup for the first LocalReadAddr only (others will be calculated from the first one) self.states.a.startVgprLocalReadAddrOrig = vgprIdx vgprIdx += 1 if self.states.a.numVgprLocalReadAddr > 0 else 0 + if kernel["ProblemType"]["MXBlockA"]: + self.states.mxsa.startVgprLocalReadAddrOrig = vgprIdx + vgprIdx += 1 if self.states.mxsa.numVgprLocalReadAddr > 0 else 0 self.states.b.startVgprLocalReadAddrOrig = vgprIdx vgprIdx += 1 if self.states.b.numVgprLocalReadAddr > 0 else 0 + if kernel["ProblemType"]["MXBlockB"]: + self.states.mxsb.startVgprLocalReadAddrOrig = vgprIdx + vgprIdx += 1 if self.states.mxsb.numVgprLocalReadAddr > 0 else 0 # ---------------------------- # TODO: alignment hack, figure out a better solution diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index e7e1257ea36..5c11107efce 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -1156,9 +1156,17 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: if self.states.a.numVgprLocalReadAddr > 0: module.add(RegSet("v", "vgprLocalReadAddrOrigA", \ self.states.a.startVgprLocalReadAddrOrig)) + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalReadAddr > 0: + module.add(RegSet("v", "vgprLocalReadAddrOrigMXSA", \ + self.states.mxsa.startVgprLocalReadAddrOrig)) if self.states.b.numVgprLocalReadAddr > 0: module.add(RegSet("v", "vgprLocalReadAddrOrigB", \ self.states.b.startVgprLocalReadAddrOrig)) + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalReadAddr > 0: + module.add(RegSet("v", "vgprLocalReadAddrOrigMXSB", \ + self.states.mxsb.startVgprLocalReadAddrOrig)) if self.states.m.numVgprLocalReadAddr > 0: module.add(RegSet("v", "vgprLocalReadAddrMetadata", \ self.states.m.startVgprLocalReadAddr))