diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/Signature.py b/projects/hipblaslt/tensilelite/Tensile/Components/Signature.py index d178722aff7..5da27b664bf 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/Signature.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/Signature.py @@ -221,20 +221,19 @@ def __call__(self, writer) -> SignatureBase: if kernel["ProblemType"]["UseScaleAB"]: signature.addArg("AddressScaleA", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") signature.addArg("AddressScaleB", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") - userArgumentsInfo.scaleASize += 8 - userArgumentsInfo.scaleBSize += 8 + userArgumentsInfo.scaleASize += 8 + userArgumentsInfo.scaleBSize += 8 if kernel["ProblemType"]["UseScaleCD"]: signature.addArg("AddressScaleC", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") signature.addArg("AddressScaleD", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") - userArgumentsInfo.scaleCSize += 8 - userArgumentsInfo.scaleDSize += 8 + userArgumentsInfo.scaleCSize += 8 + userArgumentsInfo.scaleDSize += 8 if kernel["ProblemType"]["UseScaleAlphaVec"]: signature.addArg("AddressScaleAlphaVec", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") if kernel["ProblemType"]["UseScaleAlphaVec"] == 3: userArgumentsInfo.factorDimSize =4 - - userArgumentsInfo.scaleAlphaVecSize += 8 + userArgumentsInfo.scaleAlphaVecSize += 8 if writer.states.useBias != DataDirection.NONE: signature.addArg("bias", SVK.SIG_GLOBALBUFFER, biasValueType, "generic") # Note: We append the data in ws_d @@ -243,7 +242,7 @@ def __call__(self, writer) -> SignatureBase: signature.addArg("StrideBias", SVK.SIG_VALUE, "u32") if kernel["ProblemType"]["UseBias"] == 3: userArgumentsInfo.factorDimSize = 4 - userArgumentsInfo.biasSize += (8 + 4 + 4) + userArgumentsInfo.biasSize += (8 + 4 + 4) if userArgumentsInfo.factorDimSize == 4: signature.addArg("factorDim", SVK.SIG_VALUE, "u32") @@ -252,9 +251,9 @@ def __call__(self, writer) -> SignatureBase: signature.addArg( "E", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") for i in range(0, writer.states.e.numSgprStrides): signature.addArg("StrideE%u"%i, SVK.SIG_VALUE, "u32") - userArgumentsInfo.eSize += 8 - for i in range(0, writer.states.e.numSgprStrides): - userArgumentsInfo.eSize += 4 + userArgumentsInfo.eSize += 8 + for i in range(0, writer.states.e.numSgprStrides): + userArgumentsInfo.eSize += 4 if ((kernel["ProblemType"]["ActivationType"] != 'none') and kernel["ActivationFused"]): if kernel["ProblemType"]["ActivationComputeDataType"].isHalf(): diff --git a/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp b/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp index afbf36ddec5..51df3f01c4c 100644 --- a/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp +++ b/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp @@ -1709,7 +1709,10 @@ namespace TensileLite or i == ContractionProblemGemm::TENSOR::METADATA) continue; - if(useMXGenerator && (i == ContractionProblemGemm::TENSOR::A || i == ContractionProblemGemm::TENSOR::B)) + if(useMXGenerator && (i == ContractionProblemGemm::TENSOR::A + || i == ContractionProblemGemm::TENSOR::B + || i == ContractionProblemGemm::TENSOR::MXSA + || i == ContractionProblemGemm::TENSOR::MXSB)) continue; if(m_problemDependentData) diff --git a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp index 03b49626763..4152c3bfc21 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp @@ -809,8 +809,7 @@ namespace TensileLite // NOTE: an assumption here is A & B must be both MX data types or non-MX data types. // Mixing is not supported. - if(!problemType.useScaleAB.empty() or - (problemType.mxBlockA != 0 && problemType.mxBlockB != 0)) //kernel input data + if(!problemType.useScaleAB.empty()) { args.template append("scaleA", inputs.scaleA); args.template append("scaleB", inputs.scaleB);