Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions projects/hipblaslt/tensilelite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ if(TENSILELITE_ENABLE_HOST)
add_subdirectory(include)

if(TENSILELITE_ENABLE_CLIENT)
# Add mxDataGenerator if it hasn't been added by clients build
if(NOT TARGET roc::mxDataGenerator)
if(NOT ROCM_LIBS_SUPERBUILD)
if(HIPBLASLT_ENABLE_THEROCK)
find_package(mxDataGenerator REQUIRED)
else()
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../../shared/mxdatagenerator"
"${CMAKE_CURRENT_BINARY_DIR}/mxdatagenerator")
endif()
endif()
endif()

add_subdirectory(client)
endif()

Expand Down
16 changes: 9 additions & 7 deletions projects/hipblaslt/tensilelite/client/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ target_link_libraries(tensilelite-client
Boost::filesystem
OpenMP::OpenMP_CXX
)

target_link_libraries(tensilelite-client PRIVATE roc::mxDataGenerator)
set_target_properties(tensilelite-client
PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
)

if(NOT WIN32)
find_package(rocm_smi REQUIRED)
else()
Expand All @@ -24,13 +33,6 @@ endif()
target_include_directories(tensilelite-client PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include")
target_compile_definitions(tensilelite-client PRIVATE TENSILE_DEFAULT_SERIALIZATION)

set_target_properties(tensilelite-client
PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
Comment thread
amd-chunxlin marked this conversation as resolved.
)

if(HIPBLASLT_ENABLE_ASAN)
hipblaslt_target_configure_sanitizers(tensilelite-client PRIVATE)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

#include "RunListener.hpp"

#include <mxDataGen.hpp>

namespace po = boost::program_options;

namespace TensileLite
Expand Down Expand Up @@ -944,6 +946,8 @@ namespace TensileLite

void initializeConstantInputs(ContractionProblemGemm const& problem);

void initializeMXDataForFP4(ContractionProblemGemm const& problem);

void copyInputs(std::vector<void*>& ptrs,
std::vector<void**>& batchPtrs,
std::vector<size_t>& maxElements,
Expand Down
9 changes: 9 additions & 0 deletions projects/hipblaslt/tensilelite/client/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ target_sources(tensilelite-client
"${CMAKE_CURRENT_SOURCE_DIR}/SolutionIterator.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/LibraryUpdateReporter.cpp"
)

target_sources(tensilelite-client
PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/../../../clients/common/src/mxDataGen.cpp"
Comment thread
amd-chunxlin marked this conversation as resolved.
)
target_include_directories(tensilelite-client
PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/../../../clients/common/include"
)
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,6 @@ namespace TensileLite
std::cout << "Tensor name " << m_vdata[i].name << " init mode "
<< ToString(m_vdata[i].init) << std::endl;
}

// Init contants
for(size_t i = 0; i < m_cdata.size(); i++)
{
Expand Down Expand Up @@ -1269,6 +1268,13 @@ namespace TensileLite
m_problemDependentData
|= (m_sparse
| (args["bias-type-args"].as<std::vector<rocisa::DataType>>().size() > 1));

// Force problem-dependent initialization for MX FP4 to enable mxDataGenerator
if(args.count("mx-block-a") && args["mx-block-a"].as<int>() > 0)
m_problemDependentData = true;
if(args.count("mx-block-b") && args["mx-block-b"].as<int>() > 0)
m_problemDependentData = true;

allocNewCPUInputs();
allocNewGPUInputs();

Expand Down Expand Up @@ -1692,13 +1698,21 @@ namespace TensileLite

void DataInitialization::initializeCPUInputs(ContractionProblemGemm const& problem)
{
bool useMXGenerator = (problem.a().dataType() == rocisa::DataType::Float4 && problem.mxBlockA() > 0)
|| (problem.b().dataType() == rocisa::DataType::Float4 && problem.mxBlockB() > 0);
if(useMXGenerator)
initializeMXDataForFP4(problem);

auto& tensors = problem.tensors();
for(size_t i = 0; i < m_vdata.size(); i++)
{
if(i == ContractionProblemGemm::TENSOR::COMPRESSED
or i == ContractionProblemGemm::TENSOR::METADATA)
continue;

if(useMXGenerator && (i == ContractionProblemGemm::TENSOR::A || i == ContractionProblemGemm::TENSOR::B))
continue;

if(m_problemDependentData)
{
// Should this m_cEqualsD set in ContractionProblem or boost args?
Expand Down Expand Up @@ -1756,6 +1770,70 @@ namespace TensileLite
}
}

void DataInitialization::initializeMXDataForFP4(ContractionProblemGemm const& problem)
{
std::vector<size_t> emptySwizzle;
std::vector<size_t> emptyTile;

if(problem.mxBlockA() > 0 && problem.a().dataType() == rocisa::DataType::Float4)
{
auto const& tensorA = problem.a();
auto rows = tensorA.sizes()[0];
auto cols = tensorA.sizes()[1];
auto stride = tensorA.strides()[1];

auto& pristineA
= m_vdata[ContractionProblemGemm::TENSOR::A].pristine[rocisa::DataType::Float4];
auto& pristineMXScaleA
= m_vdata[ContractionProblemGemm::TENSOR::MXSA].pristine[problem.mxsa().dataType()];

generateMXInput((hipDataType)HIP_R_4F_E2M1_EXT,
pristineA.cpuInput.valid.get(),
pristineMXScaleA.cpuInput.valid.get(),
rows,
cols,
stride,
problem.transA(),
emptySwizzle,
emptyTile,
problem.mxBlockA(),
1,
true,
"Bounded",
-1.0f,
1.0f);
}

if(problem.mxBlockB() > 0 && problem.b().dataType() == rocisa::DataType::Float4)
{
auto const& tensorB = problem.b();
auto rows = tensorB.sizes()[0];
auto cols = tensorB.sizes()[1];
auto stride = tensorB.strides()[1];

auto& pristineB
= m_vdata[ContractionProblemGemm::TENSOR::B].pristine[rocisa::DataType::Float4];
auto& pristineMXScaleB
= m_vdata[ContractionProblemGemm::TENSOR::MXSB].pristine[problem.mxsb().dataType()];

generateMXInput((hipDataType)HIP_R_4F_E2M1_EXT,
pristineB.cpuInput.valid.get(),
pristineMXScaleB.cpuInput.valid.get(),
rows,
cols,
stride,
problem.transB(),
emptySwizzle,
emptyTile,
problem.mxBlockB(),
1,
false,
"Bounded",
-1.0f,
1.0f);
}
}

void DataInitialization::initializeConstantInputs(ContractionProblemGemm const& problem)
{
// Update constants if needed
Expand Down
19 changes: 19 additions & 0 deletions projects/hipblaslt/tensilelite/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ target_link_libraries(tensilelite-tests
${CMAKE_DL_LIBS}
)

if(TARGET roc::mxDataGenerator)
target_sources(tensilelite-tests
PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/MXDataGen_test.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/../../clients/common/src/mxDataGen.cpp"
)
target_include_directories(tensilelite-tests
PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/../../clients/common/include"
)
target_link_libraries(tensilelite-tests PRIVATE roc::mxDataGenerator)
set_target_properties(tensilelite-tests
PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
)
endif()

gtest_discover_tests(tensilelite-tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} TIMEOUT 60)

target_link_libraries(tensilelite-tests PUBLIC GTest::gtest)
Expand Down
107 changes: 107 additions & 0 deletions projects/hipblaslt/tensilelite/tests/MXDataGen_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include <gtest/gtest.h>

#include <mxDataGen.hpp>

#include <cstdint>
#include <vector>

/**
* @brief Unpack two FP4 nibbles from a packed byte.
*
* FP4 E2M1 values are packed two-per-byte (low nibble first).
* Zero is represented by nibble value 0x0 (+0) or 0x8 (-0).
*/
static bool isZeroNibble(uint8_t nibble)
{
// FP4 E2M1: 0x0 = +0.0, 0x8 = -0.0
return (nibble == 0x0) || (nibble == 0x8);
}

/**
* @brief Count elements that decode to zero in a packed FP4 buffer.
*/
static size_t countZerosFP4(const uint8_t* packedData, size_t numPackedBytes)
{
size_t zeros = 0;
for(size_t i = 0; i < numPackedBytes; ++i)
{
uint8_t lo = packedData[i] & 0x0F;
uint8_t hi = (packedData[i] >> 4) & 0x0F;
if(isZeroNibble(lo))
++zeros;
if(isZeroNibble(hi))
++zeros;
}
return zeros;
}

class MXDataGenFP4Test : public ::testing::TestWithParam<std::tuple<uint64_t, uint64_t, int, bool>>
{
};

/**
* @brief Verify that generateMXInput produces FP4 data with an acceptable zero frequency.
*
* FP4 E2M1 has 16 nibble values, 2 of which are zero (0x0 = +0, 0x8 = -0), giving a
* naive baseline of 2/16 = 12.5%. MX block scaling slightly elevates this: the block
* maximum is guaranteed non-zero, pushing small elements toward zero. Empirically the
* zero frequency converges to ~12.89% for large matrices with bounded [-1, 1] input.
*/
TEST_P(MXDataGenFP4Test, ZeroFrequencyWithinBounds)
{
auto [rows, cols, mxBlock, isTranspose] = GetParam();

const uint64_t numElements = rows * cols;
const uint64_t numPacked = (numElements + 1) / 2;
const size_t numScales = ((rows + mxBlock - 1) / mxBlock) * cols;

std::vector<uint8_t> dataBuffer(numPacked, 0);
std::vector<uint8_t> scaleBuffer(numScales, 0);

std::vector<size_t> emptySwizzle;
std::vector<size_t> emptyTile;

generateMXInput((hipDataType)HIP_R_4F_E2M1_EXT,
dataBuffer.data(),
scaleBuffer.data(),
rows,
cols,
rows, // stride = rows (column-major)
isTranspose,
emptySwizzle,
emptyTile,
mxBlock,
1,
true,
"Bounded",
-1.0f,
1.0f);

size_t zeros = countZerosFP4(dataBuffer.data(), numPacked);
double zeroPercent = 100.0 * static_cast<double>(zeros) / static_cast<double>(numElements);

// Empirically ~12.5–12.9% zeros; naive baseline is 2/16 = 12.5% (2 zero values
// out of 16 FP4 nibble values), slightly elevated by MX block scaling bias.
EXPECT_LT(zeroPercent, 13.0)
<< "Zero frequency " << zeroPercent << "% exceeds 13% upper bound for "
<< rows << "x" << cols << " FP4 matrix (transpose=" << isTranspose << ")";

// Ensure non-trivial data was actually generated (not all zeros)
EXPECT_GT(numElements - zeros, 0u)
<< "All elements are zero for " << rows << "x" << cols << " FP4 matrix";
}

INSTANTIATE_TEST_SUITE_P(
FP4ZeroFrequency,
MXDataGenFP4Test,
::testing::Values(
// rows, cols, mxBlock, isTranspose
std::make_tuple(128u, 128u, 32, true),
std::make_tuple(256u, 256u, 32, true),
std::make_tuple(2048u, 1026u, 32, true),
std::make_tuple(2048u, 514u, 32, false)
)
);
Loading