Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions projects/hipblaslt/tensilelite/tests/MXDataGen_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include <vector>

/**
* @brief Unpack two FP4 nibbles from a packed byte.
* @brief Returns true if a 4-bit FP4 E2M1 nibble represents zero.
*
* FP4 E2M1 values are packed two-per-byte (low nibble first).
* Zero is represented by nibble value 0x0 (+0) or 0x8 (-0).
* Both 0x0 (+0) and 0x8 (-0) decode to zero.
*/
static bool isZeroNibble(uint8_t nibble)
{
Expand Down Expand Up @@ -64,7 +64,7 @@ TEST_P(MXDataGenFP4Test, ZeroFrequencyWithinBounds)
std::vector<size_t> emptySwizzle;
std::vector<size_t> emptyTile;

generateMXInput((hipDataType)HIP_R_4F_E2M1_EXT,
generateMXInput((hipDataType)HIP_R_4F_E2M1,
dataBuffer.data(),
scaleBuffer.data(),
rows,
Expand All @@ -83,8 +83,6 @@ TEST_P(MXDataGenFP4Test, ZeroFrequencyWithinBounds)
size_t zeros = countZerosFP4(dataBuffer.data(), numPacked);
double zeroPercent = 100.0 * static_cast<double>(zeros) / static_cast<double>(numElements);

// Empirically ~12.5–12.9% zeros; naive baseline is 2/16 = 12.5% (2 zero values
// out of 16 FP4 nibble values), slightly elevated by MX block scaling bias.
EXPECT_LT(zeroPercent, 13.0)
<< "Zero frequency " << zeroPercent << "% exceeds 13% upper bound for "
<< rows << "x" << cols << " FP4 matrix (transpose=" << isTranspose << ")";
Expand All @@ -105,3 +103,65 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(2048u, 514u, 32, false)
)
);

/**
* @brief Regression guard: generateMXInput must be deterministic (fixed seed).
*
* Any post-generation overwrite of the MXSA/MXSB buffers (e.g., the general
* tensor-init loop in initializeCPUInputs) desynchronises the CPU reference
* from GPU data, causing intermittent single-element validation failures.
* rows=K (must be mxBlock-aligned), cols=M/N (need not be).
*/
class MXGeneratorDeterminismTest
: public ::testing::TestWithParam<std::tuple<uint64_t, uint64_t, int, bool, bool>>
{
};

TEST_P(MXGeneratorDeterminismTest, GeneratorOutputIsDeterministic)
{
auto [rows, cols, mxBlock, isTranspose, isMatrixA] = GetParam();

const size_t numPacked = (rows * cols + 1) / 2;
const size_t numScales = (rows / mxBlock) * cols;

std::vector<uint8_t> data1(numPacked);
std::vector<uint8_t> data2(numPacked);
std::vector<uint8_t> scale1(numScales, 0x00);
std::vector<uint8_t> scale2(numScales, 0xFF); // sentinel: catches no-write if scale1==scale2 passes

std::vector<size_t> emptySwizzle, emptyTile;

generateMXInput((hipDataType)HIP_R_4F_E2M1,
data1.data(), scale1.data(),
rows, cols, rows, isTranspose,
emptySwizzle, emptyTile,
mxBlock, 1, isMatrixA, "Bounded", -1.f, 1.f);

generateMXInput((hipDataType)HIP_R_4F_E2M1,
Comment thread
amd-chunxlin marked this conversation as resolved.
data2.data(), scale2.data(),
rows, cols, rows, isTranspose,
emptySwizzle, emptyTile,
mxBlock, 1, isMatrixA, "Bounded", -1.f, 1.f);

EXPECT_EQ(data1, data2)
<< "FP4 data is non-deterministic";
EXPECT_EQ(scale1, scale2)
<< "Scale data is non-deterministic; any post-generation overwrite will corrupt validation";

bool allZero = std::all_of(scale1.begin(), scale1.end(), [](uint8_t b){ return b == 0; });
bool allOnes = std::all_of(scale1.begin(), scale1.end(), [](uint8_t b){ return b == 0xFF; });
Comment thread
amd-chunxlin marked this conversation as resolved.
EXPECT_FALSE(allZero) << "Scale buffer is all-zero — generator did not write";
EXPECT_FALSE(allOnes) << "Scale buffer is all-0xFF (max UE8M0 value) — generator likely failed; bounded [-1,1] input should produce varied scales";
}

INSTANTIATE_TEST_SUITE_P(
GeneratorDeterminism,
MXGeneratorDeterminismTest,
::testing::Values(
// rows=K, cols=M or N (tensorA.sizes()={K,M}, tensorB.sizes()={K,N})
std::make_tuple(1024u, 128u, 32, true, true), // transposed A
std::make_tuple(1024u, 128u, 32, false, false), // non-transposed B
std::make_tuple(1024u, 204u, 32, true, true), // M=204, non-32-aligned (was failing)
std::make_tuple(1024u, 213u, 32, true, true) // M=213, non-32-aligned (was failing)
)
);
Loading