Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,23 @@ void run_function(const func_map& map, const Arguments& arg, const std::string&

// Template to dispatch testing_matmul for performance tests
// the test is marked invalid when (TiA, TiB, To, Tc) not in (H/H/S, B/B/S)
template <typename TiA, typename TiB = TiA, typename To = TiB, typename Tc = To, typename = void>
template <typename TiA,
typename TiB = TiA,
typename To = TiB,
typename Tc = To,
typename Tci = TiA,
typename = void>
struct perf_matmul : hipblaslt_test_invalid
{
};

template <typename TiA, typename TiB, typename To, typename Tc>
template <typename TiA, typename TiB, typename To, typename Tc, typename Tci>
struct perf_matmul<
TiA,
TiB,
To,
Tc,
Tci,
std::enable_if_t<(std::is_same<TiA, hipblasLtHalf>{} && std::is_same<TiB, hipblasLtHalf>{})
|| (std::is_same<TiA, hip_bfloat16>{} && std::is_same<TiB, hip_bfloat16>{})
|| (std::is_same<TiA, float>{} && std::is_same<TiB, float>{})
Expand All @@ -100,7 +106,7 @@ struct perf_matmul<
{
void operator()(const Arguments& arg)
{
static const func_map map = {{"matmul", testing_matmul<TiA, TiB, To, Tc>}};
static const func_map map = {{"matmul", testing_matmul<TiA, TiB, To, Tc, Tci>}};
run_function(map, arg);
}
};
Expand Down
1,851 changes: 734 additions & 1,117 deletions clients/common/cblas_interface.cpp

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions clients/gtest/matmul_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,21 @@ namespace
typename TiB = TiA,
typename To = TiB,
typename Tc = To,
typename Tci = TiA,
typename = void>
struct matmul_testing : hipblaslt_test_invalid
{
};

// When Ti = To = Tc != void, this test applies.
// When converted to bool, this functor returns true.
template <typename TiA, typename TiB, typename To, typename Tc>
template <typename TiA, typename TiB, typename To, typename Tc, typename Tci>
struct matmul_testing<
TiA,
TiB,
To,
Tc,
Tci,
std::enable_if_t<
(std::is_same<TiA, hipblasLtHalf>{} && std::is_same<TiB, hipblasLtHalf>{})
|| (std::is_same<TiA, hip_bfloat16>{} && std::is_same<TiB, hip_bfloat16>{})
Expand All @@ -75,9 +77,9 @@ namespace
void operator()(const Arguments& arg)
{
if(!strcmp(arg.function, "matmul"))
testing_matmul<TiA, TiB, To, Tc>(arg);
testing_matmul<TiA, TiB, To, Tc, Tci>(arg);
else if(!strcmp(arg.function, "matmul_bad_arg"))
testing_matmul_bad_arg<TiA, TiB, To, Tc>(arg);
testing_matmul_bad_arg<TiA, TiB, To, Tc, Tci>(arg);
else
FAIL() << "Internal error: Test called with unknown function: " << arg.function;
}
Expand Down Expand Up @@ -162,11 +164,11 @@ namespace

if(arg.scaleC)
name << "_SC";

if(arg.scaleD)
name << "_SD";

if (arg.scaleE)
if(arg.scaleE)
name << "_SAux";

if(arg.scaleAlpha_vector)
Expand Down
23 changes: 4 additions & 19 deletions clients/include/cblas_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
*/

// gemm
template <typename TiA, typename TiB, typename To, typename Tc>
template <typename TiA, typename TiB, typename To, typename Tc, typename Tci = TiA>
void cblas_gemm(hipblasOperation_t transA,
hipblasOperation_t transB,
int64_t m,
Expand All @@ -50,23 +50,8 @@ void cblas_gemm(hipblasOperation_t transA,
Tc beta,
std::add_pointer_t<To> C,
int64_t ldc,
const Tc* AlphaVec,
Tc scaleA,
Tc scaleB,
Tc scaleD,
bool alt = false);

template <typename TiA, typename TiB, typename To, typename Tc>
void cblas_gemm_alphascale(hipblasOperation_t transA,
hipblasOperation_t transB,
int64_t m,
int64_t n,
int64_t k,
Tc alpha,
const TiA* A,
int64_t lda,
const TiB* B,
int64_t ldb,
Tc beta,
std::add_pointer_t<To> C,
int64_t ldc,
const Tc* AlphaVec,
Tc scaleD,
bool alt = false);
72 changes: 38 additions & 34 deletions clients/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ auto _dgelu = [](auto in, auto /*arg1*/, auto /*arg2*/) -> decltype(in) {
return static_cast<decltype(in)>(0.5f * tanh(xx) + x1 * x2 + 0.5f);
};

template <typename TiA, typename TiB, typename To, typename Tc>
template <typename TiA, typename TiB, typename To, typename Tc, typename Tci>
void testing_matmul_bad_arg(const Arguments& arg)
{
const int64_t M = 128;
Expand Down Expand Up @@ -220,7 +220,7 @@ void testing_matmul_bad_arg(const Arguments& arg)
hipStream_t stream = nullptr;
}

template <typename TiA, typename TiB, typename To, typename Tc>
template <typename TiA, typename TiB, typename To, typename Tc, typename Tci>
void testing_matmul(const Arguments& arg)
{
double gpu_time_used, cpu_time_used;
Expand Down Expand Up @@ -853,7 +853,7 @@ void testing_matmul(const Arguments& arg)

for(int gemmIdx = 0; gemmIdx < gemm_count; gemmIdx++)
{
auto bias_type = static_cast<hipblasltDatatype_t>(0);
auto bias_type = HIPBLASLT_DATATYPE_INVALID;
void* bias_addr = nullptr;
if(arg.bias_vector)
{
Expand Down Expand Up @@ -1418,28 +1418,28 @@ void testing_matmul(const Arguments& arg)
scaleEValue, applyBias
for(int gemmIdx = 0; gemmIdx < gemm_count; gemmIdx++)
{
auto alphaTemp = h_alpha[gemmIdx];
auto betaTemp = h_beta[gemmIdx];
if(arg.scaleA)
alphaTemp *= (*hScaleA[gemmIdx])[0];
if(arg.scaleB)
alphaTemp *= (*hScaleB[gemmIdx])[0];
auto alpha = h_alpha[gemmIdx];
auto betaTemp = h_beta[gemmIdx];
if(arg.scaleC)
betaTemp *= (*hScaleC[gemmIdx])[0];
auto scaleAValue = arg.scaleA ? (*hScaleA[gemmIdx])[0] : 1;
auto scaleBValue = arg.scaleB ? (*hScaleB[gemmIdx])[0] : 1;
auto scaleDValue = arg.scaleD ? (*hScaleD[gemmIdx])[0] : 1;
auto scaleEValue = arg.scaleE ? (*hScaleE[gemmIdx])[0] : 1;

for(int batchIdx = 0; batchIdx < num_batches[gemmIdx]; batchIdx++)
{
if(epilogue_on[gemmIdx])
{
if(arg.scaleAlpha_vector)
{
cblas_gemm_alphascale<TiA, TiB, Talpha, Talpha>(
cblas_gemm<TiA, TiB, Talpha, Talpha, Tci>(
transA,
transB,
M[gemmIdx],
N[gemmIdx],
K[gemmIdx],
alphaTemp,
alpha,
*(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx,
lda[gemmIdx],
*(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx,
Expand All @@ -1448,34 +1448,37 @@ void testing_matmul(const Arguments& arg)
*(hD_gold_epl[gemmIdx]) + stride_d[gemmIdx] * batchIdx,
ldd[gemmIdx],
*(hScaleAlphaVec[gemmIdx]) + 0,
scaleAValue,
scaleBValue,
1,
false);
}
else
{
cblas_gemm<TiA, TiB, Talpha, Talpha>(
cblas_gemm<TiA, TiB, Talpha, Talpha, Tci>(
transA,
transB,
M[gemmIdx],
N[gemmIdx],
K[gemmIdx],
alphaTemp,
alpha,
*(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx,
lda[gemmIdx],
*(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx,
ldb[gemmIdx],
betaTemp,
*(hD_gold_epl[gemmIdx]) + stride_d[gemmIdx] * batchIdx,
ldd[gemmIdx],
nullptr,
scaleAValue,
scaleBValue,
1,
false);
}
auto pos = stride_d[gemmIdx] * batchIdx;
auto hEInst = arg.gradient ? hE : hE_gold;
auto ePos = (hEInst[gemmIdx] == nullptr) ? nullptr : (*(hEInst[gemmIdx]) + pos);
auto scaleDValue = arg.scaleD ? (*hScaleD[gemmIdx])[0] : 1;
auto scaleEValue = arg.scaleE ? (*hScaleE[gemmIdx])[0] : 1;
auto applyBias = arg.gradient ? false : arg.bias_vector;
auto applyBias = arg.gradient ? false : arg.bias_vector;

if(change_bias_type[gemmIdx] == false)
{
Expand Down Expand Up @@ -1652,24 +1655,25 @@ void testing_matmul(const Arguments& arg)
}
else
{
auto scaleDValue = arg.scaleD ? (*hScaleD[gemmIdx])[0] : 1;

cblas_gemm<TiA, TiB, To, Talpha>(transA,
transB,
M[gemmIdx],
N[gemmIdx],
K[gemmIdx],
alphaTemp,
*(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx,
lda[gemmIdx],
*(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx,
ldb[gemmIdx],
betaTemp,
*(hD_gold[gemmIdx])
+ stride_d[gemmIdx] * batchIdx,
ldd[gemmIdx],
scaleDValue,
false);
cblas_gemm<TiA, TiB, To, Talpha, Tci>(
transA,
transB,
M[gemmIdx],
N[gemmIdx],
K[gemmIdx],
alpha,
*(hA[gemmIdx]) + stride_a[gemmIdx] * batchIdx,
lda[gemmIdx],
*(hB[gemmIdx]) + stride_b[gemmIdx] * batchIdx,
ldb[gemmIdx],
betaTemp,
*(hD_gold[gemmIdx]) + stride_d[gemmIdx] * batchIdx,
ldd[gemmIdx],
nullptr,
scaleAValue,
scaleBValue,
scaleDValue,
false);
}
}
}
Expand Down
54 changes: 42 additions & 12 deletions clients/include/type_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,45 +149,75 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
{
return TEST<hipblaslt_f8, hipblaslt_bf8, hipblasLtHalf, float>{}(arg);
}
/*
else if(Ti == HIPBLASLT_R_8I && To == HIPBLASLT_R_8I && Tc == HIPBLASLT_COMPUTE_I32)
{
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t>{}(arg);
}
*/
else if(TiA == HIPBLASLT_R_8I && To == HIPBLASLT_R_32I && Tc == HIPBLASLT_COMPUTE_I32)
{
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t, int32_t>{}(arg);
}
else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_8F_E4M3
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
{
return TEST<hipblaslt_f8, hipblasLtHalf, hipblaslt_f8, float>{}(arg);
return TEST<hipblaslt_f8, hipblasLtHalf, hipblaslt_f8, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_8F_E4M3
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
{
return TEST<hipblasLtHalf, hipblaslt_f8, hipblaslt_f8, float>{}(arg);
return TEST<hipblasLtHalf, hipblaslt_f8, hipblaslt_f8, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
{
return TEST<hipblaslt_f8, hipblasLtHalf, hipblasLtHalf, float>{}(arg);
return TEST<hipblaslt_f8, hipblasLtHalf, hipblasLtHalf, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
{
return TEST<hipblasLtHalf, hipblaslt_f8, hipblasLtHalf, float>{}(arg);
return TEST<hipblasLtHalf, hipblaslt_f8, hipblasLtHalf, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
{
return TEST<hipblaslt_f8, hipblasLtHalf, float, float>{}(arg);
return TEST<hipblaslt_f8, hipblasLtHalf, float, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
{
return TEST<hipblasLtHalf, hipblaslt_f8, float, float>{}(arg);
return TEST<hipblasLtHalf, hipblaslt_f8, float, float, hipblasLtHalf>{}(arg);
}
/*
else if(Ti == HIPBLASLT_R_8I && To == HIPBLASLT_R_8I && Tc == HIPBLASLT_COMPUTE_I32)
else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_8F_E4M3
&& Tc == HIPBLASLT_COMPUTE_F32)
{
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t>{}(arg);
return TEST<hipblaslt_f8, hipblasLtHalf, hipblaslt_f8, float, hipblaslt_f8>{}(arg);
}
*/
else if(TiA == HIPBLASLT_R_8I && To == HIPBLASLT_R_32I && Tc == HIPBLASLT_COMPUTE_I32)
else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_8F_E4M3
&& Tc == HIPBLASLT_COMPUTE_F32)
{
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t, int32_t>{}(arg);
return TEST<hipblasLtHalf, hipblaslt_f8, hipblaslt_f8, float, hipblaslt_f8>{}(arg);
}
else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
{
return TEST<hipblaslt_f8, hipblasLtHalf, hipblasLtHalf, float, hipblaslt_f8>{}(arg);
}
else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
{
return TEST<hipblasLtHalf, hipblaslt_f8, hipblasLtHalf, float, hipblaslt_f8>{}(arg);
}
else if(TiA == HIPBLASLT_R_8F_E4M3 && TiB == HIPBLASLT_R_16F && To == HIPBLASLT_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
{
return TEST<hipblaslt_f8, hipblasLtHalf, float, float, hipblaslt_f8>{}(arg);
}
else if(TiA == HIPBLASLT_R_16F && TiB == HIPBLASLT_R_8F_E4M3 && To == HIPBLASLT_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
{
return TEST<hipblasLtHalf, hipblaslt_f8, float, float, hipblaslt_f8>{}(arg);
}
}
return TEST<void>{}(arg);
Expand Down
8 changes: 6 additions & 2 deletions library/include/hipblaslt-ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ namespace hipblaslt_ext
{
hipblasLtEpilogue_t mode
= HIPBLASLT_EPILOGUE_DEFAULT; //!< The mode of epilogue. Default is gemm.
hipblasltDatatype_t bias_data_type = static_cast<hipblasltDatatype_t>(
0); //!< The bias datatype. Only works if mode is set to bias related epilogues.
hipblasltDatatype_t bias_data_type
= HIPBLASLT_DATATYPE_INVALID; //!< The bias datatype. Only works if mode is set to bias related epilogues.
int aux_ld
= 0; //!< The aux leading dimension. Only works if mode is set to aux related epilogues.
int aux_stride
Expand Down Expand Up @@ -167,6 +167,10 @@ namespace hipblaslt_ext
int8_t alpha[16]; //!< The alpha value.
int8_t beta[16]; //!< The beta value.
// Epilogue inputs
void* scaleA; //!< The scaleA input pointer.
void* scaleB; //!< The scaleA input pointer.
void* scaleC; //!< The scaleC input pointer.
void* scaleD; //!< The scaleD input pointer.
void* scaleAlphaVec; //!< The scaleAlpha vector input pointer.
void* bias; //!< The bias input pointer.
int biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues.
Expand Down
Loading