diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py b/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py index 57a96f3abd96..e1c92a6791d6 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py @@ -22,6 +22,8 @@ # ################################################################################ +import math + from rocisa.code import Module, Label from rocisa.container import vgpr, ContinuousRegister from rocisa.instruction import VAddU32, VAndB32, VLShiftLeftB32, VLShiftRightB32 @@ -226,6 +228,12 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi strideTile = 1 # DTV case. Actual stride will be applied later. strideK = offsetK if umlds else (mt + LdsPad) * offsetK + + # StrideK might be a float value due to sub-byte data types (e.g. fp4) + # and causes function signature error later. + # Use ceil to ensure no overlap between adjacent K groups in LDS. + strideK = int(math.ceil(strideK)) + if enableLDSTr: if kernel["UseGeneralizedNLCOne%s"%tc] and perpStride > 1: strideK = 8 diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 694202c913bf..c6f3341a28ec 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -128,8 +128,8 @@ class StateValues: bpr: int = 4 # all registers are 32bit # default setup # AB=DataType / Cexternal=DestDataType / Cinternal=Accumulation (MAC or MFMA) - bpeA: float = field(init=False) - bpeB: float = field(init=False) + bpeA: float = field(init=False) # this is a float because of sub-byte data types (f6, f4) + bpeB: float = field(init=False) # this is a float because of sub-byte data types (f6, f4) bpeE: int = field(init=False) # Cexternal = the "current" kernel output type, # - default: the "current" kernel is a non-GSU-kernel, @@ -4101,7 +4101,7 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB): """ if kernel["EnableMatrixInstruction"] and kernel["LocalReadVectorWidthA"] >= kernel["MIInputPerThread"]: - WLR = max(kernel["LocalReadVectorWidthA"]//kernel["MIInputPerThread"], 1) + WLR = int(max(kernel["LocalReadVectorWidthA"]//kernel["MIInputPerThread"], 1)) self.states.numItersPLR = kernel["PrefetchLocalRead"]%(kernel["LoopIters"]//WLR) else: self.states.numItersPLR = kernel["PrefetchLocalRead"]%(kernel["LoopIters"]) diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index 64931a18a1e0..7cd44ef30370 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -12794,7 +12794,7 @@ def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \ if useBuffer: rv = Module("Global Read") - mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, nt=nt, lds=lds) + mubuf = MUBUFModifiers(offen=True, offset12=int(offset), glc=glc, slc=slc, nt=nt, lds=lds) # Nested buffer load implementation function for easy branching for soffset def bufferLoadImpl(soffset): @@ -12831,7 +12831,7 @@ def bufferLoadImpl(soffset): dst = None if lds else vgpr(destVgpr, 4) rv.add(BufferLoadB128(dst=dst, vaddr=addr0, saddr=addr1, \ soffset=soffset, mubuf=mubuf, comment=comment)) - mubuf2 = MUBUFModifiers(offen=True, offset12=offset+16, glc=glc, slc=slc, nt=nt, lds=lds) + mubuf2 = MUBUFModifiers(offen=True, offset12=int(offset+16), glc=glc, slc=slc, nt=nt, lds=lds) if isinstance(destVgpr, str): dst2 = destVgpr + "+" + str(int(4)) elif isinstance(destVgpr, int): @@ -12865,7 +12865,6 @@ def bufferLoadImpl(soffset): dst = vgpr(destVgpr, rpv//4) rv.add(BufferLoadB128(dst=dst, vaddr=addr0, saddr=addr1, \ soffset=soffset, mubuf=mubuf, comment=comment)) - mubuf2 = MUBUFModifiers(offen=True, offset12=int(offset + bpl/4), glc=glc, slc=slc, nt=nt, lds=lds) dst2 = destVgpr + "+" + str(int(rpv//4)) if isinstance(destVgpr, str) else int(destVgpr + int(rpv//4)) diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index e963451b1389..5300c8875acf 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -930,7 +930,7 @@ def isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason: bool): return False # DTL + input type conversion - if state["ProblemType"]["DataType%s"%tc] != state["ProblemType"]["MacDataTypeA"]: + if state["ProblemType"]["DataType%s"%tc] != state["ProblemType"]["MacDataType%s"%tc]: if not state["ConvertAfterDS"]: reject(state, printRejectionReason, "DirectToLds%s + input conversion + ConvertAfterDS=False not supported"%(tc)) return False @@ -1720,16 +1720,20 @@ def depthUIteration( state["_staggerStrideShift"] = (int)(math.ceil(math.log(state["StaggerUStride"] / (state["DepthU"] * bpeA), 2))) - def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: - lrvwA = state["LocalReadVectorWidthA"] - lrvwB = state["LocalReadVectorWidthB"] + def calcLdsPad(lrvw: int, isaInfoMap: Dict[str, IsaInfo]) -> int: + isMX = state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"] ldsPadA = state["LdsPadA"] ldsPadB = state["LdsPadB"] ldsPadM = state["LdsPadMetadata"] - optPadA = lrvwA - optPadB = lrvwB - readRegsA = int(lrvwA * state["ProblemType"]["MacDataTypeA"].numBytes() // 4) - readRegsB = int(lrvwB * state["ProblemType"]["MacDataTypeB"].numBytes() // 4) + lrvwA = state["LocalReadVectorWidthA"] if isMX else lrvw + lrvwB = state["LocalReadVectorWidthB"] if isMX else lrvw + optPadA = lrvwA if isMX else lrvw + optPadB = lrvwB if isMX else lrvw + numBytesA = state["ProblemType"]["MacDataTypeA"].numBytes() if isMX else state["ProblemType"]["DataType"].numBytes() + numBytesB = state["ProblemType"]["MacDataTypeB"].numBytes() if isMX else state["ProblemType"]["DataType"].numBytes() + readRegsA = int(lrvwA * numBytesA // 4) + readRegsB = int(lrvwB * numBytesB // 4) + if state["ProblemType"]["Sparse"]: if state["ProblemType"]["Sparse"] == 2: optPadB //= 2 @@ -1737,8 +1741,8 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: else: optPadA //= 2 readRegsA //= 2 - if (not isaInfoMap[isa].asmCaps['HasWMMA']) and (readRegsA > 6 or readRegsB > 6): - reject(state, "LocalReadVectorWidth results in attemping to read LDS larger than b192, reject") + if (not isaInfoMap[isa].asmCaps['HasWMMA']) and (readRegsA > 4 or readRegsB > 4): + reject(state, "LocalReadVectorWidth results in attemping to read LDS larger than b128, reject") return ldsPadA, ldsPadB, ldsPadM if state["EnableMatrixInstruction"]: # for readRegs = 1 or 4, we need to double pad for MI16x16xNx1 to avoid bank conflict. @@ -1748,16 +1752,19 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: if readRegsB == 4 or readRegsB == 1: optPadB *= 2 if ldsPadA == -1: - if state["ProblemType"]["DataTypeA"].is6bitFloat(): + if isMX and state["ProblemType"]["DataTypeA"].is6bitFloat(): ldsPadA = 0 else: if not state["UnrollMajorLDSA"]: if state["EnableMatrixInstruction"]: ldsPadA = 0 if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: - ldsPadA = int(((16 * state["VectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() + state["MacroTile0"] * state["ProblemType"]["MacDataTypeA"].numBytes() * lrvwA) % 128) // state["ProblemType"]["MacDataTypeA"].numBytes()) - if state["GlobalReadVectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() == 32 and ldsPadA == 0: - ldsPadA = int(16 // state["ProblemType"]["MacDataTypeA"].numBytes()) + ldsPadA = int(((16 * state["VectorWidthA"] * numBytesA + state["MacroTile0"] * numBytesA * lrvwA) % 128) // numBytesA) + if state["GlobalReadVectorWidthA"] * numBytesA == 32 and ldsPadA == 0: + ldsPadA = int(16 // numBytesA) + if state["DirectToLdsA"]: + # TODO: Check if there are cases which benefit from padding, currently set to zero by default + ldsPadA = state["MatrixInstM"] if state["enableLDSTrA"] else 0 else: # mac instruction if state["ProblemType"]["TLUA"]: ldsPadA = 0 @@ -1771,17 +1778,20 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: assert(ldsPadA >= 0) if ldsPadB == -1: - if state["ProblemType"]["DataTypeB"].is6bitFloat(): + if isMX and state["ProblemType"]["DataTypeB"].is6bitFloat(): ldsPadB = 0 else: if not state["UnrollMajorLDSB"]: if state["EnableMatrixInstruction"]: ldsPadB = 0 if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: - ldsPadB = int(((16 * state["VectorWidthB"] * state["ProblemType"]["MacDataTypeB"].numBytes() + state["MacroTile1"] * state["ProblemType"]["MacDataTypeB"].numBytes() * lrvwB) % 128) // state["ProblemType"]["MacDataTypeB"].numBytes()) - if state["GlobalReadVectorWidthB"] * state["ProblemType"]["MacDataTypeB"].numBytes() == 32 and ldsPadB == 0: - ldsPadB = int(16 // state["ProblemType"]["MacDataTypeB"].numBytes()) - else: + ldsPadB = int(((16 * state["VectorWidthB"] * numBytesB + state["MacroTile1"] * numBytesB * lrvwB) % 128) // numBytesB) + if state["GlobalReadVectorWidthB"] * numBytesB == 32 and ldsPadB == 0: + ldsPadB = int(16 // numBytesB) + if state["DirectToLdsB"]: + # TODO: Check if there are cases which benefit from padding, currently set to zero by default + ldsPadB = state["MatrixInstM"] if state["enableLDSTrB"] else 0 + else: # mac instruction if state["ProblemType"]["TLUB"]: ldsPadB = 0 else: @@ -1833,37 +1843,48 @@ def removeLdsPadLogicForDTL(tc, ldsPad): return ldsPadA, ldsPadB, ldsPadM - def calcLdsBlockSizePerPad() -> int: - lrvwA = state["LocalReadVectorWidthA"] - lrvwB = state["LocalReadVectorWidthB"] + def calcLdsBlockSizePerPad(lrvw: int) -> int: + isMX = state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"] LdsBlockSizePerPadA = state["LdsBlockSizePerPadA"] LdsBlockSizePerPadB = state["LdsBlockSizePerPadB"] - tmpBpe = state["ProblemType"]["DataTypeA"].numBytes() if state["ConvertAfterDS"] else state["ProblemType"]["MacDataTypeA"].numBytes() + lrvwA = state["LocalReadVectorWidthA"] if isMX else lrvw + lrvwB = state["LocalReadVectorWidthB"] if isMX else lrvw + numBytesA = state["ProblemType"]["MacDataTypeA"].numBytes() if isMX else state["ProblemType"]["DataType"].numBytes() + numBytesB = state["ProblemType"]["MacDataTypeB"].numBytes() if isMX else state["ProblemType"]["DataType"].numBytes() + + tmpBpe = state["ProblemType"]["DataTypeA"].numBytes() if state["ConvertAfterDS"] else numBytesA if LdsBlockSizePerPadA == -1: - if state["EnableMatrixInstruction"] and not state["ProblemType"]["DataTypeA"].is6bitFloat(): - if state["UnrollMajorLDSA"]: - LdsBlockSizePerPadA = roundUpToNearestMultiple(int(state["_DepthUA"] * tmpBpe), 128) - if state["_DepthUA"] * tmpBpe * state["VectorWidthA"] > 128: - LdsBlockSizePerPadA = roundUpToNearestMultiple(int(state["_DepthUA"] * tmpBpe * state["VectorWidthA"]), 128) + if state["EnableMatrixInstruction"]: + if isMX and state["ProblemType"]["DataTypeA"].is6bitFloat(): + LdsBlockSizePerPadA = 0 else: - if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: - LdsBlockSizePerPadA = int(state["MacroTile0"] * tmpBpe * lrvwA) + if state["UnrollMajorLDSA"]: + LdsBlockSizePerPadA = roundUpToNearestMultiple(int(state["_DepthUA"] * tmpBpe), 128) + if state["_DepthUA"] * tmpBpe * state["VectorWidthA"] > 128: + LdsBlockSizePerPadA = roundUpToNearestMultiple(int(state["_DepthUA"] * tmpBpe * state["VectorWidthA"]), 128) else: - LdsBlockSizePerPadA = 0 + if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: + LdsBlockSizePerPadA = int(state["MacroTile0"] * tmpBpe * lrvwA) + else: + LdsBlockSizePerPadA = 0 else: LdsBlockSizePerPadA = 0 - tmpBpe = state["ProblemType"]["DataTypeB"].numBytes() if state["ConvertAfterDS"] else state["ProblemType"]["MacDataTypeB"].numBytes() + + tmpBpe = state["ProblemType"]["DataTypeB"].numBytes() if state["ConvertAfterDS"] else numBytesB if LdsBlockSizePerPadB == -1: - if state["EnableMatrixInstruction"] and not state["ProblemType"]["DataTypeB"].is6bitFloat(): - if state["UnrollMajorLDSB"]: - LdsBlockSizePerPadB = roundUpToNearestMultiple(int(state["_DepthUB"] * tmpBpe), 128) - if state["_DepthUB"] * tmpBpe * state["VectorWidthB"] > 128: - LdsBlockSizePerPadB = roundUpToNearestMultiple(int(state["_DepthUB"] * tmpBpe * state["VectorWidthB"]), 128) + if state["EnableMatrixInstruction"]: + if isMX and state["ProblemType"]["DataTypeB"].is6bitFloat(): + LdsBlockSizePerPadB = 0 else: - if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: - LdsBlockSizePerPadB = int(state["MacroTile1"] * tmpBpe * lrvwB) + if state["UnrollMajorLDSB"]: + LdsBlockSizePerPadB = roundUpToNearestMultiple(int(state["_DepthUB"] * tmpBpe), 128) + if state["_DepthUB"] * tmpBpe * state["VectorWidthB"] > 128: + LdsBlockSizePerPadB = roundUpToNearestMultiple(int(state["_DepthUB"] * tmpBpe * state["VectorWidthB"]), 128) else: - LdsBlockSizePerPadB = 0 + if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: + LdsBlockSizePerPadB = int(state["MacroTile1"] * tmpBpe * lrvwB) + else: + LdsBlockSizePerPadB = 0 else: LdsBlockSizePerPadB = 0 @@ -1896,6 +1917,10 @@ def calcLdsBlockSizePerPad() -> int: return LdsBlockSizePerPadA, LdsBlockSizePerPadB def calcLdsNumBytes(ldsPadA: int, LdsBlockSizePerPadA: int, ldsPadB: int, LdsBlockSizePerPadB: int) -> int: + if state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]: + assert state["ProblemType"]["MacDataTypeA"] == state["ProblemType"]["MacDataTypeB"], \ + "Matrix A and B must have the same MX data types (mixed data types not supported yet)." + bpeA = state["ProblemType"]["DataTypeA"].numBytes() if state["ConvertAfterDS"] else state["ProblemType"]["MacDataTypeA"].numBytes() bpeB = state["ProblemType"]["DataTypeB"].numBytes() if state["ConvertAfterDS"] else state["ProblemType"]["MacDataTypeB"].numBytes() ldsAlign = int(64 / state["ProblemType"]["MacDataTypeA"].numRegisters()) @@ -1944,13 +1969,19 @@ def calcLdsNumBytes(ldsPadA: int, LdsBlockSizePerPadA: int, ldsPadB: int, LdsBlo ldsNumBytesMetadata = 0 ldsNumBytesAlignedMetadata = 0 + if not state["ProblemType"]["MXBlockA"]: + assert not state["ProblemType"]["MXBlockB"], \ + "A and B must be both MX data types or non-MX data types." + return ldsNumBytesA, ldsNumBytesAlignedA, ldsNumBytesB, ldsNumBytesAlignedB, ldsNumBytesMetadata, ldsNumBytesAlignedMetadata, \ + 0, 0, 0, 0 + if state["ProblemType"]["MXBlockA"]: ldsAlign = 64 ldsNumBytesMXSA = state["_DepthUMXSA"] * state["MacroTileA"] ldsNumBytesAlignedMXSA = roundUpToNearestMultiple(ldsNumBytesMXSA, ldsAlign) else: ldsNumBytesMXSA = 0 - ldsNumBytesAlignedMXSA = 0; + ldsNumBytesAlignedMXSA = 0 if state["ProblemType"]["MXBlockB"]: ldsAlign = 64 @@ -1969,109 +2000,147 @@ def calcLdsNumBytes(ldsPadA: int, LdsBlockSizePerPadA: int, ldsPadB: int, LdsBlo if state["LocalReadVectorWidthB"] == -1: state["LocalReadVectorWidthB"] = state["LocalReadVectorWidth"] - # Default LocalReadVectorWidth - if state["EnableMatrixInstruction"]: - autoLRVWA = 0 - if state["LocalReadVectorWidthA"] != -1: - tmplrvw = (state["LocalReadVectorWidthA"] // 2) if state["ProblemType"]["Sparse"] else state["LocalReadVectorWidthA"] - if tmplrvw * state["ProblemType"]["MacDataTypeA"].numRegisters() < 1: - reject(state, "LocalReadVectorWidth * dataRegister < 1") - if state["LocalReadVectorWidthA"] > state["MIInputPerThread"] and not state["TransposeLDS"]: - reject(state, "LocalReadVectorWidth require Transpose LDS") - else: - if state["ProblemType"]["MacDataTypeA"].is6bitFloat(): - state["LocalReadVectorWidthA"] = 32 if state["UnrollMajorLDSA"] else 16 + + # This function calculates LRVW by separating MX and non-MX types + def calLRVWForMX() -> int: + # Determine if we need to infer LRVW. Returns True if, + # - state["LocalReadVectorWidth{tc}"] is -1, and, + # - state["ProblemType"]["MacDataType{tc}"] is not 6-bit float + # If the LRVW is set by the user, validate the configuration and rejects if, + # - state["LocalReadVectorWidth{tc}"] * state["ProblemType"]["MacDataType{tc}"].numRegisters() < 1 if not sparse + # - state["LocalReadVectorWidth{tc}"] // 2 * state["ProblemType"]["MacDataType{tc}"].numRegisters() < 1 is sparse + # - state["LocalReadVectorWidth{tc}"] > state["MIInputPerThread"] and LDS is not transposed + def isAutoLRVW(tc) -> bool: + autoLRVW = False + if state[f"LocalReadVectorWidth{tc}"] != -1: + tmplrvw = (state[f"LocalReadVectorWidth{tc}"] // 2) if state["ProblemType"]["Sparse"] else state[f"LocalReadVectorWidth{tc}"] + if tmplrvw * state["ProblemType"][f"MacDataType{tc}"].numRegisters() < 1: + reject(state, "LocalReadVectorWidth * dataRegister < 1") + if state[f"LocalReadVectorWidth{tc}"] > state["MIInputPerThread"] and not state["TransposeLDS"]: + reject(state, "LocalReadVectorWidth require Transpose LDS") else: - autoLRVWA = 1 - if state["TransposeLDS"] and (not state["DirectToLds"]): - state["LocalReadVectorWidthA"] = int(16 // state["ProblemType"]["MacDataTypeA"].numBytes()) + if state["ProblemType"][f"MacDataType{tc}"].is6bitFloat(): + state[f"LocalReadVectorWidth{tc}"] = 32 if state[f"UnrollMajorLDS{tc}"] else 16 else: - if state["ProblemType"]["Sparse"] and state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeA"].numBytes() > 16: - state["LocalReadVectorWidthA"] = int(16 // state["ProblemType"]["MacDataTypeA"].numBytes()) + autoLRVW = True + if state["TransposeLDS"] and (not state[f"DirectToLds{tc}"]): + state[f"LocalReadVectorWidth{tc}"] = int(16 // state["ProblemType"][f"MacDataType{tc}"].numBytes()) else: - state["LocalReadVectorWidthA"] = state["MIInputPerThread"] + if state["ProblemType"]["Sparse"] and state["MIInputPerThread"] * state["ProblemType"][f"MacDataType{tc}"].numBytes() > 16: + state[f"LocalReadVectorWidth{tc}"] = int(16 // state["ProblemType"][f"MacDataType{tc}"].numBytes()) + else: + state[f"LocalReadVectorWidth{tc}"] = state["MIInputPerThread"] + if state[f"LocalReadVectorWidth{tc}"] // state["MIInputPerThread"] > 1: + if (state["DepthU"] // state["MatrixInstK"] <= state[f"LocalReadVectorWidth{tc}"] // state["MIInputPerThread"]): + # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) + state[f"LocalReadVectorWidth{tc}"] //= 2 + return autoLRVW - if state["LocalReadVectorWidthA"] // state["MIInputPerThread"] > 1: - if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidthA"] // state["MIInputPerThread"]): - # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) - state["LocalReadVectorWidthA"] //= 2 - - autoLRVWB = 0 - if state["LocalReadVectorWidthB"] != -1: - tmplrvw = (state["LocalReadVectorWidthB"] // 2) if state["ProblemType"]["Sparse"] else state["LocalReadVectorWidthB"] - if tmplrvw * state["ProblemType"]["MacDataTypeB"].numRegisters() < 1: - reject(state, "LocalReadVectorWidth * dataRegister < 1") - if state["LocalReadVectorWidthB"] > state["MIInputPerThread"] and not state["TransposeLDS"]: - reject(state, "LocalReadVectorWidth require Transpose LDS") + if state["EnableMatrixInstruction"]: + autoLRVWA = isAutoLRVW("A") + autoLRVWB = isAutoLRVW("B") + if autoLRVWA or autoLRVWB: + wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) + wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) + + if (wlrA > 1) or (wlrB > 1): + padA, padB, padM = calcLdsPad(state["LocalReadVectorWidth"], isaInfoMap) + ldsBlockSizePerPadA, ldsBlockSizePerPadB = calcLdsBlockSizePerPad(state["LocalReadVectorWidth"]) + ldsNumBytesA, ldsNumBytesAlignedA, ldsNumBytesB, ldsNumBytesAlignedB, ldsNumBytesMetadata, ldsNumBytesAlignedMetadata, \ + ldsNumBytesMXSA, ldsNumBytesAlignedMXSA, ldsNumBytesMXSB, ldsNumBytesAlignedMXSB \ + = calcLdsNumBytes(padA, ldsBlockSizePerPadA, padB, ldsBlockSizePerPadB) + ldsNumBytes = ldsNumBytesAlignedA + ldsNumBytesAlignedB + \ + ldsNumBytesAlignedMXSA + ldsNumBytesAlignedMXSB + \ + ldsNumBytesAlignedMetadata + if ldsNumBytes > state["MaxLDS"]: + if wlrA > 1: + state["LocalReadVectorWidthA"] //= 2 + if wlrB > 1: + state["LocalReadVectorWidthB"] //= 2 + + wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) + wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) + if wlrA > wlrB: + state["LocalReadVectorWidthA"] = wlrB * state["MIInputPerThread"] + if wlrA < wlrB: + state["LocalReadVectorWidthB"] = wlrA * state["MIInputPerThread"] + + if state["ProblemType"]["Sparse"] == 1: + state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthA"] + elif state["ProblemType"]["Sparse"] == 2: + state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthB"] + + if state["ProblemType"]["MXBlockA"]: + state["LocalReadVectorWidthMXSA"] = 1 # TODO: check if need to fomulization + if state["ProblemType"]["MXBlockB"]: + state["LocalReadVectorWidthMXSB"] = 1 # TODO: check if need to fomulization else: - if state["ProblemType"]["MacDataTypeB"].is6bitFloat(): - state["LocalReadVectorWidthB"] = 32 if state["UnrollMajorLDSB"] else 16 - else: - autoLRVWB = 1 - if state["TransposeLDS"] and (not state["DirectToLds"]): - state["LocalReadVectorWidthB"] = int(16 // state["ProblemType"]["MacDataTypeB"].numBytes()) - else: - if state["ProblemType"]["Sparse"] and state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeB"].numBytes() > 16: - state["LocalReadVectorWidthB"] = int(16 // state["ProblemType"]["MacDataTypeB"].numBytes()) - else: - state["LocalReadVectorWidthB"] = state["MIInputPerThread"] + assert False, "expecting MFMA for MX datatypes" - if state["LocalReadVectorWidthB"] // state["MIInputPerThread"] > 1: - if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidthB"] // state["MIInputPerThread"]): + def calLRVWForNonMX() -> int: + if state["EnableMatrixInstruction"]: + autoLRVW = 0 + if state["LocalReadVectorWidth"] == -1: + autoLRVW = 1 + if state["TransposeLDS"] or (state["MIInputPerThread"] * state["ProblemType"]["DataType"].numBytes() > 16): + state["LocalReadVectorWidth"] = 16 // state["ProblemType"]["DataType"].numBytes() + else: + state["LocalReadVectorWidth"] = state["MIInputPerThread"] + else: + if state["ProblemType"]["Sparse"] and state["MIInputPerThread"] * state["ProblemType"]["DataType"].numBytes() > 16: + if state["LocalReadVectorWidth"] < state["MIInputPerThread"] // 2: + reject(state, printRejectionReason, "LocalReadVectorWidth < %u" %(state["MIInputPerThread"] // 2)) + elif not state["ProblemType"]["Sparse"] and not state["UseF32XEmulation"] and not(state["ProblemType"]["DataType"].is8bitFloat() and (state["MatrixInstK"] == 64 or state["MatrixInstK"] == 128)): + if state["LocalReadVectorWidth"] < state["MIInputPerThread"]: + reject(state, printRejectionReason, "LocalReadVectorWidth < %u" %(state["MIInputPerThread"])) + if state["LocalReadVectorWidth"] > state["MIInputPerThread"] and not state["TransposeLDS"]: + reject(state, printRejectionReason, "LocalReadVectorWidth require Transpose LDS") + + if autoLRVW: + if state["LocalReadVectorWidth"] // state["MIInputPerThread"] > 1: + if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidth"] // state["MIInputPerThread"]): # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) - state["LocalReadVectorWidthB"] //= 2 - - if autoLRVWA or autoLRVWB: - wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) - wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) - - if (wlrA > 1) or (wlrB > 1): - padA, padB, padM = calcLdsPad(isaInfoMap) - ldsBlockSizePerPadA, ldsBlockSizePerPadB = calcLdsBlockSizePerPad() - ldsNumBytesA, ldsNumBytesAlignedA, ldsNumBytesB, ldsNumBytesAlignedB, ldsNumBytesMetadata, ldsNumBytesAlignedMetadata, \ - ldsNumBytesMXSA, ldsNumBytesAlignedMXSA, ldsNumBytesMXSB, ldsNumBytesAlignedMXSB = calcLdsNumBytes(padA, ldsBlockSizePerPadA, padB, ldsBlockSizePerPadB) - if (ldsNumBytesAlignedA + ldsNumBytesAlignedB) > state["MaxLDS"]: - if wlrA > 1: - state["LocalReadVectorWidthA"] //= 2 - if wlrB > 1: - state["LocalReadVectorWidthB"] //= 2 - - wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) - wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) - if wlrA > wlrB: - state["LocalReadVectorWidthA"] = wlrB // state["MIInputPerThread"] - if wlrA < wlrB: - state["LocalReadVectorWidthB"] = wlrA // state["MIInputPerThread"] - - if state["ProblemType"]["Sparse"] == 1: - state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthA"] - elif state["ProblemType"]["Sparse"] == 2: - state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthB"] + state["LocalReadVectorWidth"] //= 2 + if state["LocalReadVectorWidth"] // state["MIInputPerThread"] > 1: + padA, padB, padM = calcLdsPad(state["LocalReadVectorWidth"], isaInfoMap) + ldsBlockSizePerPadA, ldsBlockSizePerPadB = calcLdsBlockSizePerPad(state["LocalReadVectorWidth"]) + ldsBlockSizePerPadA = 0 if padA == 0 else ldsBlockSizePerPadA + ldsBlockSizePerPadB = 0 if padB == 0 else ldsBlockSizePerPadB + ldsNumBytesA, ldsNumBytesAlignedA, ldsNumBytesB, ldsNumBytesAlignedB, ldsNumBytesMetadata, ldsNumBytesAlignedMetadata, \ + ldsNumBytesMXSA, ldsNumBytesAlignedMXSA, ldsNumBytesMXSB, ldsNumBytesAlignedMXSB \ + = calcLdsNumBytes(padA, ldsBlockSizePerPadA, padB, ldsBlockSizePerPadB) + if (ldsNumBytesAlignedA + ldsNumBytesAlignedB) > state["MaxLDS"]: + state["LocalReadVectorWidth"] //= 2 + + if state["enableGLTrA"] or state["enableGLTrB"]: + state["LocalReadVectorWidth"] = 8 + else: + if state["UseDotInstruction"]: + # dot2: LRVW should be equal to NumDotElements * InnerUnroll + if state["LocalReadVectorWidth"] not in [-1, state["NumDotElements"] * state["InnerUnroll"]]: + reject(state, printRejectionReason, "dot kernel requires LocalReadVectorWidth = NumDotElements(%u) * InnerUnroll(%u)" \ + % (state["NumDotElements"], state["InnerUnroll"])) + return + state["LocalReadVectorWidth"] = state["NumDotElements"] * state["InnerUnroll"] + else: # mac + if state["LocalReadVectorWidth"] == -1: + state["LocalReadVectorWidth"] = state["VectorWidthA"] + if state["LocalReadVectorWidth"] != state["VectorWidthA"] or \ + state["LocalReadVectorWidth"] != state["VectorWidthB"]: + reject(state, printRejectionReason, "LocalReadVectorWidth must equal VectorWidthA/B for MAC kernels") - if state["ProblemType"]["MXBlockA"]: - state["LocalReadVectorWidthMXSA"] = 1 # TODO: check if need to fomulization - if state["ProblemType"]["MXBlockB"]: - state["LocalReadVectorWidthMXSB"] = 1 # TODO: check if need to fomulization + + # Default LocalReadVectorWidth + if state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]: + calLRVWForMX() else: - if state["UseDotInstruction"]: - # dot2: LRVW should be equal to NumDotElements * InnerUnroll - if state["LocalReadVectorWidth"] not in [-1, state["NumDotElements"] * state["InnerUnroll"]]: - reject(state, printRejectionReason, "dot kernel requires LocalReadVectorWidth = NumDotElements(%u) * InnerUnroll(%u)" \ - % (state["NumDotElements"], state["InnerUnroll"])) - return - state["LocalReadVectorWidth"] = state["NumDotElements"] * state["InnerUnroll"] - state["LocalReadVectorWidthA"] = state["NumDotElements"] * state["InnerUnroll"] - state["LocalReadVectorWidthB"] = state["NumDotElements"] * state["InnerUnroll"] - else: - if state["LocalReadVectorWidth"] == -1: - state["LocalReadVectorWidth"] = state["VectorWidthA"] - if state["LocalReadVectorWidthA"] == -1: - state["LocalReadVectorWidthA"] = state["VectorWidthA"] - if state["LocalReadVectorWidthB"] == -1: - state["LocalReadVectorWidthB"] = state["VectorWidthB"] - if state["LocalReadVectorWidth"] != state["VectorWidthA"] or \ - state["LocalReadVectorWidth"] != state["VectorWidthB"]: - reject(state, printRejectionReason, "LocalReadVectorWidth must equal VectorWidthA/B for MAC kernels") + calLRVWForNonMX() + # We still need to set LocalReadVectorWidthA & LocalReadVectorWidthB + # because subsequent might be looking at them not LocalReadVectorWidth + state["LocalReadVectorWidthA"] = state["LocalReadVectorWidth"] + state["LocalReadVectorWidthB"] = state["LocalReadVectorWidth"] + + def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int: # with UnrollMajorLDS, GRVW need to less or equal than LRVW to have conflict free LDS read with padding. @@ -2436,13 +2505,6 @@ def calSwizzlePackK(state, tc): if not Solution.setGlobalReadVectorWidth(state, "Metadata", tvm, glvwMlimit, printRejectionReason): validDepthU = False - if validDepthU and state["KernelLanguage"] == "Assembly": - if isaInfoMap[isa].archCaps["HasEccHalf"]: - if state["ProblemType"]["MacDataTypeA"].numRegisters() == 0.5 and (not state["ProblemType"]["HighPrecisionAccumulate"]): - if state["GlobalReadVectorWidthA"] == 1 or state["GlobalReadVectorWidthB"] == 1: - reject(state, "HalfEcc requires HPA if glvw = 1") - break - if state["ProblemType"]["Sparse"] and state["DirectToVgprSparseMetadata"]: if state["VectorWidthA"] > 1 or state["VectorWidthB"] > 1 : reject(state, printRejectionReason, "Not implement DTVSM with VW>1") @@ -2885,7 +2947,7 @@ def calSwizzlePackK(state, tc): auto_LdsBlockSizePerPadB_for_mix = 0 if state["LdsBlockSizePerPadB"] == -1: auto_LdsBlockSizePerPadB_for_mix = 1 - state["LdsBlockSizePerPadA"], state["LdsBlockSizePerPadB"] = calcLdsBlockSizePerPad() + state["LdsBlockSizePerPadA"], state["LdsBlockSizePerPadB"] = calcLdsBlockSizePerPad(-1) # for MX datatypes, the lrvw argument is ignored if state["LdsBlockSizePerPadMetadata"] == -1: state["LdsBlockSizePerPadMetadata"] = state["LdsBlockSizePerPadA"] @@ -3100,7 +3162,7 @@ def subCheckLdsBlockSizePerPad(tc, idx): state["NoLdsWriteCode"] = False # calculate ldsPad - state["LdsPadA"], state["LdsPadB"], state["LdsPadMetadata"] = calcLdsPad(isaInfoMap) + state["LdsPadA"], state["LdsPadB"], state["LdsPadMetadata"] = calcLdsPad(state["LocalReadVectorWidth"], isaInfoMap) if state["GlobalReadVectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() == 32 and state["LdsPadA"] == 16 // state["ProblemType"]["MacDataTypeA"].numBytes(): if auto_LdsBlockSizePerPadA_for_mix: @@ -3117,11 +3179,20 @@ def subCheckLdsBlockSizePerPad(tc, idx): if state["LdsPad%s"%tc] == 0: state["LdsBlockSizePerPad%s"%tc] = 0 + # Normalize lds block-size-per-pad fields to native Python int. + assert(int(state["LdsBlockSizePerPadA"]) == state["LdsBlockSizePerPadA"]) + assert(int(state["LdsBlockSizePerPadB"]) == state["LdsBlockSizePerPadB"]) + assert(int(state["LdsBlockSizePerPadMetadata"]) == state["LdsBlockSizePerPadMetadata"]) + state["LdsBlockSizePerPadA"] = int(state["LdsBlockSizePerPadA"]) + state["LdsBlockSizePerPadB"] = int(state["LdsBlockSizePerPadB"]) + state["LdsBlockSizePerPadMetadata"] = int(state["LdsBlockSizePerPadMetadata"]) + if (state["UnrollMajorLDSA"] or state["UnrollMajorLDSB"]) and (not state["EnableMatrixInstruction"]) and (not state["UseDotInstruction"]): reject(state, printRejectionReason, "UnrollMajorLDS Supports only in EnableMatrixInstruction=1 or dot2 kernel") ldsNumBytesA, ldsNumBytesAlignedA, ldsNumBytesB, ldsNumBytesAlignedB, ldsNumBytesMetadata, ldsNumBytesAlignedMetadata, \ - ldsNumBytesMXSA, ldsNumBytesAlignedMXSA, ldsNumBytesMXSB, ldsNumBytesAlignedMXSB = calcLdsNumBytes(state["LdsPadA"], state["LdsBlockSizePerPadA"], state["LdsPadB"], state["LdsBlockSizePerPadB"]) + ldsNumBytesMXSA, ldsNumBytesAlignedMXSA, ldsNumBytesMXSB, ldsNumBytesAlignedMXSB \ + = calcLdsNumBytes(state["LdsPadA"], state["LdsBlockSizePerPadA"], state["LdsPadB"], state["LdsBlockSizePerPadB"]) state["LdsOffsetA_Blk"] = 0 state["LdsOffsetB_Blk"] = 0 @@ -3144,14 +3215,18 @@ def subCheckLdsBlockSizePerPad(tc, idx): if state["PrefetchGlobalRead"]: offsetBlk = state["LdsOffsetB"] + state["LdsNumElementsAlignedB"] + + # Separate MX and non-MX, otherwise, will hit LDS overflow issue. # TODO: # Disable StoreSwapAddr to ensure LdsOffsetA_Blk is always a power of 2 # This is consistent with referenc implementation which doesn't have StoreSwapAddr - state["StoreSwapAddr"] = False - # Original logic (disabled): - # state["StoreSwapAddr"] = (state["PrefetchGlobalRead"] == 2) and \ - # (state["1LDSBuffer"] == 0) and \ - # (offsetBlk + int(2**(math.ceil(math.log(offsetBlk, 2)))) > state["MaxLDS"]) + if state["ProblemType"]["MXBlockA"]: + state["StoreSwapAddr"] = False + else: + # Original logic: + state["StoreSwapAddr"] = (state["PrefetchGlobalRead"] == 2) and \ + (state["1LDSBuffer"] == 0) and \ + (offsetBlk + int(2**(math.ceil(math.log(offsetBlk, 2)))) > state["MaxLDS"]) if offsetBlk > 0 and not state["StoreSwapAddr"]: # Rounds offsetBlk to a power of two to enable inlining {s,v}_xor constants for swapping offsets diff --git a/projects/hipblaslt/tensilelite/client/include/TensorDataManipulation.hpp b/projects/hipblaslt/tensilelite/client/include/TensorDataManipulation.hpp index ee93894b7729..cfe1d6640299 100644 --- a/projects/hipblaslt/tensilelite/client/include/TensorDataManipulation.hpp +++ b/projects/hipblaslt/tensilelite/client/include/TensorDataManipulation.hpp @@ -175,10 +175,10 @@ namespace Tensor return Tensor(shape, sizeof(T)); } - Tensor(const Shape shape, float elementSize) + Tensor(const Shape shape, size_t elementSize) : desc(shape) , elementSize(elementSize) - , data(new char[TensileLite::multiplyElementSize(desc.flattenSize(), elementSize)]) + , data(new char[elementSize * desc.flattenSize()]) { } @@ -241,7 +241,7 @@ namespace Tensor size_t getNumBytes() const { - return TensileLite::multiplyElementSize(getDesc().flattenSize(), getElementSize()); + return getDesc().flattenSize() * getElementSize(); } void reshape(const Shape& shape) @@ -255,7 +255,7 @@ namespace Tensor } private: - float elementSize{}; + size_t elementSize{}; TensorDesc desc; std::unique_ptr data; }; diff --git a/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp b/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp index 804f626476d1..a0bc69947df5 100644 --- a/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp +++ b/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp @@ -741,14 +741,14 @@ namespace TensileLite size_t totalElements, hipMemcpyKind kind) { + auto const info = DataTypeInfo::Get(descriptor.dataType()); HIP_CHECK_EXC( hipMemcpy(dst, src, - multiplyElementSize(totalElements, - DataTypeInfo::Get(descriptor.dataType()).elementSize), + totalElements * info.elementSize / info.packing, kind)); ptrdiff_t dPadding = totalElements - descriptor.totalAllocatedElements(); - dPadding = multiplyElementSize(dPadding, descriptor.elementBytes()); + dPadding *= descriptor.elementBytes(); void* dstOffset = (void*)((uint8_t*)dst + dPadding / 2); TensileLite::hip::CopyTensorVoid(dstOffset, src, descriptor, kind); return dstOffset; @@ -767,8 +767,7 @@ namespace TensileLite const size_t numElementsToCopy = (customPadding == -1) ? descriptor.totalAllocatedElements() : (descriptor.totalAllocatedElements() + customPadding); - uint8_t* dstOffset - = (uint8_t*)dst + multiplyElementSize(dPadding, descriptor.elementBytes()); + uint8_t* dstOffset = (uint8_t*)dst + dPadding * descriptor.elementBytes(); HIP_CHECK_EXC( hipMemcpy(dstOffset, src, diff --git a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp index a00222a9e573..c5ec3ac848d7 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp @@ -67,7 +67,7 @@ namespace rocisa */ std::string TypeAbbrev(rocisa::DataType d); - float GetElementSize(rocisa::DataType d); + size_t GetElementSize(rocisa::DataType d); std::ostream& operator<<(std::ostream& stream, rocisa::DataType const& t); std::istream& operator>>(std::istream& stream, rocisa::DataType& t); @@ -90,7 +90,7 @@ namespace TensileLite std::string name; std::string abbrev; - float elementSize; + size_t elementSize; size_t packing; size_t segmentSize; @@ -129,11 +129,13 @@ namespace TensileLite constexpr static rocisa::DataType Enum = T_Enum; /// Bytes of one element. May contain multiple segments. - constexpr static float ElementSize = float(sizeof(T)) / float(T_Packing); + constexpr static size_t ElementSize = sizeof(T); /// Segments per element. constexpr static size_t Packing = T_Packing; /// Bytes per segment. - constexpr static float SegmentSize = ElementSize / Packing; + /// TODO: this needs to be enhanced as the value would be + /// 0 for MX data type, FP4: ElementSize=1 byte, Packing=2. + constexpr static size_t SegmentSize = ElementSize / Packing; constexpr static bool IsComplex = T_IsComplex; constexpr static bool IsIntegral = T_IsIntegral; @@ -159,7 +161,7 @@ namespace TensileLite int T_Packing, bool T_IsComplex, bool T_IsIntegral> - constexpr float BaseTypeInfo::ElementSize; + constexpr size_t BaseTypeInfo::ElementSize; template - constexpr float BaseTypeInfo::SegmentSize; + constexpr size_t BaseTypeInfo::SegmentSize; template Float4x2, + // Float6 -> Float6x32. As a result, return elementSize is + // incorrect for MX data types because elementSize represents + // unsegmented size in bytes not segment size. + // + // To get element size (in bytes) for f4/f6/bf6, use + // + // auto const info = DataTypeInfo::Get(m_dataType); + // auto elementSize = info.elementSize / info.packing + // + // tensileLite returns sizeof(Float4x2), sizeof(Float6x32), + // sizeof(BFloat6x32) for rocisa::f4,f6,bf6. + // + assert(m_dataType != rocisa::DataType::Float6 && + m_dataType != rocisa::DataType::BFloat6 && + m_dataType != rocisa::DataType::Float4); return DataTypeInfo::Get(m_dataType).elementSize; } @@ -509,14 +531,14 @@ namespace TensileLite { coord[0] = 0; - auto const* localPtr = data + (desc.index(coord) / TypeInfo::Packing); + auto const* localPtr = data + desc.index(coord); if(sizes[0] > 0) stream << localPtr[0]; - for(coord[0] = TypeInfo::Packing; coord[0] < sizes[0]; coord[0]+=TypeInfo::Packing) + for(coord[0] = 1; coord[0] < sizes[0]; coord[0]++) { - stream << " " << localPtr[coord[0] * stride0 / TypeInfo::Packing]; + stream << " " << localPtr[coord[0] * stride0]; } stream << std::endl; @@ -524,7 +546,7 @@ namespace TensileLite if(decorated) { - stream << "]" << std::endl; + stream << std::endl << "]" << std::endl; } } } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/hip/HipUtils.hpp b/projects/hipblaslt/tensilelite/include/Tensile/hip/HipUtils.hpp index cd8bd2dbe645..5c5ed02a5eee 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/hip/HipUtils.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/hip/HipUtils.hpp @@ -135,7 +135,7 @@ namespace TensileLite size_t maxStride = *std::max_element(strides.begin(), strides.begin() + contiguousDimensions); - size_t copyBytes = multiplyElementSize(maxStride * sizes.at(contiguousDimensions - 1), desc.elementBytes()); + size_t copyBytes = maxStride * sizes.at(contiguousDimensions - 1) * desc.elementBytes(); for(size_t idx = 0; idx < copyCount; idx++) { @@ -146,7 +146,7 @@ namespace TensileLite sizes.end()); auto beginOffset = desc.index(coord); - size_t bytesOffset = multiplyElementSize(beginOffset, desc.elementBytes()); + size_t bytesOffset = beginOffset * desc.elementBytes(); uint8_t* dstBytes = (uint8_t*)dst + bytesOffset; uint8_t* srcBytes = (uint8_t*)dst + bytesOffset; diff --git a/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp b/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp index 5756428420e9..38a23df6d186 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp @@ -1212,10 +1212,16 @@ namespace TensileLite gflop += 2 * cSize * 1e-9; // Include (+ beta * C) in gflops cSize *= 2; // Include read C and write D in gbytes } + // TODO: for MX data types, the size is smaller than a byte + // so we need to use (elementSize/packing) to derive the actual + // byte size of a segment. + auto infoA = DataTypeInfo::Get(a().dataType()); + auto infoB = DataTypeInfo::Get(b().dataType()); + auto infoC = DataTypeInfo::Get(c().dataType()); double gbyte - = (multiplyElementSize(aSize, a().elementBytes()) + - multiplyElementSize(bSize, b().elementBytes()) + - multiplyElementSize(cSize, c().elementBytes())) + = ((aSize * infoA.elementSize / infoA.packing) + + (bSize * infoB.elementSize / infoB.packing) + + (cSize * infoC.elementSize / infoC.packing)) * 1e-9; m_arithmeticIntensity = gflop / gbyte; diff --git a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp index 493f7ce8a6f6..03b496267637 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp @@ -2951,26 +2951,35 @@ namespace TensileLite auto cInfo = DataTypeInfo::Get(problemType.cType); auto dInfo = DataTypeInfo::Get(problemType.dType); - spm.memReadBytesA = multiplyElementSize((NumBatches * M * N * K) / MT1, aInfo.elementSize); - spm.memReadBytesB = multiplyElementSize((NumBatches * M * N * K) / MT0, bInfo.elementSize); - spm.memReadBytesC = multiplyElementSize((NumBatches * M * N) * betaReads, cInfo.elementSize); + // TODO: For MX data types, their size is smaller than a bytes, so we can't use + // segmentSize which would be 0. This needs to be enhanaced. + assert( ((NumBatches * M * N * K) * aInfo.elementSize / aInfo.packing) % MT1 == 0); + assert( ((NumBatches * M * N * K) * bInfo.elementSize / bInfo.packing) % MT0 == 0); + spm.memReadBytesA = (NumBatches * M * N * K) * aInfo.elementSize / aInfo.packing / MT1; + spm.memReadBytesB = (NumBatches * M * N * K) * bInfo.elementSize / bInfo.packing / MT0; + spm.memReadBytesC = (NumBatches * M * N) * betaReads * cInfo.elementSize / cInfo.packing; if(GlobalSplitU == 1) - spm.memWriteBytesD = multiplyElementSize((NumBatches * M * N) * (1 + betaWrites), dInfo.elementSize); + { + // Use segmentSize as D currently does not support MX data types. + spm.memWriteBytesD = (NumBatches * M * N) * (1 + betaWrites) * dInfo.elementSize; + } else { bool hardwareAtomic = false; // TODO-model double atomicOperations = hardwareAtomic ? 2 : 3; // read-mod-write or cas //TODO-model double atomicCollisions = 1.0; // TODO-could be based on K, GSU - spm.memWriteBytesD = multiplyElementSize((NumBatches * M * N) + // Use segmentSize as D currently does not support MX data types. + spm.memWriteBytesD = (NumBatches * M * N) * (betaWrites + atomicOperations * atomicCollisions) - , dInfo.elementSize); + * dInfo.segmentSize; } spm.memReadBytes = spm.memReadBytesA + spm.memReadBytesB + spm.memReadBytesC; - spm.memGlobalReads = divideElementSize(spm.memReadBytesA, aInfo.elementSize) - + divideElementSize(spm.memReadBytesB, bInfo.elementSize) - + divideElementSize(spm.memReadBytesC, cInfo.elementSize); - spm.memGlobalWrites = divideElementSize(spm.memWriteBytesD, dInfo.elementSize); + + spm.memGlobalReads = spm.memReadBytesA * aInfo.packing / aInfo.elementSize + + spm.memReadBytesB * bInfo.packing / bInfo.elementSize + + spm.memReadBytesC * cInfo.packing / cInfo.elementSize; + spm.memGlobalWrites = spm.memWriteBytesD / dInfo.segmentSize; return spm; } @@ -3115,7 +3124,9 @@ namespace TensileLite if(problemType.outputAmaxD) { auto numWGS = getNumWorkGroups(problem, sizeMapping); - size += multiplyElementSize(numWGS, problem.amaxd().elementBytes()); + // Use elementBytes instead of segmentSize beceause D currently + // does not support sub-byte data types + size += numWGS * problem.amaxd().elementBytes(); } return size; diff --git a/projects/hipblaslt/tensilelite/src/DataTypes.cpp b/projects/hipblaslt/tensilelite/src/DataTypes.cpp index 7185131b337c..02f06e9a4438 100644 --- a/projects/hipblaslt/tensilelite/src/DataTypes.cpp +++ b/projects/hipblaslt/tensilelite/src/DataTypes.cpp @@ -87,7 +87,7 @@ namespace rocisa return "Invalid"; } - float GetElementSize(rocisa::DataType d) + size_t GetElementSize(rocisa::DataType d) { switch(d) { diff --git a/projects/hipblaslt/tensilelite/tests/RangeLibrary_test.cpp b/projects/hipblaslt/tensilelite/tests/RangeLibrary_test.cpp index 0619707b42ba..652fb6d77b95 100644 --- a/projects/hipblaslt/tensilelite/tests/RangeLibrary_test.cpp +++ b/projects/hipblaslt/tensilelite/tests/RangeLibrary_test.cpp @@ -142,7 +142,8 @@ TEST_P(RangeLibraryTest, SpecificSizes) M, // ldd M*N, // strided 2.0); // beta - problem.setComputeInputType(rocisa::DataType::BFloat16); + problem.setComputeInputTypeA(rocisa::DataType::BFloat16); + problem.setComputeInputTypeB(rocisa::DataType::BFloat16); problem.setHighPrecisionAccumulate(true); problem.setWorkspaceSize(120324096);