Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000)
add_compile_options("SHELL: -mllvm --lsr-drop-solution=1")
endif()
endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
if(HAS_ENABLE_POST_MISCHED)
message("Adding the enable-post-misched=0 compiler flag")
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
endif()
endif()
set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ using CDEElementOp = PassThrough;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;

static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;

Expand All @@ -65,26 +65,27 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
128, 128,
128, 16, 16,
16, 128,
256, 16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
1, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 16, 1, 16>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on

int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool flush_cache = true;

// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t M = 128;
ck::index_t N = 1024;
ck::index_t K = 1024;

ck::index_t StrideA = K;
ck::index_t StrideB = K;
Expand All @@ -100,7 +101,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
else if(argc == 8)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
Expand All @@ -110,16 +111,19 @@ int main(int argc, char* argv[])
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);

StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
flush_cache = std::stoi(argv[7]);

StrideA = K;
StrideB = K;
StrideE = N;
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
printf("arg4 to 6: M, N, K\n");
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
exit(0);
}

Expand Down Expand Up @@ -182,9 +186,15 @@ int main(int argc, char* argv[])
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 4:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
default:
Expand All @@ -194,6 +204,16 @@ int main(int argc, char* argv[])
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
}
#endif
#if 0
for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){
float row_sum = .0;
for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){
printf("%lf ",a1_m_k(im, ik));
row_sum += a1_m_k(im, ik);
}
printf("sum: %lf\n", row_sum * 128);
}
#endif

DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
Expand Down Expand Up @@ -239,12 +259,24 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;

float ave_time = .0;

if(flush_cache)
{
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;

ave_time = invoker.Run(argument,
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
}
else
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
}

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

float gb_per_sec = num_btype / 1.E6 / ave_time;
Expand Down
Loading