Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
464ede1
Apply workaround to separate non-MX and MX by Tom Tang
NineKa Feb 5, 2026
3aa8800
Add temporary workarounds to allow Tensile to build. These workaround…
amd-chunxlin Feb 5, 2026
032906b
use setComputeInputTypeA instead of setComputeInputType
NineKa Feb 10, 2026
d4d1800
Fix elementSize issue (not sure if this is correct) and refactor code…
amd-chunxlin Feb 11, 2026
c7d5d40
remove manually set threads count for debugging
NineKa Feb 19, 2026
21d388d
revert changes in Tensile/Common/Parallel.py
NineKa Feb 19, 2026
223efa9
remove debug comments in Tensile/SolutionStructs/Solution.py
NineKa Feb 19, 2026
5cf7c51
normalize lds block-size-per-pad into int
NineKa Feb 19, 2026
3c1f207
format and remove commented lines of code
NineKa Feb 19, 2026
eedecd0
cleanup tensilelite/Tensile/KernelWriterAssembly.py
NineKa Feb 19, 2026
cff8a74
move cast to int to loadRangePerThread
NineKa Feb 19, 2026
a455dcb
Merge remote-tracking branch 'origin/gfx950_mx_rebase' into users/hon…
NineKa Feb 19, 2026
35b8195
Merge calcLdsNumBytesForNonMX and calcLdsNumBytesForMX
amd-chunxlin Feb 20, 2026
c946745
port changes for calcualting LDS padding for mx datatypes from #4683
NineKa Feb 20, 2026
681bb0b
Merge remote-tracking branch 'origin/gfx950_mx_rebase' into users/hon…
NineKa Feb 23, 2026
3632161
merge calcLdsForMX and calcLdsforNonMX
NineKa Feb 23, 2026
a05365e
Fix indentation issue for calculating LDS padding for B in calcLdsPad
NineKa Feb 23, 2026
7ea8fe3
reject larger than b192 loads for mx datatypes and b128 loads for non…
NineKa Feb 23, 2026
d2d9522
merge LdsBlockSizePerPadForMX and LdsBlockSizePerPadForNonMX
NineKa Feb 23, 2026
9bbc944
cleanup whitespaces
NineKa Feb 23, 2026
5ab315f
reject read larger than b128 for both mx and non-mx datatype
NineKa Feb 23, 2026
b7c8600
fix bad invocation for calcLdsPad and calcLdsBlockSizePerPad
NineKa Feb 23, 2026
3ef482e
fix bad invocation for calcLdsBlockSizePerPad for mx datatype
NineKa Feb 23, 2026
8b476e4
fix various issue in comments
NineKa Feb 24, 2026
f720d64
merge invocations of calcLdsNumBytes
NineKa Feb 24, 2026
f06129d
revert changes for state["DirectToLdsA"] and state["DirectToLdsB"]
NineKa Feb 24, 2026
75f268f
use numBytesA/B instead of state["ProblemType"]["MacDataTypeA/B"]
NineKa Feb 24, 2026
68028c2
enforce strideK to be integer
NineKa Feb 24, 2026
ea2fe94
merge check if auto LRVW logic into a single function
NineKa Feb 25, 2026
f976824
use DirectToLdsA/B instead of DirectToLds
NineKa Feb 25, 2026
af04bee
wlrA/B seems already been a ratio for MIInputPerThread
NineKa Feb 25, 2026
fb531be
expecting MFMA for MX datatypes
NineKa Feb 25, 2026
110ce44
consider lds consumption of A + B + MXA + MXB + Metadata
NineKa Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi
strideTile = 1 # DTV case. Actual stride will be applied later.

strideK = offsetK if umlds else (mt + LdsPad) * offsetK

# Workaround: StrideK might be a float value and causes function
# signature error later (the function expected an int)
if not isinstance(strideK, int):
Comment thread
NineKa marked this conversation as resolved.
Outdated
strideK = int(strideK) if strideK.is_integer() else strideK

