Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions projects/hipblaslt/clients/gtest/matmul_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,25 @@ Tests:
unit_check: 1
gpu_arch: '942'

- name: matmul_extapi_swizzleA
category: pre_checkin
function:
matmul: *real_precisions_swizzleA_support
M: [128, 129]
N: [128, 129]
K: [128, 129]
lda: [129, 2440]
batch_count: [1, 5]
transA: T
transB: N
use_ext: [1]
use_ext_setproblem: [0, 1]
alpha: 1
beta: [ 0.0, 2.0 ]
bias_vector: [0, 1]
unit_check: 1
gpu_arch: '942'

### GFX12 FP8
- name: matmul_f8_bf8_dst_fp32_gfx12
category: pre_checkin
Expand Down
34 changes: 21 additions & 13 deletions projects/hipblaslt/clients/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,8 @@ void testing_matmul_with_bias(const Arguments& arg,

std::vector<void*> alpha_in(gemm_count);

bool do_swizzle_a = arg.swizzle_a && isSwizzleSupported(TiA);

// Need to split into two for loop to calculate the rotating buffer
int64_t totalRotatingSizeNeeded = 0;
for(int i = 0; i < gemm_count; i++)
Expand Down Expand Up @@ -1386,7 +1388,7 @@ void testing_matmul_with_bias(const Arguments& arg,
= stride_a[i] == 0 ? lda[i] * A_col[i] * num_batches[i] : stride_a[i] * num_batches[i];
size_dA[i] = size_A[i];
stride_da[i] = stride_a[i];
if(arg.swizzle_a && isSwizzleSupported(TiA))
if(do_swizzle_a)
{
//TODO:
// 1. support stride_a is 0 when swizzle
Expand Down Expand Up @@ -1491,7 +1493,7 @@ void testing_matmul_with_bias(const Arguments& arg,
CHECK_HIPBLASLT_ERROR(
hipblasLtMatrixLayoutCreate(&(matD[i]), arg.d_type, M[i], N[i], ldd[i]));

if(arg.swizzle_a && isSwizzleSupported(TiA))
if(do_swizzle_a)
{
hipblasLtOrder_t orderA = orderForDatatype(TiA);
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
Expand Down Expand Up @@ -1834,9 +1836,9 @@ void testing_matmul_with_bias(const Arguments& arg,
dA[i].buf(),
A_row[i],
A_col[i],
(arg.swizzle_a) ? A_row[i] : lda[i],
(do_swizzle_a) ? A_row[i] : lda[i],
TiA,
(arg.swizzle_a) ? A_row[i] * A_col[i] : stride_a[i],
(do_swizzle_a) ? A_row[i] * A_col[i] : stride_a[i],
num_batches[i]);
#ifdef HIPBLASLT_USE_ROCROLLER
}
Expand Down Expand Up @@ -1907,7 +1909,7 @@ void testing_matmul_with_bias(const Arguments& arg,
CHECK_HIP_ERROR(broadcast(dB[i], block_count));
CHECK_HIP_ERROR(broadcast(dC[i], block_count));

if(arg.unit_check || arg.norm_check || arg.allclose_check || arg.swizzle_a)
if(arg.unit_check || arg.norm_check || arg.allclose_check || do_swizzle_a)
{
CHECK_HIP_ERROR(synchronize(hA[i],
dA[i],
Expand All @@ -1916,12 +1918,12 @@ void testing_matmul_with_bias(const Arguments& arg,
A_col[i],
lda[i],
realDataTypeSize(TiA),
arg.swizzle_a));
do_swizzle_a));
CHECK_HIP_ERROR(synchronize(hB[i], dB[i]));
CHECK_HIP_ERROR(synchronize(hC[i], dC[i]));
}

if(arg.swizzle_a && isSwizzleSupported(TiA))
if(do_swizzle_a)
{
HipHostBuffer tmp(TiA, size_dA[i]);
swizzle_tensor_type(tmp, hA[i], TiA, arg, num_batches[i], M[i], K[i], lda[i], false);
Expand Down Expand Up @@ -2435,6 +2437,12 @@ void testing_matmul_with_bias(const Arguments& arg,
extproblemtype.setTypeC(arg.c_type);
extproblemtype.setTypeD(arg.d_type);
extproblemtype.setTypeCompute(arg.compute_type);

if(do_swizzle_a)
{
hipblasLtOrder_t orderA = orderForDatatype(TiA);
extproblemtype.setOrderA(orderA);
}
}
else if(arg.grouped_gemm)
{
Expand Down Expand Up @@ -2545,7 +2553,7 @@ void testing_matmul_with_bias(const Arguments& arg,
ldb[0],
ldc[0],
ldd[0],
stride_a[0],
(do_swizzle_a) ? stride_da[0] : stride_a[0],
stride_b[0],
stride_c[0],
stride_d[0],
Expand Down Expand Up @@ -2640,7 +2648,7 @@ void testing_matmul_with_bias(const Arguments& arg,
ldb,
ldc,
ldd,
stride_a,
(do_swizzle_a) ? stride_da : stride_a,
stride_b,
stride_c,
stride_d,
Expand Down Expand Up @@ -2738,7 +2746,7 @@ void testing_matmul_with_bias(const Arguments& arg,
ldb[0],
ldc[0],
ldd[0],
stride_a[0],
(do_swizzle_a) ? stride_da[0] : stride_a[0],
stride_b[0],
stride_c[0],
stride_d[0],
Expand Down Expand Up @@ -2842,7 +2850,7 @@ void testing_matmul_with_bias(const Arguments& arg,
ldb,
ldc,
ldd,
stride_a,
(do_swizzle_a) ? stride_da : stride_a,
stride_b,
stride_c,
stride_d,
Expand Down Expand Up @@ -2920,7 +2928,7 @@ void testing_matmul_with_bias(const Arguments& arg,
ldb[0],
ldc[0],
ldd[0],
stride_a[0],
(do_swizzle_a) ? stride_da[0] : stride_a[0],
stride_b[0],
stride_c[0],
stride_d[0],
Expand Down Expand Up @@ -3009,7 +3017,7 @@ void testing_matmul_with_bias(const Arguments& arg,
ldb,
ldc,
ldd,
stride_a,
(do_swizzle_a) ? stride_da : stride_a,
stride_b,
stride_c,
stride_d,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ namespace hipblaslt_ext
HIPBLASLT_EXPORT void setTypeD(hipDataType type); //!< Set the D matrix datatype.
HIPBLASLT_EXPORT void
setTypeCompute(hipblasComputeType_t type); //!< Set the compute datatype.
HIPBLASLT_EXPORT void setOrderA(hipblasLtOrder_t order); //!< Set the A martix data order.
HIPBLASLT_EXPORT void setOrderB(hipblasLtOrder_t order); //!< Set the B matrix data order.

HIPBLASLT_EXPORT hipblasOperation_t getOpA() const; //!< The A matrix transpose.
HIPBLASLT_EXPORT hipblasOperation_t getOpB() const; //!< The B matrix transpose.
Expand All @@ -142,6 +144,8 @@ namespace hipblaslt_ext
HIPBLASLT_EXPORT hipDataType getTypeC() const; //!< The C matrix datatype.
HIPBLASLT_EXPORT hipDataType getTypeD() const; //!< The D matrix datatype.
HIPBLASLT_EXPORT hipblasComputeType_t getTypeCompute() const; //!< The compute datatype.
HIPBLASLT_EXPORT hipblasLtOrder_t getOrderA() const; //!< The A matrix data order.
HIPBLASLT_EXPORT hipblasLtOrder_t getOrderB() const; //!< The B matrix data order.
};

[[deprecated("GemmProblemTypeV2 is deprecated, use GemmProblemType instead.")]]
Expand Down
35 changes: 31 additions & 4 deletions projects/hipblaslt/library/src/amd_detail/hipblaslt-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ namespace hipblaslt_ext
hipDataType type_c; //!< The C matrix datatype.
hipDataType type_d; //!< The D matrix datatype.
hipblasComputeType_t type_compute; //!< The compute datatype.
hipblasLtOrder_t order_a; //!< The A martix data layout order
hipblasLtOrder_t order_b; //!< The B martix data layout order
};

GemmProblemType::GemmProblemType()
Expand All @@ -106,6 +108,11 @@ namespace hipblaslt_ext
pimpl->type_c = typeC;
pimpl->type_d = typeD;
pimpl->type_compute = typeCompute;

// default value of order is COL despite of opA/B,
// currently only swizzle cases use the variables
pimpl->order_a = HIPBLASLT_ORDER_COL;
pimpl->order_b = HIPBLASLT_ORDER_COL;
}

GemmProblemType::~GemmProblemType() = default;
Expand Down Expand Up @@ -170,6 +177,16 @@ namespace hipblaslt_ext
pimpl->type_compute = type;
}

void GemmProblemType::setOrderA(hipblasLtOrder_t order)
{
pimpl->order_a = order;
}

void GemmProblemType::setOrderB(hipblasLtOrder_t order)
{
pimpl->order_b = order;
}

hipblasOperation_t GemmProblemType::getOpA() const
{
return pimpl->op_a;
Expand Down Expand Up @@ -205,6 +222,16 @@ namespace hipblaslt_ext
return pimpl->type_compute;
}

hipblasLtOrder_t GemmProblemType::getOrderA() const
{
return pimpl->order_a;
}

hipblasLtOrder_t GemmProblemType::getOrderB() const
{
return pimpl->order_b;
}

class GemmEpilogue::GemmEpilogueImpl
{
public:
Expand Down Expand Up @@ -307,12 +334,12 @@ namespace hipblaslt_ext

void GemmEpilogue::setAct0(float act0)
{
pimpl->act0 = act0;
pimpl->act0 = act0;
}

void GemmEpilogue::setAct1(float act1)
{
pimpl->act1 = act1;
pimpl->act1 = act1;
}

hipblasLtEpilogue_t GemmEpilogue::getMode() const
Expand Down Expand Up @@ -372,12 +399,12 @@ namespace hipblaslt_ext

float GemmEpilogue::getAct0()
{
return pimpl->act0;
return pimpl->act0;
}

float GemmEpilogue::getAct1()
{
return pimpl->act1;
return pimpl->act1;
}

class GemmTuning::GemmTuningImpl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ namespace rocblaslt
hipDataType type_c;
hipDataType type_d;
rocblaslt_compute_type type_compute;
hipblasLtOrder_t order_a;
hipblasLtOrder_t order_b;
};

class RocGemmEpilogueV2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,38 @@ inline bool isValidOrderForDatatype(hipDataType datatype, hipblasLtOrder_t order
return true;
}

/*******************************************************************************
* Calculate Default Swizzle-A-Batched-Stride
******************************************************************************/
inline void setDefaultSwizzledBatchedStride(const hipblasLtOrder_t& matOrder,
const uint64_t matUnrollDim,
const uint64_t matTileDim,
int64_t& batch_stride)
{
size_t MiM = 16, MiK = 0, MiKv = 0, PackK = 0;
Comment thread
solaslin marked this conversation as resolved.
Outdated
if(matOrder == HIPBLASLT_ORDER_COL16_4R8)
{
//f16
MiK = 16;
MiKv = 4;
PackK = 16 / MiKv / 2;
}
else if(matOrder == HIPBLASLT_ORDER_COL16_4R16)
{
//f8
MiK = 32;
MiKv = 8;
PackK = 16 / MiKv / 1;
}
else
return;

size_t K_block = MiK * PackK;
//align to k for swizzleK and to m for 16
batch_stride = ((matTileDim + MiM - 1) / MiM) * MiM * ((matUnrollDim + K_block - 1) / K_block)
* K_block;
}

/*******************************************************************************
* Validate Matmul Swizzle Arguments
******************************************************************************/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,6 @@ inline void assignAlphaBeta1(const rocblaslt_compute_type& compute_type, void* a
}
}

inline void setDefaultSwizzledBatchedStride(const rocblaslt_matrix_layout& matLayout,
int64_t& batch_stride)
{
size_t MiM = 16, MiK = 0, MiKv = 0, PackK = 0;
if(matLayout->order == HIPBLASLT_ORDER_COL16_4R8)
{
//f16
MiK = 16;
MiKv = 4;
PackK = 16 / MiKv / 2;
}
else if(matLayout->order == HIPBLASLT_ORDER_COL16_4R16)
{
//f8
MiK = 32;
MiKv = 8;
PackK = 16 / MiKv / 1;
}
else
return;

size_t K_block = MiK * PackK;
//align to k for swizzleK and to m for 16
batch_stride = ((matLayout->n + MiM - 1) / MiM) * MiM * ((matLayout->m + K_block - 1) / K_block)
* K_block;
}

inline void heuristicResult_copy(rocblaslt_matmul_heuristic_result* heuristicResultsDest,
rocblaslt_matmul_heuristic_result* heuristicResultsSrc,
Expand Down Expand Up @@ -357,8 +331,9 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle

if(swizzleA && matA->batch_stride == 0)
{
//If batch_stride has never been assigned for swizzle, set it to the default value
setDefaultSwizzledBatchedStride(matA, matA->batch_stride);
// If batch_stride has never been assigned for swizzle, set it to the default value
// Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim
setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride);
}
Comment thread
solaslin marked this conversation as resolved.
Outdated

rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr,
Expand Down
Loading