From beff7e17435516de96231601640d478f6e1d205a Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Tue, 12 Sep 2023 05:07:34 -0500 Subject: [PATCH 1/9] Change the formula of ScaleA/B if DataTypeA/B > DataType Add: if DataTypeA/B > DataType data[i]/scale before mfma --- tensilelite/Tensile/Components/Signature.py | 8 +- tensilelite/Tensile/KernelWriter.py | 7 + tensilelite/Tensile/KernelWriterAssembly.py | 113 ++++++++++--- .../Tensile/Source/client/include/TypedId.hpp | 56 +++--- .../Source/client/source/Reference.cpp | 159 +++++++++++++++--- .../Source/lib/source/ContractionSolution.cpp | 14 +- 6 files changed, 290 insertions(+), 67 deletions(-) diff --git a/tensilelite/Tensile/Components/Signature.py b/tensilelite/Tensile/Components/Signature.py index 10047fc0ad..9d0df42a48 100644 --- a/tensilelite/Tensile/Components/Signature.py +++ b/tensilelite/Tensile/Components/Signature.py @@ -188,9 +188,11 @@ def __call__(self, writer) -> SignatureBase: if not kernel["ProblemType"]["GroupedGemm"]: signature.addArg( "gsu", SVK.SIG_VALUE, "u32") - if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1): - signature.addArg("AddressScaleA", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") - signature.addArg("AddressScaleB", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") + if kernel["ProblemType"]["UseScaleAB"]: + if (kernel["GlobalSplitU"] == 1) or (kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters()): + signature.addArg("AddressScaleA", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") + if (kernel["GlobalSplitU"] == 1) or (kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters()): + signature.addArg("AddressScaleB", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") if kernel["ProblemType"]["UseScaleCD"] and (kernel["GlobalSplitU"] == 1): signature.addArg("AddressScaleC", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") signature.addArg("AddressScaleD", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index 081ec83984..6fea2a88c2 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -3449,6 +3449,13 @@ def readWriteVectors(mat, vw, kernel): # for conditionals self.states.lastPostLoopSgpr = self.sgprPool.size() + if kernel["ProblemType"]["UseScaleAB"]: + for name in ['A','B']: + if kernel["ProblemType"]["DataType%s"%name].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + self.defineSgpr("AddressScale%s"%name, 2, 2) + self.defineSgpr("Scale%s"%name, numSgprAlpha, numSgprAlpha if numSgprAlpha > 1 else 2) + + self.states.numSgprToLoad = self.states.numSgprSizesFree + self.states.numSgprSizesSum + \ numSgprAddressD + numSgprAddressC + numSgprAddressA + numSgprAddressB + numSgprAlpha + numSgprAddressMetadata + \ (numSgprBeta if kernel["ProblemType"]["UseBeta"] else 0) + \ diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 8eb7813a0b..1534c6a223 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -449,8 +449,8 @@ def defineVariableSgprs(self, kernel): self.states.numStoreSgprNameSizes = [] storeSgprLoad = 0 if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1): - self.states.numSgprAddressScaleA = self.states.rpga - self.states.numSgprAddressScaleB = self.states.rpga + self.states.numSgprAddressScaleA = self.states.rpga if kernel["ProblemType"]["DataTypeA"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters() else 0 + self.states.numSgprAddressScaleB = self.states.rpga if kernel["ProblemType"]["DataTypeB"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters() else 0 storeSgprLoad += self.states.numSgprAddressScaleA + self.states.numSgprAddressScaleB if self.states.numSgprAddressScaleA: self.states.numStoreSgprNames.append("AddressScaleA") @@ -1122,6 +1122,13 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): load = self.states.numSgprToLoad sgprStart = self.sgprs["SizesFree"] moduleLoadAllKernArg = self.argLoader.loadAllKernArg(sgprStart, "KernArgAddress", load, 0) + if kernel["ProblemType"]["UseScaleAB"]: + sgprOffset = self.argLoader.getOffset() + for name in ['A','B']: + if kernel["ProblemType"]["DataType%s"%name].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + moduleLoadAllKernArg.add(self.argLoader.loadKernArg("AddressScale%s"%name, "KernArgAddress", sgprOffset=hex(sgprOffset), dword=2)) + sgprOffset += (self.states.rpga * self.states.bpr) + moduleArgs.addModuleAsFlatItems(moduleLoadAllKernArg) if self.states.numSgprPreload > 0: # add common kern entry label in the begining of common reg init code @@ -1183,6 +1190,19 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): if self.states.kernel["WavefrontSize"] == 32: moduleRegInit.add(SMovB32(dst=VCC(setHi=True), src=0, comment="Ensure hi bits are zero")) + waitForScaleAB = False + moduleScaleAB = Module("Load ScaleAB") + if kernel["ProblemType"]["UseScaleAB"]: + for name in ['A','B']: + if kernel["ProblemType"]["DataType%s"%name].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + waitForScaleAB = True + moduleScaleAB.add(SMovB32(dst=sgpr("Scale%s"%name), src=1.0 , comment="init as 1" )) + label = Label(self.labels.getNameInc("Scale%sValid"%name), "") + moduleScaleAB.add(SBranchIfZero("AddressScale%s"%name, DataType('int64'), None, kernel["WavefrontSize"]/32, label, kernel["WavefrontSize"])) + # load scale data + moduleScaleAB.add(SLoadB32(dst=sgpr("Scale%s"%name), base=sgpr("AddressScale%s"%name,2), soffset=0, comment="load scale%s"%name)) + moduleScaleAB.add(label) + moduleWg = Module("Calculate Workgroup") moduleWg.addModuleAsFlatItems(lralwaCode) if kernel["ProblemType"]["SupportUserArgs"]: @@ -1191,6 +1211,7 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): else: moduleWg.add(SWaitCnt(lgkmcnt=0, comment="wait for %u bytes of kern args" % \ (self.argLoader.getOffset() - (self.states.numSgprPreload*4)))) + moduleWg.addModuleAsFlatItems(moduleScaleAB) moduleWg.add(Label(label="stop", comment="")) #### calculate numWorkGroup #### qReg = self.vgprPool.checkOut(4) @@ -1213,6 +1234,31 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): with self.allocTmpSgpr(self.states.laneSGPRCount) as tmpSgpr: moduleWg.add(self.loadBatchedAddress(kernel, "WorkGroup2", tmpSgpr)) moduleWg.add(SWaitCnt(lgkmcnt=0, comment="wait global buffer address ready")) + elif waitForScaleAB: + moduleWg.add(SWaitCnt(lgkmcnt=0, comment="wait for scaleA/B to load")) + + # Calculate RCP and update Alpha + if kernel["ProblemType"]["UseScaleAB"]: + tmpVgpr = self.vgprPool.checkOut(1) + newAlphaVgpr = None + NeedUpdateAlpha = False + for name in ['A','B']: + if kernel["ProblemType"]["DataType%s"%name].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + NeedUpdateAlpha = True + if NeedUpdateAlpha: + newAlphaVgpr = self.vgprPool.checkOut(1) + moduleWg.add(VMovB32(dst=vgpr(newAlphaVgpr), src=sgpr("Alpha"))) + for name in ['A','B']: + if kernel["ProblemType"]["DataType%s"%name].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + moduleWg.add(VMovB32(dst=vgpr(tmpVgpr), src=sgpr("Scale%s"%name))) + moduleWg.add(VRcpF32(dst=vgpr(tmpVgpr), src=vgpr(tmpVgpr))) + moduleWg.add(VMulF32(dst=vgpr(newAlphaVgpr), src0=vgpr(newAlphaVgpr), src1=sgpr("Scale%s"%name))) + moduleWg.add(VReadfirstlaneB32(dst=sgpr("Scale%s"%name), src=vgpr(tmpVgpr), comment="Get Rcp")) + if NeedUpdateAlpha: + moduleWg.add(VReadfirstlaneB32(dst=sgpr("Alpha"), src=vgpr(newAlphaVgpr), comment="Update Alpha")) + self.vgprPool.checkIn(tmpVgpr) + if newAlphaVgpr != None: + self.vgprPool.checkIn(newAlphaVgpr) if not kernel["ProblemType"]["GroupedGemm"]: ###### SingleGemm ############ @@ -4324,7 +4370,21 @@ def endSummation(self, kernel, tPA, tPB, label = None): module.add(extReadEpilogueLabelEnd) else: argOffset = self.argLoader.getOffset() # Backup offset - loadModule = module.addModuleAsFlatItems(self.argLoader.loadAllKernArg(sgpxIdxVec[0], "KernArgAddress", self.states.numStoreSgprToLoad)) + startVgprName = sgpxIdxVec[0] + numStoreSgprToLoad = self.states.numStoreSgprToLoad + if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1): + if (kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters()) and (kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters()): + self.argLoader.setOffset(argOffset + ((self.states.rpga * self.states.bpr) * 2)) + elif kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + assert sgpxIdxVec[0] == self.sgprs["AddressScaleB"] + self.argLoader.setOffset(argOffset + (self.states.rpga * self.states.bpr)) + elif kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + assert sgpxIdxVec[0] == self.sgprs["AddressScaleA"] + module.add(self.argLoader.loadKernArg(self.sgprs["AddressScaleA"], "KernArgAddress", dword=2)) + startVgprName = sgpxIdxVec[1] + numStoreSgprToLoad -= self.states.rpga + self.argLoader.setOffset(argOffset + ((self.states.rpga * self.states.bpr) * 2)) + loadModule = module.addModuleAsFlatItems(self.argLoader.loadAllKernArg(startVgprName, "KernArgAddress", numStoreSgprToLoad)) self.states.numStoreSgprInst = loadModule.countType(SMemLoadInstruction) self.argLoader.setOffset(argOffset) # Restore offset @@ -6356,18 +6416,23 @@ def localWriteBody(tP): vgprTmp = self.vgprPool.checkOut(1) src_sel = SelectBit.WORD_1 if isHigh16Bits else SelectBit.WORD_0 localWriteCode.add(VCvtF16toF32(dst=vgpr(vgprTmp), src=vgpr("G2L%s+%u"%(tP["tensorChar"], g2lIdx)), sdwa=SDWAModifiers(src0_sel=src_sel), comment="convert to F32")) + # ScaleA/B + if kernel["ProblemType"]["UseScaleAB"] and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + localWriteCode.add(VMulF32(dst=vgpr(vgprTmp), src0=vgpr(vgprTmp), src1=sgpr("Scale%s"%tc), comment="Input *= scale %s"%tc)) localWriteCode.add(VCvtPkF32toFP8(dst=paramList[0], src0=vgpr(vgprTmp), src1=vgpr(vgprTmp), vop3=VOP3PModifiers(op_sel=[0,0,0,0]), comment="Convert to FP8")) self.vgprPool.checkIn(vgprTmp) else: - vgprTmp = self.vgprPool.checkOut(1) - vgprTmp2 = self.vgprPool.checkOut(1) + vgprTmp = self.vgprPool.checkOutAligned(2, 2) + vgprTmp2 = vgprTmp + 1 for vi in range(0, int(newBlockWidth)): sel = 1 if vi %2 == 1 else 0 localWriteCode.add(VCvtF16toF32(dst=vgpr(vgprTmp), src=vgpr("G2L%s+%u+%u"%(tP["tensorChar"], g2lIdx, vi)), sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_0), comment="convert to F32")) localWriteCode.add(VCvtF16toF32(dst=vgpr(vgprTmp2), src=vgpr("G2L%s+%u+%u"%(tP["tensorChar"], g2lIdx, vi)), sdwa=SDWAModifiers(src0_sel=SelectBit.WORD_1), comment="convert to F32")) + # ScaleA/B, sgpr upper is dummy. + if kernel["ProblemType"]["UseScaleAB"] and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + localWriteCode.add(VMulPKF32(dst=vgpr(vgprTmp, 2), src0=vgpr(vgprTmp, 2), src1=sgpr("Scale%s"%tc, 2), vop3=VOP3PModifiers(op_sel_hi=[1,0,1]), comment="Input *= scale %s"%tc)) localWriteCode.add(VCvtPkF32toFP8(dst=vgpr("G2L%s+%u+%u"%(tP["tensorChar"], g2lIdx, vi//2)), src0=vgpr(vgprTmp), src1=vgpr(vgprTmp2), vop3=VOP3PModifiers(op_sel=[0,0,sel]), comment="Convert to FP8")) self.vgprPool.checkIn(vgprTmp) - self.vgprPool.checkIn(vgprTmp2) elif (kernel["ProblemType"]["DataType%s"%tc].isFloat8() and kernel["ProblemType"]["DataType"].isHalf()): if newBlockWidth == 0.25: new_src = fastdeepcopy(paramList[0]) @@ -7783,16 +7848,21 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths, elements, useSize = [] # Issue read scale A/B value for later use - if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1): + if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1) and \ + ((kernel["ProblemType"]["DataTypeA"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters()) or \ + (kernel["ProblemType"]["DataTypeB"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters())): assert(kernel["ProblemType"]["ComputeDataType"].isSingle()) - sgprScaleAB = self.sgprPool.checkOut(1) + sgprScaleA = self.sgprPool.checkOut(1, preventOverflow=False) + sgprScaleB = self.sgprPool.checkOut(1, preventOverflow=False) for i,name in enumerate(['A','B']): - module.add(SMovB32(dst=sgpr(sgprScaleAB+i), src=1.0 , comment="init as 1" )) - label = Label(self.labels.getNameInc("Scale%sValid"%name), "") - module.add(SBranchIfZero("AddressScale%s"%name, DataType('int64'), None, kernel["WavefrontSize"]/32, label, kernel["WavefrontSize"])) - # load scale data - module.add(SLoadB32(dst=sgpr(sgprScaleAB+i), base=sgpr("AddressScale%s"%name,2), soffset=0, comment="load scale%s"%name)) - module.add(label) + if kernel["ProblemType"]["DataType%s"%name].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters(): + sgprScale = sgprScaleA if name == 'A' else sgprScaleB + module.add(SMovB32(dst=sgpr(sgprScale), src=1.0 , comment="init as 1" )) + label = Label(self.labels.getNameInc("Scale%sValid"%name), "") + module.add(SBranchIfZero("AddressScale%s"%name, DataType('int64'), None, kernel["WavefrontSize"]/32, label, kernel["WavefrontSize"])) + # load scale data + module.add(SLoadB32(dst=sgpr(sgprScale), base=sgpr("AddressScale%s"%name,2), soffset=0, comment="load scale%s"%name)) + module.add(label) # Issue read scale C/D value for later use if kernel["ProblemType"]["UseScaleCD"] and (kernel["GlobalSplitU"] == 1): @@ -7803,7 +7873,7 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths, elements, # load scale data module.add(SLoadB32(dst=sgpr("ScaleD"), base=sgpr("AddressScaleD",2), soffset=0, comment="load scaleD")) module.add(label) - sgprScaleC = self.sgprPool.checkOut(1) + sgprScaleC = self.sgprPool.checkOut(1, preventOverflow=False) module.add(SMovB32(dst=sgpr(sgprScaleC), src=1.0 , comment="init as 1" )) label = Label(self.labels.getNameInc("ScaleCValid"), "") module.add(SBranchIfZero("AddressScaleC", DataType('int64'), None, kernel["WavefrontSize"]/32, label, kernel["WavefrontSize"])) @@ -7912,17 +7982,22 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths, elements, ssslist.append("Bias") useSize.append(True) - if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1): + if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1) and \ + ((kernel["ProblemType"]["DataTypeA"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters()) or \ + (kernel["ProblemType"]["DataTypeB"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters())): assert(kernel["ProblemType"]["ComputeDataType"].isSingle()) newAlphaVgpr = self.vgprPool.checkOut(1) module.add(VMovB32(dst=vgpr(newAlphaVgpr), src=sgpr("Alpha"))) module.add(SWaitCnt(lgkmcnt=0, comment="wait for scaleAB load")) - module.add(VMulF32(dst=vgpr(newAlphaVgpr), src0=vgpr(newAlphaVgpr), src1=sgpr(sgprScaleAB))) - module.add(VMulF32(dst=vgpr(newAlphaVgpr), src0=vgpr(newAlphaVgpr), src1=sgpr(sgprScaleAB+1))) + if kernel["ProblemType"]["DataTypeA"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters(): + module.add(VMulF32(dst=vgpr(newAlphaVgpr), src0=vgpr(newAlphaVgpr), src1=sgpr(sgprScaleA))) + if kernel["ProblemType"]["DataTypeB"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters(): + module.add(VMulF32(dst=vgpr(newAlphaVgpr), src0=vgpr(newAlphaVgpr), src1=sgpr(sgprScaleB))) module.add(SNop(waitState=0, comment="1 wait states")) module.add(VReadfirstlaneB32(dst=sgpr("Alpha"), src=vgpr(newAlphaVgpr), comment="Update Alpha")) self.vgprPool.checkIn(newAlphaVgpr) - self.sgprPool.checkIn(sgprScaleAB) + self.sgprPool.checkIn(sgprScaleA) + self.sgprPool.checkIn(sgprScaleB) # Update beta if kernel["ProblemType"]["UseScaleCD"] and (kernel["GlobalSplitU"] == 1): diff --git a/tensilelite/Tensile/Source/client/include/TypedId.hpp b/tensilelite/Tensile/Source/client/include/TypedId.hpp index 8522dd9b60..a8a1891463 100644 --- a/tensilelite/Tensile/Source/client/include/TypedId.hpp +++ b/tensilelite/Tensile/Source/client/include/TypedId.hpp @@ -35,9 +35,10 @@ namespace Tensile DataType cType, DataType dType, DataType alphaType, - DataType betaType) + DataType betaType, + DataType computeInputType) { - static_assert(BitFieldGenerator::ElementWidth((uint32_t)DataType::Count) * 6 + static_assert(BitFieldGenerator::ElementWidth((uint32_t)DataType::Count) * 7 <= BitFieldGenerator::maxBitFieldWidth, "Max bitfield width exceeded"); @@ -48,23 +49,26 @@ namespace Tensile (uint32_t)cType, (uint32_t)dType, (uint32_t)alphaType, - (uint32_t)betaType); + (uint32_t)betaType, + (uint32_t)computeInputType); } - template + template struct TypedGemm { - using AType = A; - using BType = B; - using CType = C; - using DType = D; - using AlphaType = Alpha; - using BetaType = Beta; + using AType = A; + using BType = B; + using CType = C; + using DType = D; + using AlphaType = Alpha; + using BetaType = Beta; + using ComputeInputType = ComputeInput; constexpr static uint32_t TypeId() { @@ -73,7 +77,8 @@ namespace Tensile TypeInfo::Enum, TypeInfo::Enum, TypeInfo::Enum, - TypeInfo::Enum); + TypeInfo::Enum, + TypeInfo::Enum); } }; @@ -120,12 +125,19 @@ namespace Tensile using TypedGemm_F8B8_B8_S = TypedGemm; using TypedGemm_B8F8_B8_S = TypedGemm; #ifdef TENSILE_USE_HALF - using TypedGemm_HF8_S_S = TypedGemm; - using TypedGemm_F8H_S_S = TypedGemm; - using TypedGemm_HF8_H_S = TypedGemm; - using TypedGemm_F8H_H_S = TypedGemm; - using TypedGemm_HF8_FP8_S = TypedGemm; - using TypedGemm_F8H_FP8_S = TypedGemm; + // Mix precision + using TypedGemm_HF8_H_S_S = TypedGemm; + using TypedGemm_F8H_H_S_S = TypedGemm; + using TypedGemm_HF8_H_H_S = TypedGemm; + using TypedGemm_F8H_H_H_S = TypedGemm; + using TypedGemm_HF8_H_FP8_S = TypedGemm; + using TypedGemm_F8H_H_FP8_S = TypedGemm; + using TypedGemm_HF8_FP8_S_S = TypedGemm; + using TypedGemm_F8H_FP8_S_S = TypedGemm; + using TypedGemm_HF8_FP8_H_S = TypedGemm; + using TypedGemm_F8H_FP8_H_S = TypedGemm; + using TypedGemm_HF8_FP8_FP8_S = TypedGemm; + using TypedGemm_F8H_FP8_FP8_S = TypedGemm; #endif // TENSILE_USE_HALF #endif // TENSILE_USE_FP8_BF8 } // namespace Tensile diff --git a/tensilelite/Tensile/Source/client/source/Reference.cpp b/tensilelite/Tensile/Source/client/source/Reference.cpp index 8df3629f97..192c7e2c72 100644 --- a/tensilelite/Tensile/Source/client/source/Reference.cpp +++ b/tensilelite/Tensile/Source/client/source/Reference.cpp @@ -91,6 +91,35 @@ namespace Tensile * static_cast(static_cast(r))); } + template + inline Accumulator div(TypeL l, TypeR r) + { + /* Transform the data type from TypeL/TypeR to Accumulator if TypeL!=ACC or TypeR!=ACC, but filter out cases, I8/I32/I32 and I8x4/I32/I32 + * + * There are three cases of doing multiplication and their conditions to do transform or not are as below. + * 1. AxB : (A!=ACC or B!=ACC) and A!=I8 and A!=I8x4 + * 2. Alpha x rC : (Alpha!=ACC or rC!=ACC) + * 3. Beta x C : (Beta!=ACC or C!=ACC) + */ + constexpr bool needAccumCast + = !(std::is_same() && std::is_same()) + && !std::is_same() //case I8/I32/I32, I8 be implicitly cast to int. + && !std::is_same(); //case I8x4/I32/I32, I8x4 overloading the op*. + + using LMultT = std::conditional_t; + using RMultT = std::conditional_t; + + constexpr bool needMathOpAccumCast = !std::is_same(); + using LMathOpMultT = std::conditional_t; + using RMathOpMultT = std::conditional_t; + + return static_cast(static_cast(static_cast(l)) + / static_cast(static_cast(r))); + } + template inline Accumulator cast(Type val) { @@ -359,7 +388,7 @@ namespace Tensile { return static_cast(exp(static_cast(val))); } - else if(new_type == ActivationType::Gelu || new_type == ActivationType::Geluscaling ) + else if(new_type == ActivationType::Gelu || new_type == ActivationType::Geluscaling) { auto castedVal = static_cast(val); auto k0 = static_cast(0.7978845608028654); @@ -369,9 +398,9 @@ namespace Tensile + multiply(k1, multiply(castedVal, castedVal))); tmp = multiply(k0, multiply(castedVal, tmp)); tmp = static_cast(1) + static_cast(tanh(tmp)); - tmp = multiply(static_cast(0.5f), multiply(castedVal, tmp)); + tmp = multiply(static_cast(0.5f), multiply(castedVal, tmp)); if(new_type == ActivationType::Geluscaling) - tmp = multiply(tmp, static_cast(args[0])); + tmp = multiply(tmp, static_cast(args[0])); return static_cast(tmp); } else if(new_type == ActivationType::Leakyrelu) @@ -726,7 +755,68 @@ namespace Tensile bVal = Transform::Input( bPtr[bIndex + (bI * bStride)], bConjugate); - value += multiply(aVal, bVal); + if constexpr(sizeof(typename Inputs::AType) + > sizeof(typename Inputs::ComputeInputType) + && sizeof(typename Inputs::BType) + > sizeof(typename Inputs::ComputeInputType)) + { + typename Inputs::ComputeInputType aValCast, bValCast; + if(problem.useScaleAB()) + { + Accumulator scaleA = GetValue( + problem.alphaType(), inputs.scaleA, 0, aConjugate); + auto tmp = div(aVal, scaleA); + aValCast = static_cast(tmp); + Accumulator scaleB = GetValue( + problem.alphaType(), inputs.scaleB, 0, aConjugate); + tmp = div(bVal, scaleB); + bValCast = static_cast(tmp); + } + else + { + aValCast = static_cast(aVal); + bValCast = static_cast(bVal); + } + value += multiply(aValCast, bValCast); + } + else if constexpr(sizeof(typename Inputs::AType) + > sizeof(typename Inputs::ComputeInputType)) + { + typename Inputs::ComputeInputType aValCast; + if(problem.useScaleAB()) + { + Accumulator scaleA = GetValue( + problem.alphaType(), inputs.scaleA, 0, aConjugate); + auto tmp = div(aVal, scaleA); + aValCast = static_cast(tmp); + } + else + { + aValCast = static_cast(aVal); + } + value += multiply(aValCast, bVal); + } + else if constexpr(sizeof(typename Inputs::BType) + > sizeof(typename Inputs::ComputeInputType)) + { + typename Inputs::ComputeInputType bValCast; + if(problem.useScaleAB()) + { + Accumulator scaleB = GetValue( + problem.alphaType(), inputs.scaleB, 0, aConjugate); + auto tmp = div(bVal, scaleB); + bValCast = static_cast(tmp); + } + else + { + bValCast = static_cast(bVal); + } + value += multiply(aVal, bValCast); + } + else + { + value += multiply(aVal, bVal); + } } } } @@ -763,8 +853,8 @@ namespace Tensile Accumulator cValue = multiply(beta, cPtr[cIndex]); if(problem.useScaleCD()) { - Accumulator scaleC - = GetValue(problem.betaType(), inputs.scaleC, 0, aConjugate); + Accumulator scaleC = GetValue( + problem.betaType(), inputs.scaleC, 0, aConjugate); cValue *= scaleC; } @@ -924,7 +1014,8 @@ namespace Tensile problem.c().dataType(), problem.d().dataType(), alphaType, - betaType); + betaType, + problem.computeInputType()); } template @@ -1127,34 +1218,64 @@ namespace Tensile problem, inputs, elementsToValidate); } #ifdef TENSILE_USE_HALF - case TypedGemm_HF8_S_S::TypeId(): + case TypedGemm_HF8_H_S_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } + case TypedGemm_F8H_H_S_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } + case TypedGemm_HF8_H_H_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } + case TypedGemm_F8H_H_H_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } + case TypedGemm_HF8_H_FP8_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } + case TypedGemm_F8H_H_FP8_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } + case TypedGemm_HF8_FP8_S_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_F8H_S_S::TypeId(): + case TypedGemm_F8H_FP8_S_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_HF8_H_S::TypeId(): + case TypedGemm_HF8_FP8_H_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_F8H_H_S::TypeId(): + case TypedGemm_F8H_FP8_H_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_HF8_FP8_S::TypeId(): + case TypedGemm_HF8_FP8_FP8_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_F8H_FP8_S::TypeId(): + case TypedGemm_F8H_FP8_FP8_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } #endif // TENSILE_USE_HALF diff --git a/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp b/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp index f47c5aef71..c6add00104 100644 --- a/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp +++ b/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp @@ -634,10 +634,16 @@ namespace Tensile if constexpr(insertKernelArgs) kernelArgs(args); - if(problemType.useScaleAB && (sizeMapping.globalSplitU == 1)) //kernel input data - { - args.template append("scaleA", inputs.scaleA); - args.template append("scaleB", inputs.scaleB); + if(problemType.useScaleAB) //kernel input data + { + if(DataTypeInfo::Get(problemType.aType).elementSize + > DataTypeInfo::Get(problemType.computeInputType).elementSize + || (sizeMapping.globalSplitU == 1)) + args.template append("scaleA", inputs.scaleA); + if(DataTypeInfo::Get(problemType.bType).elementSize + > DataTypeInfo::Get(problemType.computeInputType).elementSize + || (sizeMapping.globalSplitU == 1)) + args.template append("scaleB", inputs.scaleB); } if(problemType.useScaleCD && (sizeMapping.globalSplitU == 1)) //kernel input data { From 3474dbf6193f4e78577c26c8b180f19182ed97df Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Wed, 13 Sep 2023 02:57:50 -0500 Subject: [PATCH 2/9] Fix mix precision not working properly if global read block width = 0.5 and local write is 0.25 --- tensilelite/Tensile/KernelWriterAssembly.py | 18 +++++++++++++---- .../Tests/common/gemm/fp8fp16mix_fp8ss.yaml | 20 ++++++++++++++++--- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 1534c6a223..8c8bb8e1cf 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -6328,8 +6328,17 @@ def localWriteBody(tP): if (blockWidth == 0.25) and ((s % 4) == 0) and not tP["isM"]: src = "G2L%s+%u" % (tc, g2lIdx) dst = "G2L%s+%u+%u" % (tc, tmpVgprOffset, g2lIdx) - localWriteCode.add(VMovB32(dst=vgpr(dst), src=vgpr(src), comment="another VGPR storing lshr 8-bit value")) - localWriteCode.add(VLShiftRightB32(dst=vgpr(dst), shiftHex=hex(0x8), src=vgpr(dst), comment="G2L Vpgr >> 8")) + if tP["bpe"] != tP["bpeGR"]: + if kernel["ProblemType"]["DataType%s"%tc].isHalf(): + localWriteCode.add(VPackF16toB32(dst=vgpr(dst), src0=vgpr(src), src1=vgpr("G2L%s+%u" % (tc, g2lIdx+1)), \ + vop3=VOP3PModifiers(op_sel=[1,1,0]), comment="Pack with neighbor")) + localWriteCode.add(VPackF16toB32(dst=vgpr(src), src0=vgpr(src), src1=vgpr("G2L%s+%u" % (tc, g2lIdx+1)), \ + vop3=VOP3PModifiers(op_sel=[0,0,0]), comment="Pack with neighbor")) + else: + printExit("Unsupported combination DataType%s (%s) -> DataType (%s)"%(tc, kernel["ProblemType"]["DataType%s"%tc].toChar(), kernel["ProblemType"]["DataType"].toChar())) + else: + localWriteCode.add(VMovB32(dst=vgpr(dst), src=vgpr(src), comment="another VGPR storing lshr 8-bit value")) + localWriteCode.add(VLShiftRightB32(dst=vgpr(dst), shiftHex=hex(8), src=vgpr(dst), comment="G2L Vpgr >> 8")) if self.states.archCaps["HasEccHalf"]: numVgprG2L = self.states.a.numVgprG2L if tc == 'A' else self.states.b.numVgprG2L if tc == 'B' else self.states.m.numVgprG2L @@ -6415,11 +6424,12 @@ def localWriteBody(tP): if newBlockWidth == 0.5: vgprTmp = self.vgprPool.checkOut(1) src_sel = SelectBit.WORD_1 if isHigh16Bits else SelectBit.WORD_0 - localWriteCode.add(VCvtF16toF32(dst=vgpr(vgprTmp), src=vgpr("G2L%s+%u"%(tP["tensorChar"], g2lIdx)), sdwa=SDWAModifiers(src0_sel=src_sel), comment="convert to F32")) + sel = 1 if isHigh16Bits else 0 + localWriteCode.add(VCvtF16toF32(dst=vgpr(vgprTmp), src=paramList[0], sdwa=SDWAModifiers(src0_sel=src_sel), comment="convert to F32")) # ScaleA/B if kernel["ProblemType"]["UseScaleAB"] and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): localWriteCode.add(VMulF32(dst=vgpr(vgprTmp), src0=vgpr(vgprTmp), src1=sgpr("Scale%s"%tc), comment="Input *= scale %s"%tc)) - localWriteCode.add(VCvtPkF32toFP8(dst=paramList[0], src0=vgpr(vgprTmp), src1=vgpr(vgprTmp), vop3=VOP3PModifiers(op_sel=[0,0,0,0]), comment="Convert to FP8")) + localWriteCode.add(VCvtPkF32toFP8(dst=paramList[0], src0=vgpr(vgprTmp), src1=vgpr(vgprTmp), vop3=VOP3PModifiers(op_sel=[0,0,sel]), comment="Convert to FP8")) self.vgprPool.checkIn(vgprTmp) else: vgprTmp = self.vgprPool.checkOutAligned(2, 2) diff --git a/tensilelite/Tensile/Tests/common/gemm/fp8fp16mix_fp8ss.yaml b/tensilelite/Tensile/Tests/common/gemm/fp8fp16mix_fp8ss.yaml index 565b85775c..6a98a8c07d 100644 --- a/tensilelite/Tensile/Tests/common/gemm/fp8fp16mix_fp8ss.yaml +++ b/tensilelite/Tensile/Tests/common/gemm/fp8fp16mix_fp8ss.yaml @@ -12,7 +12,7 @@ GlobalParameters: KernelTime: True MaxWorkspaceSize: 13421772800 DataInitTypeAlpha: 1 - DataInitTypeBeta: 0 + DataInitTypeBeta: 1 NumElementsToValidate: -1 BoundsCheck: 2 @@ -32,6 +32,12 @@ BenchmarkProblems: TransposeB: 0 UseBeta: True Batched: True + Activation: True + ActivationHPA: True + UseScaleAB: True + UseScaleCD: False + UseScaleAlphaVec: True + UseBias: True BiasDataTypeList: [s,b] - # BenchmarkProblemSizeGroup - Standard InitialSolutionParameters: @@ -40,6 +46,7 @@ BenchmarkProblems: ForkParameters: - MatrixInstruction: - [16,16,32, 1, 1, 1,1, 1,1] # 16x16 + - [16,16,32, 1, 1, 2,2, 1,1] - DepthU: [ 64 ] - AssertFree0ElementMultiple: [1] - PrefetchGlobalRead: [2] @@ -52,7 +59,7 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1] - ExpandPointerSwap: [1] - - TransposeLDS: [1] + - TransposeLDS: [0,1] - LdsBlockSizePerPadA: [0] - LdsBlockSizePerPadB: [0] - LdsPadA: [0] @@ -262,6 +269,12 @@ BenchmarkProblems: TransposeB: 0 UseBeta: True Batched: True + Activation: True + ActivationHPA: True + UseScaleAB: True + UseScaleCD: False + UseScaleAlphaVec: True + UseBias: True BiasDataTypeList: [s,b] - # BenchmarkProblemSizeGroup - Standard InitialSolutionParameters: @@ -270,6 +283,7 @@ BenchmarkProblems: ForkParameters: - MatrixInstruction: - [16,16,32, 1, 1, 1,1, 1,1] # 16x16 + - [16,16,32, 1, 1, 2,2, 1,1] - DepthU: [ 64 ] - AssertFree0ElementMultiple: [1] - PrefetchGlobalRead: [2] @@ -282,7 +296,7 @@ BenchmarkProblems: - ScheduleIterAlg: [3] - InnerUnroll: [1] - ExpandPointerSwap: [1] - - TransposeLDS: [1] + - TransposeLDS: [0,1] - LdsBlockSizePerPadA: [0] - LdsBlockSizePerPadB: [0] - LdsPadA: [0] From 330ab521e227d85183fc59f88592098c7659e46a Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Thu, 14 Sep 2023 22:59:59 -0500 Subject: [PATCH 3/9] Add ScaleABCD support for UserArguments --- library/include/hipblaslt-ext.hpp | 4 +++ tensilelite/Tensile/Components/Signature.py | 12 +++++++ tensilelite/Tensile/KernelWriter.py | 13 ++++++-- tensilelite/Tensile/KernelWriterAssembly.py | 33 +++++++++++++++++-- .../include/Tensile/ContractionSolution.hpp | 4 +++ .../Source/lib/source/ContractionSolution.cpp | 4 +++ 6 files changed, 65 insertions(+), 5 deletions(-) diff --git a/library/include/hipblaslt-ext.hpp b/library/include/hipblaslt-ext.hpp index 97f75c471e..a6441ff228 100644 --- a/library/include/hipblaslt-ext.hpp +++ b/library/include/hipblaslt-ext.hpp @@ -167,6 +167,10 @@ namespace hipblaslt_ext int8_t alpha[16]; //!< The alpha value. int8_t beta[16]; //!< The beta value. // Epilogue inputs + void* scaleA; //!< The scaleA input pointer. + void* scaleB; //!< The scaleA input pointer. + void* scaleC; //!< The scaleC input pointer. + void* scaleD; //!< The scaleD input pointer. void* scaleAlphaVec; //!< The scaleAlpha vector input pointer. void* bias; //!< The bias input pointer. int biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues. diff --git a/tensilelite/Tensile/Components/Signature.py b/tensilelite/Tensile/Components/Signature.py index 9d0df42a48..f8c9035ff6 100644 --- a/tensilelite/Tensile/Components/Signature.py +++ b/tensilelite/Tensile/Components/Signature.py @@ -39,6 +39,10 @@ class UserArgumentsInfo: alphaMaxRegisterSize: int = field(init=False) betaMaxSize: int = 16 betaMaxRegisterSize: int = field(init=False) + scaleASize: int = 0 + scaleBSize: int = 0 + scaleCSize: int = 0 + scaleDSize: int = 0 actMaxSize: int = 4 actMaxRegisterSize: int = field(init=False) # gemm related @@ -193,9 +197,13 @@ def __call__(self, writer) -> SignatureBase: signature.addArg("AddressScaleA", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") if (kernel["GlobalSplitU"] == 1) or (kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters()): signature.addArg("AddressScaleB", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") + userArgumentsInfo.scaleASize += 8 + userArgumentsInfo.scaleBSize += 8 if kernel["ProblemType"]["UseScaleCD"] and (kernel["GlobalSplitU"] == 1): signature.addArg("AddressScaleC", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") signature.addArg("AddressScaleD", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") + userArgumentsInfo.scaleCSize += 8 + userArgumentsInfo.scaleDSize += 8 if kernel["ProblemType"]["UseScaleAlphaVec"] and (kernel["GlobalSplitU"] == 1): signature.addArg("AddressScaleAlphaVec", SVK.SIG_GLOBALBUFFER, cptValueType, "generic") @@ -231,6 +239,10 @@ def __call__(self, writer) -> SignatureBase: # Calculate total size userArgumentsInfo.totalSize = userArgumentsInfo.gemmArgumentSize + \ + userArgumentsInfo.scaleASize + \ + userArgumentsInfo.scaleBSize + \ + userArgumentsInfo.scaleCSize + \ + userArgumentsInfo.scaleDSize + \ userArgumentsInfo.scaleAlphaVecSize + \ userArgumentsInfo.biasSize + \ userArgumentsInfo.eSize + \ diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index 6fea2a88c2..0c753fbb28 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -258,8 +258,10 @@ class StateValues: lraTileProperties: Dict[int, LraTileProperties] = field(init=False) # Epilogue states - useBias = DataDirection.NONE - needBiasType = False + preloadScaleA = False + preloadScaleB = False + useBias = DataDirection.NONE + needBiasType = False def __post_init__(self): """ How many SGPRs does it take to have one bit per lane? """ @@ -3562,6 +3564,13 @@ def readWriteVectors(mat, vw, kernel): assert not self.db["CheckValueC"] or canCheckValueC # Epilogue related + self.states.preloadScaleA = False + if kernel["ProblemType"]["UseScaleAB"]: + if kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + self.states.preloadScaleA = True + if kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + self.states.preloadScaleB = True + self.states.useBias = DataDirection.NONE self.states.needBiasType = False if kernel["ProblemType"]["UseBias"]: diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 8c8bb8e1cf..9fb3120bc3 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -449,8 +449,8 @@ def defineVariableSgprs(self, kernel): self.states.numStoreSgprNameSizes = [] storeSgprLoad = 0 if kernel["ProblemType"]["UseScaleAB"] and (kernel["GlobalSplitU"] == 1): - self.states.numSgprAddressScaleA = self.states.rpga if kernel["ProblemType"]["DataTypeA"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters() else 0 - self.states.numSgprAddressScaleB = self.states.rpga if kernel["ProblemType"]["DataTypeB"].numRegisters() <= kernel["ProblemType"]["DataType"].numRegisters() else 0 + self.states.numSgprAddressScaleA = self.states.rpga if (not self.states.preloadScaleA) else 0 + self.states.numSgprAddressScaleB = self.states.rpga if (not self.states.preloadScaleB) else 0 storeSgprLoad += self.states.numSgprAddressScaleA + self.states.numSgprAddressScaleB if self.states.numSgprAddressScaleA: self.states.numStoreSgprNames.append("AddressScaleA") @@ -461,7 +461,7 @@ def defineVariableSgprs(self, kernel): if kernel["ProblemType"]["UseScaleCD"] and (kernel["GlobalSplitU"] == 1): self.states.numSgprAddressScaleC = self.states.rpga self.states.numSgprAddressScaleD = self.states.rpga - storeSgprLoad += self.states.numSgprAddressScaleA + self.states.numSgprAddressScaleD + storeSgprLoad += self.states.numSgprAddressScaleC + self.states.numSgprAddressScaleD if self.states.numSgprAddressScaleC: self.states.numStoreSgprNames.append("AddressScaleC") self.states.numStoreSgprNameSizes.append(self.states.numSgprAddressScaleC) @@ -1686,6 +1686,12 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): moduleExternalArgs.addComment("Read Beta") moduleExternalArgs.addModuleAsFlatItems(self.externalArgLoader.loadAllKernArg(self.sgprs["Beta"], "ExternalArgAddress", self.states.numSgprBeta)) offset = self.externalArgLoader.getOffset() + self.states.bpr * (self.states.userArgsInfo.betaMaxRegisterSize - self.states.numSgprBeta) + if kernel["ProblemType"]["UseScaleAB"]: + sgprOffset = self.externalArgLoader.getOffset() + for name in ['A','B']: + if kernel["ProblemType"]["DataType%s"%name].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + moduleExternalArgs.add(self.externalArgLoader.loadKernArg("AddressScale%s"%name, "ExternalArgAddress", sgprOffset=hex(sgprOffset), dword=2)) + sgprOffset += self.states.userArgsInfo.scaleASize self.externalArgLoader.setOffset(offset) module.add(moduleExternalArgs) module.add(extLabelEnd) @@ -4319,6 +4325,27 @@ def endSummation(self, kernel, tPA, tPB, label = None): extArgOffset = self.externalArgLoader.getOffset() backupExtArgOffset = extArgOffset loadList = [[-1, 0, extArgOffset]] + extArgOffset += self.states.userArgsInfo.scaleASize + if kernel["ProblemType"]["UseScaleAB"] and (not self.states.preloadScaleA) and (kernel["GlobalSplitU"] == 1): + if loadList[-1][0] == -1: + loadList[-1][0] = self.sgprs["AddressScaleA"] + loadList[-1][1] += self.states.userArgsInfo.scaleASize + else: + loadList.append([-1, 0, extArgOffset]) # Need to start a new loadAllKernArg cause the argument is not consecutively anymore. + extArgOffset += self.states.userArgsInfo.scaleBSize + if kernel["ProblemType"]["UseScaleAB"] and (not self.states.preloadScaleB) and (kernel["GlobalSplitU"] == 1): + if loadList[-1][0] == -1: + loadList[-1][0] = self.sgprs["AddressScaleB"] + loadList[-1][1] += self.states.userArgsInfo.scaleBSize + else: + loadList.append([-1, 0, extArgOffset]) # Need to start a new loadAllKernArg cause the argument is not consecutively anymore. + extArgOffset += self.states.userArgsInfo.scaleCSize + self.states.userArgsInfo.scaleDSize + if kernel["ProblemType"]["UseScaleCD"] and (kernel["GlobalSplitU"] == 1): + if loadList[-1][0] == -1: + loadList[-1][0] = self.sgprs["AddressScaleC"] + loadList[-1][1] += self.states.userArgsInfo.scaleCSize + self.states.userArgsInfo.scaleDSize + else: + loadList.append([-1, 0, extArgOffset]) # Need to start a new loadAllKernArg cause the argument is not consecutively anymore. extArgOffset += self.states.userArgsInfo.scaleAlphaVecSize if kernel["ProblemType"]["UseScaleAlphaVec"] and (kernel["GlobalSplitU"] == 1): if loadList[-1][0] == -1: diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp index 1a09807ac6..058693355f 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp @@ -63,6 +63,10 @@ namespace Tensile uint32_t strideB2; int8_t alpha[16]; int8_t beta[16]; + void* scaleA; + void* scaleB; + void* scaleC; + void* scaleD; void* scaleAlphaVec; void* bias; int biasType; diff --git a/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp b/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp index c6add00104..4aee0094bd 100644 --- a/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp +++ b/tensilelite/Tensile/Source/lib/source/ContractionSolution.cpp @@ -225,6 +225,10 @@ namespace Tensile inputs.grouped[i].alpha, arg.alpha, sizeof(arg.alpha), problems[i].alphaType()); setVariantToBuffer( inputs.grouped[i].beta, arg.beta, sizeof(arg.beta), problems[i].betaType()); + arg.scaleA = const_cast(inputs.grouped[i].scaleA); + arg.scaleB = const_cast(inputs.grouped[i].scaleB); + arg.scaleC = const_cast(inputs.grouped[i].scaleC); + arg.scaleD = const_cast(inputs.grouped[i].scaleD); arg.bias = const_cast(inputs.grouped[i].bias); arg.scaleAlphaVec = const_cast(inputs.grouped[i].scaleAlphaVec); arg.e = const_cast(inputs.grouped[i].e); From c60ed8e785edbff17013c07421e4a7b7ed84aecd Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Thu, 14 Sep 2023 23:58:51 -0500 Subject: [PATCH 4/9] Fix scaleAB not working in GG --- tensilelite/Tensile/KernelWriterAssembly.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 9fb3120bc3..7f4a5bb962 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -1266,6 +1266,12 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): module.add(moduleRegInit) module.add(moduleWg) else: + numStoreSgprToLoad = self.states.numStoreSgprToLoad + if kernel["ProblemType"]["UseScaleAB"]: + if self.states.preloadScaleA: + numStoreSgprToLoad += 2 + if self.states.preloadScaleB: + numStoreSgprToLoad += 2 ###### GroupedGemm ############ ###### # linear search @@ -1330,7 +1336,7 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): if kernel["ProblemType"]["SupportUserArgs"]: module.addComment0("Check if custom structure pointer is null") module.add(SBranchIfNotZero("ExternalArgAddress", DataType('int64'), extValidLabel)) - module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (self.states.numStoreSgprToLoad * 4)))) + module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (numStoreSgprToLoad * 4)))) module.add(SMulI32(dst=sgpr(tmpSgprAddrM), src0=sgpr(tmpSgprNumGemm), src1=4)) # offset wgTable module.add(SMovB64(dst=sgpr(tmpSgprArgAddress0,2), src=sgpr("KernArgAddress",2))) module.add(SBranch(extValidLabelEnd.getLabelName())) @@ -1340,7 +1346,7 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): module.add(SMovB64(dst=sgpr(tmpSgprArgAddress0,2), src=sgpr("ExternalArgAddress",2))) module.add(extValidLabelEnd) else: - module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (self.states.numStoreSgprToLoad * 4)))) + module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (numStoreSgprToLoad * 4)))) module.add(SMulI32(dst=sgpr(tmpSgprAddrM), src0=sgpr(tmpSgprNumGemm), src1=4)) # offset wgTable module.add(SMovB64(dst=sgpr(tmpSgprArgAddress0,2), src=sgpr("KernArgAddress",2))) @@ -1486,7 +1492,7 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): if kernel["ProblemType"]["SupportUserArgs"]: module.addComment0("Check if custom structure pointer is null") module.add(SBranchIfNotZero("ExternalArgAddress", DataType('int64'), extValidLabel)) - module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (self.states.numStoreSgprToLoad * 4)))) + module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (numStoreSgprToLoad * 4)))) module.add(SMulI32(dst=sgpr(tmpSgprAddrM), src0=sgpr(tmpSgprNumGemm), src1=4)) # offset wgTable module.add(SMovB64(dst=sgpr(tmpSgprArgAddress0,2), src=sgpr("KernArgAddress",2))) module.add(SBranch(extValidLabelEnd.getLabelName())) @@ -1496,7 +1502,7 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): module.add(SMovB64(dst=sgpr(tmpSgprArgAddress0,2), src=sgpr("ExternalArgAddress",2))) module.add(extValidLabelEnd) else: - module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (self.states.numStoreSgprToLoad * 4)))) + module.add(SMovB32(dst=sgpr(tmpSgprArgOffsett), src=(self.argLoader.getOffset() + (numStoreSgprToLoad * 4)))) module.add(SMulI32(dst=sgpr(tmpSgprAddrM), src0=sgpr(tmpSgprNumGemm), src1=4)) # offset wgTable module.add(SMovB64(dst=sgpr(tmpSgprArgAddress0,2), src=sgpr("KernArgAddress",2))) @@ -1665,7 +1671,7 @@ def defineAndResources(self, kernel, tPA, tPB, lralwaCode): module.add(SAddCU32(dst=sgpr("KernArgAddress+1"), src0=sgpr("KernArgAddress+1"), src1=hex(0))) module.addComment0("Grouped Gemm: offset address from args_start to gemm_start") module.add(SMulI32(dst=sgpr(tmpSgprGemmIdxLeft), src0=sgpr(tmpSgprGemmIdxLeft),\ - src1=(self.argLoader.getOffset() + (self.states.numStoreSgprToLoad * 4)))) + src1=(self.argLoader.getOffset() + (numStoreSgprToLoad * 4)))) module.add(SAddU32(dst=sgpr("KernArgAddress"), src0=sgpr("KernArgAddress"), src1=sgpr(tmpSgprGemmIdxLeft))) module.add(SAddCU32(dst=sgpr("KernArgAddress+1"), src0=sgpr("KernArgAddress+1"), src1=hex(0))) module.add(moduleArgs) From d80329561838082a338f4c916d7e0853e607ee79 Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Fri, 15 Sep 2023 01:02:10 -0500 Subject: [PATCH 5/9] Set default value for bias_type if it is not initialized --- clients/include/testing_matmul.hpp | 2 +- library/include/hipblaslt-ext.hpp | 4 +- .../amd_detail/rocblaslt/src/include/handle.h | 2 +- .../rocblaslt/src/include/tensile_host.hpp | 54 ++++++++++++++++++- 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/clients/include/testing_matmul.hpp b/clients/include/testing_matmul.hpp index 031fc707ab..b98413bdda 100644 --- a/clients/include/testing_matmul.hpp +++ b/clients/include/testing_matmul.hpp @@ -853,7 +853,7 @@ void testing_matmul(const Arguments& arg) for(int gemmIdx = 0; gemmIdx < gemm_count; gemmIdx++) { - auto bias_type = static_cast(0); + auto bias_type = HIPBLASLT_DATATYPE_INVALID; void* bias_addr = nullptr; if(arg.bias_vector) { diff --git a/library/include/hipblaslt-ext.hpp b/library/include/hipblaslt-ext.hpp index a6441ff228..f393dc10e0 100644 --- a/library/include/hipblaslt-ext.hpp +++ b/library/include/hipblaslt-ext.hpp @@ -107,8 +107,8 @@ namespace hipblaslt_ext { hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT; //!< The mode of epilogue. Default is gemm. - hipblasltDatatype_t bias_data_type = static_cast( - 0); //!< The bias datatype. Only works if mode is set to bias related epilogues. + hipblasltDatatype_t bias_data_type + = HIPBLASLT_DATATYPE_INVALID; //!< The bias datatype. Only works if mode is set to bias related epilogues. int aux_ld = 0; //!< The aux leading dimension. Only works if mode is set to aux related epilogues. int aux_stride diff --git a/library/src/amd_detail/rocblaslt/src/include/handle.h b/library/src/amd_detail/rocblaslt/src/include/handle.h index 30869ffe6d..c3062b36fc 100644 --- a/library/src/amd_detail/rocblaslt/src/include/handle.h +++ b/library/src/amd_detail/rocblaslt/src/include/handle.h @@ -150,7 +150,7 @@ struct _rocblaslt_matmul_desc void* scaleD = nullptr; void* scaleE = nullptr; void* pointermode = nullptr; - hipblasltDatatype_t bias_type = static_cast(0); + hipblasltDatatype_t bias_type = HIPBLASLT_DATATYPE_INVALID; // E void* e = nullptr; int64_t lde = 0; diff --git a/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp b/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp index fff335114d..ca2eec7a3f 100644 --- a/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp +++ b/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp @@ -42,8 +42,8 @@ #include "handle.h" //#include "tuple_helper.hpp" #include "utility.hpp" -#include #include +#include // Return the value category for a value, as a double precision value, such // such as whether it's 0, 1, -1 or some other value. Tensile uses a double @@ -55,6 +55,33 @@ constexpr double value_category(const T& beta) return beta == T(0) ? 0.0 : beta == T(1) ? 1.0 : beta == T(-1) ? -1.0 : 2.0; } +template +inline constexpr auto hipblaslt_datatype = nullptr; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_16F; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_32F; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_64F; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_16B; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_8F_E4M3; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_8F_E5M2; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_8I; + +template <> +inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_32I; + /******************************************************************** * RocblasltContractionProblem captures the arguments for a GEMM-like * * contraction problem, to be passed to runContractionProblem. * @@ -224,6 +251,31 @@ struct RocblasltContractionProblem , workspaceSize(workspaceSize) , stream(stream) { + // Tensile DataTypes corresponding to rocblaslt data types + static constexpr hipblasltDatatype_t dataType_TiA = hipblaslt_datatype; + static constexpr hipblasltDatatype_t dataType_TiB = hipblaslt_datatype; + static constexpr hipblasltDatatype_t dataType_To = hipblaslt_datatype; + static constexpr hipblasltDatatype_t dataType_Tc = hipblaslt_datatype; + if(this->bias_type == HIPBLASLT_DATATYPE_INVALID) + { + if((dataType_TiA == HIPBLASLT_R_8F_E4M3 && dataType_TiB == HIPBLASLT_R_16F) + || (dataType_TiA == HIPBLASLT_R_16F && dataType_TiB == HIPBLASLT_R_8F_E4M3)) + { + this->bias_type = HIPBLASLT_R_32F; + } + else if(dataType_TiA == HIPBLASLT_R_8F_E4M3 || dataType_TiA == HIPBLASLT_R_8F_E5M2) + { + this->bias_type = HIPBLASLT_R_16F; + } + else if(dataType_Tc == HIPBLASLT_R_32I) + { + this->bias_type = HIPBLASLT_R_32F; + } + else + { + this->bias_type = dataType_To; + } + } } }; From 93b73610ce733eb9e662762fda9fa3e33807c6af Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Thu, 14 Sep 2023 03:24:09 -0500 Subject: [PATCH 6/9] Add typename compute input type for hipblaslt testing --- clients/benchmarks/client.cpp | 12 +- clients/common/cblas_interface.cpp | 1810 +++++++-------------------- clients/gtest/matmul_gtest.cpp | 12 +- clients/include/cblas_interface.hpp | 23 +- clients/include/testing_matmul.hpp | 70 +- clients/include/type_dispatch.hpp | 12 +- 6 files changed, 543 insertions(+), 1396 deletions(-) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 9ecdcfae4e..d8688958b1 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -75,17 +75,23 @@ void run_function(const func_map& map, const Arguments& arg, const std::string& // Template to dispatch testing_matmul for performance tests // the test is marked invalid when (TiA, TiB, To, Tc) not in (H/H/S, B/B/S) -template +template struct perf_matmul : hipblaslt_test_invalid { }; -template +template struct perf_matmul< TiA, TiB, To, Tc, + Tci, std::enable_if_t<(std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) @@ -100,7 +106,7 @@ struct perf_matmul< { void operator()(const Arguments& arg) { - static const func_map map = {{"matmul", testing_matmul}}; + static const func_map map = {{"matmul", testing_matmul}}; run_function(map, arg); } }; diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index 370c31da09..c19744ac16 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -58,6 +58,9 @@ void cblas_gemm(hipblasOperatio float beta, hip_bfloat16* C, int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -70,13 +73,27 @@ void cblas_gemm(hipblasOperatio host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -120,6 +137,9 @@ void cblas_gemm(hipblasOperation_t tr float beta, float* C, int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -130,13 +150,27 @@ void cblas_gemm(hipblasOperation_t tr size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); + host_vector A_float(sizeA), B_float(sizeB); - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -175,6 +209,9 @@ void cblas_gemm(hipblasOperation_t tr float beta, float* C, int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -182,13 +219,27 @@ void cblas_gemm(hipblasOperation_t tr size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_float(sizeA), B_float(sizeB); + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -205,6 +256,7 @@ void cblas_gemm(hipblasOperation_t tr beta, C, ldc); + if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) @@ -226,6 +278,9 @@ void cblas_gemm(hipblasOperation_t float beta, float* C, int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -233,13 +288,27 @@ void cblas_gemm(hipblasOperation_t size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_float(sizeA), B_float(sizeB); + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -256,6 +325,7 @@ void cblas_gemm(hipblasOperation_t beta, C, ldc); + if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) @@ -277,6 +347,9 @@ void cblas_gemm(hipblasOperation_t float beta, float* C, int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -284,13 +357,27 @@ void cblas_gemm(hipblasOperation_t size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_float(sizeA), B_float(sizeB); + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -314,7 +401,6 @@ void cblas_gemm(hipblasOperation_t C[i] = static_cast(C[i] * scaleD); } } - template <> void cblas_gemm(hipblasOperation_t transA, hipblasOperation_t transB, @@ -329,6 +415,9 @@ void cblas_gemm(hipblasOperati float beta, hipblasLtHalf* C, int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -338,74 +427,29 @@ void cblas_gemm(hipblasOperati host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) + if(AlphaVec != nullptr) { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } } else { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_bf8* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - float scaleD, - bool alt) -{ - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, HIPOperationToCBLASTanspose(transA), HIPOperationToCBLASTanspose(transB), @@ -429,1096 +473,29 @@ void cblas_gemm(hipblasOperat else { for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); + C[i] = hipblasLtHalf(C_float[i]); } } template <> -void cblas_gemm(hipblasOperation_t transA, +void cblas_gemm(hipblasOperation_t transA, hipblasOperation_t transB, int64_t m, int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblaslt_bf8* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - float scaleD, - bool alt) -{ - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - hipblaslt_f8* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - hipblaslt_f8* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB); - - for(size_t i = 0; i < sizeA; i++) - A_float[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hipblasLtHalf, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - if(alt) - { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = float_to_bfloat16_truncate(float(A[i])); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = float_to_bfloat16_truncate(float(B[i])); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = float_to_bfloat16_truncate(float(C[i])); - } - else - { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = A[i]; - for(size_t i = 0; i < sizeB; i++) - B_float[i] = B[i]; - for(size_t i = 0; i < sizeC; i++) - C_float[i] = C[i]; - } - - // just directly cast, since transA, transB are integers in the enum - //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = hipblasLtHalf(C_float[i]); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - float scaleD, - bool alt) -{ - // cblas does not support hipblasLtHalf, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB); - - if(alt) - { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = float_to_bfloat16_truncate(float(A[i])); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = float_to_bfloat16_truncate(float(B[i])); - } - else - { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = A[i]; - for(size_t i = 0; i < sizeB; i++) - B_float[i] = B[i]; - } - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const float* A, - int64_t lda, - const float* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - float scaleD, - bool alt) -{ - size_t sizeC = n * size_t(ldc); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - double alpha, - const double* A, - int64_t lda, - const double* B, - int64_t ldb, - double beta, - double* C, - int64_t ldc, - double scaleD, - bool alt) -{ - size_t sizeC = n * size_t(ldc); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_dgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - int32_t alpha, - const int8_t* A, - int64_t lda, - const int8_t* B, - int64_t ldb, - int32_t beta, - int32_t* C, - int64_t ldc, - int32_t scaleD, - bool alt) -{ - - // cblas does not support int8_t input / int32_t output, however non-overflowing - // 32-bit integer operations can be represented accurately with double-precision - // floats, so convert to doubles and downcast result down to int32_t. - // NOTE: This will not properly account for 32-bit integer overflow, however - // the result should be acceptable for testing. - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_double(sizeA); - host_vector B_double(sizeB); - host_vector C_double(sizeC); - - for(size_t i = 0; i < sizeA; i++) - A_double[i] = static_cast(A[i]); - for(size_t i = 0; i < sizeB; i++) - B_double[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_double[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_dgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_double, - lda, - B_double, - ldb, - beta, - C_double, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_double[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_double[i]); - } -} - -// AlphaVec gemm -template <> -void cblas_gemm_alphascale( - hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hip_bfloat16* A, - int64_t lda, - const hip_bfloat16* B, - int64_t ldb, - float beta, - hip_bfloat16* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; - } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i]); - } -} - -template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hip_bfloat16* A, - int64_t lda, - const hip_bfloat16* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) -{ - // cblas does not support hip_bfloat16, so convert to higher precision float - // This will give more precise result which is acceptable for testing - - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB); - - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; - } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) -{ - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; - } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_bf8* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) -{ - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; - } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} - -template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblaslt_bf8* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) -{ - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; - } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - - // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); - } -} -template <> -void cblas_gemm_alphascale( - hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) -{ - size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); - size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); - size_t sizeC = n * size_t(ldc); - - host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; - } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = static_cast(C[i]); - - // just directly cast, since transA, transB are integers in the enum - //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_sgemm(CblasColMajor, - HIPOperationToCBLASTanspose(transA), - HIPOperationToCBLASTanspose(transB), - m, - n, - k, - alpha, - A_float, - lda, - B_float, - ldb, - beta, - C_float, - ldc); - - if(scaleD != 1) - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = hipblasLtHalf(C_float[i]); - } -} - -template <> -void cblas_gemm_alphascale( - hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_bf8* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) + int64_t k, + float alpha, + const hipblaslt_bf8* A, + int64_t lda, + const hipblaslt_f8* B, + int64_t ldb, + float beta, + hipblasLtHalf* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); @@ -1526,16 +503,27 @@ void cblas_gemm_alphascale( host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1566,23 +554,24 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale( - hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblaslt_bf8* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblaslt_f8* A, + int64_t lda, + const hipblaslt_bf8* B, + int64_t ldb, + float beta, + hipblasLtHalf* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); @@ -1590,16 +579,27 @@ void cblas_gemm_alphascale( host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1630,7 +630,7 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale( +void cblas_gemm( hipblasOperation_t transA, hipblasOperation_t transB, int64_t m, @@ -1645,6 +645,8 @@ void cblas_gemm_alphascale( hipblaslt_f8* C, int64_t ldc, const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -1657,16 +659,27 @@ void cblas_gemm_alphascale( host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1697,7 +710,7 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale( +void cblas_gemm( hipblasOperation_t transA, hipblasOperation_t transB, int64_t m, @@ -1712,6 +725,8 @@ void cblas_gemm_alphascale( hipblaslt_f8* C, int64_t ldc, const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -1724,16 +739,27 @@ void cblas_gemm_alphascale( host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1764,7 +790,7 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale( +void cblas_gemm( hipblasOperation_t transA, hipblasOperation_t transB, int64_t m, @@ -1779,6 +805,8 @@ void cblas_gemm_alphascale( hipblasLtHalf* C, int64_t ldc, const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -1791,16 +819,27 @@ void cblas_gemm_alphascale( host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1831,7 +870,7 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale( +void cblas_gemm( hipblasOperation_t transA, hipblasOperation_t transB, int64_t m, @@ -1846,6 +885,8 @@ void cblas_gemm_alphascale( hipblasLtHalf* C, int64_t ldc, const float* AlphaVec, + float scaleA, + float scaleB, float scaleD, bool alt) { @@ -1858,16 +899,27 @@ void cblas_gemm_alphascale( host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_float[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1898,22 +950,24 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblaslt_f8* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblaslt_f8* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { // cblas does not support hipblasLtHalf, so convert to higher precision float // This will give more precise result which is acceptable for testing @@ -1924,14 +978,25 @@ void cblas_gemm_alphascale(hipblasOpe host_vector A_float(sizeA), B_float(sizeB); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -1957,22 +1022,24 @@ void cblas_gemm_alphascale(hipblasOpe } template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblaslt_f8* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblaslt_f8* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { // cblas does not support hipblasLtHalf, so convert to higher precision float // This will give more precise result which is acceptable for testing @@ -1983,14 +1050,25 @@ void cblas_gemm_alphascale(hipblasOpe host_vector A_float(sizeA), B_float(sizeB); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = static_cast(A[i]); - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = static_cast(B[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -2016,23 +1094,24 @@ void cblas_gemm_alphascale(hipblasOpe } template <> -void cblas_gemm_alphascale( - hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + hipblasLtHalf* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { // cblas does not support hipblasLtHalf, so convert to higher precision float // This will give more precise result which is acceptable for testing @@ -2045,8 +1124,16 @@ void cblas_gemm_alphascale( if(alt) { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = float_to_bfloat16_truncate(float(A[i])); + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + A_float[i] = float_to_bfloat16_truncate(float(A[i])) * AlphaVec[i % m]; + } + else + { + for(size_t i = 0; i < sizeA; i++) + A_float[i] = float_to_bfloat16_truncate(float(A[i])); + } for(size_t i = 0; i < sizeB; i++) B_float[i] = float_to_bfloat16_truncate(float(B[i])); for(size_t i = 0; i < sizeC; i++) @@ -2054,10 +1141,19 @@ void cblas_gemm_alphascale( } else { - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = A[i]; - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = B[i]; @@ -2065,6 +1161,8 @@ void cblas_gemm_alphascale( C_float[i] = C[i]; } + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -2095,22 +1193,24 @@ void cblas_gemm_alphascale( } template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { // cblas does not support hipblasLtHalf, so convert to higher precision float // This will give more precise result which is acceptable for testing @@ -2130,15 +1230,26 @@ void cblas_gemm_alphascale(hipblasOp } else { - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = A[i]; - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_float[i] = B[i]; } + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -2164,22 +1275,24 @@ void cblas_gemm_alphascale(hipblasOp } template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const float* A, - int64_t lda, - const float* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); @@ -2187,12 +1300,23 @@ void cblas_gemm_alphascale(hipblasOperation_t transA host_vector A_float(sizeA); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_float[i] = A[i]; - A_float[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } } + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, @@ -2218,22 +1342,24 @@ void cblas_gemm_alphascale(hipblasOperation_t transA } template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - double alpha, - const double* A, - int64_t lda, - const double* B, - int64_t ldb, - double beta, - double* C, - int64_t ldc, - const double* AlphaVec, - double scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + double alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + double beta, + double* C, + int64_t ldc, + const double* AlphaVec, + double scaleA, + double scaleB, + double scaleD, + bool alt) { size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); @@ -2241,12 +1367,23 @@ void cblas_gemm_alphascale(hipblasOperation_t tr host_vector A_double(sizeA); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else { - A_double[i] = A[i]; - A_double[i] *= AlphaVec[i % m]; + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]); + } } + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_dgemm(CblasColMajor, @@ -2272,22 +1409,24 @@ void cblas_gemm_alphascale(hipblasOperation_t tr } template <> -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - int32_t alpha, - const int8_t* A, - int64_t lda, - const int8_t* B, - int64_t ldb, - int32_t beta, - int32_t* C, - int64_t ldc, - const int32_t* AlphaVec, //cm review - int32_t scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + int32_t alpha, + const int8_t* A, + int64_t lda, + const int8_t* B, + int64_t ldb, + int32_t beta, + int32_t* C, + int64_t ldc, + const int32_t* AlphaVec, //cm review + int32_t scaleA, + int32_t scaleB, + int32_t scaleD, + bool alt) { size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); @@ -2297,16 +1436,27 @@ void cblas_gemm_alphascale(hipblasOperation_t host_vector B_double(sizeB); host_vector C_double(sizeC); - for(size_t i = 0; i < sizeA; i++) + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]) * static_cast(AlphaVec[i % m]); + } + } + else { - A_double[i] = static_cast(A[i]); - A_double[i] *= static_cast(AlphaVec[i % m]); + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]); + } } for(size_t i = 0; i < sizeB; i++) B_double[i] = static_cast(B[i]); for(size_t i = 0; i < sizeC; i++) C_double[i] = static_cast(C[i]); + alpha *= scaleA * scaleB; + // just directly cast, since transA, transB are integers in the enum // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_dgemm(CblasColMajor, diff --git a/clients/gtest/matmul_gtest.cpp b/clients/gtest/matmul_gtest.cpp index 67c6f026cb..e331b188fe 100644 --- a/clients/gtest/matmul_gtest.cpp +++ b/clients/gtest/matmul_gtest.cpp @@ -46,6 +46,7 @@ namespace typename TiB = TiA, typename To = TiB, typename Tc = To, + typename Tci = TiA, typename = void> struct matmul_testing : hipblaslt_test_invalid { @@ -53,12 +54,13 @@ namespace // When Ti = To = Tc != void, this test applies. // When converted to bool, this functor returns true. - template + template struct matmul_testing< TiA, TiB, To, Tc, + Tci, std::enable_if_t< (std::is_same{} && std::is_same{}) || (std::is_same{} && std::is_same{}) @@ -75,9 +77,9 @@ namespace void operator()(const Arguments& arg) { if(!strcmp(arg.function, "matmul")) - testing_matmul(arg); + testing_matmul(arg); else if(!strcmp(arg.function, "matmul_bad_arg")) - testing_matmul_bad_arg(arg); + testing_matmul_bad_arg(arg); else FAIL() << "Internal error: Test called with unknown function: " << arg.function; } @@ -162,11 +164,11 @@ namespace if(arg.scaleC) name << "_SC"; - + if(arg.scaleD) name << "_SD"; - if (arg.scaleE) + if(arg.scaleE) name << "_SAux"; if(arg.scaleAlpha_vector) diff --git a/clients/include/cblas_interface.hpp b/clients/include/cblas_interface.hpp index 5967967dbb..49cb9befcf 100644 --- a/clients/include/cblas_interface.hpp +++ b/clients/include/cblas_interface.hpp @@ -36,7 +36,7 @@ */ // gemm -template +template void cblas_gemm(hipblasOperation_t transA, hipblasOperation_t transB, int64_t m, @@ -50,23 +50,8 @@ void cblas_gemm(hipblasOperation_t transA, Tc beta, std::add_pointer_t C, int64_t ldc, + const Tc* AlphaVec, + Tc scaleA, + Tc scaleB, Tc scaleD, bool alt = false); - -template -void cblas_gemm_alphascale(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - Tc alpha, - const TiA* A, - int64_t lda, - const TiB* B, - int64_t ldb, - Tc beta, - std::add_pointer_t C, - int64_t ldc, - const Tc* AlphaVec, - Tc scaleD, - bool alt = false); diff --git a/clients/include/testing_matmul.hpp b/clients/include/testing_matmul.hpp index b98413bdda..7f34bf5135 100644 --- a/clients/include/testing_matmul.hpp +++ b/clients/include/testing_matmul.hpp @@ -178,7 +178,7 @@ auto _dgelu = [](auto in, auto /*arg1*/, auto /*arg2*/) -> decltype(in) { return static_cast(0.5f * tanh(xx) + x1 * x2 + 0.5f); }; -template +template void testing_matmul_bad_arg(const Arguments& arg) { const int64_t M = 128; @@ -220,7 +220,7 @@ void testing_matmul_bad_arg(const Arguments& arg) hipStream_t stream = nullptr; } -template +template void testing_matmul(const Arguments& arg) { double gpu_time_used, cpu_time_used; @@ -1418,14 +1418,14 @@ void testing_matmul(const Arguments& arg) scaleEValue, applyBias for(int gemmIdx = 0; gemmIdx < gemm_count; gemmIdx++) { - auto alphaTemp = h_alpha[gemmIdx]; - auto betaTemp = h_beta[gemmIdx]; - if(arg.scaleA) - alphaTemp *= (*hScaleA[gemmIdx])[0]; - if(arg.scaleB) - alphaTemp *= (*hScaleB[gemmIdx])[0]; + auto alpha = h_alpha[gemmIdx]; + auto betaTemp = h_beta[gemmIdx]; if(arg.scaleC) betaTemp *= (*hScaleC[gemmIdx])[0]; + auto scaleAValue = arg.scaleA ? (*hScaleA[gemmIdx])[0] : 1; + auto scaleBValue = arg.scaleB ? (*hScaleB[gemmIdx])[0] : 1; + auto scaleDValue = arg.scaleD ? (*hScaleD[gemmIdx])[0] : 1; + auto scaleEValue = arg.scaleE ? (*hScaleE[gemmIdx])[0] : 1; for(int batchIdx = 0; batchIdx < num_batches[gemmIdx]; batchIdx++) { @@ -1433,13 +1433,13 @@ void testing_matmul(const Arguments& arg) { if(arg.scaleAlpha_vector) { - cblas_gemm_alphascale( + cblas_gemm( transA, transB, M[gemmIdx], N[gemmIdx], K[gemmIdx], - alphaTemp, + alpha, *(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx, lda[gemmIdx], *(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx, @@ -1448,18 +1448,20 @@ void testing_matmul(const Arguments& arg) *(hD_gold_epl[gemmIdx]) + stride_d[gemmIdx] * batchIdx, ldd[gemmIdx], *(hScaleAlphaVec[gemmIdx]) + 0, + scaleAValue, + scaleBValue, 1, false); } else { - cblas_gemm( + cblas_gemm( transA, transB, M[gemmIdx], N[gemmIdx], K[gemmIdx], - alphaTemp, + alpha, *(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx, lda[gemmIdx], *(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx, @@ -1467,15 +1469,16 @@ void testing_matmul(const Arguments& arg) betaTemp, *(hD_gold_epl[gemmIdx]) + stride_d[gemmIdx] * batchIdx, ldd[gemmIdx], + nullptr, + scaleAValue, + scaleBValue, 1, false); } auto pos = stride_d[gemmIdx] * batchIdx; auto hEInst = arg.gradient ? hE : hE_gold; auto ePos = (hEInst[gemmIdx] == nullptr) ? nullptr : (*(hEInst[gemmIdx]) + pos); - auto scaleDValue = arg.scaleD ? (*hScaleD[gemmIdx])[0] : 1; - auto scaleEValue = arg.scaleE ? (*hScaleE[gemmIdx])[0] : 1; - auto applyBias = arg.gradient ? false : arg.bias_vector; + auto applyBias = arg.gradient ? false : arg.bias_vector; if(change_bias_type[gemmIdx] == false) { @@ -1652,24 +1655,25 @@ void testing_matmul(const Arguments& arg) } else { - auto scaleDValue = arg.scaleD ? (*hScaleD[gemmIdx])[0] : 1; - - cblas_gemm(transA, - transB, - M[gemmIdx], - N[gemmIdx], - K[gemmIdx], - alphaTemp, - *(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx, - lda[gemmIdx], - *(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx, - ldb[gemmIdx], - betaTemp, - *(hD_gold[gemmIdx]) - + stride_d[gemmIdx] * batchIdx, - ldd[gemmIdx], - scaleDValue, - false); + cblas_gemm( + transA, + transB, + M[gemmIdx], + N[gemmIdx], + K[gemmIdx], + alpha, + *(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx, + lda[gemmIdx], + *(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx, + ldb[gemmIdx], + betaTemp, + *(hD_gold[gemmIdx]) + stride_d[gemmIdx] * batchIdx, + ldd[gemmIdx], + nullptr, + scaleAValue, + scaleBValue, + scaleDValue, + false); } } } diff --git a/clients/include/type_dispatch.hpp b/clients/include/type_dispatch.hpp index ee0d5d1aca..5317dd6975 100644 --- a/clients/include/type_dispatch.hpp +++ b/clients/include/type_dispatch.hpp @@ -152,32 +152,32 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg) else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_8F_E4M3 && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { - return TEST{}(arg); + return TEST{}(arg); } else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_8F_E4M3 && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { - return TEST{}(arg); + return TEST{}(arg); } else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_16F && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { - return TEST{}(arg); + return TEST{}(arg); } else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_16F && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { - return TEST{}(arg); + return TEST{}(arg); } else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_32F && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { - return TEST{}(arg); + return TEST{}(arg); } else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_32F && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { - return TEST{}(arg); + return TEST{}(arg); } /* else if(Ti == HIPBLASLT_R_8I && To == HIPBLASLT_R_8I && Tc == HIPBLASLT_COMPUTE_I32) From 576aaa80b9cdc2dd1967d8c4231342c1f4abdbf6 Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Sun, 17 Sep 2023 22:01:25 -0500 Subject: [PATCH 7/9] Add test yamls for fp8 mfma mix precision types --- ...lk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml | 341 ++++++++++ ...lk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml | 595 ++++++++++++++++++ ...ik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml | 341 ++++++++++ ...ik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml | 341 ++++++++++ 4 files changed, 1618 insertions(+) create mode 100644 library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml create mode 100644 library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml create mode 100644 library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml create mode 100644 library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml diff --git a/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml new file mode 100644 index 0000000000..b41449d07a --- /dev/null +++ b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml @@ -0,0 +1,341 @@ +- {MinimumRequiredVersion: 4.33.0} +- aquavanjaram +- gfx941 +- [Device 0050, Device 0049] +- Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [0, 3, 2] + IndexAssignmentsB: [1, 3, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 1 + IndexUnrollB: 1 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: true + TLUB: true + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 0 + TransposeB: 1 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false +- - 1LDSBuffer: 0 + ActivationAlt: false + ActivationFuncCall: true + ActivationFused: true + AssertFree0ElementMultiple: 1 + AssertFree1ElementMultiple: 1 + AssertSummationElementMultiple: 1 + AssignedDerivedParameters: true + AssignedProblemIndependentDerivedParameters: true + BufferLoad: true + BufferStore: true + CUCount: null + ClusterLocalRead: 1 + CodeObjectVersion: V3 + CustomKernelName: '' + DepthU: 64 + DirectToLds: false + DirectToLdsA: false + DirectToLdsB: false + DirectToVgprSparseMetadata: false + EdgeType: ShiftPtr + EnableF32XdlMathOp: false + EnableMatrixInstruction: true + ExpandPointerSwap: 0 + GlobalReadPerMfma: 1 + GlobalReadVectorWidthA: 8 + GlobalReadVectorWidthB: 16 + GlobalSplitU: 1 + GlobalSplitUAlgorithm: MultipleBuffer + GlobalWriteVectorWidth: 1 + GroupLoadStore: false + GuaranteeNoPartialA: false + GuaranteeNoPartialB: false + GuaranteeNoPartialMetadata: true + ISA: [9, 4, 1] + InnerUnroll: 1 + InterleaveAlpha: 0 + KernelLanguage: Assembly + KernelNameMin: Cijk_Ailk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GSU1_MIWT1_1 + LSCA: 16 + LSCB: 16 + LSPA: 32 + LSPB: 64 + LVCA: 2 + LVCB: 1 + LVPA: 4 + LVPB: 4 + LdsBlockSizePerPadA: 0 + LdsBlockSizePerPadB: 0 + LdsBlockSizePerPadMetadata: 0 + LdsInitCVgprs: false + LdsNumElements: 4096 + LdsNumElementsAlignedA: 1024 + LdsNumElementsAlignedB: 1024 + LdsNumElementsAlignedMetadata: 0 + LdsOffsetA: 0 + LdsOffsetA_Blk: 2048 + LdsOffsetB: 1024 + LdsOffsetB_Blk: 3072 + LdsOffsetBias: 0 + LdsOffsetMetadata: 1024 + LdsOffsetMetadata_Blk: 3072 + LdsPadA: 0 + LdsPadB: 0 + LdsPadMetadata: 0 + LocalReadVectorWidth: 8 + LocalSplitU: 1 + LocalWritePerMfma: -1 + LocalWriteUseSgprA: false + LocalWriteUseSgprB: false + LoopIters: 2 + LoopUnroll: 64 + MFMA_BF16_1K: false + MIArchVgpr: false + MIBlock: [16, 16, 32, 1, 1, 1] + MIInputPerThread: 8 + MIInputPerThreadA: 8 + MIInputPerThreadB: 8 + MIInputPerThreadMetadata: 8 + MIOutputVectorWidth: 4 + MIRegPerOut: 1 + MIWaveGroup: [1, 1] + MIWaveTile: [1, 1] + MIWaveTileA: 1 + MIWaveTileB: 1 + MIWaveTileMetadata: 0 + MacroTile0: 16 + MacroTile1: 16 + MacroTileA: 16 + MacroTileB: 16 + MagicDivAlg: 2 + MatrixInstB: 1 + MatrixInstBM: 1 + MatrixInstBN: 1 + MatrixInstK: 32 + MatrixInstM: 16 + MatrixInstN: 16 + MatrixInstruction: [16, 16, 32, 1] + MaxOccupancy: 40 + MaxVgprNumber: 256 + MinVgprNumber: 0 + NoLdsWriteCode: false + NoReject: false + NoTailLoop: false + NonTemporal: -1 + NonTemporalA: 0 + NonTemporalB: 0 + NonTemporalC: 0 + NonTemporalD: 0 + NonTemporalE: 0 + NonTemporalMetadata: 0 + NumElementsPerBatchStore: 0 + NumElementsPerThread: 4 + NumGlobalWriteVectorsPerThread: 4 + NumLoadsA: 2 + NumLoadsB: 1 + NumLoadsCoalescedA: 1 + NumLoadsCoalescedB: 1 + NumLoadsPerpendicularA: 2 + NumLoadsPerpendicularB: 1 + NumThreads: 64 + OptNoLoadLoop: 0 + PackedC0IdxChars: [I] + PackedC0IndicesX: [0] + PackedC1IdxChars: [J] + PackedC1IndicesX: [1] + PrefetchGlobalRead: 2 + PrefetchLocalRead: 1 + PreloadKernArgs: false + ProblemType: + Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [0, 3, 2] + IndexAssignmentsB: [1, 3, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 1 + IndexUnrollB: 1 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: true + TLUB: true + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 0 + TransposeB: 1 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false + ScheduleGlobalRead: 1 + ScheduleIterAlg: 3 + ScheduleLocalWrite: 1 + SolutionIndex: 0 + SolutionNameMin: Cijk_Ailk_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GSU1_MIWT1_1 + SourceSwap: 1 + StaggerU: 32 + StaggerUMapping: 0 + StaggerUStride: 256 + StorePriorityOpt: 0 + StoreRemapVectorWidth: 0 + StoreSyncOpt: 0 + StoreVectorWidth: 1 + SubGroup0: 4 + SubGroup1: 16 + SubGroupA: 4 + SubGroupB: 16 + SuppressNoLoadLoop: false + ThreadTile: [1, 1] + ThreadTile0: 4 + ThreadTile1: 1 + ThreadTileA: 4 + ThreadTileB: 1 + TransposeLDS: 0 + TransposeLDSMetadata: true + UnrollMajorLDSA: 0 + UnrollMajorLDSB: 0 + UnrollMajorLDSMetadata: true + Use64bShadowLimit: 1 + UseInstOffsetForGRO: 0 + UseSgprForGRO: -1 + Valid: true + VectorStore: -1 + VectorWidthA: 1 + VectorWidthB: 1 + WaveSeparateGlobalReadA: 0 + WaveSeparateGlobalReadB: 1 + WaveSeparateGlobalReadMetadata: 0 + WavefrontSize: 64 + WorkGroup: [16, 4, 1] + WorkGroupMapping: 8 + WorkGroupReduction: false + WorkspaceCheck: [0, 0] + _DepthU: 64 + _DepthUA: 64 + _DepthUB: 64 + _DepthUMetadata: 64 + _GlobalAccumulation: null + _UseSgprForGRO: false + _VectorStore: 1 + _WorkspaceSizePerElemBias: 0 + _WorkspaceSizePerElemC: 0 + _staggerStrideShift: 2 +- [2, 3, 0, 1] +- - - [127, 128, 1, 640, 127, 127, 127, 128] + - [0, 322.3] + - - [128, 128, 1, 640, 128, 128, 128, 128] + - [0, 497.899] + - - [129, 128, 1, 640, 129, 129, 129, 128] + - [0, 505.147] +- null +- null +- DeviceEfficiency +- null +- GridBased diff --git a/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml new file mode 100644 index 0000000000..68853bc207 --- /dev/null +++ b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml @@ -0,0 +1,595 @@ +- {MinimumRequiredVersion: 4.33.0} +- aquavanjaram +- gfx941 +- [Device 0050, Device 0049] +- Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [0, 3, 2] + IndexAssignmentsB: [3, 1, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 1 + IndexUnrollB: 0 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: true + TLUB: false + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 0 + TransposeB: 0 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false +- - 1LDSBuffer: 0 + ActivationAlt: false + ActivationFuncCall: true + ActivationFused: true + AssertFree0ElementMultiple: 1 + AssertFree1ElementMultiple: 1 + AssertSummationElementMultiple: 1 + AssignedDerivedParameters: true + AssignedProblemIndependentDerivedParameters: true + BufferLoad: true + BufferStore: true + CUCount: null + ClusterLocalRead: 1 + CodeObjectVersion: V3 + CustomKernelName: '' + DepthU: 64 + DirectToLds: false + DirectToLdsA: false + DirectToLdsB: false + DirectToVgprSparseMetadata: false + EdgeType: ShiftPtr + EnableF32XdlMathOp: false + EnableMatrixInstruction: true + ExpandPointerSwap: 0 + GlobalReadPerMfma: 1 + GlobalReadVectorWidthA: 8 + GlobalReadVectorWidthB: 16 + GlobalSplitU: 1 + GlobalSplitUAlgorithm: MultipleBuffer + GlobalWriteVectorWidth: 1 + GroupLoadStore: false + GuaranteeNoPartialA: false + GuaranteeNoPartialB: true + GuaranteeNoPartialMetadata: true + ISA: [9, 4, 1] + InnerUnroll: 1 + InterleaveAlpha: 0 + KernelLanguage: Assembly + KernelNameMin: Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GRVWB16_GSU1_MIWT1_1_TLDS0 + LSCA: 16 + LSCB: 64 + LSPA: 32 + LSPB: 16 + LVCA: 2 + LVCB: 4 + LVPA: 4 + LVPB: 1 + LdsBlockSizePerPadA: 0 + LdsBlockSizePerPadB: 0 + LdsBlockSizePerPadMetadata: 0 + LdsInitCVgprs: false + LdsNumElements: 4096 + LdsNumElementsAlignedA: 1024 + LdsNumElementsAlignedB: 1024 + LdsNumElementsAlignedMetadata: 0 + LdsOffsetA: 0 + LdsOffsetA_Blk: 2048 + LdsOffsetB: 1024 + LdsOffsetB_Blk: 3072 + LdsOffsetBias: 0 + LdsOffsetMetadata: 1024 + LdsOffsetMetadata_Blk: 3072 + LdsPadA: 0 + LdsPadB: 0 + LdsPadMetadata: 0 + LocalReadVectorWidth: 8 + LocalSplitU: 1 + LocalWritePerMfma: -1 + LocalWriteUseSgprA: false + LocalWriteUseSgprB: false + LoopIters: 2 + LoopUnroll: 64 + MFMA_BF16_1K: false + MIArchVgpr: false + MIBlock: [16, 16, 32, 1, 1, 1] + MIInputPerThread: 8 + MIInputPerThreadA: 8 + MIInputPerThreadB: 8 + MIInputPerThreadMetadata: 8 + MIOutputVectorWidth: 4 + MIRegPerOut: 1 + MIWaveGroup: [1, 1] + MIWaveTile: [1, 1] + MIWaveTileA: 1 + MIWaveTileB: 1 + MIWaveTileMetadata: 0 + MacroTile0: 16 + MacroTile1: 16 + MacroTileA: 16 + MacroTileB: 16 + MagicDivAlg: 2 + MatrixInstB: 1 + MatrixInstBM: 1 + MatrixInstBN: 1 + MatrixInstK: 32 + MatrixInstM: 16 + MatrixInstN: 16 + MatrixInstruction: [16, 16, 32, 1] + MaxOccupancy: 40 + MaxVgprNumber: 256 + MinVgprNumber: 0 + NoLdsWriteCode: false + NoReject: false + NoTailLoop: false + NonTemporal: -1 + NonTemporalA: 0 + NonTemporalB: 0 + NonTemporalC: 0 + NonTemporalD: 0 + NonTemporalE: 0 + NonTemporalMetadata: 0 + NumElementsPerBatchStore: 0 + NumElementsPerThread: 4 + NumGlobalWriteVectorsPerThread: 4 + NumLoadsA: 2 + NumLoadsB: 1 + NumLoadsCoalescedA: 1 + NumLoadsCoalescedB: 1 + NumLoadsPerpendicularA: 2 + NumLoadsPerpendicularB: 1 + NumThreads: 64 + OptNoLoadLoop: 0 + PackedC0IdxChars: [I] + PackedC0IndicesX: [0] + PackedC1IdxChars: [J] + PackedC1IndicesX: [1] + PrefetchGlobalRead: 2 + PrefetchLocalRead: 1 + PreloadKernArgs: false + ProblemType: + Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [0, 3, 2] + IndexAssignmentsB: [3, 1, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 1 + IndexUnrollB: 0 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: true + TLUB: false + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 0 + TransposeB: 0 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false + ScheduleGlobalRead: 1 + ScheduleIterAlg: 3 + ScheduleLocalWrite: 1 + SolutionIndex: 0 + SolutionNameMin: Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GRVWB16_GSU1_MIWT1_1_TLDS0 + SourceSwap: 1 + StaggerU: 32 + StaggerUMapping: 0 + StaggerUStride: 256 + StorePriorityOpt: 0 + StoreRemapVectorWidth: 0 + StoreSyncOpt: 0 + StoreVectorWidth: 1 + SubGroup0: 4 + SubGroup1: 16 + SubGroupA: 4 + SubGroupB: 16 + SuppressNoLoadLoop: false + ThreadTile: [1, 1] + ThreadTile0: 4 + ThreadTile1: 1 + ThreadTileA: 4 + ThreadTileB: 1 + TransposeLDS: 0 + TransposeLDSMetadata: true + UnrollMajorLDSA: 0 + UnrollMajorLDSB: 0 + UnrollMajorLDSMetadata: true + Use64bShadowLimit: 1 + UseInstOffsetForGRO: 0 + UseSgprForGRO: -1 + Valid: true + VectorStore: -1 + VectorWidthA: 1 + VectorWidthB: 1 + WaveSeparateGlobalReadA: 0 + WaveSeparateGlobalReadB: 1 + WaveSeparateGlobalReadMetadata: 0 + WavefrontSize: 64 + WorkGroup: [16, 4, 1] + WorkGroupMapping: 8 + WorkGroupReduction: false + WorkspaceCheck: [0, 0] + _DepthU: 64 + _DepthUA: 64 + _DepthUB: 64 + _DepthUMetadata: 64 + _GlobalAccumulation: null + _UseSgprForGRO: false + _VectorStore: 1 + _WorkspaceSizePerElemBias: 0 + _WorkspaceSizePerElemC: 0 + _staggerStrideShift: 2 + - 1LDSBuffer: 0 + ActivationAlt: false + ActivationFuncCall: true + ActivationFused: true + AssertFree0ElementMultiple: 1 + AssertFree1ElementMultiple: 1 + AssertSummationElementMultiple: 1 + AssignedDerivedParameters: true + AssignedProblemIndependentDerivedParameters: true + BufferLoad: true + BufferStore: true + CUCount: null + ClusterLocalRead: 1 + CodeObjectVersion: V3 + CustomKernelName: '' + DepthU: 64 + DirectToLds: false + DirectToLdsA: false + DirectToLdsB: false + DirectToVgprSparseMetadata: false + EdgeType: ShiftPtr + EnableF32XdlMathOp: false + EnableMatrixInstruction: true + ExpandPointerSwap: 0 + GlobalReadPerMfma: 1 + GlobalReadVectorWidthA: 8 + GlobalReadVectorWidthB: 8 + GlobalSplitU: 1 + GlobalSplitUAlgorithm: MultipleBuffer + GlobalWriteVectorWidth: 1 + GroupLoadStore: false + GuaranteeNoPartialA: false + GuaranteeNoPartialB: true + GuaranteeNoPartialMetadata: true + ISA: [9, 4, 1] + InnerUnroll: 1 + InterleaveAlpha: 0 + KernelLanguage: Assembly + KernelNameMin: Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GRVWB8_GSU1_MIWT1_1_TLDS1 + LSCA: 16 + LSCB: 64 + LSPA: 32 + LSPB: 8 + LVCA: 2 + LVCB: 8 + LVPA: 4 + LVPB: 1 + LdsBlockSizePerPadA: 0 + LdsBlockSizePerPadB: 0 + LdsBlockSizePerPadMetadata: 0 + LdsInitCVgprs: false + LdsNumElements: 4096 + LdsNumElementsAlignedA: 1024 + LdsNumElementsAlignedB: 1024 + LdsNumElementsAlignedMetadata: 0 + LdsOffsetA: 0 + LdsOffsetA_Blk: 2048 + LdsOffsetB: 1024 + LdsOffsetB_Blk: 3072 + LdsOffsetBias: 0 + LdsOffsetMetadata: 1024 + LdsOffsetMetadata_Blk: 3072 + LdsPadA: 0 + LdsPadB: 0 + LdsPadMetadata: 0 + LocalReadVectorWidth: 8 + LocalSplitU: 1 + LocalWritePerMfma: -1 + LocalWriteUseSgprA: false + LocalWriteUseSgprB: false + LoopIters: 2 + LoopUnroll: 64 + MFMA_BF16_1K: false + MIArchVgpr: false + MIBlock: [16, 16, 32, 1, 1, 1] + MIInputPerThread: 8 + MIInputPerThreadA: 8 + MIInputPerThreadB: 8 + MIInputPerThreadMetadata: 8 + MIOutputVectorWidth: 4 + MIRegPerOut: 1 + MIWaveGroup: [1, 1] + MIWaveTile: [1, 1] + MIWaveTileA: 1 + MIWaveTileB: 1 + MIWaveTileMetadata: 0 + MacroTile0: 16 + MacroTile1: 16 + MacroTileA: 16 + MacroTileB: 16 + MagicDivAlg: 2 + MatrixInstB: 1 + MatrixInstBM: 1 + MatrixInstBN: 1 + MatrixInstK: 32 + MatrixInstM: 16 + MatrixInstN: 16 + MatrixInstruction: [16, 16, 32, 1] + MaxOccupancy: 40 + MaxVgprNumber: 256 + MinVgprNumber: 0 + NoLdsWriteCode: false + NoReject: false + NoTailLoop: false + NonTemporal: -1 + NonTemporalA: 0 + NonTemporalB: 0 + NonTemporalC: 0 + NonTemporalD: 0 + NonTemporalE: 0 + NonTemporalMetadata: 0 + NumElementsPerBatchStore: 0 + NumElementsPerThread: 4 + NumGlobalWriteVectorsPerThread: 4 + NumLoadsA: 2 + NumLoadsB: 2 + NumLoadsCoalescedA: 1 + NumLoadsCoalescedB: 1 + NumLoadsPerpendicularA: 2 + NumLoadsPerpendicularB: 2 + NumThreads: 64 + OptNoLoadLoop: 0 + PackedC0IdxChars: [I] + PackedC0IndicesX: [0] + PackedC1IdxChars: [J] + PackedC1IndicesX: [1] + PrefetchGlobalRead: 2 + PrefetchLocalRead: 1 + PreloadKernArgs: false + ProblemType: + Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [0, 3, 2] + IndexAssignmentsB: [3, 1, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 1 + IndexUnrollB: 0 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: true + TLUB: false + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 0 + TransposeB: 0 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false + ScheduleGlobalRead: 1 + ScheduleIterAlg: 3 + ScheduleLocalWrite: 1 + SolutionIndex: 1 + SolutionNameMin: Cijk_Ailk_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GRVWB8_GSU1_MIWT1_1_TLDS1 + SourceSwap: 1 + StaggerU: 32 + StaggerUMapping: 0 + StaggerUStride: 256 + StorePriorityOpt: 0 + StoreRemapVectorWidth: 0 + StoreSyncOpt: 0 + StoreVectorWidth: 1 + SubGroup0: 4 + SubGroup1: 16 + SubGroupA: 4 + SubGroupB: 16 + SuppressNoLoadLoop: false + ThreadTile: [1, 1] + ThreadTile0: 4 + ThreadTile1: 1 + ThreadTileA: 4 + ThreadTileB: 1 + TransposeLDS: 1 + TransposeLDSMetadata: true + UnrollMajorLDSA: false + UnrollMajorLDSB: true + UnrollMajorLDSMetadata: true + Use64bShadowLimit: 1 + UseInstOffsetForGRO: 0 + UseSgprForGRO: -1 + Valid: true + VectorStore: -1 + VectorWidthA: 1 + VectorWidthB: 1 + WaveSeparateGlobalReadA: 0 + WaveSeparateGlobalReadB: 1 + WaveSeparateGlobalReadMetadata: 0 + WavefrontSize: 64 + WorkGroup: [16, 4, 1] + WorkGroupMapping: 8 + WorkGroupReduction: false + WorkspaceCheck: [0, 0] + _DepthU: 64 + _DepthUA: 64 + _DepthUB: 64 + _DepthUMetadata: 64 + _GlobalAccumulation: null + _UseSgprForGRO: false + _VectorStore: 1 + _WorkspaceSizePerElemBias: 0 + _WorkspaceSizePerElemC: 0 + _staggerStrideShift: 2 +- [2, 3, 0, 1] +- - - [127, 128, 1, 640, 127, 127, 127, 640] + - [1, 629.011] + - - [128, 128, 1, 640, 128, 128, 128, 640] + - [1, 745.76] + - - [129, 128, 1, 640, 129, 129, 129, 640] + - [0, 652.326] +- null +- null +- DeviceEfficiency +- null +- GridBased diff --git a/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml new file mode 100644 index 0000000000..330d8798a2 --- /dev/null +++ b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml @@ -0,0 +1,341 @@ +- {MinimumRequiredVersion: 4.33.0} +- aquavanjaram +- gfx941 +- [Device 0050, Device 0049] +- Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [3, 0, 2] + IndexAssignmentsB: [1, 3, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 0 + IndexUnrollB: 1 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: false + TLUB: true + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 1 + TransposeB: 1 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false +- - 1LDSBuffer: 0 + ActivationAlt: false + ActivationFuncCall: true + ActivationFused: true + AssertFree0ElementMultiple: 1 + AssertFree1ElementMultiple: 1 + AssertSummationElementMultiple: 1 + AssignedDerivedParameters: true + AssignedProblemIndependentDerivedParameters: true + BufferLoad: true + BufferStore: true + CUCount: null + ClusterLocalRead: 1 + CodeObjectVersion: V3 + CustomKernelName: '' + DepthU: 64 + DirectToLds: false + DirectToLdsA: false + DirectToLdsB: false + DirectToVgprSparseMetadata: false + EdgeType: ShiftPtr + EnableF32XdlMathOp: false + EnableMatrixInstruction: true + ExpandPointerSwap: 0 + GlobalReadPerMfma: 1 + GlobalReadVectorWidthA: 8 + GlobalReadVectorWidthB: 16 + GlobalSplitU: 1 + GlobalSplitUAlgorithm: MultipleBuffer + GlobalWriteVectorWidth: 1 + GroupLoadStore: false + GuaranteeNoPartialA: true + GuaranteeNoPartialB: false + GuaranteeNoPartialMetadata: true + ISA: [9, 4, 1] + InnerUnroll: 1 + InterleaveAlpha: 0 + KernelLanguage: Assembly + KernelNameMin: Cijk_Alik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GSU1_MIWT1_1_TLDS1 + LSCA: 64 + LSCB: 16 + LSPA: 8 + LSPB: 64 + LVCA: 8 + LVCB: 1 + LVPA: 1 + LVPB: 4 + LdsBlockSizePerPadA: 0 + LdsBlockSizePerPadB: 0 + LdsBlockSizePerPadMetadata: 0 + LdsInitCVgprs: false + LdsNumElements: 4096 + LdsNumElementsAlignedA: 1024 + LdsNumElementsAlignedB: 1024 + LdsNumElementsAlignedMetadata: 0 + LdsOffsetA: 0 + LdsOffsetA_Blk: 2048 + LdsOffsetB: 1024 + LdsOffsetB_Blk: 3072 + LdsOffsetBias: 0 + LdsOffsetMetadata: 1024 + LdsOffsetMetadata_Blk: 3072 + LdsPadA: 0 + LdsPadB: 0 + LdsPadMetadata: 0 + LocalReadVectorWidth: 8 + LocalSplitU: 1 + LocalWritePerMfma: -1 + LocalWriteUseSgprA: false + LocalWriteUseSgprB: false + LoopIters: 2 + LoopUnroll: 64 + MFMA_BF16_1K: false + MIArchVgpr: false + MIBlock: [16, 16, 32, 1, 1, 1] + MIInputPerThread: 8 + MIInputPerThreadA: 8 + MIInputPerThreadB: 8 + MIInputPerThreadMetadata: 8 + MIOutputVectorWidth: 4 + MIRegPerOut: 1 + MIWaveGroup: [1, 1] + MIWaveTile: [1, 1] + MIWaveTileA: 1 + MIWaveTileB: 1 + MIWaveTileMetadata: 0 + MacroTile0: 16 + MacroTile1: 16 + MacroTileA: 16 + MacroTileB: 16 + MagicDivAlg: 2 + MatrixInstB: 1 + MatrixInstBM: 1 + MatrixInstBN: 1 + MatrixInstK: 32 + MatrixInstM: 16 + MatrixInstN: 16 + MatrixInstruction: [16, 16, 32, 1] + MaxOccupancy: 40 + MaxVgprNumber: 256 + MinVgprNumber: 0 + NoLdsWriteCode: false + NoReject: false + NoTailLoop: false + NonTemporal: -1 + NonTemporalA: 0 + NonTemporalB: 0 + NonTemporalC: 0 + NonTemporalD: 0 + NonTemporalE: 0 + NonTemporalMetadata: 0 + NumElementsPerBatchStore: 0 + NumElementsPerThread: 4 + NumGlobalWriteVectorsPerThread: 4 + NumLoadsA: 2 + NumLoadsB: 1 + NumLoadsCoalescedA: 1 + NumLoadsCoalescedB: 1 + NumLoadsPerpendicularA: 2 + NumLoadsPerpendicularB: 1 + NumThreads: 64 + OptNoLoadLoop: 0 + PackedC0IdxChars: [I] + PackedC0IndicesX: [0] + PackedC1IdxChars: [J] + PackedC1IndicesX: [1] + PrefetchGlobalRead: 2 + PrefetchLocalRead: 1 + PreloadKernArgs: false + ProblemType: + Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [3, 0, 2] + IndexAssignmentsB: [1, 3, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 0 + IndexUnrollB: 1 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: false + TLUB: true + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 1 + TransposeB: 1 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false + ScheduleGlobalRead: 1 + ScheduleIterAlg: 3 + ScheduleLocalWrite: 1 + SolutionIndex: 0 + SolutionNameMin: Cijk_Alik_Bjlk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GSU1_MIWT1_1_TLDS1 + SourceSwap: 1 + StaggerU: 32 + StaggerUMapping: 0 + StaggerUStride: 256 + StorePriorityOpt: 0 + StoreRemapVectorWidth: 0 + StoreSyncOpt: 0 + StoreVectorWidth: 1 + SubGroup0: 4 + SubGroup1: 16 + SubGroupA: 4 + SubGroupB: 16 + SuppressNoLoadLoop: false + ThreadTile: [1, 1] + ThreadTile0: 4 + ThreadTile1: 1 + ThreadTileA: 4 + ThreadTileB: 1 + TransposeLDS: 1 + TransposeLDSMetadata: true + UnrollMajorLDSA: true + UnrollMajorLDSB: false + UnrollMajorLDSMetadata: true + Use64bShadowLimit: 1 + UseInstOffsetForGRO: 0 + UseSgprForGRO: -1 + Valid: true + VectorStore: -1 + VectorWidthA: 1 + VectorWidthB: 1 + WaveSeparateGlobalReadA: 0 + WaveSeparateGlobalReadB: 1 + WaveSeparateGlobalReadMetadata: 0 + WavefrontSize: 64 + WorkGroup: [16, 4, 1] + WorkGroupMapping: 8 + WorkGroupReduction: false + WorkspaceCheck: [0, 0] + _DepthU: 64 + _DepthUA: 64 + _DepthUB: 64 + _DepthUMetadata: 64 + _GlobalAccumulation: null + _UseSgprForGRO: false + _VectorStore: 1 + _WorkspaceSizePerElemBias: 0 + _WorkspaceSizePerElemC: 0 + _staggerStrideShift: 2 +- [2, 3, 0, 1] +- - - [127, 128, 1, 640, 127, 127, 640, 128] + - [0, 618.54] + - - [128, 128, 1, 640, 128, 128, 640, 128] + - [0, 750.054] + - - [129, 128, 1, 640, 129, 129, 640, 128] + - [0, 697.997] +- null +- null +- DeviceEfficiency +- null +- GridBased diff --git a/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml new file mode 100644 index 0000000000..458d19eeca --- /dev/null +++ b/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/gfx941/GridBased/aquavanjaram_Cijk_Alik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV.yaml @@ -0,0 +1,341 @@ +- {MinimumRequiredVersion: 4.33.0} +- aquavanjaram +- gfx941 +- [Device 0050, Device 0049] +- Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [3, 0, 2] + IndexAssignmentsB: [3, 1, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 0 + IndexUnrollB: 0 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: false + TLUB: false + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 1 + TransposeB: 0 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false +- - 1LDSBuffer: 0 + ActivationAlt: false + ActivationFuncCall: true + ActivationFused: true + AssertFree0ElementMultiple: 1 + AssertFree1ElementMultiple: 1 + AssertSummationElementMultiple: 1 + AssignedDerivedParameters: true + AssignedProblemIndependentDerivedParameters: true + BufferLoad: true + BufferStore: true + CUCount: null + ClusterLocalRead: 1 + CodeObjectVersion: V3 + CustomKernelName: '' + DepthU: 64 + DirectToLds: false + DirectToLdsA: false + DirectToLdsB: false + DirectToVgprSparseMetadata: false + EdgeType: ShiftPtr + EnableF32XdlMathOp: false + EnableMatrixInstruction: true + ExpandPointerSwap: 0 + GlobalReadPerMfma: 1 + GlobalReadVectorWidthA: 8 + GlobalReadVectorWidthB: 8 + GlobalSplitU: 1 + GlobalSplitUAlgorithm: MultipleBuffer + GlobalWriteVectorWidth: 1 + GroupLoadStore: false + GuaranteeNoPartialA: true + GuaranteeNoPartialB: true + GuaranteeNoPartialMetadata: true + ISA: [9, 4, 1] + InnerUnroll: 1 + InterleaveAlpha: 0 + KernelLanguage: Assembly + KernelNameMin: Cijk_Alik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GRVWB8_GSU1_MIWT1_1_TLDS1 + LSCA: 64 + LSCB: 64 + LSPA: 8 + LSPB: 8 + LVCA: 8 + LVCB: 8 + LVPA: 1 + LVPB: 1 + LdsBlockSizePerPadA: 0 + LdsBlockSizePerPadB: 0 + LdsBlockSizePerPadMetadata: 0 + LdsInitCVgprs: false + LdsNumElements: 4096 + LdsNumElementsAlignedA: 1024 + LdsNumElementsAlignedB: 1024 + LdsNumElementsAlignedMetadata: 0 + LdsOffsetA: 0 + LdsOffsetA_Blk: 2048 + LdsOffsetB: 1024 + LdsOffsetB_Blk: 3072 + LdsOffsetBias: 0 + LdsOffsetMetadata: 1024 + LdsOffsetMetadata_Blk: 3072 + LdsPadA: 0 + LdsPadB: 0 + LdsPadMetadata: 0 + LocalReadVectorWidth: 8 + LocalSplitU: 1 + LocalWritePerMfma: -1 + LocalWriteUseSgprA: false + LocalWriteUseSgprB: false + LoopIters: 2 + LoopUnroll: 64 + MFMA_BF16_1K: false + MIArchVgpr: false + MIBlock: [16, 16, 32, 1, 1, 1] + MIInputPerThread: 8 + MIInputPerThreadA: 8 + MIInputPerThreadB: 8 + MIInputPerThreadMetadata: 8 + MIOutputVectorWidth: 4 + MIRegPerOut: 1 + MIWaveGroup: [1, 1] + MIWaveTile: [1, 1] + MIWaveTileA: 1 + MIWaveTileB: 1 + MIWaveTileMetadata: 0 + MacroTile0: 16 + MacroTile1: 16 + MacroTileA: 16 + MacroTileB: 16 + MagicDivAlg: 2 + MatrixInstB: 1 + MatrixInstBM: 1 + MatrixInstBN: 1 + MatrixInstK: 32 + MatrixInstM: 16 + MatrixInstN: 16 + MatrixInstruction: [16, 16, 32, 1] + MaxOccupancy: 40 + MaxVgprNumber: 256 + MinVgprNumber: 0 + NoLdsWriteCode: false + NoReject: false + NoTailLoop: false + NonTemporal: -1 + NonTemporalA: 0 + NonTemporalB: 0 + NonTemporalC: 0 + NonTemporalD: 0 + NonTemporalE: 0 + NonTemporalMetadata: 0 + NumElementsPerBatchStore: 0 + NumElementsPerThread: 4 + NumGlobalWriteVectorsPerThread: 4 + NumLoadsA: 2 + NumLoadsB: 2 + NumLoadsCoalescedA: 1 + NumLoadsCoalescedB: 1 + NumLoadsPerpendicularA: 2 + NumLoadsPerpendicularB: 2 + NumThreads: 64 + OptNoLoadLoop: 0 + PackedC0IdxChars: [I] + PackedC0IndicesX: [0] + PackedC1IdxChars: [J] + PackedC1IndicesX: [1] + PrefetchGlobalRead: 2 + PrefetchLocalRead: 1 + PreloadKernArgs: false + ProblemType: + Activation: true + ActivationComputeDataType: 0 + ActivationNoGuard: false + ActivationType: all + AllowNoFreeDims: false + AssignedDerivedParameters: true + Batched: true + BetaOnlyUseBias: true + BiasDataTypeList: [0, 7] + BiasSrc: D + ComplexConjugateA: false + ComplexConjugateB: false + ComputeDataType: 0 + DataType: 11 + DataTypeA: 4 + DataTypeB: 11 + DestDataType: 0 + F32XdlMathOp: 0 + Fp16AltImpl: false + Gradient: false + GroupedGemm: false + HighPrecisionAccumulate: true + Index0: 0 + Index01A: 0 + Index01B: 1 + Index1: 1 + IndexAssignmentsA: [3, 0, 2] + IndexAssignmentsB: [3, 1, 2] + IndexAssignmentsLD: [4, 5, 6, 7] + IndexAssignmentsMetadata: [3, 0, 2] + IndexUnroll: 3 + IndexUnrollA: 0 + IndexUnrollB: 0 + IndexUnrollM: 0 + IndicesBatch: [2] + IndicesFree: [0, 1] + IndicesSummation: [3] + MirrorDimsA: [] + MirrorDimsB: [] + MirrorDimsMetadata: [] + NumIndicesBatch: 1 + NumIndicesC: 3 + NumIndicesFree: 2 + NumIndicesLD: 4 + NumIndicesSummation: 1 + OperationType: GEMM + SetConstStrideA: [] + SetConstStrideB: [] + SetConstStrideBias: [] + SilentHighPrecisionAccumulate: false + SparseA: false + StridedBatched: true + SupportUserArgs: false + TLUA: false + TLUB: false + Tensor0: 0 + Tensor1: 1 + TileA: 0 + TileAwareSelection: false + TileB: 1 + TotalIndices: 4 + TransposeA: 1 + TransposeB: 0 + UseBeta: true + UseBias: true + UseE: false + UseInitialStridesAB: false + UseInitialStridesCD: false + UseScaleAB: true + UseScaleAlphaVec: true + UseScaleCD: false + ScheduleGlobalRead: 1 + ScheduleIterAlg: 3 + ScheduleLocalWrite: 1 + SolutionIndex: 0 + SolutionNameMin: Cijk_Alik_Bljk_HF8_F8SS_BH_BiasSB_AS_SAB_SAV_MT16x16x64_MI16x16x1_SN_GRVWB8_GSU1_MIWT1_1_TLDS1 + SourceSwap: 1 + StaggerU: 32 + StaggerUMapping: 0 + StaggerUStride: 256 + StorePriorityOpt: 0 + StoreRemapVectorWidth: 0 + StoreSyncOpt: 0 + StoreVectorWidth: 1 + SubGroup0: 4 + SubGroup1: 16 + SubGroupA: 4 + SubGroupB: 16 + SuppressNoLoadLoop: false + ThreadTile: [1, 1] + ThreadTile0: 4 + ThreadTile1: 1 + ThreadTileA: 4 + ThreadTileB: 1 + TransposeLDS: 1 + TransposeLDSMetadata: true + UnrollMajorLDSA: true + UnrollMajorLDSB: true + UnrollMajorLDSMetadata: true + Use64bShadowLimit: 1 + UseInstOffsetForGRO: 0 + UseSgprForGRO: -1 + Valid: true + VectorStore: -1 + VectorWidthA: 1 + VectorWidthB: 1 + WaveSeparateGlobalReadA: 0 + WaveSeparateGlobalReadB: 1 + WaveSeparateGlobalReadMetadata: 0 + WavefrontSize: 64 + WorkGroup: [16, 4, 1] + WorkGroupMapping: 8 + WorkGroupReduction: false + WorkspaceCheck: [0, 0] + _DepthU: 64 + _DepthUA: 64 + _DepthUB: 64 + _DepthUMetadata: 64 + _GlobalAccumulation: null + _UseSgprForGRO: 1 + _VectorStore: 1 + _WorkspaceSizePerElemBias: 0 + _WorkspaceSizePerElemC: 0 + _staggerStrideShift: 2 +- [2, 3, 0, 1] +- - - [127, 128, 1, 640, 127, 127, 640, 640] + - [0, 628.251] + - - [128, 128, 1, 640, 128, 128, 640, 640] + - [0, 727.168] + - - [129, 128, 1, 640, 129, 129, 640, 640] + - [0, 736.937] +- null +- null +- DeviceEfficiency +- null +- GridBased From b74591fb266bbc4003108bd8970526bd24fdb128 Mon Sep 17 00:00:00 2001 From: "yangwen.huang" Date: Sun, 17 Sep 2023 21:36:13 -0500 Subject: [PATCH 8/9] Add support for fp8 mix precision types --- clients/common/cblas_interface.cpp | 815 ++++++++++++++---- clients/include/type_dispatch.hpp | 42 +- .../rocblaslt/src/rocblaslt_auxiliary.cpp | 152 ++-- .../rocblaslt/src/rocblaslt_mat.hpp | 108 ++- 4 files changed, 816 insertions(+), 301 deletions(-) diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index c19744ac16..b79f734c91 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -629,6 +629,401 @@ void cblas_gemm(hipblasOperat } } +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + hipblasLtHalf* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) +{ + // cblas does not support hipblasLtHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); + size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); + size_t sizeC = n * size_t(ldc); + + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); + + if(alt) + { + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + A_float[i] = float_to_bfloat16_truncate(float(A[i])) * AlphaVec[i % m]; + } + else + { + for(size_t i = 0; i < sizeA; i++) + A_float[i] = float_to_bfloat16_truncate(float(A[i])); + } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = float_to_bfloat16_truncate(float(B[i])); + for(size_t i = 0; i < sizeC; i++) + C_float[i] = float_to_bfloat16_truncate(float(C[i])); + } + else + { + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = B[i]; + for(size_t i = 0; i < sizeC; i++) + C_float[i] = C[i]; + } + + alpha *= scaleA * scaleB; + + // just directly cast, since transA, transB are integers in the enum + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(CblasColMajor, + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C_float, + ldc); + + if(scaleD != 1) + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C_float[i] * scaleD); + } + else + { + for(size_t i = 0; i < sizeC; i++) + C[i] = hipblasLtHalf(C_float[i]); + } +} + +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) +{ + // cblas does not support hipblasLtHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); + size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); + size_t sizeC = n * size_t(ldc); + + host_vector A_float(sizeA), B_float(sizeB); + + if(alt) + { + for(size_t i = 0; i < sizeA; i++) + A_float[i] = float_to_bfloat16_truncate(float(A[i])); + for(size_t i = 0; i < sizeB; i++) + B_float[i] = float_to_bfloat16_truncate(float(B[i])); + } + else + { + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = B[i]; + } + + alpha *= scaleA * scaleB; + + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(CblasColMajor, + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C, + ldc); + + if(scaleD != 1) + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C[i] * scaleD); + } +} + +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const float* A, + int64_t lda, + const float* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) +{ + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); + size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); + size_t sizeC = n * size_t(ldc); + + host_vector A_float(sizeA); + + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(A[i]); + } + } + + alpha *= scaleA * scaleB; + + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(CblasColMajor, + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); + + if(scaleD != 1) + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C[i] * scaleD); + } +} + +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + double alpha, + const double* A, + int64_t lda, + const double* B, + int64_t ldb, + double beta, + double* C, + int64_t ldc, + const double* AlphaVec, + double scaleA, + double scaleB, + double scaleD, + bool alt) +{ + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); + size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); + size_t sizeC = n * size_t(ldc); + + host_vector A_double(sizeA); + + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]) * AlphaVec[i % m]; + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]); + } + } + + alpha *= scaleA * scaleB; + + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_dgemm(CblasColMajor, + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); + + if(scaleD != 1) + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C[i] * scaleD); + } +} + +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + int32_t alpha, + const int8_t* A, + int64_t lda, + const int8_t* B, + int64_t ldb, + int32_t beta, + int32_t* C, + int64_t ldc, + const int32_t* AlphaVec, //cm review + int32_t scaleA, + int32_t scaleB, + int32_t scaleD, + bool alt) +{ + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); + size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); + size_t sizeC = n * size_t(ldc); + + host_vector A_double(sizeA); + host_vector B_double(sizeB); + host_vector C_double(sizeC); + + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]) * static_cast(AlphaVec[i % m]); + } + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]); + } + } + for(size_t i = 0; i < sizeB; i++) + B_double[i] = static_cast(B[i]); + for(size_t i = 0; i < sizeC; i++) + C_double[i] = static_cast(C[i]); + + alpha *= scaleA * scaleB; + + // just directly cast, since transA, transB are integers in the enum + // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_dgemm(CblasColMajor, + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_double, + lda, + B_double, + ldb, + beta, + C_double, + ldc); + + if(scaleD != 1) + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C_double[i] * scaleD); + } + else + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C_double[i]); + } +} + +// Mix precision +// FP16FP8 mix FP16 in MFMA template <> void cblas_gemm( hipblasOperation_t transA, @@ -1093,25 +1488,27 @@ void cblas_gemm(hipbla } } +// FP16FP8 mix FP8 in MFMA template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - hipblasLtHalf* C, - int64_t ldc, - const float* AlphaVec, - float scaleA, - float scaleB, - float scaleD, - bool alt) +void cblas_gemm( + hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblaslt_f8* B, + int64_t ldb, + float beta, + hipblaslt_f8* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { // cblas does not support hipblasLtHalf, so convert to higher precision float // This will give more precise result which is acceptable for testing @@ -1122,44 +1519,104 @@ void cblas_gemm(hipblasOpera host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - if(alt) + if(AlphaVec != nullptr) { - if(AlphaVec != nullptr) + for(size_t i = 0; i < sizeA; i++) { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = float_to_bfloat16_truncate(float(A[i])) * AlphaVec[i % m]; + A_float[i] = static_cast(static_cast(A[i])) * AlphaVec[i % m]; } - else + } + else + { + for(size_t i = 0; i < sizeA; i++) + { + A_float[i] = static_cast(static_cast(A[i])); + } + } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = static_cast(B[i]); + for(size_t i = 0; i < sizeC; i++) + C_float[i] = static_cast(C[i]); + + alpha *= scaleA * scaleB; + + // just directly cast, since transA, transB are integers in the enum + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(CblasColMajor, + HIPOperationToCBLASTanspose(transA), + HIPOperationToCBLASTanspose(transB), + m, + n, + k, + alpha, + A_float, + lda, + B_float, + ldb, + beta, + C_float, + ldc); + + if(scaleD != 1) + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C_float[i] * scaleD); + } + else + { + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C_float[i]); + } +} + +template <> +void cblas_gemm( + hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblaslt_f8* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + hipblaslt_f8* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) +{ + // cblas does not support hipblasLtHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); + size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); + size_t sizeC = n * size_t(ldc); + + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); + + if(AlphaVec != nullptr) + { + for(size_t i = 0; i < sizeA; i++) { - for(size_t i = 0; i < sizeA; i++) - A_float[i] = float_to_bfloat16_truncate(float(A[i])); + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = float_to_bfloat16_truncate(float(B[i])); - for(size_t i = 0; i < sizeC; i++) - C_float[i] = float_to_bfloat16_truncate(float(C[i])); } else { - if(AlphaVec != nullptr) - { - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; - } - } - else + for(size_t i = 0; i < sizeA; i++) { - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - } + A_float[i] = static_cast(A[i]); } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = B[i]; - for(size_t i = 0; i < sizeC; i++) - C_float[i] = C[i]; } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = static_cast(static_cast(B[i])); + for(size_t i = 0; i < sizeC; i++) + C_float[i] = static_cast(C[i]); alpha *= scaleA * scaleB; @@ -1183,34 +1640,35 @@ void cblas_gemm(hipblasOpera if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_float[i] * scaleD); + C[i] = static_cast(C_float[i] * scaleD); } else { for(size_t i = 0; i < sizeC; i++) - C[i] = hipblasLtHalf(C_float[i]); + C[i] = static_cast(C_float[i]); } } template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const hipblasLtHalf* A, - int64_t lda, - const hipblasLtHalf* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleA, - float scaleB, - float scaleD, - bool alt) +void cblas_gemm( + hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblaslt_f8* B, + int64_t ldb, + float beta, + hipblasLtHalf* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { // cblas does not support hipblasLtHalf, so convert to higher precision float // This will give more precise result which is acceptable for testing @@ -1219,39 +1677,31 @@ void cblas_gemm(hipblasOperation_t size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_float(sizeA), B_float(sizeB); + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); - if(alt) + if(AlphaVec != nullptr) { for(size_t i = 0; i < sizeA; i++) - A_float[i] = float_to_bfloat16_truncate(float(A[i])); - for(size_t i = 0; i < sizeB; i++) - B_float[i] = float_to_bfloat16_truncate(float(B[i])); + { + A_float[i] = static_cast(static_cast(A[i])) * AlphaVec[i % m]; + } } else { - if(AlphaVec != nullptr) - { - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; - } - } - else + for(size_t i = 0; i < sizeA; i++) { - for(size_t i = 0; i < sizeA; i++) - { - A_float[i] = static_cast(A[i]); - } + A_float[i] = static_cast(static_cast(A[i])); } - for(size_t i = 0; i < sizeB; i++) - B_float[i] = B[i]; } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = static_cast(B[i]); + for(size_t i = 0; i < sizeC; i++) + C_float[i] = static_cast(C[i]); alpha *= scaleA * scaleB; // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, HIPOperationToCBLASTanspose(transA), HIPOperationToCBLASTanspose(transB), @@ -1264,41 +1714,50 @@ void cblas_gemm(hipblasOperation_t B_float, ldb, beta, - C, + C_float, ldc); if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); + C[i] = static_cast(C_float[i] * scaleD); + } + else + { + for(size_t i = 0; i < sizeC; i++) + C[i] = hipblasLtHalf(C_float[i]); } } template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - float alpha, - const float* A, - int64_t lda, - const float* B, - int64_t ldb, - float beta, - float* C, - int64_t ldc, - const float* AlphaVec, - float scaleA, - float scaleB, - float scaleD, - bool alt) +void cblas_gemm( + hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblaslt_f8* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + hipblasLtHalf* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { + // cblas does not support hipblasLtHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_float(sizeA); + host_vector A_float(sizeA), B_float(sizeB), C_float(sizeC); if(AlphaVec != nullptr) { @@ -1314,11 +1773,15 @@ void cblas_gemm(hipblasOperation_t transA, A_float[i] = static_cast(A[i]); } } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = static_cast(static_cast(B[i])); + for(size_t i = 0; i < sizeC; i++) + C_float[i] = static_cast(C[i]); alpha *= scaleA * scaleB; // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); cblas_sgemm(CblasColMajor, HIPOperationToCBLASTanspose(transA), HIPOperationToCBLASTanspose(transB), @@ -1326,76 +1789,86 @@ void cblas_gemm(hipblasOperation_t transA, n, k, alpha, - A, + A_float, lda, - B, + B_float, ldb, beta, - C, + C_float, ldc); if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); + C[i] = static_cast(C_float[i] * scaleD); + } + else + { + for(size_t i = 0; i < sizeC; i++) + C[i] = hipblasLtHalf(C_float[i]); } } template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - double alpha, - const double* A, - int64_t lda, - const double* B, - int64_t ldb, - double beta, - double* C, - int64_t ldc, - const double* AlphaVec, - double scaleA, - double scaleB, - double scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblasLtHalf* A, + int64_t lda, + const hipblaslt_f8* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { + // cblas does not support hipblasLtHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_double(sizeA); + host_vector A_float(sizeA), B_float(sizeB); if(AlphaVec != nullptr) { for(size_t i = 0; i < sizeA; i++) { - A_double[i] = static_cast(A[i]) * AlphaVec[i % m]; + A_float[i] = static_cast(static_cast(A[i])) * AlphaVec[i % m]; } } else { for(size_t i = 0; i < sizeA; i++) { - A_double[i] = static_cast(A[i]); + A_float[i] = static_cast(static_cast(A[i])); } } + for(size_t i = 0; i < sizeB; i++) + B_float[i] = static_cast(B[i]); alpha *= scaleA * scaleB; // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_dgemm(CblasColMajor, + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(CblasColMajor, HIPOperationToCBLASTanspose(transA), HIPOperationToCBLASTanspose(transB), m, n, k, alpha, - A, + A_float, lda, - B, + B_float, ldb, beta, C, @@ -1404,84 +1877,78 @@ void cblas_gemm(hipblasOperation_t transA, if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C[i] * scaleD); + C[i] = static_cast(C[i] * scaleD); } } template <> -void cblas_gemm(hipblasOperation_t transA, - hipblasOperation_t transB, - int64_t m, - int64_t n, - int64_t k, - int32_t alpha, - const int8_t* A, - int64_t lda, - const int8_t* B, - int64_t ldb, - int32_t beta, - int32_t* C, - int64_t ldc, - const int32_t* AlphaVec, //cm review - int32_t scaleA, - int32_t scaleB, - int32_t scaleD, - bool alt) +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int64_t m, + int64_t n, + int64_t k, + float alpha, + const hipblaslt_f8* A, + int64_t lda, + const hipblasLtHalf* B, + int64_t ldb, + float beta, + float* C, + int64_t ldc, + const float* AlphaVec, + float scaleA, + float scaleB, + float scaleD, + bool alt) { + // cblas does not support hipblasLtHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda); size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb); size_t sizeC = n * size_t(ldc); - host_vector A_double(sizeA); - host_vector B_double(sizeB); - host_vector C_double(sizeC); + host_vector A_float(sizeA), B_float(sizeB); if(AlphaVec != nullptr) { for(size_t i = 0; i < sizeA; i++) { - A_double[i] = static_cast(A[i]) * static_cast(AlphaVec[i % m]); + A_float[i] = static_cast(A[i]) * AlphaVec[i % m]; } } else { for(size_t i = 0; i < sizeA; i++) { - A_double[i] = static_cast(A[i]); + A_float[i] = static_cast(A[i]); } } for(size_t i = 0; i < sizeB; i++) - B_double[i] = static_cast(B[i]); - for(size_t i = 0; i < sizeC; i++) - C_double[i] = static_cast(C[i]); + B_float[i] = static_cast(static_cast(B[i])); alpha *= scaleA * scaleB; // just directly cast, since transA, transB are integers in the enum - // printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); - cblas_dgemm(CblasColMajor, + //printf("transA: hipblaslt =%d, cblas=%d\n", transA, HIPOperationToCBLASTanspose(transA) ); + cblas_sgemm(CblasColMajor, HIPOperationToCBLASTanspose(transA), HIPOperationToCBLASTanspose(transB), m, n, k, alpha, - A_double, + A_float, lda, - B_double, + B_float, ldb, beta, - C_double, + C, ldc); if(scaleD != 1) { for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_double[i] * scaleD); - } - else - { - for(size_t i = 0; i < sizeC; i++) - C[i] = static_cast(C_double[i]); + C[i] = static_cast(C[i] * scaleD); } } diff --git a/clients/include/type_dispatch.hpp b/clients/include/type_dispatch.hpp index 5317dd6975..96301cdec6 100644 --- a/clients/include/type_dispatch.hpp +++ b/clients/include/type_dispatch.hpp @@ -149,6 +149,16 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg) { return TEST{}(arg); } + /* + else if(Ti == HIPBLASLT_R_8I && To == HIPBLASLT_R_8I && Tc == HIPBLASLT_COMPUTE_I32) + { + return TEST{}(arg); + } + */ + else if(TiA == HIPBLASLT_R_8I && To == HIPBLASLT_R_32I && Tc == HIPBLASLT_COMPUTE_I32) + { + return TEST{}(arg); + } else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_8F_E4M3 && Tc == HIPBLASLT_COMPUTE_F32_FAST_F16) { @@ -179,15 +189,35 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg) { return TEST{}(arg); } - /* - else if(Ti == HIPBLASLT_R_8I && To == HIPBLASLT_R_8I && Tc == HIPBLASLT_COMPUTE_I32) + else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_8F_E4M3 + && Tc == HIPBLASLT_COMPUTE_F32) { - return TEST{}(arg); + return TEST{}(arg); } - */ - else if(TiA == HIPBLASLT_R_8I && To == HIPBLASLT_R_32I && Tc == HIPBLASLT_COMPUTE_I32) + else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_8F_E4M3 + && Tc == HIPBLASLT_COMPUTE_F32) { - return TEST{}(arg); + return TEST{}(arg); + } + else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_16F + && Tc == HIPBLASLT_COMPUTE_F32) + { + return TEST{}(arg); + } + else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_16F + && Tc == HIPBLASLT_COMPUTE_F32) + { + return TEST{}(arg); + } + else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_32F + && Tc == HIPBLASLT_COMPUTE_F32) + { + return TEST{}(arg); + } + else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_32F + && Tc == HIPBLASLT_COMPUTE_F32) + { + return TEST{}(arg); } } return TEST{}(arg); diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp index d12fa07273..2979a9365f 100644 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp +++ b/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp @@ -35,8 +35,8 @@ #include #endif -#include #include +#include #include #define TO_STR2(x) #x @@ -1392,7 +1392,8 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float* alphaf = (float*)alpha; float* betaf = (float*)beta; @@ -1407,14 +1408,14 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand alphaf, betaf, algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); + status = isSolutionSupported( + handle, prob, gemmData, algo, workspaceSizeInBytes); } } else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float* alphaf = (float*)alpha; float* betaf = (float*)beta; @@ -1436,24 +1437,23 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand } else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float* alphaf = (float*)alpha; float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); + auto prob + = construct_rocblaslt_problem( + matmul_descr, + matA, + matB, + matC, + matD, + alphaf, + betaf, + algo->max_workspace_bytes); + status = isSolutionSupported( + handle, prob, gemmData, algo, workspaceSizeInBytes); } } } @@ -1461,7 +1461,8 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float* alphaf = (float*)alpha; float* betaf = (float*)beta; @@ -1476,14 +1477,14 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand alphaf, betaf, algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); + status = isSolutionSupported( + handle, prob, gemmData, algo, workspaceSizeInBytes); } } else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float* alphaf = (float*)alpha; float* betaf = (float*)beta; @@ -1505,24 +1506,23 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand } else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float* alphaf = (float*)alpha; float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); + auto prob + = construct_rocblaslt_problem( + matmul_descr, + matA, + matB, + matC, + matD, + alphaf, + betaf, + algo->max_workspace_bytes); + status = isSolutionSupported( + handle, prob, gemmData, algo, workspaceSizeInBytes); } } } @@ -1983,7 +1983,8 @@ rocblaslt_status { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float alpha = 1.0; float beta = 1.0; @@ -2010,7 +2011,8 @@ rocblaslt_status } else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float alpha = 1.0; float beta = 1.0; @@ -2037,21 +2039,21 @@ rocblaslt_status } else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float alpha = 1.0; float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); + auto prob + = construct_rocblaslt_problem( + matmul_desc, + matA, + matB, + matC, + matD, + &alpha, + &beta, + pref->max_workspace_bytes); status = getBestSolutions( prob, handle, @@ -2067,7 +2069,8 @@ rocblaslt_status { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float alpha = 1.0; float beta = 1.0; @@ -2094,7 +2097,8 @@ rocblaslt_status } else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float alpha = 1.0; float beta = 1.0; @@ -2121,21 +2125,21 @@ rocblaslt_status } else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { float alpha = 1.0; float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); + auto prob + = construct_rocblaslt_problem( + matmul_desc, + matA, + matB, + matC, + matD, + &alpha, + &beta, + pref->max_workspace_bytes); status = getBestSolutions( prob, handle, @@ -2788,14 +2792,13 @@ std::string rocblaslt_internal_get_arch_name() return ArchName{}(deviceProperties); } -bool rocblaslt_internal_test_path(const std::string &path) +bool rocblaslt_internal_test_path(const std::string& path) { #ifdef WIN32 return ((_access(path.c_str(), 4) != -1) || (_access(path.c_str(), 6) != -1)); #else return access(path.c_str(), R_OK) == 0; #endif - } #ifndef WIN32 @@ -2803,7 +2806,8 @@ int hipblaslt_dl_iterate_phdr_callback(struct dl_phdr_info* hdr_info, size_t siz { // uncomment to see all dependent .so files // fprintf(stderr, "hipblaslt so file: %s\n", hdr_info->dlpi_name); - std::pair *typedData = reinterpret_cast *>(data); + std::pair* typedData + = reinterpret_cast*>(data); if(hdr_info->dlpi_name && strstr(hdr_info->dlpi_name, typedData->second.c_str())) { typedData->first.assign(hdr_info->dlpi_name); @@ -2813,9 +2817,9 @@ int hipblaslt_dl_iterate_phdr_callback(struct dl_phdr_info* hdr_info, size_t siz } #endif -std::string rocblaslt_internal_get_so_path(const std::string &keyword) +std::string rocblaslt_internal_get_so_path(const std::string& keyword) { std::pair result{"", keyword}; dl_iterate_phdr(hipblaslt_dl_iterate_phdr_callback, &result); return result.first; -} \ No newline at end of file +} diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp index 8d7133ca99..8dbb8e28b4 100644 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp +++ b/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp @@ -647,11 +647,11 @@ inline rocblaslt_status rocblaslt_matmul_template(rocblaslt_handle h { rocblaslt_status rs_status = rocblaslt_status_not_implemented; -#define EX_TYPECASTING_PARM \ - handle, trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, \ - beta, c, ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, \ - batch_count, strided_batch, grouped_gemm, gradient, compute_type, algo, workspace, \ - workspaceSizeInBytes, bias, scaleA, scaleB, scaleC, scaleD, scaleE, scaleAlphaVec, \ +#define EX_TYPECASTING_PARM \ + handle, trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, \ + beta, c, ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, \ + batch_count, strided_batch, grouped_gemm, gradient, compute_type, algo, workspace, \ + workspaceSizeInBytes, bias, scaleA, scaleB, scaleC, scaleD, scaleE, scaleAlphaVec, \ bias_type, epilogue, gemmData, stream if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) @@ -811,7 +811,8 @@ inline rocblaslt_status rocblaslt_matmul_template(rocblaslt_handle h { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); + rs_status + = rocblaslt_matmul_typecasting( + EX_TYPECASTING_PARM); } } } @@ -844,7 +846,8 @@ inline rocblaslt_status rocblaslt_matmul_template(rocblaslt_handle h { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); + rs_status + = rocblaslt_matmul_typecasting( + EX_TYPECASTING_PARM); } } } @@ -927,10 +931,10 @@ inline rocblaslt_status rocblaslt_gemm_create_template_cpp(hipblasOperation_t { rocblaslt_status rs_status = rocblaslt_status_not_implemented; -#define EX_TYPECASTING_PARM_GEMM_CPP \ - trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, beta, c, \ - ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, batch_count, \ - strided_batch, grouped_gemm, gradient, compute_type, bias, scaleA, scaleB, scaleC, scaleD,\ +#define EX_TYPECASTING_PARM_GEMM_CPP \ + trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, beta, c, \ + ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, batch_count, \ + strided_batch, grouped_gemm, gradient, compute_type, bias, scaleA, scaleB, scaleC, scaleD, \ scaleE, scaleAlphaVec, bias_type, epilogue, gemmData, gemmCount if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) @@ -1021,7 +1025,8 @@ inline rocblaslt_status rocblaslt_gemm_create_template_cpp(hipblasOperation_t { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); + rs_status + = rocblaslt_gemm_create_typecasting( + EX_TYPECASTING_PARM_GEMM_CPP); } } } @@ -1054,7 +1060,8 @@ inline rocblaslt_status rocblaslt_gemm_create_template_cpp(hipblasOperation_t { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); + rs_status + = rocblaslt_gemm_create_typecasting( + EX_TYPECASTING_PARM_GEMM_CPP); } } } @@ -1138,10 +1146,10 @@ inline rocblaslt_status { rocblaslt_status rs_status = rocblaslt_status_not_implemented; -#define EX_TYPECASTING_PARM_GroupedGemm_CPP \ - trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, beta, c, \ - ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, batch_count, \ - strided_batch, grouped_gemm, compute_type, gradient, bias, scaleA, scaleB, scaleC, scaleD,\ +#define EX_TYPECASTING_PARM_GroupedGemm_CPP \ + trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, beta, c, \ + ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, batch_count, \ + strided_batch, grouped_gemm, compute_type, gradient, bias, scaleA, scaleB, scaleC, scaleD, \ scaleE, scaleAlphaVec, bias_type, epilogue, gemmData, gemmCount if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) @@ -1238,7 +1246,8 @@ inline rocblaslt_status { if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) { - if(compute_type == rocblaslt_compute_f32_fast_f16) + if(compute_type == rocblaslt_compute_f32_fast_f16 + || compute_type == rocblaslt_compute_f32) { rs_status = rocblaslt_groupedgemm_create_typecasting Date: Mon, 18 Sep 2023 09:56:19 -0500 Subject: [PATCH 9/9] Remove RocblasltContractionProblem template --- .../src/include/rocblaslt_mat_utils.hpp | 29 +- .../rocblaslt/src/include/tensile_host.hpp | 231 ++- .../rocblaslt/src/rocblaslt_auxiliary.cpp | 1608 ++--------------- .../rocblaslt/src/rocblaslt_mat.cpp | 435 ++++- .../rocblaslt/src/rocblaslt_mat.hpp | 1331 -------------- .../amd_detail/rocblaslt/src/tensile_host.cpp | 465 ++--- 6 files changed, 756 insertions(+), 3343 deletions(-) delete mode 100644 library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp diff --git a/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp b/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp index b86777f70e..7d3adbe283 100644 --- a/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp +++ b/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp @@ -250,12 +250,16 @@ inline rocblaslt_status rocblaslt_matmul_valid_args(const rocblaslt_matmul_desc int64_t& m, int64_t& n, int64_t& k, + hipblasltDatatype_t& a_type, int64_t& lda, int64_t& batch_stride_a, + hipblasltDatatype_t& b_type, int64_t& ldb, int64_t& batch_stride_b, + hipblasltDatatype_t& c_type, int64_t& ldc, int64_t& batch_stride_c, + hipblasltDatatype_t& d_type, int64_t& ldd, int64_t& batch_stride_d, int64_t& lde, @@ -274,16 +278,19 @@ inline rocblaslt_status rocblaslt_matmul_valid_args(const rocblaslt_matmul_desc int64_t num_rows_a = matA->m; int64_t num_cols_a = matA->n; int num_batches_a = matA->batch_count; + a_type = matA->type; lda = matA->ld; batch_stride_a = matA->batch_stride; // matrix B int num_batches_b = matB->batch_count; + b_type = matB->type; ldb = matB->ld; batch_stride_b = matB->batch_stride; // matrix C int num_batches_c = matC->batch_count; + c_type = matC->type; ldc = matC->ld; batch_stride_c = matC->batch_stride; @@ -291,6 +298,7 @@ inline rocblaslt_status rocblaslt_matmul_valid_args(const rocblaslt_matmul_desc int64_t num_rows_d = matD->m; int64_t num_cols_d = matD->n; int num_batches_d = matD->batch_count; + d_type = matD->type; ldd = matD->ld; batch_stride_d = matD->batch_stride; @@ -352,10 +360,23 @@ inline rocblaslt_status rocblaslt_matmul_valid_args(const rocblaslt_matmul_desc return status; } -template -inline int rocblaslt_get_matmul_alg_config_max_id(hipblasOperation_t opA, hipblasOperation_t opB) +// Assign 1 to onePtr then set set the address to dst. +inline void setTo1(const rocblaslt_compute_type& compute_type, const void* onePtr, const void** dst) { - // TODO - return true; + if(compute_type == rocblaslt_compute_f64) + { + *((double*)onePtr) = 1.f; + *dst = onePtr; + } + else if(compute_type == rocblaslt_compute_i32) + { + *((int32_t*)onePtr) = 1.f; + *dst = onePtr; + } + else + { + *((float*)onePtr) = 1.f; + *dst = onePtr; + } } #endif diff --git a/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp b/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp index ca2eec7a3f..2786b724f6 100644 --- a/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp +++ b/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp @@ -55,38 +55,10 @@ constexpr double value_category(const T& beta) return beta == T(0) ? 0.0 : beta == T(1) ? 1.0 : beta == T(-1) ? -1.0 : 2.0; } -template -inline constexpr auto hipblaslt_datatype = nullptr; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_16F; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_32F; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_64F; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_16B; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_8F_E4M3; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_8F_E5M2; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_8I; - -template <> -inline constexpr auto hipblaslt_datatype = HIPBLASLT_R_32I; - /******************************************************************** * RocblasltContractionProblem captures the arguments for a GEMM-like * * contraction problem, to be passed to runContractionProblem. * ********************************************************************/ -template struct RocblasltContractionProblem { hipblasOperation_t trans_a; @@ -100,39 +72,43 @@ struct RocblasltContractionProblem size_t n; size_t k; - const Tc* alpha; - - const TiA* A; - const TiA* const* batch_A; - size_t row_stride_a; - size_t col_stride_a; - size_t batch_stride_a; - - const TiB* B; - const TiB* const* batch_B; - size_t row_stride_b; - size_t col_stride_b; - size_t batch_stride_b; - - const Tc* beta; - - const To* C; - const To* const* batch_C; - size_t row_stride_c; - size_t col_stride_c; - size_t batch_stride_c; - - To* D; - To* const* batch_D; - size_t row_stride_d; - size_t col_stride_d; - size_t batch_stride_d; - - Tc* E; - Tc* const* batch_E; - size_t row_stride_e; - size_t col_stride_e; - size_t batch_stride_e; + const void* alpha; + + hipblasltDatatype_t a_type; + const void* A; + const void* const* batch_A; + size_t row_stride_a; + size_t col_stride_a; + size_t batch_stride_a; + + hipblasltDatatype_t b_type; + const void* B; + const void* const* batch_B; + size_t row_stride_b; + size_t col_stride_b; + size_t batch_stride_b; + + const void* beta; + + hipblasltDatatype_t c_type; + const void* C; + const void* const* batch_C; + size_t row_stride_c; + size_t col_stride_c; + size_t batch_stride_c; + + hipblasltDatatype_t d_type; + void* D; + void* const* batch_D; + size_t row_stride_d; + size_t col_stride_d; + size_t batch_stride_d; + + void* E; + void* const* batch_E; + size_t row_stride_e; + size_t col_stride_e; + size_t batch_stride_e; size_t batch_count; bool strided_batch; @@ -142,12 +118,12 @@ struct RocblasltContractionProblem rocblaslt_compute_type compute_type; const void* bias; - const Tc* scaleA; - const Tc* scaleB; - const Tc* scaleC; - const Tc* scaleD; - const Tc* scaleE; - const Tc* scaleAlphaVec; + const void* scaleA; + const void* scaleB; + const void* scaleC; + const void* scaleD; + const void* scaleE; + const void* scaleAlphaVec; hipblasltDatatype_t bias_type; rocblaslt_epilogue epilogue; void* workspace; @@ -162,26 +138,30 @@ struct RocblasltContractionProblem int64_t m, int64_t n, int64_t k, - const Tc* alpha, - const TiA* A, - const TiA* const* batch_A, + const void* alpha, + hipblasltDatatype_t a_type, + const void* A, + const void* const* batch_A, int64_t ld_a, int64_t batch_stride_a, - const TiB* B, - const TiB* const* batch_B, + hipblasltDatatype_t b_type, + const void* B, + const void* const* batch_B, int64_t ld_b, int64_t batch_stride_b, - const Tc* beta, - const To* C, - const To* const* batch_C, + const void* beta, + hipblasltDatatype_t c_type, + const void* C, + const void* const* batch_C, int64_t ld_c, int64_t batch_stride_c, - To* D, - To* const* batch_D, + hipblasltDatatype_t d_type, + void* D, + void* const* batch_D, int64_t ld_d, int64_t batch_stride_d, - Tc* E, - Tc* const* batch_E, + void* E, + void* const* batch_E, int64_t ld_e, int64_t batch_stride_e, int64_t batch_count, @@ -190,12 +170,12 @@ struct RocblasltContractionProblem bool gradient, rocblaslt_compute_type compute_type, const void* bias, - const Tc* scaleA, - const Tc* scaleB, - const Tc* scaleC, - const Tc* scaleD, - const Tc* scaleE, - const Tc* scaleAlphaVec, + const void* scaleA, + const void* scaleB, + const void* scaleC, + const void* scaleD, + const void* scaleE, + const void* scaleAlphaVec, hipblasltDatatype_t bias_type, rocblaslt_epilogue epilogue, void* workspace, @@ -207,22 +187,26 @@ struct RocblasltContractionProblem , n(n) , k(k) , alpha(alpha) + , a_type(a_type) , A(A) , batch_A(batch_A) , row_stride_a(1) , col_stride_a(ld_a) , batch_stride_a(batch_stride_a) + , b_type(b_type) , B(B) , batch_B(batch_B) , row_stride_b(1) , col_stride_b(ld_b) , batch_stride_b(batch_stride_b) , beta(beta) + , c_type(c_type) , C(C) , batch_C(batch_C) , row_stride_c(1) , col_stride_c(ld_c) , batch_stride_c(batch_stride_c) + , d_type(d_type) , D(D) , batch_D(batch_D) , row_stride_d(1) @@ -251,29 +235,24 @@ struct RocblasltContractionProblem , workspaceSize(workspaceSize) , stream(stream) { - // Tensile DataTypes corresponding to rocblaslt data types - static constexpr hipblasltDatatype_t dataType_TiA = hipblaslt_datatype; - static constexpr hipblasltDatatype_t dataType_TiB = hipblaslt_datatype; - static constexpr hipblasltDatatype_t dataType_To = hipblaslt_datatype; - static constexpr hipblasltDatatype_t dataType_Tc = hipblaslt_datatype; if(this->bias_type == HIPBLASLT_DATATYPE_INVALID) { - if((dataType_TiA == HIPBLASLT_R_8F_E4M3 && dataType_TiB == HIPBLASLT_R_16F) - || (dataType_TiA == HIPBLASLT_R_16F && dataType_TiB == HIPBLASLT_R_8F_E4M3)) + if((this->a_type == HIPBLASLT_R_8F_E4M3 && this->b_type == HIPBLASLT_R_16F) + || (this->a_type == HIPBLASLT_R_16F && this->b_type == HIPBLASLT_R_8F_E4M3)) { this->bias_type = HIPBLASLT_R_32F; } - else if(dataType_TiA == HIPBLASLT_R_8F_E4M3 || dataType_TiA == HIPBLASLT_R_8F_E5M2) + else if(this->a_type == HIPBLASLT_R_8F_E4M3 || this->a_type == HIPBLASLT_R_8F_E5M2) { this->bias_type = HIPBLASLT_R_16F; } - else if(dataType_Tc == HIPBLASLT_R_32I) + else if(this->compute_type == rocblaslt_compute_i32) { this->bias_type = HIPBLASLT_R_32F; } else { - this->bias_type = dataType_To; + this->bias_type = this->d_type; } } } @@ -294,22 +273,18 @@ void initTensileGemmData(rocblaslt_handle handle, /******************************************************************************* * runContractionProblem() solves a RocblasltContractionProblem * *******************************************************************************/ -template -rocblaslt_status runContractionProblem(rocblaslt_handle handle, - const rocblaslt_matmul_algo* algo, - RocblasltContractionProblem const& problem, - std::shared_ptr gemmData); - -template -rocblaslt_status gemmCreate(RocblasltContractionProblem const& problem, - std::shared_ptr& gemmData, - size_t& gemmCount); - -template -rocblaslt_status - groupedGemmCreate(std::vector>& probs, - std::shared_ptr& gemmData, - size_t& gemmCount); +rocblaslt_status runContractionProblem(rocblaslt_handle handle, + const rocblaslt_matmul_algo* algo, + RocblasltContractionProblem const& problem, + std::shared_ptr gemmData); + +rocblaslt_status gemmCreate(RocblasltContractionProblem const& problem, + std::shared_ptr& gemmData, + size_t& gemmCount); + +rocblaslt_status groupedGemmCreate(std::vector& probs, + std::shared_ptr& gemmData, + size_t& gemmCount); rocblaslt_status makeArgument(rocblaslt_handle handle, const rocblaslt::RocGemmType gemmType, @@ -360,15 +335,13 @@ inline bool& rocblaslt_suppress_tensile_error_messages() return t_suppress; } -template -rocblaslt_status getAllSolutions(RocblasltContractionProblem& prob, +rocblaslt_status getAllSolutions(RocblasltContractionProblem& prob, rocblaslt_handle handle, std::vector& heuristicResults, size_t maxWorkSpaceBytes); -template -rocblaslt_status getAllSolutions(std::vector>& probs, - rocblaslt_handle handle, +rocblaslt_status getAllSolutions(std::vector& probs, + rocblaslt_handle handle, std::vector& heuristicResults, size_t maxWorkSpaceBytes); @@ -378,12 +351,11 @@ rocblaslt_status std::vector& heuristicResults, size_t maxWorkSpaceBytes); -template -rocblaslt_status isSolutionSupported(rocblaslt_handle handle, - RocblasltContractionProblem& prob, - std::shared_ptr gemmData, - rocblaslt_matmul_algo* algo, - size_t* workspaceSizeInBytes); +rocblaslt_status isSolutionSupported(rocblaslt_handle handle, + RocblasltContractionProblem& prob, + std::shared_ptr gemmData, + rocblaslt_matmul_algo* algo, + size_t* workspaceSizeInBytes); rocblaslt_status isSolutionSupported(rocblaslt_handle handle, const rocblaslt::RocGemmType& gemmType, @@ -395,14 +367,13 @@ rocblaslt_status isSolutionSupported(rocblaslt_handle handle, * getBestSolutions() calls finTopSolutions from Tensile and converts to * * rocblaslt_matmul_heuristic_result * *******************************************************************************/ -template -rocblaslt_status getBestSolutions(RocblasltContractionProblem prob, - rocblaslt_handle handle, - std::shared_ptr gemmData, - int requestedAlgoCount, - rocblaslt_matmul_heuristic_result heuristicResultsArray[], - int* returnAlgoCount, - size_t maxWorkSpaceBytes); +rocblaslt_status getBestSolutions(RocblasltContractionProblem const& prob, + rocblaslt_handle handle, + std::shared_ptr gemmData, + int requestedAlgoCount, + rocblaslt_matmul_heuristic_result heuristicResultsArray[], + int* returnAlgoCount, + size_t maxWorkSpaceBytes); rocblaslt_status getBestSolutions(rocblaslt_handle handle, rocblaslt::RocGemmType gemmType, diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp index 2979a9365f..f7c2914d0a 100644 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp +++ b/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp @@ -42,26 +42,44 @@ #define TO_STR2(x) #x #define TO_STR(x) TO_STR2(x) +inline void assignAlphaBeta1(const rocblaslt_compute_type& compute_type, void* alpha, void* beta) +{ + if(compute_type == rocblaslt_compute_f64) + { + *((double*)alpha) = 1.f; + *((double*)beta) = 1.f; + } + else if(compute_type == rocblaslt_compute_i32) + { + *((int32_t*)alpha) = 1.f; + *((int32_t*)beta) = 1.f; + } + else + { + *((float*)alpha) = 1.f; + *((float*)beta) = 1.f; + } +} + /****************************************************************************** * construct_rocblaslt_problem creates RocblasltContractionProblem from mat * * layout and descriptor for Tensile's findTopSolutions. * ******************************************************************************/ -template -RocblasltContractionProblem - construct_rocblaslt_problem(const rocblaslt_matmul_desc matmul_descr, - rocblaslt_matrix_layout matA, - rocblaslt_matrix_layout matB, - rocblaslt_matrix_layout matC, - rocblaslt_matrix_layout matD, - const Tc* alpha, - const Tc* beta, - size_t maxWorkSpaceBytes) +RocblasltContractionProblem construct_rocblaslt_problem(const rocblaslt_matmul_desc matmul_descr, + rocblaslt_matrix_layout matA, + rocblaslt_matrix_layout matB, + rocblaslt_matrix_layout matC, + rocblaslt_matrix_layout matD, + const void* alpha, + const void* beta, + size_t maxWorkSpaceBytes) { int8_t dummy; const void* dummy_ptr = &dummy; int64_t m, n, k, lda, ldb, ldc, ldd, lde, batch_stride_a, batch_stride_b, batch_stride_c, batch_stride_d, batch_stride_e; hipblasltDatatype_t bias_type; + hipblasltDatatype_t a_type, b_type, c_type, d_type; rocblaslt_compute_type compute_type; void * bias = nullptr, *scaleAlphaVec = nullptr, *e = nullptr; bool gradient = false; @@ -79,12 +97,16 @@ RocblasltContractionProblem m, n, k, + a_type, lda, batch_stride_a, + b_type, ldb, batch_stride_b, + c_type, ldc, batch_stride_c, + d_type, ldd, batch_stride_d, lde, @@ -117,54 +139,60 @@ RocblasltContractionProblem constexpr bool strided_batch = true; constexpr bool grouped_gemm = false; - Tc alpha_1 = 1.0; // use dScaleAlphaVec instead, original alpha => 1.0 + int8_t alpha_1[16] = {0}; // use dScaleAlphaVec instead, original alpha => 1.0 if(scaleAlphaVec) - alpha = &alpha_1; + { + setTo1(matmul_descr->compute_type, (void*)alpha_1, &alpha); + } - RocblasltContractionProblem problem{opA, - opB, - m, - n, - k, - alpha, - nullptr, - nullptr, - lda, - batch_stride_a, - nullptr, - nullptr, - ldb, - batch_stride_b, - beta, - nullptr, - nullptr, - ldc, - batch_stride_c, - nullptr, - nullptr, - ldd, - batch_stride_d, - (Tc*)e, - nullptr, - lde, - batch_stride_e, - num_batches_a, - strided_batch, - grouped_gemm, - gradient, - compute_type, - bias, - (const Tc*)scaleA, - (const Tc*)scaleB, - (const Tc*)scaleC, - (const Tc*)scaleD, - (const Tc*)scaleE, - (const Tc*)scaleAlphaVec, - bias_type, - epilogue, - nullptr, - maxWorkSpaceBytes, - nullptr}; + RocblasltContractionProblem problem{opA, + opB, + m, + n, + k, + alpha, + a_type, + nullptr, + nullptr, + lda, + batch_stride_a, + b_type, + nullptr, + nullptr, + ldb, + batch_stride_b, + beta, + c_type, + nullptr, + nullptr, + ldc, + batch_stride_c, + d_type, + nullptr, + nullptr, + ldd, + batch_stride_d, + e, + nullptr, + lde, + batch_stride_e, + num_batches_a, + strided_batch, + grouped_gemm, + gradient, + compute_type, + bias, + scaleA, + scaleB, + scaleC, + scaleD, + scaleE, + scaleAlphaVec, + bias_type, + epilogue, + nullptr, + maxWorkSpaceBytes, + nullptr}; return problem; } @@ -1111,426 +1139,12 @@ rocblaslt_status rocblaslt_matmul_is_algo_supported(rocblaslt_handle hand hipblasltDatatype_t d_type = matD->type; rocblaslt_compute_type compute_type = matmul_descr->compute_type; auto& gemmData = matmul_descr->m_data; - if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32 - || compute_type == rocblaslt_compute_f32_fast_xf32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_16F) - { - if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob - = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_16B && b_type == HIPBLASLT_R_16B) - { - if(c_type == HIPBLASLT_R_16B && d_type == HIPBLASLT_R_16B) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_8F_E4M3) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob - = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_8F_E5M2) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob - = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E5M2 && b_type == HIPBLASLT_R_8F_E4M3) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob - = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_8I && b_type == HIPBLASLT_R_8I) //INT8 - { - if(c_type == HIPBLASLT_R_8I && d_type == HIPBLASLT_R_8I) - { - if(compute_type == rocblaslt_compute_i32) - { - int32_t* alphaf = (int32_t*)alpha; - int32_t* betaf = (int32_t*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_32I && d_type == HIPBLASLT_R_32I) - { - if(compute_type == rocblaslt_compute_i32) - { - int32_t* alphaf = (int32_t*)alpha; - int32_t* betaf = (int32_t*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_8F_E4M3) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob - = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_16F) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob = construct_rocblaslt_problem(matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status - = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float* alphaf = (float*)alpha; - float* betaf = (float*)beta; - auto prob - = construct_rocblaslt_problem( - matmul_descr, - matA, - matB, - matC, - matD, - alphaf, - betaf, - algo->max_workspace_bytes); - status = isSolutionSupported( - handle, prob, gemmData, algo, workspaceSizeInBytes); - } - } - } - else - { - log_error(__func__, "No such template."); - status = rocblaslt_status_not_implemented; - } + + void* alphaf = (void*)alpha; + void* betaf = (void*)beta; + auto prob = construct_rocblaslt_problem( + matmul_descr, matA, matB, matC, matD, alphaf, betaf, algo->max_workspace_bytes); + status = isSolutionSupported(handle, prob, gemmData, algo, workspaceSizeInBytes); if(status != rocblaslt_status_success) { @@ -1581,581 +1195,20 @@ rocblaslt_status hipblasltDatatype_t d_type = matD->type; rocblaslt_compute_type compute_type = matmul_desc->compute_type; auto& tensile_data = matmul_desc->m_data; - if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32 - || compute_type == rocblaslt_compute_f32_fast_xf32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status - = getBestSolutions(prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_16F) - { - if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status - = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_16B && b_type == HIPBLASLT_R_16B) - { - if(c_type == HIPBLASLT_R_16B && d_type == HIPBLASLT_R_16B) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions(prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_8F_E4M3) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_8F_E5M2) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_64F && b_type == HIPBLASLT_R_64F) - { - if(c_type == HIPBLASLT_R_64F && d_type == HIPBLASLT_R_64F) - { - if(compute_type == rocblaslt_compute_f64) - { - double alpha = 1.0; - double beta = 1.0; - auto prob = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E5M2 && b_type == HIPBLASLT_R_8F_E4M3) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_8I && b_type == HIPBLASLT_R_8I) //int8 support - { - if(c_type == HIPBLASLT_R_32I && d_type == HIPBLASLT_R_32I) - { - if(compute_type == rocblaslt_compute_i32) - { - int32_t alpha = 1; - int32_t beta = 1; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_8I && d_type == HIPBLASLT_R_8I) - { - if(compute_type == rocblaslt_compute_i32) - { - int32_t alpha = 1; - int32_t beta = 1; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_8F_E4M3) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_16F) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem(matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - matmul_desc, - matA, - matB, - matC, - matD, - &alpha, - &beta, - pref->max_workspace_bytes); - status = getBestSolutions( - prob, - handle, - tensile_data, - requestedAlgoCount, - heuristicResultsArray, - returnAlgoCount, - pref->max_workspace_bytes); - } - } - } - else - { - log_error(__func__, "No such template."); - status = rocblaslt_status_not_implemented; - } + + int8_t alpha[16] = {0}; + int8_t beta[16] = {0}; + assignAlphaBeta1(compute_type, (void*)alpha, (void*)beta); + + auto prob = construct_rocblaslt_problem( + matmul_desc, matA, matB, matC, matD, &alpha, &beta, pref->max_workspace_bytes); + status = getBestSolutions(prob, + handle, + tensile_data, + requestedAlgoCount, + heuristicResultsArray, + returnAlgoCount, + pref->max_workspace_bytes); log_api(__func__, "returnAlogCount", *returnAlgoCount); if(status != rocblaslt_status_success) @@ -2240,453 +1293,24 @@ rocblaslt_status rocblaslt_matmul_get_all_algos_cpp( size_t maxWorkspaceSize = std::numeric_limits::max(); try { - if(typeA == HIPBLASLT_R_32F && typeB == HIPBLASLT_R_32F) - { - if(typeC == HIPBLASLT_R_32F && typeD == HIPBLASLT_R_32F) - { - if(typeCompute == rocblaslt_compute_f32 - || typeCompute == rocblaslt_compute_f32_fast_xf32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> probs - = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - } - else if(typeA == HIPBLASLT_R_16F && typeB == HIPBLASLT_R_16F) - { - if(typeC == HIPBLASLT_R_16F && typeD == HIPBLASLT_R_16F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - else if(typeC == HIPBLASLT_R_32F && typeD == HIPBLASLT_R_32F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - &matmul_desc, - &matA, - &matB, - &matC, - &matD, - &alpha, - &beta, - maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - } - else if(typeA == HIPBLASLT_R_16B && typeB == HIPBLASLT_R_16B) - { - if(typeC == HIPBLASLT_R_16B && typeD == HIPBLASLT_R_16B) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - } - else if(typeA == HIPBLASLT_R_8F_E4M3 && typeB == HIPBLASLT_R_8F_E4M3) - { - if(typeC == HIPBLASLT_R_32F && typeD == HIPBLASLT_R_32F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - &matmul_desc, - &matA, - &matB, - &matC, - &matD, - &alpha, - &beta, - maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector< - RocblasltContractionProblem> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - else if(typeC == HIPBLASLT_R_16F && typeD == HIPBLASLT_R_16F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - } - else if(typeA == HIPBLASLT_R_8F_E4M3 && typeB == HIPBLASLT_R_8F_E5M2) - { - if(typeC == HIPBLASLT_R_32F && typeD == HIPBLASLT_R_32F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - &matmul_desc, - &matA, - &matB, - &matC, - &matD, - &alpha, - &beta, - maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector< - RocblasltContractionProblem> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - else if(typeC == HIPBLASLT_R_16F && typeD == HIPBLASLT_R_16F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status - = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status - = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - } - else if(typeA == HIPBLASLT_R_8F_E5M2 && typeB == HIPBLASLT_R_8F_E4M3) + int8_t alpha[16] = {0}; + int8_t beta[16] = {0}; + assignAlphaBeta1(matmul_desc.compute_type, (void*)alpha, (void*)beta); + + auto prob = construct_rocblaslt_problem( + &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); + if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) { - if(typeC == HIPBLASLT_R_32F && typeD == HIPBLASLT_R_32F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob - = construct_rocblaslt_problem( - &matmul_desc, - &matA, - &matB, - &matC, - &matD, - &alpha, - &beta, - maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector< - RocblasltContractionProblem> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - else if(typeC == HIPBLASLT_R_16F && typeD == HIPBLASLT_R_16F) - { - if(typeCompute == rocblaslt_compute_f32) - { - float alpha = 1.0; - float beta = 1.0; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status - = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status - = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } + status = getAllSolutions(prob, handle, heuristicResults, maxWorkspaceSize); } - else if(typeA == HIPBLASLT_R_8I && typeB == HIPBLASLT_R_8I) //int8 + else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) { - if(typeC == HIPBLASLT_R_32I && typeD == HIPBLASLT_R_32I) - { - if(typeCompute == rocblaslt_compute_i32) - { - int32_t alpha = 1; - int32_t beta = 1; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } - else if(typeC == HIPBLASLT_R_8I && typeD == HIPBLASLT_R_8I) - { - if(typeCompute == rocblaslt_compute_i32) - { - int32_t alpha = 1; - int32_t beta = 1; - auto prob = construct_rocblaslt_problem( - &matmul_desc, &matA, &matB, &matC, &matD, &alpha, &beta, maxWorkspaceSize); - if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GEMM) - { - status - = getAllSolutions( - prob, handle, heuristicResults, maxWorkspaceSize); - } - else if(typeGemm == rocblaslt::RocGemmType::ROCBLASLT_GROUPED_GEMM) - { - std::vector> - probs = {prob}; - status - = getAllSolutions( - probs, handle, heuristicResults, maxWorkspaceSize); - } - else - { - log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); - status = rocblaslt_status_not_implemented; - } - } - } + std::vector probs = {prob}; + status = getAllSolutions(probs, handle, heuristicResults, maxWorkspaceSize); } else { - log_error(__func__, "No such template."); + log_api(__func__, "Invalid gemm type", static_cast(typeGemm)); status = rocblaslt_status_not_implemented; } diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp index 1710c823fb..d2d648d3e2 100644 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp +++ b/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp @@ -24,10 +24,10 @@ * * ************************************************************************ */ -#include "rocblaslt_mat.hpp" #include "definitions.h" #include "handle.h" #include "rocblaslt_mat_utils.hpp" +#include "tensile_host.hpp" #include @@ -59,6 +59,7 @@ rocblaslt_status rocblaslt_matmul_impl(const rocblaslt_handle handle, int64_t m, n, k, lda, ldb, ldc, ldd, lde, batch_stride_a, batch_stride_b, batch_stride_c, batch_stride_d, batch_stride_e; hipblasltDatatype_t bias_type; + hipblasltDatatype_t type_a, type_b, type_c, type_d; rocblaslt_compute_type compute_type; void * bias = nullptr, *scaleAlphaVec = nullptr, *E = nullptr; bool gradient = false; @@ -76,12 +77,16 @@ rocblaslt_status rocblaslt_matmul_impl(const rocblaslt_handle handle, m, n, k, + type_a, lda, batch_stride_a, + type_b, ldb, batch_stride_b, + type_c, ldc, batch_stride_c, + type_d, ldd, batch_stride_d, lde, @@ -96,19 +101,15 @@ rocblaslt_status rocblaslt_matmul_impl(const rocblaslt_handle handle, return isValid; // Internal assign - hipblasOperation_t opA = matmul_descr->op_A; - hipblasOperation_t opB = matmul_descr->op_B; - hipblasltDatatype_t type_a = matA->type; - hipblasltDatatype_t type_b = matB->type; - hipblasltDatatype_t type_c = matC->type; - hipblasltDatatype_t type_d = matD->type; - int num_batches_a = matA->batch_count; - rocblaslt_epilogue epilogue = matmul_descr->epilogue; - void* scaleA = matmul_descr->scaleA; - void* scaleB = matmul_descr->scaleB; - void* scaleC = matmul_descr->scaleC; - void* scaleD = matmul_descr->scaleD; - void* scaleE = matmul_descr->scaleE; + hipblasOperation_t opA = matmul_descr->op_A; + hipblasOperation_t opB = matmul_descr->op_B; + int num_batches_a = matA->batch_count; + rocblaslt_epilogue epilogue = matmul_descr->epilogue; + void* scaleA = matmul_descr->scaleA; + void* scaleB = matmul_descr->scaleB; + void* scaleC = matmul_descr->scaleC; + void* scaleD = matmul_descr->scaleD; + void* scaleE = matmul_descr->scaleE; // Others bool strided_batch = true; @@ -136,14 +137,65 @@ rocblaslt_status rocblaslt_matmul_impl(const rocblaslt_handle handle, } } -#define EX_PARM \ - handle, opA, opB, m, n, k, alpha, A, type_a, lda, batch_stride_a, B, type_b, ldb, \ - batch_stride_b, beta, C, type_c, ldc, batch_stride_c, D, type_d, ldd, batch_stride_d, E, \ - lde, batch_stride_e, num_batches_a, strided_batch, grouped_gemm, gradient, compute_type, \ - algo, workspace, workspaceSizeInBytes, bias, scaleA, scaleB, scaleC, scaleD, scaleE, \ - scaleAlphaVec, bias_type, epilogue, gemmData, stream - - return rocblaslt_matmul_template(EX_PARM); + // FIXME: Is this still needed? + // // check alignment of pointers before casting + // if(!isAligned(a, sizeof(TiA)) || !isAligned(b, sizeof(TiB)) || !isAligned(c, sizeof(To)) + // || !isAligned(d, sizeof(To))) + // { + // std::cerr << "memmory is not aligned" << std::endl; + // return rocblaslt_status_invalid_size; + // } + + workspaceSizeInBytes = min(workspaceSizeInBytes, algo->max_workspace_bytes); + RocblasltContractionProblem problem{opA, + opB, + m, + n, + k, + alpha, + type_a, + A, + nullptr, + lda, + batch_stride_a, + type_b, + B, + nullptr, + ldb, + batch_stride_b, + beta, + type_c, + C, + nullptr, + ldc, + batch_stride_c, + type_d, + D, + nullptr, + ldd, + batch_stride_d, + E, + nullptr, + lde, + batch_stride_e, + num_batches_a, + strided_batch, + grouped_gemm, + gradient, + compute_type, + bias, + scaleA, + scaleB, + scaleC, + scaleD, + scaleE, + scaleAlphaVec, + bias_type, + epilogue, + workspace, + workspaceSizeInBytes, + stream}; + return runContractionProblem(handle, algo, problem, gemmData); } rocblaslt_status rocblaslt_gemm_create_cpp_impl(rocblaslt_matmul_desc matmul_descr, @@ -164,6 +216,7 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(rocblaslt_matmul_desc m int64_t m, n, k, lda, ldb, ldc, ldd, lde, batch_stride_a, batch_stride_b, batch_stride_c, batch_stride_d, batch_stride_e; hipblasltDatatype_t bias_type; + hipblasltDatatype_t type_a, type_b, type_c, type_d; rocblaslt_compute_type compute_type; void * bias = nullptr, *scaleAlphaVec = nullptr, *E = nullptr; bool gradient = false; @@ -181,12 +234,16 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(rocblaslt_matmul_desc m m, n, k, + type_a, lda, batch_stride_a, + type_b, ldb, batch_stride_b, + type_c, ldc, batch_stride_c, + type_d, ldd, batch_stride_d, lde, @@ -201,19 +258,15 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(rocblaslt_matmul_desc m return isValid; // Internal assign - hipblasOperation_t opA = matmul_descr->op_A; - hipblasOperation_t opB = matmul_descr->op_B; - hipblasltDatatype_t type_a = matA->type; - hipblasltDatatype_t type_b = matB->type; - hipblasltDatatype_t type_c = matC->type; - hipblasltDatatype_t type_d = matD->type; - int num_batches_a = matA->batch_count; - rocblaslt_epilogue epilogue = matmul_descr->epilogue; - void* scaleA = matmul_descr->scaleA; - void* scaleB = matmul_descr->scaleB; - void* scaleC = matmul_descr->scaleC; - void* scaleD = matmul_descr->scaleD; - void* scaleE = matmul_descr->scaleE; + hipblasOperation_t opA = matmul_descr->op_A; + hipblasOperation_t opB = matmul_descr->op_B; + int num_batches_a = matA->batch_count; + rocblaslt_epilogue epilogue = matmul_descr->epilogue; + void* scaleA = matmul_descr->scaleA; + void* scaleB = matmul_descr->scaleB; + void* scaleC = matmul_descr->scaleC; + void* scaleD = matmul_descr->scaleD; + void* scaleE = matmul_descr->scaleE; // Others bool strided_batch = true; @@ -222,31 +275,66 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(rocblaslt_matmul_desc m int8_t alpha_1[16] = {0}; // use dScaleAlphaVec instead, original alpha => 1.0 if(scaleAlphaVec) { - if(matmul_descr->compute_type == rocblaslt_compute_f64) - { - *((double*)alpha_1) = 1.f; - alpha = alpha_1; - } - else if(matmul_descr->compute_type == rocblaslt_compute_i32) - { - *((int32_t*)alpha_1) = 1.f; - alpha = alpha_1; - } - else - { - *((float*)alpha_1) = 1.f; - alpha = alpha_1; - } + setTo1(matmul_descr->compute_type, (void*)alpha_1, &alpha); } -#define EX_PARM_GEMM_CPP \ - opA, opB, m, n, k, alpha, A, type_a, lda, batch_stride_a, B, type_b, ldb, batch_stride_b, \ - beta, C, type_c, ldc, batch_stride_c, D, type_d, ldd, batch_stride_d, E, lde, \ - batch_stride_e, num_batches_a, strided_batch, grouped_gemm, gradient, compute_type, bias, \ - scaleA, scaleB, scaleC, scaleD ,scaleE, scaleAlphaVec, bias_type, epilogue, gemmData, \ - gemmCount - - return rocblaslt_gemm_create_template_cpp(EX_PARM_GEMM_CPP); + // // check alignment of pointers before casting + // if(!isAligned(a, sizeof(TiA)) || !isAligned(b, sizeof(TiB)) || !isAligned(c, sizeof(To)) + // || !isAligned(d, sizeof(To))) + // { + // std::cerr << "memmory is not aligned" << std::endl; + // return rocblaslt_status_invalid_size; + // } + + RocblasltContractionProblem problem{opA, + opB, + m, + n, + k, + alpha, + type_a, + A, + nullptr, + lda, + batch_stride_a, + type_b, + B, + nullptr, + ldb, + batch_stride_b, + beta, + type_c, + C, + nullptr, + ldc, + batch_stride_c, + type_d, + D, + nullptr, + ldd, + batch_stride_d, + E, + nullptr, + lde, + batch_stride_e, + num_batches_a, + strided_batch, + grouped_gemm, + gradient, + compute_type, + bias, + scaleA, + scaleB, + scaleC, + scaleD, + scaleE, + scaleAlphaVec, + bias_type, + epilogue, + nullptr, + 0, + 0}; + return gemmCreate(problem, gemmData, gemmCount); } rocblaslt_status rocblaslt_gemm_create_cpp_impl_2(int64_t m, @@ -338,22 +426,69 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl_2(int64_t bool strided_batch = true; bool grouped_gemm = false; - problemtype.op_a = opA; - problemtype.op_b = opB; - problemtype.type_a = type_a; - problemtype.type_b = type_b; - problemtype.type_c = type_c; - problemtype.type_d = type_d; - problemtype.type_compute = compute_type; - -#define EX_PARM_GEMM_CPP_2 \ - opA, opB, m, n, k, alpha, A, type_a, lda, batch_stride_a, B, type_b, ldb, batch_stride_b, \ - beta, C, type_c, ldc, batch_stride_c, D, type_d, ldd, batch_stride_d, E, lde, \ - batch_stride_e, num_batches_a, strided_batch, grouped_gemm, gradient, compute_type, bias, \ - scaleA, scaleB, scaleC, scaleD, scaleE, scaleAlphaVec, bias_type, epilogue, gemmData, \ - gemmCount - - return rocblaslt_gemm_create_template_cpp(EX_PARM_GEMM_CPP_2); + int8_t alpha_1[16] = {0}; // use dScaleAlphaVec instead, original alpha => 1.0 + if(scaleAlphaVec) + { + setTo1(compute_type, (void*)alpha_1, (const void**)&alpha); + } + + // // check alignment of pointers before casting + // if(!isAligned(a, sizeof(TiA)) || !isAligned(b, sizeof(TiB)) || !isAligned(c, sizeof(To)) + // || !isAligned(d, sizeof(To))) + // { + // std::cerr << "memmory is not aligned" << std::endl; + // return rocblaslt_status_invalid_size; + // } + + RocblasltContractionProblem problem{opA, + opB, + m, + n, + k, + alpha, + type_a, + A, + nullptr, + lda, + batch_stride_a, + type_b, + B, + nullptr, + ldb, + batch_stride_b, + beta, + type_c, + C, + nullptr, + ldc, + batch_stride_c, + type_d, + D, + nullptr, + ldd, + batch_stride_d, + E, + nullptr, + lde, + batch_stride_e, + num_batches_a, + strided_batch, + grouped_gemm, + gradient, + compute_type, + bias, + scaleA, + scaleB, + scaleC, + scaleD, + scaleE, + scaleAlphaVec, + bias_type, + epilogue, + nullptr, + 0, + 0}; + return gemmCreate(problem, gemmData, gemmCount); } rocblaslt_status @@ -397,6 +532,7 @@ rocblaslt_status std::vector ldc_vec, batch_stride_c_vec, num_batches_c_vec; std::vector ldd_vec, batch_stride_d_vec, num_batches_d_vec; std::vector lde_vec, batch_stride_e_vec, num_batches_e_vec; + std::vector alpha_1(matmul_descr.size()); std::vector gradient_vec; @@ -485,6 +621,17 @@ rocblaslt_status if(validArgs != rocblaslt_status_continue) return validArgs; + const void* alphaTmp = nullptr; + memset(alpha_1[i], 0, sizeof(int8_t) * 16); + if(scaleAlphaVec) + { + setTo1(compute_type, (void*)alpha_1[i], &alphaTmp); + } + else + { + alphaTmp = alpha[i]; + } + tempprobemtype.push_back({matmul_descr[i]->op_A, matmul_descr[i]->op_B, matA[i]->type, @@ -540,7 +687,7 @@ rocblaslt_status C_vec.push_back(C[i]); D_vec.push_back(D[i]); E_vec.push_back(E); - alpha_vec.push_back(alpha[i]); + alpha_vec.push_back(alphaTmp); beta_vec.push_back(beta[i]); gradient_vec.push_back(gradient); @@ -551,15 +698,59 @@ rocblaslt_status bool strided_batch = true; bool grouped_gemm = true; -#define EX_PARM_GroupedGemm_CPP \ - opA, opB, m_vec, n_vec, k_vec, alpha_vec, A_vec, type_a, lda_vec, batch_stride_a_vec, B_vec, \ - type_b, ldb_vec, batch_stride_b_vec, beta_vec, C_vec, type_c, ldc_vec, batch_stride_c_vec, \ - D_vec, type_d, ldd_vec, batch_stride_d_vec, E_vec, lde_vec, batch_stride_e_vec, \ - num_batches_a_vec, strided_batch, grouped_gemm, gradient_vec, compute_type, bias_vec, \ - scaleA_vec, scaleB_vec, scaleC_vec, scaleD_vec, scaleE_vec, scaleAlpha_vec, bias_type_vec, \ - epilogue_vec, gemmData, gemmCount - - return rocblaslt_groupedgemm_create_template_cpp(EX_PARM_GroupedGemm_CPP); + std::vector problems; + for(int i = 0; i < m_vec.size(); i++) + { + problems.push_back(RocblasltContractionProblem{opA, + opB, + m_vec[i], + n_vec[i], + k_vec[i], + alpha_vec[i], + type_a, + A_vec[i], + nullptr, + lda_vec[i], + batch_stride_a_vec[i], + type_b, + B_vec[i], + nullptr, + ldb_vec[i], + batch_stride_b_vec[i], + beta_vec[i], + type_c, + C_vec[i], + nullptr, + ldc_vec[i], + batch_stride_c_vec[i], + type_d, + D_vec[i], + nullptr, + ldd_vec[i], + batch_stride_d_vec[i], + E_vec[i], + nullptr, + lde_vec[i], + batch_stride_e_vec[i], + num_batches_a_vec[i], + strided_batch, + grouped_gemm, + gradient_vec[i], + compute_type, + bias_vec[i], + scaleA_vec[i], + scaleB_vec[i], + scaleC_vec[i], + scaleD_vec[i], + scaleE_vec[i], + scaleAlpha_vec[i], + bias_type_vec[i], + epilogue_vec[i], + nullptr, + 0, + 0}); + } + return groupedGemmCreate(problems, gemmData, gemmCount); } rocblaslt_status @@ -604,6 +795,8 @@ rocblaslt_status std::vector lde_vec, batch_stride_e_vec, num_batches_e_vec; std::vector gradient_vec; + std::vector alpha_1(m.size()); + for(int i = 0; i < m.size(); i++) { auto validArgs = validateMatmulArgs(m[i], @@ -658,6 +851,17 @@ rocblaslt_status if(validArgs != rocblaslt_status_continue) return validArgs; + const void* alphaTmp = nullptr; + memset(alpha_1[i], 0, sizeof(int8_t) * 16); + if(scaleAlphaVec) + { + setTo1(compute_type, (void*)alpha_1[i], &alphaTmp); + } + else + { + alphaTmp = inputs[i].alpha; + } + bias_type_vec.push_back(bias_type); epilogue_vec.push_back(epilogue); bias_vec.push_back(bias); @@ -673,7 +877,7 @@ rocblaslt_status C_vec.push_back(inputs[i].c); D_vec.push_back(inputs[i].d); E_vec.push_back(E); - alpha_vec.push_back(inputs[i].alpha); + alpha_vec.push_back(alphaTmp); beta_vec.push_back(inputs[i].beta); lde_vec.push_back(lde); @@ -684,14 +888,59 @@ rocblaslt_status bool strided_batch = true; bool grouped_gemm = true; -#define EX_PARM_GroupedGemm_CPP_2 \ - opA, opB, m, n, k, alpha_vec, A_vec, type_a, lda, strideA, B_vec, type_b, ldb, strideB, \ - beta_vec, C_vec, type_c, ldc, strideC, D_vec, type_d, ldd, strideD, E_vec, lde_vec, \ - batch_stride_e_vec, b, strided_batch, grouped_gemm, gradient_vec, compute_type, bias_vec, \ - scaleA_vec, scaleB_vec, scaleC_vec, scaleD_vec, scaleE_vec, scaleAlpha_vec, bias_type_vec,\ - epilogue_vec, gemmData, gemmCount - - return rocblaslt_groupedgemm_create_template_cpp(EX_PARM_GroupedGemm_CPP_2); + std::vector problems; + for(int i = 0; i < m.size(); i++) + { + problems.push_back(RocblasltContractionProblem{opA, + opB, + m[i], + n[i], + k[i], + alpha_vec[i], + type_a, + A_vec[i], + nullptr, + lda[i], + strideA[i], + type_b, + B_vec[i], + nullptr, + ldb[i], + strideB[i], + beta_vec[i], + type_c, + C_vec[i], + nullptr, + ldc[i], + strideC[i], + type_d, + D_vec[i], + nullptr, + ldd[i], + strideD[i], + E_vec[i], + nullptr, + lde_vec[i], + batch_stride_e_vec[i], + b[i], + strided_batch, + grouped_gemm, + gradient_vec[i], + compute_type, + bias_vec[i], + scaleA_vec[i], + scaleB_vec[i], + scaleC_vec[i], + scaleD_vec[i], + scaleE_vec[i], + scaleAlpha_vec[i], + bias_type_vec[i], + epilogue_vec[i], + nullptr, + 0, + 0}); + } + return groupedGemmCreate(problems, gemmData, gemmCount); } /******************************************************************************** diff --git a/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp b/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp deleted file mode 100644 index 8dbb8e28b4..0000000000 --- a/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.hpp +++ /dev/null @@ -1,1331 +0,0 @@ -/* ************************************************************************ - * - * MIT License - * - * Copyright (C) 2022-2023 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - * ************************************************************************ */ - -#pragma once -#ifndef ROCBLASLT_MAT_HPP -#define ROCBLASLT_MAT_HPP - -#include "handle.h" - -#include "tensile_host.hpp" - -template -rocblaslt_status rocblaslt_batched_template(rocblaslt_handle handle, - hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int64_t m, - int64_t n, - int64_t k, - const Tc* alpha, - const TiA* a, - int64_t ld_a, - int64_t batch_stride_a, - const TiB* b, - int64_t ld_b, - int64_t batch_stride_b, - const Tc* beta, - const To* c, - int64_t ld_c, - int64_t batch_stride_c, - To* d, - int64_t ld_d, - int64_t batch_stride_d, - Tc* e, - int64_t ld_e, - int64_t batch_stride_e, - int64_t batch_count, - bool strided_batch, - bool grouped_gemm, - bool gradient, - rocblaslt_compute_type compute_type, - const rocblaslt_matmul_algo* algo, - void* workspace, - size_t workspaceSizeInBytes, - const void* bias, - const Tc* scaleA, - const Tc* scaleB, - const Tc* scaleC, - const Tc* scaleD, - const Tc* scaleE, - const Tc* scaleAlphaVec, - hipblasltDatatype_t bias_type, - rocblaslt_epilogue epilogue, - std::shared_ptr gemmData, - hipStream_t stream) -{ - workspaceSizeInBytes = min(workspaceSizeInBytes, algo->max_workspace_bytes); - RocblasltContractionProblem problem{trans_a, - trans_b, - m, - n, - k, - alpha, - a, - nullptr, - ld_a, - batch_stride_a, - b, - nullptr, - ld_b, - batch_stride_b, - beta, - c, - nullptr, - ld_c, - batch_stride_c, - d, - nullptr, - ld_d, - batch_stride_d, - e, - nullptr, - ld_e, - batch_stride_e, - batch_count, - strided_batch, - grouped_gemm, - gradient, - compute_type, - bias, - scaleA, - scaleB, - scaleC, - scaleD, - scaleE, - scaleAlphaVec, - bias_type, - epilogue, - workspace, - workspaceSizeInBytes, - stream}; - return runContractionProblem(handle, algo, problem, gemmData); -} - -template -rocblaslt_status rocblaslt_gemm_create_batched_template(hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int64_t m, - int64_t n, - int64_t k, - const Tc* alpha, - const TiA* a, - int64_t ld_a, - int64_t batch_stride_a, - const TiB* b, - int64_t ld_b, - int64_t batch_stride_b, - const Tc* beta, - const To* c, - int64_t ld_c, - int64_t batch_stride_c, - To* d, - int64_t ld_d, - int64_t batch_stride_d, - Tc* e, - int64_t ld_e, - int64_t batch_stride_e, - int64_t batch_count, - bool strided_batch, - bool grouped_gemm, - bool gradient, - rocblaslt_compute_type compute_type, - const void* bias, - const Tc* scaleA, - const Tc* scaleB, - const Tc* scaleC, - const Tc* scaleD, - const Tc* scaleE, - const Tc* scaleAlphaVec, - hipblasltDatatype_t bias_type, - rocblaslt_epilogue epilogue, - std::shared_ptr& gemmData, - size_t& gemmCount) -{ - Tc alpha_1 = 1.0; // use dScaleAlphaVec instead, original alpha => 1.0 - if(scaleAlphaVec) - alpha = &alpha_1; - RocblasltContractionProblem problem{trans_a, - trans_b, - m, - n, - k, - alpha, - a, - nullptr, - ld_a, - batch_stride_a, - b, - nullptr, - ld_b, - batch_stride_b, - beta, - c, - nullptr, - ld_c, - batch_stride_c, - d, - nullptr, - ld_d, - batch_stride_d, - e, - nullptr, - ld_e, - batch_stride_e, - batch_count, - strided_batch, - grouped_gemm, - gradient, - compute_type, - bias, - scaleA, - scaleB, - scaleC, - scaleD, - scaleE, - scaleAlphaVec, - bias_type, - epilogue, - nullptr, - 0, - 0}; - return gemmCreate(problem, gemmData, gemmCount); -} - -template -rocblaslt_status - rocblaslt_groupedgemm_create_batched_template(hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - std::vector& m, - std::vector& n, - std::vector& k, - std::vector& alpha, - std::vector& a, - std::vector& ld_a, - std::vector& batch_stride_a, - std::vector& b, - std::vector& ld_b, - std::vector& batch_stride_b, - std::vector& beta, - std::vector& c, - std::vector& ld_c, - std::vector& batch_stride_c, - std::vector& d, - std::vector& ld_d, - std::vector& batch_stride_d, - std::vector& e, - std::vector& ld_e, - std::vector& batch_stride_e, - std::vector& batch_count, - bool strided_batch, - bool grouped_gemm, - rocblaslt_compute_type compute_type, - std::vector& gradient, - std::vector& bias, - std::vector& scaleAVec, - std::vector& scaleBVec, - std::vector& scaleCVec, - std::vector& scaleDVec, - std::vector& scaleEVec, - std::vector& scaleAlphaVec, - std::vector& bias_type, - std::vector& epilogue, - std::shared_ptr& gemmData, - size_t& gemmCount) -{ - std::vector> problems; - - for(int i = 0; i < m.size(); i++) - { - problems.push_back(RocblasltContractionProblem{trans_a, - trans_b, - m[i], - n[i], - k[i], - alpha[i], - a[i], - nullptr, - ld_a[i], - batch_stride_a[i], - b[i], - nullptr, - ld_b[i], - batch_stride_b[i], - beta[i], - c[i], - nullptr, - ld_c[i], - batch_stride_c[i], - d[i], - nullptr, - ld_d[i], - batch_stride_d[i], - (Tc*)e[i], - nullptr, - ld_e[i], - batch_stride_e[i], - batch_count[i], - strided_batch, - grouped_gemm, - gradient[i], - compute_type, - bias[i], - scaleAVec[i], - scaleBVec[i], - scaleCVec[i], - scaleDVec[i], - scaleEVec[i], - scaleAlphaVec[i], - bias_type[i], - epilogue[i], - nullptr, - 0, - 0}); - } - return groupedGemmCreate(problems, gemmData, gemmCount); -} - -template -rocblaslt_status rocblaslt_matmul_typecasting(rocblaslt_handle handle, - hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int64_t m, - int64_t n, - int64_t k, - const void* alpha, - const void* a, - int64_t ld_a, - int64_t batch_stride_a, - const void* b, - int64_t ld_b, - int64_t batch_stride_b, - const void* beta, - const void* c, - int64_t ld_c, - int64_t batch_stride_c, - void* d, - int64_t ld_d, - int64_t batch_stride_d, - void* e, - int64_t ld_e, - int64_t batch_stride_e, - int64_t batch_count, - bool strided_batch, - bool grouped_gemm, - bool gradient, - rocblaslt_compute_type compute_type, - const rocblaslt_matmul_algo* algo, - void* workspace, - size_t workspaceSizeInBytes, - const void* bias, - const void* scaleA, - const void* scaleB, - const void* scaleC, - const void* scaleD, - const void* scaleE, - const void* scaleAlphaVec, - hipblasltDatatype_t bias_type, - rocblaslt_epilogue epilogue, - std::shared_ptr gemmData, - hipStream_t stream) -{ - // check alignment of pointers before casting - if(!isAligned(a, sizeof(TiA)) || !isAligned(b, sizeof(TiB)) || !isAligned(c, sizeof(To)) - || !isAligned(d, sizeof(To))) - { - std::cerr << "memmory is not aligned" << std::endl; - return rocblaslt_status_invalid_size; - } - return rocblaslt_batched_template(handle, - trans_a, - trans_b, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(a), - ld_a, - batch_stride_a, - reinterpret_cast(b), - ld_b, - batch_stride_b, - reinterpret_cast(beta), - reinterpret_cast(c), - ld_c, - batch_stride_c, - (To*)d, - ld_d, - batch_stride_d, - (Tc*)e, - ld_e, - batch_stride_e, - batch_count, - strided_batch, - grouped_gemm, - gradient, - compute_type, - algo, - workspace, - workspaceSizeInBytes, - reinterpret_cast(bias), - reinterpret_cast(scaleA), - reinterpret_cast(scaleB), - reinterpret_cast(scaleC), - reinterpret_cast(scaleD), - reinterpret_cast(scaleE), - reinterpret_cast(scaleAlphaVec), - bias_type, - epilogue, - gemmData, - stream); -} - -template -rocblaslt_status rocblaslt_gemm_create_typecasting(hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int64_t m, - int64_t n, - int64_t k, - const void* alpha, - const void* a, - int64_t ld_a, - int64_t batch_stride_a, - const void* b, - int64_t ld_b, - int64_t batch_stride_b, - const void* beta, - const void* c, - int64_t ld_c, - int64_t batch_stride_c, - void* d, - int64_t ld_d, - int64_t batch_stride_d, - void* e, - int64_t ld_e, - int64_t batch_stride_e, - int64_t batch_count, - bool strided_batch, - bool grouped_gemm, - bool gradient, - rocblaslt_compute_type compute_type, - const void* bias, - const void* scaleA, - const void* scaleB, - const void* scaleC, - const void* scaleD, - const void* scaleE, - const void* scaleAlphaVec, - hipblasltDatatype_t bias_type, - rocblaslt_epilogue epilogue, - std::shared_ptr& gemmData, - size_t& gemmCount) -{ - // check alignment of pointers before casting - if(!isAligned(a, sizeof(TiA)) || !isAligned(b, sizeof(TiB)) || !isAligned(c, sizeof(To)) - || !isAligned(d, sizeof(To))) - { - std::cerr << "memmory is not aligned" << std::endl; - return rocblaslt_status_invalid_size; - } - return rocblaslt_gemm_create_batched_template(trans_a, - trans_b, - m, - n, - k, - reinterpret_cast(alpha), - reinterpret_cast(a), - ld_a, - batch_stride_a, - reinterpret_cast(b), - ld_b, - batch_stride_b, - reinterpret_cast(beta), - reinterpret_cast(c), - ld_c, - batch_stride_c, - (To*)d, - ld_d, - batch_stride_d, - (Tc*)e, - ld_e, - batch_stride_e, - batch_count, - strided_batch, - grouped_gemm, - gradient, - compute_type, - reinterpret_cast(bias), - reinterpret_cast(scaleA), - reinterpret_cast(scaleB), - reinterpret_cast(scaleC), - reinterpret_cast(scaleD), - reinterpret_cast(scaleE), - reinterpret_cast(scaleAlphaVec), - bias_type, - epilogue, - gemmData, - gemmCount); -} - -template -rocblaslt_status - rocblaslt_groupedgemm_create_typecasting(hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - std::vector& m, - std::vector& n, - std::vector& k, - std::vector& alpha, - std::vector& a, - std::vector& ld_a, - std::vector& batch_stride_a, - std::vector& b, - std::vector& ld_b, - std::vector& batch_stride_b, - std::vector& beta, - std::vector& c, - std::vector& ld_c, - std::vector& batch_stride_c, - std::vector& d, - std::vector& ld_d, - std::vector& batch_stride_d, - std::vector& e, - std::vector& ld_e, - std::vector& batch_stride_e, - std::vector& batch_count, - bool strided_batch, - bool grouped_gemm, - rocblaslt_compute_type compute_type, - std::vector& gradient, - std::vector& bias, - std::vector& scaleA, - std::vector& scaleB, - std::vector& scaleC, - std::vector& scaleD, - std::vector& scaleE, - std::vector& scaleAlphaVec, - std::vector& bias_type, - std::vector& epilogue, - std::shared_ptr& gemmData, - size_t& gemmCount) -{ - std::vector groupedAlpha, groupedBeta; - std::vector groupedA; - std::vector groupedB; - std::vector groupedC; - std::vector groupedD; - std::vector groupedScaleA; - std::vector groupedScaleB; - std::vector groupedScaleC; - std::vector groupedScaleD; - std::vector groupedScaleE; - std::vector groupedScaleAlphaVec; - - for(int i = 0; i < alpha.size(); i++) - { - groupedAlpha.push_back(reinterpret_cast(alpha[i])); - groupedBeta.push_back(reinterpret_cast(beta[i])); - groupedA.push_back(reinterpret_cast(a[i])); - groupedB.push_back(reinterpret_cast(b[i])); - groupedC.push_back(reinterpret_cast(c[i])); - groupedD.push_back(reinterpret_cast(d[i])); - groupedScaleA.push_back(reinterpret_cast(scaleA[i])); - groupedScaleB.push_back(reinterpret_cast(scaleB[i])); - groupedScaleC.push_back(reinterpret_cast(scaleC[i])); - groupedScaleD.push_back(reinterpret_cast(scaleD[i])); - groupedScaleE.push_back(reinterpret_cast(scaleE[i])); - groupedScaleAlphaVec.push_back(reinterpret_cast(scaleAlphaVec[i])); - } - - return rocblaslt_groupedgemm_create_batched_template(trans_a, - trans_b, - m, - n, - k, - groupedAlpha, - groupedA, - ld_a, - batch_stride_a, - groupedB, - ld_b, - batch_stride_b, - groupedBeta, - groupedC, - ld_c, - batch_stride_c, - groupedD, - ld_d, - batch_stride_d, - e, - ld_e, - batch_stride_e, - batch_count, - strided_batch, - grouped_gemm, - compute_type, - gradient, - bias, - groupedScaleA, - groupedScaleB, - groupedScaleC, - groupedScaleD, - groupedScaleE, - groupedScaleAlphaVec, - bias_type, - epilogue, - gemmData, - gemmCount); -} - -inline rocblaslt_status rocblaslt_matmul_template(rocblaslt_handle handle, - hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int64_t m, - int64_t n, - int64_t k, - const void* alpha, - const void* a, - hipblasltDatatype_t a_type, - int64_t ld_a, - int64_t batch_stride_a, - const void* b, - hipblasltDatatype_t b_type, - int64_t ld_b, - int64_t batch_stride_b, - const void* beta, - const void* c, - hipblasltDatatype_t c_type, - int64_t ld_c, - int64_t batch_stride_c, - void* d, - hipblasltDatatype_t d_type, - int64_t ld_d, - int64_t batch_stride_d, - void* e, - int64_t ld_e, - int64_t batch_stride_e, - int64_t batch_count, - bool strided_batch, - bool grouped_gemm, - bool gradient, - rocblaslt_compute_type compute_type, - const rocblaslt_matmul_algo* algo, - void* workspace, - size_t workspaceSizeInBytes, - const void* bias, - const void* scaleA, - const void* scaleB, - const void* scaleC, - const void* scaleD, - const void* scaleE, - const void* scaleAlphaVec, - hipblasltDatatype_t bias_type, - rocblaslt_epilogue epilogue, - std::shared_ptr gemmData, - hipStream_t stream) -{ - rocblaslt_status rs_status = rocblaslt_status_not_implemented; - -#define EX_TYPECASTING_PARM \ - handle, trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, \ - beta, c, ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, \ - batch_count, strided_batch, grouped_gemm, gradient, compute_type, algo, workspace, \ - workspaceSizeInBytes, bias, scaleA, scaleB, scaleC, scaleD, scaleE, scaleAlphaVec, \ - bias_type, epilogue, gemmData, stream - - if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32 - || compute_type == rocblaslt_compute_f32_fast_xf32) - { - rs_status - = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_16F) - { - if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status - = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_16B && b_type == HIPBLASLT_R_16B) - { - if(c_type == HIPBLASLT_R_16B && d_type == HIPBLASLT_R_16B) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_8F_E4M3) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_8F_E5M2) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_64F && b_type == HIPBLASLT_R_64F) - { - if(c_type == HIPBLASLT_R_64F && d_type == HIPBLASLT_R_64F) - { - if(compute_type == rocblaslt_compute_f64) - { - rs_status = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E5M2 && b_type == HIPBLASLT_R_8F_E4M3) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_8I && b_type == HIPBLASLT_R_8I) //int8 - { - if(c_type == HIPBLASLT_R_8I && d_type == HIPBLASLT_R_8I) - { - if(compute_type == rocblaslt_compute_i32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_32I && d_type == HIPBLASLT_R_32I) - { - if(compute_type == rocblaslt_compute_i32) - { - rs_status - = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_8F_E4M3) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status - = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_16F) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_matmul_typecasting(EX_TYPECASTING_PARM); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status - = rocblaslt_matmul_typecasting( - EX_TYPECASTING_PARM); - } - } - } - else - { - log_error(__func__, "No such template."); - rs_status = rocblaslt_status_not_implemented; - } - - return rs_status; -} - -inline rocblaslt_status rocblaslt_gemm_create_template_cpp(hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int64_t m, - int64_t n, - int64_t k, - const void* alpha, - const void* a, - hipblasltDatatype_t a_type, - int64_t ld_a, - int64_t batch_stride_a, - const void* b, - hipblasltDatatype_t b_type, - int64_t ld_b, - int64_t batch_stride_b, - const void* beta, - const void* c, - hipblasltDatatype_t c_type, - int64_t ld_c, - int64_t batch_stride_c, - void* d, - hipblasltDatatype_t d_type, - int64_t ld_d, - int64_t batch_stride_d, - void* e, - int64_t ld_e, - int64_t batch_stride_e, - int64_t batch_count, - bool strided_batch, - bool grouped_gemm, - bool gradient, - rocblaslt_compute_type compute_type, - const void* bias, - const void* scaleA, - const void* scaleB, - const void* scaleC, - const void* scaleD, - const void* scaleE, - const void* scaleAlphaVec, - hipblasltDatatype_t bias_type, - rocblaslt_epilogue epilogue, - std::shared_ptr& gemmData, - size_t& gemmCount) -{ - rocblaslt_status rs_status = rocblaslt_status_not_implemented; - -#define EX_TYPECASTING_PARM_GEMM_CPP \ - trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, beta, c, \ - ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, batch_count, \ - strided_batch, grouped_gemm, gradient, compute_type, bias, scaleA, scaleB, scaleC, scaleD, \ - scaleE, scaleAlphaVec, bias_type, epilogue, gemmData, gemmCount - - if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32 - || compute_type == rocblaslt_compute_f32_fast_xf32) - { - rs_status = rocblaslt_gemm_create_typecasting( - EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_16F) - { - if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_16B && b_type == HIPBLASLT_R_16B) - { - if(c_type == HIPBLASLT_R_16B && d_type == HIPBLASLT_R_16B) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_64F && b_type == HIPBLASLT_R_64F) - { - if(c_type == HIPBLASLT_R_64F && d_type == HIPBLASLT_R_64F) - { - if(compute_type == rocblaslt_compute_f64) - { - rs_status = rocblaslt_gemm_create_typecasting( - EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_8I && b_type == HIPBLASLT_R_8I) //int8 - { - if(c_type == HIPBLASLT_R_8I && d_type == HIPBLASLT_R_8I) - { - if(compute_type == rocblaslt_compute_i32) - { - rs_status - = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - else if(c_type == HIPBLASLT_R_32I && d_type == HIPBLASLT_R_32I) - { - if(compute_type == rocblaslt_compute_i32) - { - rs_status - = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_8F_E4M3) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status - = rocblaslt_gemm_create_typecasting( - EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_16F) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_gemm_create_typecasting(EX_TYPECASTING_PARM_GEMM_CPP); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status - = rocblaslt_gemm_create_typecasting( - EX_TYPECASTING_PARM_GEMM_CPP); - } - } - } - else - { - log_error(__func__, "No such template."); - rs_status = rocblaslt_status_not_implemented; - } - - return rs_status; -} - -inline rocblaslt_status - rocblaslt_groupedgemm_create_template_cpp(hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - std::vector& m, - std::vector& n, - std::vector& k, - std::vector& alpha, - std::vector& a, - hipblasltDatatype_t a_type, - std::vector& ld_a, - std::vector& batch_stride_a, - std::vector& b, - hipblasltDatatype_t b_type, - std::vector& ld_b, - std::vector& batch_stride_b, - std::vector& beta, - std::vector& c, - hipblasltDatatype_t c_type, - std::vector& ld_c, - std::vector& batch_stride_c, - std::vector& d, - hipblasltDatatype_t d_type, - std::vector& ld_d, - std::vector& batch_stride_d, - std::vector& e, - std::vector& ld_e, - std::vector& batch_stride_e, - std::vector& batch_count, - bool strided_batch, - bool grouped_gemm, - std::vector& gradient, - rocblaslt_compute_type compute_type, - std::vector& bias, - std::vector& scaleA, - std::vector& scaleB, - std::vector& scaleC, - std::vector& scaleD, - std::vector& scaleE, - std::vector& scaleAlphaVec, - std::vector& bias_type, - std::vector& epilogue, - std::shared_ptr& gemmData, - size_t& gemmCount) -{ - rocblaslt_status rs_status = rocblaslt_status_not_implemented; - -#define EX_TYPECASTING_PARM_GroupedGemm_CPP \ - trans_a, trans_b, m, n, k, alpha, a, ld_a, batch_stride_a, b, ld_b, batch_stride_b, beta, c, \ - ld_c, batch_stride_c, d, ld_d, batch_stride_d, e, ld_e, batch_stride_e, batch_count, \ - strided_batch, grouped_gemm, compute_type, gradient, bias, scaleA, scaleB, scaleC, scaleD, \ - scaleE, scaleAlphaVec, bias_type, epilogue, gemmData, gemmCount - - if(a_type == HIPBLASLT_R_32F && b_type == HIPBLASLT_R_32F) - { - if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32 - || compute_type == rocblaslt_compute_f32_fast_xf32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_16F) - { - if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_16B && b_type == HIPBLASLT_R_16B) - { - if(c_type == HIPBLASLT_R_16B && d_type == HIPBLASLT_R_16B) - { - if(compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_64F && b_type == HIPBLASLT_R_64F) - { - if(c_type == HIPBLASLT_R_64F && d_type == HIPBLASLT_R_64F) - { - if(compute_type == rocblaslt_compute_f64) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_8I && b_type == HIPBLASLT_R_8I) //int8 - { - if(c_type == HIPBLASLT_R_8I && d_type == HIPBLASLT_R_8I) - { - if(compute_type == rocblaslt_compute_i32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - else if(c_type == HIPBLASLT_R_32I && d_type == HIPBLASLT_R_32I) - { - if(compute_type == rocblaslt_compute_i32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_16F && b_type == HIPBLASLT_R_8F_E4M3) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else if(a_type == HIPBLASLT_R_8F_E4M3 && b_type == HIPBLASLT_R_16F) // mix types - { - if(c_type == HIPBLASLT_R_8F_E4M3 && d_type == HIPBLASLT_R_8F_E4M3) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - else if(c_type == HIPBLASLT_R_16F && d_type == HIPBLASLT_R_16F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - else if(c_type == HIPBLASLT_R_32F && d_type == HIPBLASLT_R_32F) - { - if(compute_type == rocblaslt_compute_f32_fast_f16 - || compute_type == rocblaslt_compute_f32) - { - rs_status = rocblaslt_groupedgemm_create_typecasting( - EX_TYPECASTING_PARM_GroupedGemm_CPP); - } - } - } - else - { - log_error(__func__, "No such template."); - rs_status = rocblaslt_status_not_implemented; - } - - return rs_status; -} -#endif diff --git a/library/src/amd_detail/rocblaslt/src/tensile_host.cpp b/library/src/amd_detail/rocblaslt/src/tensile_host.cpp index 5f7a5994f5..6c210aec12 100644 --- a/library/src/amd_detail/rocblaslt/src/tensile_host.cpp +++ b/library/src/amd_detail/rocblaslt/src/tensile_host.cpp @@ -76,91 +76,35 @@ namespace { return rocblaslt_internal_get_so_path("libhipblaslt"); } - /****************************************************** - * Map a rocblaslt type to a corresponding Tensile type * - ******************************************************/ - template - struct rocblaslt_to_tensile_type - { - using tensile_type = T; - }; - - template <> - struct rocblaslt_to_tensile_type - { - using tensile_type = Tensile::Half; - }; - - template <> - struct rocblaslt_to_tensile_type - { - using tensile_type = Tensile::BFloat16; - }; - - template <> - struct rocblaslt_to_tensile_type - { - using tensile_type = Tensile::Float8; - }; - - template <> - struct rocblaslt_to_tensile_type - { - using tensile_type = Tensile::BFloat8; - }; - - template <> - struct rocblaslt_to_tensile_type - { - using tensile_type = Tensile::Int8; - }; - /******************************************************************** - * Variable template to map a rocblaslt type into a Tensile::DataType * - ********************************************************************/ - template - constexpr auto tensile_datatype = nullptr; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::Half; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::Float; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::Double; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::BFloat16; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::Float8; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::BFloat8; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::Int8; - - template <> - constexpr auto tensile_datatype = Tensile::DataType::Int32; /************************************************************************* * Class for converting alpha and beta between rocblaslt and Tensile types * * By default, alpha and beta are the same type as Tc compute_type * *************************************************************************/ - template - struct AlphaBeta + static void assignAlphaBeta(Tensile::DataType type, + const void* alphaPtr, + const void* betaPtr, + double* alpha, + double* beta) { - using tensile_type = typename rocblaslt_to_tensile_type::tensile_type; - static void copy(tensile_type* dst, const Tc* src) + switch(type) { - static_assert(sizeof(*src) == sizeof(*dst), - "Tensile and rocblaslt types are not the same size"); - static_assert(std::is_standard_layout{} && std::is_standard_layout{}, - "Tensile or rocblaslt types are not standard layout types"); - memcpy(dst, src, sizeof(*dst)); + case Tensile::DataType::Float: + *alpha = *(float*)alphaPtr; + *beta = *(int32_t*)betaPtr; + break; + case Tensile::DataType::Double: + *alpha = *(double*)alphaPtr; + *beta = *(int32_t*)betaPtr; + break; + case Tensile::DataType::Int32: + *alpha = *(int32_t*)alphaPtr; + *beta = *(int32_t*)betaPtr; + break; + default: + throw std::runtime_error("Unsupported alpha, beta type."); } - }; + } inline Tensile::ActivationType getTensileActivationType(rocblaslt_epilogue epilogue) { @@ -336,15 +280,13 @@ namespace /**************************************************************** * Construct a Tensile Problem from a RocblasltContractionProblem * ****************************************************************/ - // Should remove this - template - auto ConstructTensileProblem(const RocblasltContractionProblem& prob) + auto ConstructTensileProblem(const RocblasltContractionProblem& prob) { - // Tensile DataTypes corresponding to rocblaslt data types - static constexpr Tensile::DataType Tensile_TiA = tensile_datatype; - static constexpr Tensile::DataType Tensile_TiB = tensile_datatype; - static constexpr Tensile::DataType Tensile_To = tensile_datatype; - static constexpr Tensile::DataType Tensile_Tc = tensile_datatype; + auto a_type = hipblasltDatatype_to_tensile_type(prob.a_type); + auto b_type = hipblasltDatatype_to_tensile_type(prob.b_type); + auto c_type = hipblasltDatatype_to_tensile_type(prob.c_type); + auto d_type = hipblasltDatatype_to_tensile_type(prob.d_type); + auto compute_type = roc2TensileType(prob.compute_type); // Tensor descriptors for a, b Tensile::TensorDescriptor a, b; @@ -364,7 +306,9 @@ namespace // This makes alpha==0 a change in the problem, and not just a change in the // inputs. It optimizes all problems with alpha==0 into K=0 and alpha=(don't // care) - auto k = prob.k && *prob.alpha ? prob.k : 0; + double alpha = 0, beta = 0; + assignAlphaBeta(compute_type, prob.alpha, prob.beta, &alpha, &beta); + auto k = prob.k && alpha ? prob.k : 0; // clang-format off @@ -373,7 +317,7 @@ namespace { a = { "a", - Tensile_TiA, + a_type, {k, prob.m, prob.batch_count}, {prob.row_stride_a, prob.col_stride_a, prob.batch_stride_a} }; @@ -384,7 +328,7 @@ namespace { a = { "a", - Tensile_TiA, + a_type, {prob.m, k, prob.batch_count}, {prob.row_stride_a, prob.col_stride_a, prob.batch_stride_a} }; @@ -397,7 +341,7 @@ namespace { b = { "b", - Tensile_TiB, + b_type, {prob.n, k, prob.batch_count}, {prob.row_stride_b, prob.col_stride_b, prob.batch_stride_b} }; @@ -408,7 +352,7 @@ namespace { b = { "b", - Tensile_TiB, + b_type, {k, prob.n, prob.batch_count}, {prob.row_stride_b, prob.col_stride_b, prob.batch_stride_b} }; @@ -420,13 +364,13 @@ namespace // Descriptor for input matrix C Tensile::TensorDescriptor c{"c", - Tensile_To, + c_type, {prob.m, prob.n, prob.batch_count}, {prob.row_stride_c, prob.col_stride_c, prob.batch_stride_c}}; // Descriptor for output matrix D Tensile::TensorDescriptor d{"d", - Tensile_To, + d_type, {prob.m, prob.n, prob.batch_count}, {prob.row_stride_d, prob.col_stride_d, prob.batch_stride_d}}; @@ -453,16 +397,18 @@ namespace freeIndex, batchIndex, boundIndex, - value_category(*prob.beta), + value_category(beta), prob.workspaceSize}; tensileProblem.setComputeInputType( - roc2TensileComputeInputType(Tensile_TiA, Tensile_TiB, prob.compute_type)); - tensileProblem.setAlphaType(Tensile_Tc); - tensileProblem.setBetaType(Tensile_Tc); + roc2TensileComputeInputType(a_type, b_type, prob.compute_type)); + tensileProblem.setAlphaType(compute_type); + tensileProblem.setBetaType(compute_type); // HPA is active iff sizeof(compute type) > sizeof(input type) - tensileProblem.setHighPrecisionAccumulate(sizeof(Tc) > sizeof(TiA)); + tensileProblem.setHighPrecisionAccumulate( + Tensile::DataTypeInfo::Get(compute_type).elementSize + > Tensile::DataTypeInfo::Get(a_type).elementSize); // set batch mode tensileProblem.setStridedBatched(prob.strided_batch); @@ -476,12 +422,10 @@ namespace // alpha and beta are copied from host to Tensile::TypedContractionInputs // If k==0, we do not need to dereference prob.alpha and can set // tensileAlpha=0 Not positive if this is necessary here as well - typename AlphaBeta::tensile_type tensileAlpha; + double alphaRestriction = 0; if(prob.k) - AlphaBeta::copy(&tensileAlpha, prob.alpha); - else - memset(&tensileAlpha, 0, sizeof(tensileAlpha)); - tensileProblem.setAlphaRestriction(Tensile::toScalarValueEnum(tensileAlpha)); + alphaRestriction = alpha; + tensileProblem.setAlphaRestriction(Tensile::toScalarValueEnum(alphaRestriction)); // Add problem predicates for CEqualsD tensileProblem.setCEqualsD(prob.C == prob.D); @@ -490,7 +434,7 @@ namespace { bool isOutput = prob.gradient ? false : true; tensileProblem.setUseE(true); - tensileProblem.setE(Tensile_Tc, + tensileProblem.setE(compute_type, {prob.m, prob.n, prob.batch_count}, {prob.row_stride_e, prob.col_stride_e, prob.batch_stride_e}, isOutput); @@ -505,21 +449,27 @@ namespace hipblasltDatatype_to_tensile_type(prob.bias_type), biasSize, 0, prob.gradient, biasSrc); // ScaleAB is only supported on F8/BF8 - if(Tensile_TiA == Tensile::DataType::Float8 || Tensile_TiA == Tensile::DataType::BFloat8 - || Tensile_TiB == Tensile::DataType::Float8 || Tensile_TiB == Tensile::DataType::BFloat8) + if(a_type == Tensile::DataType::Float8 || a_type == Tensile::DataType::BFloat8 + || b_type == Tensile::DataType::Float8 || b_type == Tensile::DataType::BFloat8) { tensileProblem.setUseScaleAB(true); - tensileProblem.setScaleA(Tensile_Tc); - tensileProblem.setScaleB(Tensile_Tc); + if(d_type == Tensile::DataType::Float8 || d_type == Tensile::DataType::BFloat8) + tensileProblem.setUseScaleCD(true); + else + tensileProblem.setUseScaleCD(false); + tensileProblem.setScaleA(compute_type); + tensileProblem.setScaleB(compute_type); + tensileProblem.setScaleC(compute_type); + tensileProblem.setScaleD(compute_type); tensileProblem.setUseScaleAlphaVec(true); - tensileProblem.setScaleAlphaVec(Tensile_Tc, d.sizes()[0]); + tensileProblem.setScaleAlphaVec(compute_type, d.sizes()[0]); } else { tensileProblem.setUseScaleAB(false); // set ScaleAlphaVec mode tensileProblem.setUseScaleAlphaVec(true); - tensileProblem.setScaleAlphaVec(Tensile_Tc, d.sizes()[0]); + tensileProblem.setScaleAlphaVec(compute_type, d.sizes()[0]); } // set Actvation @@ -527,7 +477,7 @@ namespace if(!is_biasSrc_AB(prob.epilogue)) { tensileProblem.setActivationType(Tensile::ActivationType::All); - tensileProblem.setActivationComputeType(Tensile_Tc); + tensileProblem.setActivationComputeType(compute_type); tensileProblem.setActivationEnumArg(getTensileActivationType(prob.epilogue)); } else @@ -538,24 +488,21 @@ namespace // set use gradient tensileProblem.setUseGradient(is_grad_enabled(prob.epilogue)); - if constexpr(std::is_same{} && std::is_same{} - && std::is_same{}) - if(prob.compute_type == rocblaslt_compute_f32_fast_xf32) - tensileProblem.setF32XdlMathOp(Tensile::DataType::XFloat32); + if(prob.compute_type == rocblaslt_compute_f32_fast_xf32) + tensileProblem.setF32XdlMathOp(Tensile::DataType::XFloat32); return tensileProblem; } - template - void updateTensileProblem(const bool fallback, - const RocblasltContractionProblem& prob, - Tensile::ContractionProblemGemm& tensileProblem) + void updateTensileProblem(const bool fallback, + const RocblasltContractionProblem& prob, + Tensile::ContractionProblemGemm& tensileProblem) { - // Tensile DataTypes corresponding to rocblaslt data types - static constexpr Tensile::DataType Tensile_TiA = tensile_datatype; - static constexpr Tensile::DataType Tensile_TiB = tensile_datatype; - static constexpr Tensile::DataType Tensile_To = tensile_datatype; - static constexpr Tensile::DataType Tensile_Tc = tensile_datatype; + auto a_type = hipblasltDatatype_to_tensile_type(prob.a_type); + auto b_type = hipblasltDatatype_to_tensile_type(prob.b_type); + auto c_type = hipblasltDatatype_to_tensile_type(prob.c_type); + auto d_type = hipblasltDatatype_to_tensile_type(prob.d_type); + auto compute_type = roc2TensileType(prob.compute_type); // Tensile Indices for contraction problem Tensile::ContractionProblemGemm::FreeIndices freeIndex(2); @@ -580,7 +527,7 @@ namespace if(prob.trans_a != HIPBLAS_OP_N) { tensileProblem.resetTensor(Tensile::ContractionProblemGemm::TENSOR::A, - Tensile_TiA, + a_type, {k, prob.m, prob.batch_count}, {prob.row_stride_a, prob.col_stride_a, prob.batch_stride_a}); freeIndex[0].i = 1; @@ -589,7 +536,7 @@ namespace else { tensileProblem.resetTensor(Tensile::ContractionProblemGemm::TENSOR::A, - Tensile_TiA, + a_type, {prob.m, k, prob.batch_count}, {prob.row_stride_a, prob.col_stride_a, prob.batch_stride_a}); freeIndex[0].i = 0; @@ -600,7 +547,7 @@ namespace if(prob.trans_b != HIPBLAS_OP_N) { tensileProblem.resetTensor(Tensile::ContractionProblemGemm::TENSOR::B, - Tensile_TiB, + b_type, {prob.n, k, prob.batch_count}, {prob.row_stride_b, prob.col_stride_b, prob.batch_stride_b}); freeIndex[1].i = 0; @@ -609,7 +556,7 @@ namespace else { tensileProblem.resetTensor(Tensile::ContractionProblemGemm::TENSOR::B, - Tensile_TiB, + b_type, {k, prob.n, prob.batch_count}, {prob.row_stride_b, prob.col_stride_b, prob.batch_stride_b}); freeIndex[1].i = 1; @@ -620,26 +567,30 @@ namespace // Descriptor for input matrix C tensileProblem.resetTensor(Tensile::ContractionProblemGemm::TENSOR::C, - Tensile_To, + c_type, {prob.m, prob.n, prob.batch_count}, {prob.row_stride_c, prob.col_stride_c, prob.batch_stride_c}); // Descriptor for output matrix D tensileProblem.resetTensor(Tensile::ContractionProblemGemm::TENSOR::D, - Tensile_To, + d_type, {prob.m, prob.n, prob.batch_count}, {prob.row_stride_d, prob.col_stride_d, prob.batch_stride_d}); - tensileProblem.updateProblem( - freeIndex, batchIndex, boundIndex, (double)(*prob.beta), prob.workspaceSize); + double alpha = 0, beta = 0; + assignAlphaBeta(compute_type, prob.alpha, prob.beta, &alpha, &beta); + + tensileProblem.updateProblem(freeIndex, batchIndex, boundIndex, beta, prob.workspaceSize); tensileProblem.setComputeInputType( - roc2TensileComputeInputType(Tensile_TiA, Tensile_TiB, prob.compute_type)); - tensileProblem.setAlphaType(Tensile_Tc); - tensileProblem.setBetaType(Tensile_Tc); + roc2TensileComputeInputType(a_type, b_type, prob.compute_type)); + tensileProblem.setAlphaType(compute_type); + tensileProblem.setBetaType(compute_type); // HPA is active iff sizeof(compute type) > sizeof(input type) - tensileProblem.setHighPrecisionAccumulate(sizeof(Tc) > sizeof(TiA)); + tensileProblem.setHighPrecisionAccumulate( + Tensile::DataTypeInfo::Get(compute_type).elementSize + > Tensile::DataTypeInfo::Get(a_type).elementSize); // set batch mode tensileProblem.setStridedBatched(prob.strided_batch); @@ -653,12 +604,10 @@ namespace // alpha and beta are copied from host to Tensile::TypedContractionInputs // If k==0, we do not need to dereference prob.alpha and can set // tensileAlpha=0 Not positive if this is necessary here as well - typename AlphaBeta::tensile_type tensileAlpha; + double alphaRestriction = 0; if(prob.k) - AlphaBeta::copy(&tensileAlpha, prob.alpha); - else - memset(&tensileAlpha, 0, sizeof(tensileAlpha)); - tensileProblem.setAlphaRestriction(Tensile::toScalarValueEnum(tensileAlpha)); + alphaRestriction = alpha; + tensileProblem.setAlphaRestriction(Tensile::toScalarValueEnum(alphaRestriction)); // Add problem predicates for CEqualsD tensileProblem.setCEqualsD(prob.C == prob.D); @@ -689,22 +638,20 @@ namespace biasSrc); // ScaleAB is only supported on F8/BF8 - if(Tensile_TiA == Tensile::DataType::Float8 || Tensile_TiA == Tensile::DataType::BFloat8 - || Tensile_TiB == Tensile::DataType::Float8 - || Tensile_TiB == Tensile::DataType::BFloat8) + if(a_type == Tensile::DataType::Float8 || a_type == Tensile::DataType::BFloat8 + || b_type == Tensile::DataType::Float8 || b_type == Tensile::DataType::BFloat8) { tensileProblem.setUseScaleAB(true); - if(Tensile_To == Tensile::DataType::Float8 - || Tensile_To == Tensile::DataType::BFloat8) + if(d_type == Tensile::DataType::Float8 || d_type == Tensile::DataType::BFloat8) tensileProblem.setUseScaleCD(true); else tensileProblem.setUseScaleCD(false); - tensileProblem.setScaleA(Tensile_Tc); - tensileProblem.setScaleB(Tensile_Tc); - tensileProblem.setScaleC(Tensile_Tc); - tensileProblem.setScaleD(Tensile_Tc); + tensileProblem.setScaleA(compute_type); + tensileProblem.setScaleB(compute_type); + tensileProblem.setScaleC(compute_type); + tensileProblem.setScaleD(compute_type); tensileProblem.setUseScaleAlphaVec(true); - tensileProblem.setScaleAlphaVec(Tensile_Tc, d.sizes()[0]); + tensileProblem.setScaleAlphaVec(compute_type, d.sizes()[0]); } else { @@ -712,7 +659,7 @@ namespace tensileProblem.setUseScaleCD(false); // set ScaleAlphaVec mode tensileProblem.setUseScaleAlphaVec(true); - tensileProblem.setScaleAlphaVec(Tensile_Tc, d.sizes()[0]); + tensileProblem.setScaleAlphaVec(compute_type, d.sizes()[0]); } // set Actvation @@ -720,7 +667,7 @@ namespace if(!is_biasSrc_AB(prob.epilogue)) { tensileProblem.setActivationType(Tensile::ActivationType::All); - tensileProblem.setActivationComputeType(Tensile_Tc); + tensileProblem.setActivationComputeType(compute_type); tensileProblem.setActivationEnumArg(tensileAct); } else @@ -734,7 +681,7 @@ namespace { bool isOutput = prob.gradient ? false : true; tensileProblem.setUseE(true); - tensileProblem.setE(Tensile_Tc, + tensileProblem.setE(compute_type, {prob.m, prob.n, prob.batch_count}, {prob.row_stride_e, prob.col_stride_e, prob.batch_stride_e}, isOutput); @@ -744,32 +691,16 @@ namespace tensileProblem.setUseGradient(is_grad_enabled(prob.epilogue)); } - if constexpr(std::is_same{} && std::is_same{} - && std::is_same{}) - if(prob.compute_type == rocblaslt_compute_f32_fast_xf32) - tensileProblem.setF32XdlMathOp(Tensile::DataType::XFloat32); + if(prob.compute_type == rocblaslt_compute_f32_fast_xf32) + tensileProblem.setF32XdlMathOp(Tensile::DataType::XFloat32); } /*************************************************************** * Construct the inputs to a Tensile ContractionProblemGemm * ***************************************************************/ - template - auto GetTensileInputs(const RocblasltContractionProblem& prob) + auto GetTensileInputs(const RocblasltContractionProblem& prob) { - // Tensile types corresponding to TiA, TiB, To, Tc - using Tensile_TiA = typename rocblaslt_to_tensile_type::tensile_type; - using Tensile_TiB = typename rocblaslt_to_tensile_type::tensile_type; - using Tensile_To = typename rocblaslt_to_tensile_type::tensile_type; - using Tensile_Talpha_beta = typename AlphaBeta::tensile_type; - - // Make sure rocblaslt and Tensile types are compatible - // (Even if Ti=rocblaslt_int8x4, Tensile_Ti=Int8x4, they are both 32-byte) - static_assert(sizeof(Tensile_TiA) == sizeof(TiA) && sizeof(Tensile_To) == sizeof(To), - "Tensile and rocblaslt types are not the same size"); - - static_assert(std::is_standard_layout{} && std::is_standard_layout{} - && std::is_standard_layout{} && std::is_standard_layout{}, - "Tensile or rocblaslt types are not standard layout types"); + auto compute_type = roc2TensileType(prob.compute_type); // Structure describing the inputs (A, B, C, D, alpha, beta) Tensile::ContractionInputs inputs; @@ -798,18 +729,41 @@ namespace inputs.scaleAlphaVec = reinterpret_cast(prob.scaleAlphaVec); // push 2 activation arguments - inputs.activationArgs.push_back(static_cast(0.0f)); - inputs.activationArgs.push_back(static_cast(0.0f)); - - // alpha and beta are stored by value in Tensile::TypedContractionInputs - // alpha and beta are copied from host to Tensile::TypedContractionInputs - // If k==0, we do not need to dereference prob.alpha and can set - // inputs.alpha=0 - if(prob.k) - inputs.alpha = static_cast((*prob.alpha)); + if(compute_type == Tensile::DataType::Float || compute_type == Tensile::DataType::XFloat32) + { + inputs.activationArgs.push_back(0.0f); + inputs.activationArgs.push_back(0.0f); + if(prob.k) + inputs.alpha = *(float*)(prob.alpha); + else + inputs.alpha = 0.f; + inputs.beta = *(float*)(prob.beta); + } + else if(compute_type == Tensile::DataType::Int32) + { + inputs.activationArgs.push_back((int32_t)0); + inputs.activationArgs.push_back((int32_t)0); + if(prob.k) + inputs.alpha = *(int32_t*)(prob.alpha); + else + inputs.alpha = (int32_t)0; + inputs.beta = *(int32_t*)(prob.beta); + } + else if(compute_type == Tensile::DataType::Double) + { + inputs.activationArgs.push_back((double)0.0); + inputs.activationArgs.push_back((double)0.0); + if(prob.k) + inputs.alpha = *(double*)(prob.alpha); + else + inputs.alpha = (double)0; + inputs.beta = *(double*)(prob.beta); + } else - inputs.alpha = static_cast(0); - inputs.beta = static_cast((*prob.beta)); + { + log_error(__func__, "Unsupported compute type"); + throw std::runtime_error("[GetTensileInputs] Unsupported compute type."); + } return inputs; } @@ -1203,11 +1157,10 @@ void initTensileGemmData(rocblaslt_handle handle, * runContractionProblem calls Tensile to run a contraction problem described * * by RocblasltContractionProblem * ******************************************************************************/ -template -rocblaslt_status runContractionProblem(rocblaslt_handle handle, - const rocblaslt_matmul_algo* algo, - const RocblasltContractionProblem& prob, - std::shared_ptr gemmData) +rocblaslt_status runContractionProblem(rocblaslt_handle handle, + const rocblaslt_matmul_algo* algo, + const RocblasltContractionProblem& prob, + std::shared_ptr gemmData) { rocblaslt_status status = rocblaslt_status_internal_error; try @@ -1264,10 +1217,9 @@ rocblaslt_status runContractionProblem(rocblaslt_handle return status; } -template -rocblaslt_status gemmCreate(RocblasltContractionProblem const& problem, - std::shared_ptr& gemmData, - size_t& gemmCount) +rocblaslt_status gemmCreate(RocblasltContractionProblem const& problem, + std::shared_ptr& gemmData, + size_t& gemmCount) { rocblaslt_status status = rocblaslt_status_internal_error; try @@ -1318,11 +1270,9 @@ rocblaslt_status gemmCreate(RocblasltContractionProblem const& return status; } -template -rocblaslt_status - groupedGemmCreate(std::vector>& probs, - std::shared_ptr& gemmData, - size_t& gemmCount) +rocblaslt_status groupedGemmCreate(std::vector& probs, + std::shared_ptr& gemmData, + size_t& gemmCount) { gemmCount = probs.size(); if(gemmCount == 0) @@ -1850,14 +1800,13 @@ inline auto getSolutions( return solutions; } -template -rocblaslt_status getBestSolutions(RocblasltContractionProblem prob, - rocblaslt_handle handle, - std::shared_ptr gemmData, - int requestedAlgoCount, - rocblaslt_matmul_heuristic_result heuristicResultsArray[], - int* returnAlgoCount, - size_t maxWorkSpaceBytes) +rocblaslt_status getBestSolutions(RocblasltContractionProblem const& prob, + rocblaslt_handle handle, + std::shared_ptr gemmData, + int requestedAlgoCount, + rocblaslt_matmul_heuristic_result heuristicResultsArray[], + int* returnAlgoCount, + size_t maxWorkSpaceBytes) { std::shared_ptr> library; std::shared_ptr deviceProp; @@ -1876,15 +1825,13 @@ rocblaslt_status getBestSolutions(RocblasltContractionProblem = getSolutions(prob, library, hardware, data->problem, requestedAlgoCount, fallbackSize); // when there is no solution for xfloat32, fallback comput_type to fp32 - if constexpr(std::is_same{} && std::is_same{} - && std::is_same{}) - if(solutions.size() == 0 && prob.compute_type == rocblaslt_compute_f32_fast_xf32) - { - log_api(__func__, "no solutions found, try to fallback"); - data->problem.setF32XdlMathOp(Tensile::DataType::Float); - solutions = getSolutions( - prob, library, hardware, data->problem, requestedAlgoCount, fallbackSize); - } + if(solutions.size() == 0 && prob.compute_type == rocblaslt_compute_f32_fast_xf32) + { + log_api(__func__, "no solutions found, try to fallback"); + data->problem.setF32XdlMathOp(Tensile::DataType::Float); + solutions = getSolutions( + prob, library, hardware, data->problem, requestedAlgoCount, fallbackSize); + } _convertToHeuristicResultArray(solutions, requestedAlgoCount, @@ -1962,8 +1909,7 @@ rocblaslt_status getAllSolutions(MyProblem& return rocblaslt_status_success; } -template -rocblaslt_status getAllSolutions(RocblasltContractionProblem& prob, +rocblaslt_status getAllSolutions(RocblasltContractionProblem& prob, rocblaslt_handle handle, std::vector& heuristicResults, size_t maxWorkSpaceBytes) @@ -1972,9 +1918,8 @@ rocblaslt_status getAllSolutions(RocblasltContractionProblem& return getAllSolutions(tensile_prob, handle, heuristicResults, maxWorkSpaceBytes); } -template -rocblaslt_status getAllSolutions(std::vector>& probs, - rocblaslt_handle handle, +rocblaslt_status getAllSolutions(std::vector& probs, + rocblaslt_handle handle, std::vector& heuristicResults, size_t maxWorkSpaceBytes) { @@ -2189,20 +2134,19 @@ rocblaslt_status isSolutionSupported(rocblaslt_handle handle, return rocblaslt_status_success; } -template -rocblaslt_status isSolutionSupported(rocblaslt_handle handle, - RocblasltContractionProblem& prob, - std::shared_ptr gemmData, - rocblaslt_matmul_algo* algo, - size_t* workspaceSizeInBytes) +rocblaslt_status isSolutionSupported(rocblaslt_handle handle, + RocblasltContractionProblem& prob, + std::shared_ptr gemmData, + rocblaslt_matmul_algo* algo, + size_t* workspaceSizeInBytes) { std::shared_ptr data = std::static_pointer_cast(gemmData); updateTensileProblem(false, prob, data->problem); return isSolutionSupported(handle, data->problem, prob, algo, workspaceSizeInBytes); } -template -void setRestrictions(Tensile::ContractionProblemGemm& tensile_prob, const Tc* alpha, const Tc* beta) +template +void setRestrictions(Tensile::ContractionProblemGemm& tensile_prob, const T* alpha, const T* beta) { tensile_prob.setAlphaRestriction(Tensile::toScalarValueEnum(*alpha)); tensile_prob.setBetaRestriction(Tensile::toScalarValueEnum(*beta)); @@ -2371,71 +2315,6 @@ extern "C" void rocblaslt_createialize() static_cast(get_library_and_adapter()); } -/****************************************************************************** - * Intantiate the cases of runContractionProblem which are needed to satisfy * - * rocblaslt dependencies. This file's template functions are not defined in a * - * header file, in order to keep Tensile and rocblaslt separate. * - ******************************************************************************/ - -// types -#define CREATEFUNCTION(TiA, TiB, To, Tc) \ - template rocblaslt_status runContractionProblem( \ - rocblaslt_handle handle, \ - const rocblaslt_matmul_algo* algo, \ - const RocblasltContractionProblem&, \ - std::shared_ptr); \ - template rocblaslt_status gemmCreate(const RocblasltContractionProblem&, \ - std::shared_ptr& gemmData, \ - size_t& gemmCount); \ - template rocblaslt_status groupedGemmCreate( \ - std::vector>&, \ - std::shared_ptr&, \ - size_t&); \ - template rocblaslt_status getAllSolutions( \ - RocblasltContractionProblem& prob, \ - rocblaslt_handle handle, \ - std::vector& heuristicResults, \ - size_t maxWorkSpaceBytes); \ - template rocblaslt_status getAllSolutions( \ - std::vector>& probs, \ - rocblaslt_handle handle, \ - std::vector& heuristicResults, \ - size_t maxWorkSpaceBytes); \ - template rocblaslt_status isSolutionSupported( \ - rocblaslt_handle handle, \ - RocblasltContractionProblem& prob, \ - std::shared_ptr gemmData, \ - rocblaslt_matmul_algo* algo, \ - size_t* workspaceSizeInBytes); \ - template rocblaslt_status getBestSolutions( \ - RocblasltContractionProblem prob, \ - rocblaslt_handle handle, \ - std::shared_ptr gemmData, \ - int requestedAlgoCount, \ - rocblaslt_matmul_heuristic_result heuristicResultsArray[], \ - int* returnAlgoCount, \ - size_t maxWorkSpaceBytes); - -CREATEFUNCTION(float, float, float, float) -CREATEFUNCTION(double, double, double, double) -CREATEFUNCTION(rocblaslt_half, rocblaslt_half, rocblaslt_half, float) -CREATEFUNCTION(rocblaslt_half, rocblaslt_half, float, float) -CREATEFUNCTION(rocblaslt_bfloat16, rocblaslt_bfloat16, rocblaslt_bfloat16, float) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_f8, float, float) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_bf8, float, float) -CREATEFUNCTION(rocblaslt_bf8, rocblaslt_f8, float, float) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_f8, rocblaslt_half, float) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_bf8, rocblaslt_half, float) -CREATEFUNCTION(rocblaslt_bf8, rocblaslt_f8, rocblaslt_half, float) -CREATEFUNCTION(rocblasltInt8, rocblasltInt8, int32_t, int32_t) -CREATEFUNCTION(rocblasltInt8, rocblasltInt8, rocblasltInt8, int32_t) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_half, rocblaslt_f8, float) -CREATEFUNCTION(rocblaslt_half, rocblaslt_f8, rocblaslt_f8, float) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_half, rocblaslt_half, float) -CREATEFUNCTION(rocblaslt_half, rocblaslt_f8, rocblaslt_half, float) -CREATEFUNCTION(rocblaslt_f8, rocblaslt_half, float, float) -CREATEFUNCTION(rocblaslt_half, rocblaslt_f8, float, float) - /*********************************************************************************** * Whether Tensile has been initialized for at least one device (used for *testing) *