Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/hipblaslt/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Full documentation for hipBLASLt is available at [rocm.docs.amd.com/projects/hip

* Replaced `install.sh` with an invoke-based task runner (`tasks.py`) to support cross-platform builds including Windows (ROCm 7.0+).
* gtest and msgpack-cxx are now fetched automatically via CMake FetchContent if not found on the system.
* Greatly improved MXFP4 GEMM performance when using HIPBLASLT_MATMUL_MATRIX_SCALE_BLK32_UE8M0_32_8_EXT

## hipBLASLt 1.2.2 for ROCm 7.2.1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,21 @@ inline std::vector<size_t> preTileSizeForScaleB(hipblaslt_scaling_format s)
}
}

// Compute scale buffer size accounting for padding by preSwizzleScalesGFX950.
// Compute scale buffer size with padding for block-scaled MX formats.
// dataRow, dataCol are the raw data matrix dimensions (A_row/A_col or B_row/B_col).
// When pre-swizzle is active, the output may be larger than the unpadded size
// because rows are padded to a multiple of 32 and cols to a multiple of 8.
// Scale dimensions are padded to ensure kernels that process data in 32-element (M/N)
// or 256-element (K) blocks always have valid scale entries:
// scaleRows = ceil(dataRow / blockSize) rounded up to multiple of 8
// scaleCols = dataCol rounded up to multiple of 32
// When pre-swizzle is active, additional layout requirements may apply but are
// already satisfied by the rounding above.
inline size_t scaleBufferSize(int64_t dataRow, int64_t dataCol, hipblaslt_scaling_format s)
{
auto bs = blockSize(s);
size_t scaleRows = dataRow / bs;
size_t scaleCols = dataCol;
size_t scaleRows = ((dataRow + bs - 1) / bs + 7) / 8 * 8;
size_t scaleCols = ((dataCol + 31) / 32) * 32;

auto preSwizzle = preSwizzleSizeForScale(s);
if(preSwizzle.empty())
return scaleRows * scaleCols;

// preSwizzleScalesGFX950 is called with {scaleCols, scaleRows}.
// It pads numRows (=scaleCols) to multiple of 32, numCols (=scaleRows) to multiple of 8.
size_t paddedNumRows = ((scaleCols + 31) / 32) * 32;
size_t paddedNumCols = ((scaleRows + 7) / 8) * 8;
return paddedNumRows * paddedNumCols;
return scaleRows * scaleCols;
}

inline hipblaslt_internal_ostream& operator<<(hipblaslt_internal_ostream& os,
Expand Down
58 changes: 37 additions & 21 deletions projects/hipblaslt/clients/common/include/norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "utility.hpp"
#include <cstdio>
#include <hipblaslt/hipblaslt.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include <memory>
Expand Down Expand Up @@ -557,45 +558,60 @@ bool norm_check(double norm_error)
return false;
}

// TODO: norm_check determines the required norm solely based on
// To (type). This might cause tests to fail when the input
// matrices are MX types (F4/F8/F6). A better way is
// needed to determine the required norm for MX types.
bool norm_check(double norm_error, hipDataType type)
// TODO: tune norm tolerance for MX FP6 and FP8 types
double norm_tolerance(hipDataType type)
{
switch(type)
{
case HIP_R_32F:
return norm_error < 0.00001;
return 0.00001;
case HIP_R_64F:
return norm_error < 0.000000000001;
return 0.000000000001;
case HIP_R_16F:
return norm_error < 0.01;
return 0.01;
case HIP_R_16BF:
return norm_error < 0.1;
return 0.1;
case HIP_R_8F_E4M3_FNUZ:
case HIP_R_8F_E4M3:
return norm_error < 0.125;
return 0.125;
case HIP_R_8F_E5M2_FNUZ:
case HIP_R_8F_E5M2:
return norm_error < 0.25;
return 0.25;
case HIP_R_32I:
return norm_error < 0.0001;
return 0.0001;
case HIP_R_8I:
return norm_error < 0.01;
// TODO: find a suitable rnom value for f6 and f4
return 0.01;
case HIP_R_4F_E2M1:
return 0.3;
case HIP_R_6F_E2M3:
case HIP_R_6F_E3M2:
case HIP_R_4F_E2M1:
return norm_error < 0.5;
return 0.5;
default:
return false;
return 0.0;
}
}