if enableLDSTr:
if kernel["UseGeneralizedNLCOne%s"%tc] and perpStride > 1:
strideK = 8
Expand Down
6 changes: 3 additions & 3 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class StateValues:
bpr: int = 4 # all registers are 32bit
# default setup
# AB=DataType / Cexternal=DestDataType / Cinternal=Accumulation (MAC or MFMA)
bpeA: float = field(init=False)
bpeB: float = field(init=False)
bpeA: float = field(init=False) # this is a float because of sub-byte data types (f6, f4)
bpeB: float = field(init=False) # this is a float because of sub-byte data types (f6, f4)
bpeE: int = field(init=False)
# Cexternal = the "current" kernel output type,
# - default: the "current" kernel is a non-GSU-kernel,
Expand Down Expand Up @@ -4101,7 +4101,7 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB):
"""

if kernel["EnableMatrixInstruction"] and kernel["LocalReadVectorWidthA"] >= kernel["MIInputPerThread"]:
WLR = max(kernel["LocalReadVectorWidthA"]//kernel["MIInputPerThread"], 1)
WLR = int(max(kernel["LocalReadVectorWidthA"]//kernel["MIInputPerThread"], 1))
self.states.numItersPLR = kernel["PrefetchLocalRead"]%(kernel["LoopIters"]//WLR)
else:
self.states.numItersPLR = kernel["PrefetchLocalRead"]%(kernel["LoopIters"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12794,7 +12794,7 @@ def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \

if useBuffer:
rv = Module("Global Read")
mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, nt=nt, lds=lds)
mubuf = MUBUFModifiers(offen=True, offset12=int(offset), glc=glc, slc=slc, nt=nt, lds=lds)

# Nested buffer load implementation function for easy branching for soffset
def bufferLoadImpl(soffset):
Expand Down Expand Up @@ -12831,7 +12831,7 @@ def bufferLoadImpl(soffset):
dst = None if lds else vgpr(destVgpr, 4)
rv.add(BufferLoadB128(dst=dst, vaddr=addr0, saddr=addr1, \
soffset=soffset, mubuf=mubuf, comment=comment))
mubuf2 = MUBUFModifiers(offen=True, offset12=offset+16, glc=glc, slc=slc, nt=nt, lds=lds)
mubuf2 = MUBUFModifiers(offen=True, offset12=int(offset+16), glc=glc, slc=slc, nt=nt, lds=lds)
if isinstance(destVgpr, str):
dst2 = destVgpr + "+" + str(int(4))
elif isinstance(destVgpr, int):
Expand Down Expand Up @@ -12865,7 +12865,6 @@ def bufferLoadImpl(soffset):
dst = vgpr(destVgpr, rpv//4)
rv.add(BufferLoadB128(dst=dst, vaddr=addr0, saddr=addr1, \
soffset=soffset, mubuf=mubuf, comment=comment))

mubuf2 = MUBUFModifiers(offen=True, offset12=int(offset + bpl/4), glc=glc, slc=slc, nt=nt, lds=lds)
dst2 = destVgpr + "+" + str(int(rpv//4)) if isinstance(destVgpr, str) else int(destVgpr + int(rpv//4))

Expand Down
414 changes: 269 additions & 145 deletions projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ namespace Tensor
return Tensor(shape, sizeof(T));
}

Tensor(const Shape shape, float elementSize)
Tensor(const Shape shape, size_t elementSize)
: desc(shape)
, elementSize(elementSize)
, data(new char[TensileLite::multiplyElementSize(desc.flattenSize(), elementSize)])
, data(new char[elementSize * desc.flattenSize()])
{
}

Expand Down Expand Up @@ -241,7 +241,7 @@ namespace Tensor

size_t getNumBytes() const
{
return TensileLite::multiplyElementSize(getDesc().flattenSize(), getElementSize());
return getDesc().flattenSize() * getElementSize();
}

void reshape(const Shape& shape)
Expand All @@ -255,7 +255,7 @@ namespace Tensor
}

private:
float elementSize{};
size_t elementSize{};
TensorDesc desc;
std::unique_ptr<char[]> data;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,14 @@ namespace TensileLite
size_t totalElements,
hipMemcpyKind kind)
{
auto const info = DataTypeInfo::Get(descriptor.dataType());
HIP_CHECK_EXC(
hipMemcpy(dst,
src,
multiplyElementSize(totalElements,
DataTypeInfo::Get(descriptor.dataType()).elementSize),
totalElements * info.elementSize / info.packing,
kind));
ptrdiff_t dPadding = totalElements - descriptor.totalAllocatedElements();
dPadding = multiplyElementSize(dPadding, descriptor.elementBytes());
dPadding *= descriptor.elementBytes();
void* dstOffset = (void*)((uint8_t*)dst + dPadding / 2);
TensileLite::hip::CopyTensorVoid(dstOffset, src, descriptor, kind);
return dstOffset;
Expand All @@ -767,8 +767,7 @@ namespace TensileLite
const size_t numElementsToCopy
= (customPadding == -1) ? descriptor.totalAllocatedElements()
: (descriptor.totalAllocatedElements() + customPadding);
uint8_t* dstOffset
= (uint8_t*)dst + multiplyElementSize(dPadding, descriptor.elementBytes());
uint8_t* dstOffset = (uint8_t*)dst + dPadding * descriptor.elementBytes();
HIP_CHECK_EXC(
hipMemcpy(dstOffset,
src,
Expand Down
14 changes: 8 additions & 6 deletions projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace rocisa
*/

std::string TypeAbbrev(rocisa::DataType d);
float GetElementSize(rocisa::DataType d);
size_t GetElementSize(rocisa::DataType d);
std::ostream& operator<<(std::ostream& stream, rocisa::DataType const& t);
std::istream& operator>>(std::istream& stream, rocisa::DataType& t);

Expand All @@ -90,7 +90,7 @@ namespace TensileLite
std::string name;
std::string abbrev;

float elementSize;
size_t elementSize;
size_t packing;
size_t segmentSize;

Expand Down Expand Up @@ -129,11 +129,13 @@ namespace TensileLite
constexpr static rocisa::DataType Enum = T_Enum;

/// Bytes of one element. May contain multiple segments.
constexpr static float ElementSize = float(sizeof(T)) / float(T_Packing);
constexpr static size_t ElementSize = sizeof(T);
/// Segments per element.
constexpr static size_t Packing = T_Packing;
/// Bytes per segment.
constexpr static float SegmentSize = ElementSize / Packing;
/// TODO: this needs to be enhanced as the value would be
/// 0 for MX data type, FP4: ElementSize=1 byte, Packing=2.
constexpr static size_t SegmentSize = ElementSize / Packing;

constexpr static bool IsComplex = T_IsComplex;
constexpr static bool IsIntegral = T_IsIntegral;
Expand All @@ -159,7 +161,7 @@ namespace TensileLite
int T_Packing,
bool T_IsComplex,
bool T_IsIntegral>
constexpr float BaseTypeInfo<T, T_Enum, T_Packing, T_IsComplex, T_IsIntegral>::ElementSize;
constexpr size_t BaseTypeInfo<T, T_Enum, T_Packing, T_IsComplex, T_IsIntegral>::ElementSize;
template <typename T,
rocisa::DataType T_Enum,
int T_Packing,
Expand All @@ -171,7 +173,7 @@ namespace TensileLite
int T_Packing,
bool T_IsComplex,
bool T_IsIntegral>
constexpr float BaseTypeInfo<T, T_Enum, T_Packing, T_IsComplex, T_IsIntegral>::SegmentSize;
constexpr size_t BaseTypeInfo<T, T_Enum, T_Packing, T_IsComplex, T_IsIntegral>::SegmentSize;

template <typename T,
rocisa::DataType T_Enum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,33 @@ namespace TensileLite
}
size_t totalAllocatedBytes() const
{
return multiplyElementSize(totalAllocatedElements(), elementBytes());
}

float elementBytes() const
{
// Cannot use elementSize directly as elementSize is
// packed size for MX data types.
auto const info = DataTypeInfo::Get(m_dataType);
assert(totalAllocatedElements() * info.elementSize % info.packing == 0);
return totalAllocatedElements() * info.elementSize / info.packing;
}

size_t elementBytes() const
{
// TODO:Need to enhance this to support segment type in tensileLite.
// tensileLite currently maps MX data types to unsegmented
// types in DataTypeInfo, i.e., Float4 -> Float4x2,
// Float6 -> Float6x32. As a result, return elementSize is
// incorrect for MX data types because elementSize represents
// unsegmented size in bytes not segment size.
//
// To get element size (in bytes) for f4/f6/bf6, use
//
// auto const info = DataTypeInfo::Get(m_dataType);
// auto elementSize = info.elementSize / info.packing
//
// tensileLite returns sizeof(Float4x2), sizeof(Float6x32),
// sizeof(BFloat6x32) for rocisa::f4,f6,bf6.
//
assert(m_dataType != rocisa::DataType::Float6 &&
m_dataType != rocisa::DataType::BFloat6 &&
m_dataType != rocisa::DataType::Float4);
return DataTypeInfo::Get(m_dataType).elementSize;
}

Expand Down Expand Up @@ -509,22 +531,22 @@ namespace TensileLite
{
coord[0] = 0;

auto const* localPtr = data + (desc.index(coord) / TypeInfo<T>::Packing);
auto const* localPtr = data + desc.index(coord);

if(sizes[0] > 0)
stream << localPtr[0];

for(coord[0] = TypeInfo<T>::Packing; coord[0] < sizes[0]; coord[0]+=TypeInfo<T>::Packing)
for(coord[0] = 1; coord[0] < sizes[0]; coord[0]++)
{
stream << " " << localPtr[coord[0] * stride0 / TypeInfo<T>::Packing];
stream << " " << localPtr[coord[0] * stride0];
}

stream << std::endl;
}

if(decorated)
{
stream << "]" << std::endl;
stream << std::endl << "]" << std::endl;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ namespace TensileLite

size_t maxStride
= *std::max_element(strides.begin(), strides.begin() + contiguousDimensions);
size_t copyBytes = multiplyElementSize(maxStride * sizes.at(contiguousDimensions - 1), desc.elementBytes());
size_t copyBytes = maxStride * sizes.at(contiguousDimensions - 1) * desc.elementBytes();

for(size_t idx = 0; idx < copyCount; idx++)
{
Expand All @@ -146,7 +146,7 @@ namespace TensileLite
sizes.end());

auto beginOffset = desc.index(coord);
size_t bytesOffset = multiplyElementSize(beginOffset, desc.elementBytes());
size_t bytesOffset = beginOffset * desc.elementBytes();
uint8_t* dstBytes = (uint8_t*)dst + bytesOffset;
uint8_t* srcBytes = (uint8_t*)dst + bytesOffset;

Expand Down
12 changes: 9 additions & 3 deletions projects/hipblaslt/tensilelite/src/ContractionProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,10 +1212,16 @@ namespace TensileLite
gflop += 2 * cSize * 1e-9; // Include (+ beta * C) in gflops
cSize *= 2; // Include read C and write D in gbytes
}
// TODO: for MX data types, the size is smaller than a byte
// so we need to use (elementSize/packing) to derive the actual
// byte size of a segment.
auto infoA = DataTypeInfo::Get(a().dataType());
auto infoB = DataTypeInfo::Get(b().dataType());
auto infoC = DataTypeInfo::Get(c().dataType());
double gbyte
= (multiplyElementSize(aSize, a().elementBytes()) +
multiplyElementSize(bSize, b().elementBytes()) +
multiplyElementSize(cSize, c().elementBytes()))
= ((aSize * infoA.elementSize / infoA.packing) +
(bSize * infoB.elementSize / infoB.packing) +
(cSize * infoC.elementSize / infoC.packing))
* 1e-9;

m_arithmeticIntensity = gflop / gbyte;
Expand Down
33 changes: 22 additions & 11 deletions projects/hipblaslt/tensilelite/src/ContractionSolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2951,26 +2951,35 @@ namespace TensileLite
auto cInfo = DataTypeInfo::Get(problemType.cType);
auto dInfo = DataTypeInfo::Get(problemType.dType);

spm.memReadBytesA = multiplyElementSize((NumBatches * M * N * K) / MT1, aInfo.elementSize);
spm.memReadBytesB = multiplyElementSize((NumBatches * M * N * K) / MT0, bInfo.elementSize);
spm.memReadBytesC = multiplyElementSize((NumBatches * M * N) * betaReads, cInfo.elementSize);
// TODO: For MX data types, their size is smaller than a bytes, so we can't use
// segmentSize which would be 0. This needs to be enhanaced.
assert( ((NumBatches * M * N * K) * aInfo.elementSize / aInfo.packing) % MT1 == 0);
assert( ((NumBatches * M * N * K) * bInfo.elementSize / bInfo.packing) % MT0 == 0);
spm.memReadBytesA = (NumBatches * M * N * K) * aInfo.elementSize / aInfo.packing / MT1;
spm.memReadBytesB = (NumBatches * M * N * K) * bInfo.elementSize / bInfo.packing / MT0;
spm.memReadBytesC = (NumBatches * M * N) * betaReads * cInfo.elementSize / cInfo.packing;

if(GlobalSplitU == 1)
spm.memWriteBytesD = multiplyElementSize((NumBatches * M * N) * (1 + betaWrites), dInfo.elementSize);
{
// Use segmentSize as D currently does not support MX data types.
spm.memWriteBytesD = (NumBatches * M * N) * (1 + betaWrites) * dInfo.elementSize;
}
else
{
bool hardwareAtomic = false; // TODO-model
double atomicOperations = hardwareAtomic ? 2 : 3; // read-mod-write or cas //TODO-model
double atomicCollisions = 1.0; // TODO-could be based on K, GSU
spm.memWriteBytesD = multiplyElementSize((NumBatches * M * N)
// Use segmentSize as D currently does not support MX data types.
spm.memWriteBytesD = (NumBatches * M * N)
* (betaWrites + atomicOperations * atomicCollisions)
, dInfo.elementSize);
* dInfo.segmentSize;
}
spm.memReadBytes = spm.memReadBytesA + spm.memReadBytesB + spm.memReadBytesC;
spm.memGlobalReads = divideElementSize(spm.memReadBytesA, aInfo.elementSize)
+ divideElementSize(spm.memReadBytesB, bInfo.elementSize)
+ divideElementSize(spm.memReadBytesC, cInfo.elementSize);
spm.memGlobalWrites = divideElementSize(spm.memWriteBytesD, dInfo.elementSize);

spm.memGlobalReads = spm.memReadBytesA * aInfo.packing / aInfo.elementSize
+ spm.memReadBytesB * bInfo.packing / bInfo.elementSize
+ spm.memReadBytesC * cInfo.packing / cInfo.elementSize;
spm.memGlobalWrites = spm.memWriteBytesD / dInfo.segmentSize;

return spm;
}
Expand Down Expand Up @@ -3115,7 +3124,9 @@ namespace TensileLite
if(problemType.outputAmaxD)
{
auto numWGS = getNumWorkGroups(problem, sizeMapping);
size += multiplyElementSize(numWGS, problem.amaxd().elementBytes());
// Use elementBytes instead of segmentSize beceause D currently
// does not support sub-byte data types
size += numWGS * problem.amaxd().elementBytes();
}

return size;
Expand Down
2 changes: 1 addition & 1 deletion projects/hipblaslt/tensilelite/src/DataTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ namespace rocisa
return "Invalid";
}

float GetElementSize(rocisa::DataType d)
size_t GetElementSize(rocisa::DataType d)
{
switch(d)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ TEST_P(RangeLibraryTest, SpecificSizes)
M, // ldd
M*N, // strided
2.0); // beta
problem.setComputeInputType(rocisa::DataType::BFloat16);
problem.setComputeInputTypeA(rocisa::DataType::BFloat16);
problem.setComputeInputTypeB(rocisa::DataType::BFloat16);
problem.setHighPrecisionAccumulate(true);
problem.setWorkspaceSize(120324096);

Expand Down
Loading