Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1f508c9
Add shuffle
archana-ramalingam Feb 24, 2026
a6608d0
Correct hipblaslt dependency
archana-ramalingam Feb 24, 2026
56298a7
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Feb 24, 2026
569114f
helper functions for MXFP4 tensor checks & update comments
archana-ramalingam Feb 25, 2026
c82599f
Merge branch 'users/aramalin/shuffle_scale_tensile' of https://github…
archana-ramalingam Feb 25, 2026
3778440
Update comments
archana-ramalingam Feb 25, 2026
229b1d8
Remove redundant static
archana-ramalingam Feb 26, 2026
65ba9d6
Fix tests
archana-ramalingam Feb 26, 2026
13dc00f
Make shuffle optional
archana-ramalingam Feb 26, 2026
495f0ae
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Feb 26, 2026
623ad5a
Fix failing tests
archana-ramalingam Feb 26, 2026
460884c
Merge branch 'users/aramalin/shuffle_scale_tensile' of https://github…
archana-ramalingam Feb 26, 2026
b4975ac
Revert "Fix failing tests"
archana-ramalingam Feb 26, 2026
3d0ad47
Revert "Fix tests"
archana-ramalingam Feb 26, 2026
27e41e6
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Feb 27, 2026
db25924
Add optional flag
archana-ramalingam Feb 28, 2026
4ddb182
Merge branch 'users/aramalin/shuffle_scale_tensile' of https://github…
archana-ramalingam Feb 28, 2026
f49ff93
Shorten comments
archana-ramalingam Feb 28, 2026
bf0cff1
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Feb 28, 2026
9bbeef1
Added mxDataGenerator once before both tensilelite and clients
archana-ramalingam Feb 28, 2026
c4d37d7
Merge branch 'users/aramalin/shuffle_scale_tensile' of https://github…
archana-ramalingam Feb 28, 2026
5d8c217
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Mar 2, 2026
0690204
Add flag as tensile parameter
archana-ramalingam Mar 3, 2026
c74ff94
Revert "Added mxDataGenerator once before both tensilelite and clients"
archana-ramalingam Mar 3, 2026
f741399
Merge branch 'users/aramalin/shuffle_scale_tensile' of https://github…
archana-ramalingam Mar 3, 2026
8532609
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Mar 4, 2026
8147265
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Mar 5, 2026
528c20b
Update swizzle flag
archana-ramalingam Mar 6, 2026
f940a8d
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Mar 6, 2026
5dadf0e
Merge branch 'gfx950_mx_rebase' into users/aramalin/shuffle_scale_ten…
archana-ramalingam Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions projects/hipblaslt/tensilelite/Tensile/ClientWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def runNewClient(scriptPath, clientParametersPath, cxxCompiler: str, cCompiler:
iniFile = "--config-file={}".format(clientParametersPath)
args = [clientExe, iniFile]

# Add MX scale format if set
if globalParameters["MXScaleFormat"]:
args.append("--mx-scale-format={}".format(globalParameters["MXScaleFormat"]))

try:
subprocess.run(args, check=True)
except (subprocess.CalledProcessError, OSError) as e:
Expand Down Expand Up @@ -324,8 +328,9 @@ def writeRunScript(path, forBenchmark, enableTileSelection, cxxCompiler: str, cC

clientExe = getClientExecutablePath()
timingFlag = " --timing-instrumentation" if globalParameters["TimingInstrumentation"] else ""
mxScaleFormatFlag = " --mx-scale-format={}".format(globalParameters["MXScaleFormat"]) if globalParameters["MXScaleFormat"] else ""
for configFile in configPaths:
runScriptFile.write("{} --config-file {}{}\n".format(clientExe, configFile, timingFlag))
runScriptFile.write("{} --config-file {}{}{}\n".format(clientExe, configFile, timingFlag, mxScaleFormatFlag))
runScriptFile.write("ERR2=$?\n\n")

runScriptFile.write("""
Expand All @@ -347,8 +352,9 @@ def writeRunScript(path, forBenchmark, enableTileSelection, cxxCompiler: str, cC
runScriptFile.write("%s -d 0 --resetclocks\n" % globalParameters["ROCmSMIPath"])
runScriptFile.write("%s -d 0 --setfan 50\n" % globalParameters["ROCmSMIPath"])
else:
mxScaleFormatFlag = " --mx-scale-format={}".format(globalParameters["MXScaleFormat"]) if globalParameters["MXScaleFormat"] else ""
for configFile in configPaths:
runScriptFile.write("{} --config-file {} --best-solution 1\n".format(getClientExecutablePath(), configFile))
runScriptFile.write("{} --config-file {} --best-solution 1{}\n".format(getClientExecutablePath(), configFile, mxScaleFormatFlag))

if os.name != "nt":
runScriptFile.write("exit $ERR\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
)
globalParameters["LogicFormat"] = "yaml" # set library backend (yaml, or json)
globalParameters["LibraryFormat"] = "yaml" # set library backend (yaml, or msgpack)
globalParameters["MXScaleFormat"] = 0 # MX scale data format (0=none, 1=pre-swizzle for GPU kernel layout)

# True/False: CSV will/won't export WinnerGFlops, WinnerTimeUS, WinnerIdx, WinnerName.
# TODO - if no side-effect, we can set default to True. This can make analyzing "LibraryLogic" (AddFromCSV) faster
Expand Down
5 changes: 5 additions & 0 deletions projects/hipblaslt/tensilelite/Tensile/Tensile.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def splitExtraParameters(par):
argParser.add_argument("--client-lock", default=None)
argParser.add_argument("--prebuilt-client", default=str(TENSILE_CLIENT_PATH),
type=os.path.abspath, help="Specify the full path to a pre-built tensilelite-client executable")
argParser.add_argument("--mx-scale-format", dest="MXScaleFormat", type=int, default=0,
help="MX scale data format (0=none, 1=pre-swizzle for GPU kernel layout)")

argParser.add_argument("--global-parameters", nargs="+", type=splitExtraParameters, default=[])

Expand All @@ -246,6 +248,9 @@ def argUpdatedGlobalParameters(args):
rv["ClientExecutionLockPath"] = args.client_lock
if args.prebuilt_client:
rv["PrebuiltClient"] = args.prebuilt_client
if args.MXScaleFormat:
print1("# Command-line override: MXScaleFormat")
rv["MXScaleFormat"] = args.MXScaleFormat

for key, value in args.global_parameters:
rv[key] = value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ namespace TensileLite
{
namespace Client
{
inline bool isMXFP4Tensor(const TensorDescriptor& tensor, size_t mxBlock)
{
return tensor.dataType() == rocisa::DataType::Float4 && mxBlock > 0;
}

inline bool isMXFP4Problem(const ContractionProblemGemm& problem)
{
return isMXFP4Tensor(problem.a(), problem.mxBlockA())
|| isMXFP4Tensor(problem.b(), problem.mxBlockB());
}

// Problem-indept. from 0~7, and 16, and 23~26 (fixed values for every problem)
// And problem-dept. from 8~15 (values depend on problem)
// RandomNegPosLimited: integer -128~128. fp -1.0~1.0
Expand Down Expand Up @@ -832,9 +843,35 @@ namespace TensileLite
}
virtual void preBenchmarkRun() override {}
virtual void postBenchmarkRun() override {}
virtual void preProblem(ContractionProblem* const problem) override {}
virtual void preProblem(ContractionProblem* const problem) override
{
m_currentGemmProblem
= dynamic_cast<ContractionProblemGemm const*>(problem);
}
virtual void postProblem() override {}
virtual void preSolution(ContractionSolution* const solution) override {}
virtual void preSolution(ContractionSolution* const solution) override
{
m_currentSolution = solution;
// Re-init MX scale with preSwizzle now that solution (useScaleAB) is available
if(m_currentSolution != nullptr
&& !m_currentSolution->problemType.useScaleAB.empty()
&& m_currentGemmProblem != nullptr
&& !m_gpuPtrs.empty())
{
bool isMXFP4 = isMXFP4Problem(*m_currentGemmProblem);
if(isMXFP4)
{
initializeMXDataForFP4(*m_currentGemmProblem);
copyValidToGPUBuffer(*m_currentGemmProblem);
copyInputs(m_gpuPtrs,
m_gpuBatchPtrs,
m_maxElements,
m_groupedOffsets,
*m_currentGemmProblem,
hipMemcpyDeviceToDevice);
}
}
}
virtual void postSolution() override {}
virtual bool needMoreRunsInSolution() const override
{
Expand Down Expand Up @@ -1044,6 +1081,11 @@ namespace TensileLite
int64_t m_rotatingBuffer = 0;
std::shared_ptr<RotatingMemory> m_rm;
int32_t m_rotatingMode = 0;

ContractionSolution const* m_currentSolution = nullptr;
ContractionProblemGemm const* m_currentGemmProblem = nullptr;

int m_mxScaleFormat = 0;
};

template <>
Expand Down
1 change: 1 addition & 0 deletions projects/hipblaslt/tensilelite/client/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ namespace TensileLite
("swizzle-tensor-b", po::value<bool>()->default_value(false), "Swizzle input tensor B.")
("mx-block-a", po::value<int>()->default_value(0), "block of mx datatype input matrix A")
("mx-block-b", po::value<int>()->default_value(0), "block of mx datatype input matrix B")
("mx-scale-format", po::value<int>()->default_value(0), "MX scale data format (0=none, 1=pre-swizzle for GPU kernel layout)")
("activation-compute-type", po::value<rocisa::DataType>()->default_value(rocisa::DataType::None), "Activation compute type.")
("high-precision-accumulate", po::value<bool>()->default_value(false), "Use high-precision accumulate.")
("sparse", po::value<int>()->default_value(0), "A or B matrix is sparse matrix.")
Expand Down
59 changes: 49 additions & 10 deletions projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ namespace TensileLite
, m_keepPristineCopyOnGPU(args["pristine-on-gpu"].as<bool>())
, m_workspaceSize(problemFactory.workspaceSize())
, m_pruneMode(args["prune-mode"].as<PruneSparseMode>())
, m_mxScaleFormat(args["mx-scale-format"].as<int>())

{
m_rotatingBuffer
Expand Down Expand Up @@ -1697,8 +1698,7 @@ namespace TensileLite

void DataInitialization::initializeCPUInputs(ContractionProblemGemm const& problem)
{
bool useMXGenerator = (problem.a().dataType() == rocisa::DataType::Float4 && problem.mxBlockA() > 0)
|| (problem.b().dataType() == rocisa::DataType::Float4 && problem.mxBlockB() > 0);
bool useMXGenerator = isMXFP4Problem(problem);
if(useMXGenerator)
initializeMXDataForFP4(problem);

Expand Down Expand Up @@ -1774,10 +1774,49 @@ namespace TensileLite

void DataInitialization::initializeMXDataForFP4(ContractionProblemGemm const& problem)
{
std::vector<size_t> emptySwizzle;
std::vector<size_t> emptyTile;
// Compute preSwizzle parameters from the solution's matrix instruction to rearrange
// the scale tensor into the GPU kernel's expected memory layout
std::vector<size_t> preSwizzleA, preTileA, preSwizzleB, preTileB;

if(problem.mxBlockA() > 0 && problem.a().dataType() == rocisa::DataType::Float4)
if(m_mxScaleFormat > 0 && m_currentSolution != nullptr
&& !m_currentSolution->problemType.useScaleAB.empty())
{
auto const& mi = m_currentSolution->sizeMapping.matrixInstruction;
size_t MiK = static_cast<size_t>(mi[2]);
constexpr size_t swizzleTileMN = 32; // 2 SIMDs * 16 lanes per wave for MN access
constexpr size_t tileK = 256 / swizzleTileMN; // scale blocks per wave in K

if(MiK > 0)
{
if(problem.mxBlockA() > 0 && MiK % problem.mxBlockA() == 0)
{
// scale tensor: scaleRows = sizes[0]/mxBlock, scaleCols = sizes[1]
// preSwizzle requires both to be multiples of their tile dimensions
size_t scaleRowsA = problem.a().sizes()[0] / problem.mxBlockA();
size_t scaleColsA = problem.a().sizes()[1];
if(scaleRowsA % tileK == 0 && scaleColsA % swizzleTileMN == 0)
{
size_t subTileK = MiK / problem.mxBlockA();
preSwizzleA = {swizzleTileMN, tileK, subTileK};
preTileA = {tileK, swizzleTileMN};
}
}

if(problem.mxBlockB() > 0 && MiK % problem.mxBlockB() == 0)
{
size_t scaleRowsB = problem.b().sizes()[0] / problem.mxBlockB();
size_t scaleColsB = problem.b().sizes()[1];
if(scaleRowsB % tileK == 0 && scaleColsB % swizzleTileMN == 0)
{
size_t subTileK = MiK / problem.mxBlockB();
preSwizzleB = {swizzleTileMN, tileK, subTileK};
preTileB = {tileK, swizzleTileMN};
}
}
}
}

if(isMXFP4Tensor(problem.a(), problem.mxBlockA()))
{
auto const& tensorA = problem.a();
auto rows = tensorA.sizes()[0];
Expand All @@ -1796,8 +1835,8 @@ namespace TensileLite
cols,
stride,
problem.transA(),
emptySwizzle,
emptyTile,
preSwizzleA,
preTileA,
problem.mxBlockA(),
1,
true,
Expand All @@ -1806,7 +1845,7 @@ namespace TensileLite
1.0f);
}

if(problem.mxBlockB() > 0 && problem.b().dataType() == rocisa::DataType::Float4)
if(isMXFP4Tensor(problem.b(), problem.mxBlockB()))
{
auto const& tensorB = problem.b();
auto rows = tensorB.sizes()[0];
Expand All @@ -1825,8 +1864,8 @@ namespace TensileLite
cols,
stride,
problem.transB(),
emptySwizzle,
emptyTile,
preSwizzleB,
preTileB,
problem.mxBlockB(),
1,
false,
Expand Down
88 changes: 88 additions & 0 deletions projects/hipblaslt/tensilelite/tests/MXDataGen_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,91 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(1024u, 213u, 32, true, true) // M=213, non-32-aligned (was failing)
)
);

// ============================================================================
// PreSwizzle scale tests
//
// Verify generateMXInput with preSwizzle produces scale data that is a
// permutation of the unswizzled layout. gfx950 FP4 MX kernels expect:
// preSwizzle = {swizzleTileMN=32, tileK=8, subTileK=MiK/mxBlock}
// preTile = {tileK=8, swizzleTileMN=32}
// swizzleTileMN=32 is fixed (2 SIMDs * 16 lanes); subTileK=4 for MiK=128, mxBlock=32.
// ============================================================================

// Params: {rows, cols, mxBlock, isTranspose, isMatrixA}
class MXPreSwizzleTest
: public ::testing::TestWithParam<std::tuple<uint64_t, uint64_t, int, bool, bool>>
{
};

/** @brief Verify preSwizzle produces a non-trivial permutation of scale data. */
TEST_P(MXPreSwizzleTest, ScaleIsPermutationOfUnswizzled)
{
auto [rows, cols, mxBlock, isTranspose, isMatrixA] = GetParam();

const std::vector<size_t> preSwizzle = {32, 8, 4};
const std::vector<size_t> preTile = {8, 32};

const uint64_t numElements = rows * cols;
const uint64_t numPacked = (numElements + 1) / 2;
const size_t numScales = ((rows + mxBlock - 1) / mxBlock) * cols;

std::vector<uint8_t> dataNoShuf(numPacked, 0);
std::vector<uint8_t> scaleNoShuf(numScales, 0);
std::vector<uint8_t> dataShuf(numPacked, 0);
std::vector<uint8_t> scaleShuf(numScales, 0);

// Generate without preSwizzle
generateMXInput((hipDataType)HIP_R_4F_E2M1,
dataNoShuf.data(),
scaleNoShuf.data(),
rows, cols, rows,
isTranspose,
{}, {},
mxBlock, 1, isMatrixA,
"Bounded", -1.0f, 1.0f);

// Generate with preSwizzle
generateMXInput((hipDataType)HIP_R_4F_E2M1,
dataShuf.data(),
scaleShuf.data(),
rows, cols, rows,
isTranspose,
preSwizzle, preTile,
mxBlock, 1, isMatrixA,
"Bounded", -1.0f, 1.0f);

// The scale buffers must be different
EXPECT_NE(scaleNoShuf, scaleShuf)
<< "Scale data was not shuffled for " << rows << "x" << cols
<< " (transpose=" << isTranspose << ", isMatrixA=" << isMatrixA << ")";

// The shuffled scale must be a permutation: same multiset of bytes
std::vector<uint8_t> sortedNoShuf = scaleNoShuf;
std::vector<uint8_t> sortedShuf = scaleShuf;
std::sort(sortedNoShuf.begin(), sortedNoShuf.end());
std::sort(sortedShuf.begin(), sortedShuf.end());
EXPECT_EQ(sortedNoShuf, sortedShuf)
<< "Pre-shuffled scale is not a permutation of the unshuffled scale for "
<< rows << "x" << cols;

// Data buffer must be identical (preSwizzle only affects scale, not data)
EXPECT_EQ(dataNoShuf, dataShuf)
<< "Data buffer changed unexpectedly with preSwizzle for "
<< rows << "x" << cols;
}

INSTANTIATE_TEST_SUITE_P(
FP4PreSwizzle,
MXPreSwizzleTest,
::testing::Values(
// rows, cols, mxBlock, isTranspose, isMatrixA
// Test size constraints for preSwizzle {32,8,4} + preTile {8,32}:
// rows % 256 == 0 (scaleRows = rows/mxBlock must be divisible by tileK=8)
// cols % 32 == 0 (scaleCols must be divisible by swizzleTileMN=32) std::make_tuple(256u, 256u, 32, true, true), // scale A transposed
std::make_tuple(256u, 256u, 32, false, false), // scale B non-transposed
std::make_tuple(512u, 256u, 32, true, true), // larger scale A
std::make_tuple(256u, 512u, 32, false, false), // larger scale B
std::make_tuple(4096u, 16384u, 32, true, true) // benchmark-scale problem
)
);
Loading