diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py index c776e349283..379e511ea90 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py @@ -309,6 +309,9 @@ def getLocalWriteMFMAStart(writer, kernel, tensorParametersA, tensorParametersB, # TODO: replace here for real number of globalReadIncInst # numGRIncInst = 18 # Always on. Original logic: 12 if not kernel["StaggerU"] else 18 numGRIncInst = 12 if not writer.states.staggerUCode else 18 + if kernel["ProblemType"]["MXBlockA"] and kernel["ProblemType"]["MXBlockB"]: + # MXSA+MXSB case, double the number + numGRIncInst *= 2 numInstPerMfma = max(roundUp(writer.states.miLatencyLeft/2),1) numMfmaToSched = roundUp(numGRIncInst/numInstPerMfma) lwStartMfmaIndex = 1 + numMfmaToSched diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 2d825c3825d..7659031a09b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -4666,8 +4666,13 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB): self.asmAssert = Assert(self.states.laneSGPRCount, kernel["WavefrontSize"], self.db["EnableAsserts"]) self.states.tailloopInNll = kernel["TailloopInNll"] - # remove staggerU code for tailloopInNll (cannot support staggerU) - self.states.staggerUCode = not self.states.tailloopInNll + # remove staggerU code for the following cases + # - tailloopInNll (cannot support staggerU) + # - StreamK + MX (not enough sgpr) + self.states.staggerUCode = True + if self.states.tailloopInNll or \ + (kernel["StreamK"] and (kernel["ProblemType"]["MXBlockA"] or kernel["ProblemType"]["MXBlockB"])): + self.states.staggerUCode = False self.states.tailloopInNllmaxUnit = 1 if self.states.tailloopInNll: tluA = kernel["ProblemType"]["TLUA"] diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index ab98b2dac20..829e01b040f 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -3367,14 +3367,15 @@ def subCheckLdsBlockSizePerPad(tc, idx): # number of minimum GR inc inst per MFMA # default 1 # Set at least 2 for gfx950 + MI16 + smaller MT case + # Set 3 for MX (number of GRInc is doubled) if state["MinGRIncPerMfma"] == -1: state["MinGRIncPerMfma"] = 1 if isa == (9, 5, 0): - if state["EnableMatrixInstruction"] and state["MatrixInstM"] == 16 and state["MatrixInstK"] == 1: - if numMFMA<=32: - state["MinGRIncPerMfma"] = 2 - if numMFMA<=16: + if state["EnableMatrixInstruction"] and state["MatrixInstM"] == 16 and state["MatrixInstB"] == 1: + if numMFMA<=16 or (state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]): state["MinGRIncPerMfma"] = 3 + elif numMFMA<=32: + state["MinGRIncPerMfma"] = 2 # calculate ldsPad state["LdsPadA"], state["LdsPadB"], state["LdsPadMetadata"] = calcLdsPad(state["LocalReadVectorWidth"], isaInfoMap) diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp index 23384eda6cd..7bf48655c73 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -62,7 +62,7 @@ namespace rocisa None = Count }; - inline int dataTypeToBytes(DataType type) + inline float dataTypeToBytes(DataType type) { switch(type) { @@ -104,6 +104,12 @@ namespace rocisa return 1; case DataType::BFloat8Float8: return 1; + case DataType::Float6: + return 0.75; + case DataType::BFloat6: + return 0.75; + case DataType::Float4: + return 0.5; default: return -1; // Invalid type } diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp index 9293205e13e..255cba18dac 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -43,7 +43,7 @@ namespace rocisa && matrixInstB == 1) { if(dataType == DataType::Half || dataType == DataType::BFloat16 - || dataType == DataType::Int8 || is8bitFloat(dataType)) + || dataType == DataType::Int8 || is8bitFloat(dataType) || numBytes < 1) { mi_divisor = 4; miIssueLatency = 1;