bool norm_check(double norm_error, hipDataType type, hipblasComputeType_t compute_type)
bool norm_check(double norm_error, hipDataType type)
{
if(compute_type == HIPBLAS_COMPUTE_32F_FAST_16BF && type == HIP_R_32F)
return norm_error < 0.5;
return norm_check(norm_error, type);
double tol = norm_tolerance(type);
return tol > 0.0 && norm_error < tol;
}

bool norm_check(double norm_error,
hipDataType outputType,
hipblasComputeType_t compute_type,
hipDataType inputTypeA = static_cast<hipDataType>(-1),
hipDataType inputTypeB = static_cast<hipDataType>(-1))
{
double tol = norm_tolerance(outputType);

if(compute_type == HIPBLAS_COMPUTE_32F_FAST_16BF && outputType == HIP_R_32F)
tol = std::max(tol, 0.5);

if(static_cast<int>(inputTypeA) >= 0)
tol = std::max(tol, norm_tolerance(inputTypeA));
if(static_cast<int>(inputTypeB) >= 0)
tol = std::max(tol, norm_tolerance(inputTypeB));

return tol > 0.0 && norm_error < tol;
}
115 changes: 75 additions & 40 deletions projects/hipblaslt/clients/common/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ extern "C" __global__ void flush_icache()
:);
}

// Convert element count to byte count, accounting for sub-byte packing.
// FP4 (4-bit) packs 2 elements per byte; all other types use realDataTypeSize.
size_t elementsToBytes(size_t numElements, hipDataType dtype)
{
if(static_cast<int>(dtype) == HIP_R_4F_E2M1)
return numElements / 2;
return numElements * realDataTypeSize(dtype);
}

