Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c6eb7c3
comp v4 setup
Jan 17, 2025
3e0047a
add a file
Jan 17, 2025
cca67d1
Finished the coding of the feature, Compiler not in the way we suppos…
Jan 23, 2025
66a183d
Update some of the code to better format
Jan 24, 2025
4931698
get tback the restrict variable name, need to switch out to solve the…
Jan 27, 2025
71352c4
Solve the compiler issue on SHMEM conflict
Jan 30, 2025
dec32dc
Finish the feature and merge with develop on the computeV2
Jan 31, 2025
d1e7177
roll back to compute pipeline
Jan 31, 2025
b2c7d77
Add the changes from include/ck_tile
Jan 31, 2025
6db81a1
Address the comments
Feb 3, 2025
3b30146
pre-merge with the develop branch need to fix the bug
Feb 3, 2025
800cf89
Merge from internal (#1857)
illsilin Feb 4, 2025
987cc54
Finish the integration to develop and have the correct result
Feb 4, 2025
d1715c0
Fix the gtest compilation error
Feb 5, 2025
bd09b37
Fix the gemm_basic error
Feb 5, 2025
4db6526
clang format
Feb 5, 2025
6774dda
switch the default pipeline to V3
Feb 5, 2025
2bef550
restore cron trigger (#1863)
illsilin Feb 5, 2025
c2bb46f
fix the benchmark basic script
Feb 6, 2025
5bb041b
add vectorloads on non-k dim for memory pipelines (#1856)
jakpiase Feb 6, 2025
7409674
Solving the Review comments
Feb 6, 2025
feb656d
Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm …
kylasa Feb 6, 2025
b5d201d
CK Tile - small fix to hotloop scheduler & KPack value. (#1867)
aosewski Feb 7, 2025
34612ef
address the new comments
Feb 7, 2025
e3402c9
fix a small bug on the old
Feb 7, 2025
ae4243d
Add a host mx gemm reference kernel (#1864)
geyyer Feb 7, 2025
f49de49
External CI: enable amd-develop branch trigger (#1859)
danielsu-amd Feb 7, 2025
9ba504b
merge with the develop support the fp8 with computev4
Feb 7, 2025
4106dfa
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 7, 2025
2003487
Merge branch 'develop' into ck_tile/gemm_compute_v4
Feb 7, 2025
2154151
Solve FMHA error
Feb 8, 2025
96b135f
clang format
Feb 8, 2025
df6042c
Fix the memory pipleine
Feb 8, 2025
884a2f7
Merge branch 'develop' into ck_tile/gemm_compute_v4
illsilin Feb 10, 2025
a9df418
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 10, 2025
ef2b53a
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 12, 2025
7dc420a
Solve merge conflict and add the gtest for compv4
Feb 12, 2025
4658f2f
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 12, 2025
2672ead
Merge branch 'develop' into ck_tile/gemm_compute_v4
Feb 12, 2025
1160b99
sync with develop
Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/ck_tile/03_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options(tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
16 changes: 11 additions & 5 deletions example/ck_tile/03_gemm/gemm_basic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,26 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"

#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3

#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif

#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
Expand Down Expand Up @@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
Expand Down
17 changes: 14 additions & 3 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
Expand All @@ -122,9 +123,19 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));

// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
if (init_method == 0) {
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
} else if (init_method == 1) {
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else if (init_method == 2) {
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
} else {
a_m_k.SetZero();
b_k_n.SetZero();
}

ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
Expand Down
1 change: 0 additions & 1 deletion example/ck_tile/03_gemm/script/benchmark_basic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1


for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
Expand Down
48 changes: 42 additions & 6 deletions example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;

constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
Expand All @@ -48,6 +50,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;

constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;

constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;

constexpr bool DoubleSmemBuffer = true;
#endif

constexpr bool kPadM = false;
Expand All @@ -70,8 +90,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;

using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;

Expand Down Expand Up @@ -99,8 +125,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v,
tail_number_v>;

using GemmPipeline =
GEMM_PIPELINE<UniversalGemmProblem, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
Expand Down Expand Up @@ -140,7 +165,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&

if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
Expand Down Expand Up @@ -215,6 +240,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
Comment thread
ThomasNing marked this conversation as resolved.
#endif
}
else
Expand Down
116 changes: 73 additions & 43 deletions include/ck_tile/core/utility/transpose_vectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,52 +68,82 @@ struct transpose_vectors
}
else if constexpr(sizeof(S) == 1)
{
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
static_assert(((NX % 4 == 0 && NY % 4 == 0) || (NX % 2 == 0 && NY % 2 == 0)), "wrong!");

using S4 = array<S, 4>; // typename array<S, 4>::type;

// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// 4 int8x4 data from vx_tuple
const int32_t x_s4_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
const int32_t x_s4_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
const int32_t x_s4_2 =
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
const int32_t x_s4_3 =
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);

// transpose
int32_t t_s4_0, t_s4_1;
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;

constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602;

// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);

// 4 int8x4 data from vy_tuple
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
using S2 = array<S, 2>; // typename array<S, 4>::type;

if constexpr(NX % 4 == 0 && NY % 4 == 0)
{
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// 4 int8x4 data from vx_tuple
const int32_t x_s4_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
const int32_t x_s4_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
const int32_t x_s4_2 =
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
const int32_t x_s4_3 =
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);

// transpose
int32_t t_s4_0, t_s4_1;
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;

constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602;

// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) ->
// 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits
// first)
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);

// 4 int8x4 data from vy_tuple
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
});
});
});
}
else if constexpr(NX % 2 == 0 && NY % 2 == 0)
{
static_for<0, NY, 2>{}([&](auto ix) {
static_for<0, NX, 2>{}([&](auto iy) {
const int16_t x_s2_0 =
bit_cast<int16_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
const int16_t x_s2_1 =
bit_cast<int16_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;

const int32_t x0_32 = static_cast<int32_t>(x_s2_0 & 0xFFFF);
const int32_t x1_32 = static_cast<int32_t>(x_s2_1 & 0xFFFF);

const int32_t y_s2_0 = __builtin_amdgcn_perm(x1_32, x0_32, m0);
const int32_t y_s2_1 = __builtin_amdgcn_perm(x1_32, x0_32, m1);

vy_tuple(iy).template get_as<S2>()[ix / I2] =
bit_cast<S2>(static_cast<int16_t>(y_s2_0 & 0xFFFF));
vy_tuple(iy + I1).template get_as<S2>()[ix / I2] =
bit_cast<S2>(static_cast<int16_t>(y_s2_1 & 0xFFFF));
});
});
}
}
else
{
Expand Down
2 changes: 2 additions & 0 deletions include/ck_tile/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
Expand Down
Loading