bool isSwizzleSupported(hipDataType datatype)
{
switch(datatype)
Expand Down Expand Up @@ -1235,7 +1244,8 @@ void check(hipStream_t stream,
hipblaslt_error += norm_error;
if(arg.norm_check_assert)
{
CHECK_SUCCESS(norm_check(norm_error, To, arg.compute_type));
CHECK_SUCCESS(
norm_check(norm_error, To, arg.compute_type, arg.a_type, arg.b_type));
}
if(batchMode != HIPBLASLT_BATCH_MODE_POINTER_ARRAY)
{
Expand Down Expand Up @@ -1269,7 +1279,8 @@ void check(hipStream_t stream,
hipblaslt_error += norm_error;
if(arg.norm_check_assert)
{
CHECK_SUCCESS(norm_check(norm_error, Taux, arg.compute_type));
CHECK_SUCCESS(
norm_check(norm_error, Taux, arg.compute_type, arg.a_type, arg.b_type));
}
}
if(arg.gradient && arg.bias_vector)
Expand Down Expand Up @@ -2248,8 +2259,8 @@ void testing_matmul_with_bias(const Arguments& arg,
}
else if(isBlockScaling(arg.scaleA))
{
// For MX format, use uin8_t for the scale (E8M0)
dScaleA.emplace_back(HIP_R_8U, size_scaleAVec[i] * block_count, HMM);
// For MX format, use uint8_t for the scale (E8M0), allocate for all batches
dScaleA.emplace_back(HIP_R_8U, size_scaleAVec[i] * num_batches[i] * block_count, HMM);
CHECK_DEVICE_ALLOCATION(hipGetLastError());
}
if(arg.scaleB == hipblaslt_scaling_format::Scalar
Expand All @@ -2260,8 +2271,8 @@ void testing_matmul_with_bias(const Arguments& arg,
}
else if(isBlockScaling(arg.scaleB))
{
// For MX format, use uin8_t for the scale (E8M0)
dScaleB.emplace_back(HIP_R_8U, size_scaleBVec[i] * block_count, HMM);
// For MX format, use uint8_t for the scale (E8M0), allocate for all batches
dScaleB.emplace_back(HIP_R_8U, size_scaleBVec[i] * num_batches[i] * block_count, HMM);
CHECK_DEVICE_ALLOCATION(hipGetLastError());
}
if(arg.scaleC)
Expand Down Expand Up @@ -2312,7 +2323,7 @@ void testing_matmul_with_bias(const Arguments& arg,
}
else if(isBlockScaling(arg.scaleA))
{
hScaleA.emplace_back(HIP_R_8U, size_scaleAVec[i]);
hScaleA.emplace_back(HIP_R_8U, size_scaleAVec[i] * num_batches[i]);
}
if(arg.scaleB == hipblaslt_scaling_format::Scalar
|| arg.scaleB == hipblaslt_scaling_format::Vector)
Expand All @@ -2321,7 +2332,7 @@ void testing_matmul_with_bias(const Arguments& arg,
}
else if(isBlockScaling(arg.scaleB))
{
hScaleB.emplace_back(HIP_R_8U, size_scaleBVec[i]);
hScaleB.emplace_back(HIP_R_8U, size_scaleBVec[i] * num_batches[i]);
}
if(arg.scaleC)
hScaleC.emplace_back(Talpha, 1);
Expand Down Expand Up @@ -2507,23 +2518,34 @@ void testing_matmul_with_bias(const Arguments& arg,
// (consists of data part and scale part)
// preTile for A: {tileK, tileM} - swap from preTileSizeForScaleA which returns {tileM, tileK}
auto preTileATmp = preTileSizeForScaleA(arg.scaleA);
auto preTileA = (preTileATmp.size() == 2)
? std::vector<size_t>{preTileATmp[1], preTileATmp[0]}
: std::vector<size_t>{};
refA.emplace_back(generateMXInput(TiA,
scaleDataType(arg.scaleA),
hA[i].buf(),
hScaleA[i].buf(),
A_row[i],
A_col[i],
lda[i],
transA == HIPBLAS_OP_T,
preSwizzleSizeForScale(arg.scaleA),
preTileA,
blockSize(arg.scaleA),
1,
true,
hipblaslt_initialization2string(arg.initialization)));
auto preTileA = (preTileATmp.size() == 2) ? std::vector<size_t>{preTileATmp[1], preTileATmp[0]} : std::vector<size_t>{};
// Compute batch strides in bytes for data and scale buffers.
size_t dataBatchBytesA = (num_batches[i] > 1) ? elementsToBytes(stride_a[i], TiA) : 0;
size_t scaleBatchBytesA = (num_batches[i] > 1) ? size_scaleAVec[i] : 0;
// Generate MX data for each batch and collect reference floats
std::vector<float> refAAll;
refAAll.reserve(static_cast<size_t>(A_row[i]) * A_col[i] * num_batches[i]);
for(int64_t b = 0; b < num_batches[i]; b++)
{
auto* dataPtr = reinterpret_cast<uint8_t*>(hA[i].buf()) + b * dataBatchBytesA;
auto* scalePtr = reinterpret_cast<uint8_t*>(hScaleA[i].buf()) + b * scaleBatchBytesA;
auto batchRef = generateMXInput(TiA,
scaleDataType(arg.scaleA),
dataPtr,
scalePtr,
A_row[i],
A_col[i],
lda[i],
transA == HIPBLAS_OP_T,
preSwizzleSizeForScale(arg.scaleA),
preTileA,
blockSize(arg.scaleA),
1,
true,
hipblaslt_initialization2string(arg.initialization));
refAAll.insert(refAAll.end(), batchRef.begin(), batchRef.end());
}
refA.emplace_back(std::move(refAAll));
// Copy data and scale to device buffers
CHECK_HIP_ERROR(synchronize(dA[i], hA[i], block_count));
CHECK_HIP_ERROR(synchronize(dScaleA[i], hScaleA[i], block_count));
Expand Down Expand Up @@ -2609,20 +2631,33 @@ void testing_matmul_with_bias(const Arguments& arg,
// input data (consists of data part and scale part)
// preTile for B: {tileK, tileN}
auto preTileB = preTileSizeForScaleB(arg.scaleB);
refB.emplace_back(generateMXInput(TiB,
scaleDataType(arg.scaleB),
hB[i].buf(),
hScaleB[i].buf(),
B_row[i],
B_col[i],
ldb[i],
transB == HIPBLAS_OP_T,
preSwizzleSizeForScale(arg.scaleB),
preTileB,
1,
blockSize(arg.scaleB),
false,
hipblaslt_initialization2string(arg.initialization)));
// Compute batch strides in bytes for data and scale buffers.
size_t dataBatchBytesB = (num_batches[i] > 1) ? elementsToBytes(stride_b[i], TiB) : 0;
size_t scaleBatchBytesB = (num_batches[i] > 1) ? size_scaleBVec[i] : 0;
// Generate MX data for each batch and collect reference floats
std::vector<float> refBAll;
refBAll.reserve(static_cast<size_t>(B_row[i]) * B_col[i] * num_batches[i]);
for(int64_t b = 0; b < num_batches[i]; b++)
{
auto* dataPtr = reinterpret_cast<uint8_t*>(hB[i].buf()) + b * dataBatchBytesB;
auto* scalePtr = reinterpret_cast<uint8_t*>(hScaleB[i].buf()) + b * scaleBatchBytesB;
auto batchRef = generateMXInput(TiB,
scaleDataType(arg.scaleB),
dataPtr,
scalePtr,
B_row[i],
B_col[i],
ldb[i],
transB == HIPBLAS_OP_T,
preSwizzleSizeForScale(arg.scaleB),
preTileB,
1,
blockSize(arg.scaleB),
false,
hipblaslt_initialization2string(arg.initialization));
refBAll.insert(refBAll.end(), batchRef.begin(), batchRef.end());
}
refB.emplace_back(std::move(refBAll));
// Copy data and scale to device buffers
CHECK_HIP_ERROR(synchronize(dB[i], hB[i], block_count));
CHECK_HIP_ERROR(synchronize(dScaleB[i], hScaleB[i], block_count));
Expand Down Expand Up @@ -4595,7 +4630,7 @@ void testing_matmul_with_bias(const Arguments& arg,
lda[gemmIdx],
isScaleBMXFormat
? reinterpret_cast<char*>(refB[gemmIdx].data())
+ stride_a[gemmIdx] * batchIdx * realDataTypeSize(HIP_R_32F)
+ stride_b[gemmIdx] * batchIdx * realDataTypeSize(HIP_R_32F)
: hB[gemmIdx].as<char>()
+ stride_b[gemmIdx] * batchIdx * realDataTypeSize(TiB),
ldb[gemmIdx],
Expand Down
1 change: 1 addition & 0 deletions projects/hipblaslt/clients/common/src/mxDataGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ std::vector<float> generateMXInput(hipDataType dataType,
opt.min = initMethod == "uniform_01" ? 0. : (initMethod == "hpl" ? -.5 : min_val);
opt.max = initMethod == "uniform_01" ? 1. : (initMethod == "hpl" ? .5 : max_val);
opt.blockScaling = scaleBlockRowSize * scaleBlockColSize;
opt.forceDenorm = false;

// Map string initMethod to DataInitMode
if(initMethod == "Sequential")
Expand Down
Loading
Loading