Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c6eb7c3
comp v4 setup
Jan 17, 2025
3e0047a
add a file
Jan 17, 2025
cca67d1
Finished the coding of the feature, Compiler not in the way we suppos…
Jan 23, 2025
66a183d
Update some of the code to better format
Jan 24, 2025
4931698
get tback the restrict variable name, need to switch out to solve the…
Jan 27, 2025
71352c4
Solve the compiler issue on SHMEM conflict
Jan 30, 2025
dec32dc
Finish the feature and merge with develop on the computeV2
Jan 31, 2025
d1e7177
roll back to compute pipeline
Jan 31, 2025
b2c7d77
Add the changes from include/ck_tile
Jan 31, 2025
6db81a1
Address the comments
Feb 3, 2025
3b30146
pre-merge with the develop branch need to fix the bug
Feb 3, 2025
800cf89
Merge from internal (#1857)
illsilin Feb 4, 2025
987cc54
Finish the integration to develop and have the correct result
Feb 4, 2025
d1715c0
Fix the gtest compilation error
Feb 5, 2025
bd09b37
Fix the gemm_basic error
Feb 5, 2025
4db6526
clang format
Feb 5, 2025
6774dda
switch the default pipeline to V3
Feb 5, 2025
2bef550
restore cron trigger (#1863)
illsilin Feb 5, 2025
c2bb46f
fix the benchmark basic script
Feb 6, 2025
5bb041b
add vectorloads on non-k dim for memory pipelines (#1856)
jakpiase Feb 6, 2025
7409674
Solving the Review comments
Feb 6, 2025
feb656d
Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm …
kylasa Feb 6, 2025
b5d201d
CK Tile - small fix to hotloop scheduler & KPack value. (#1867)
aosewski Feb 7, 2025
34612ef
address the new comments
Feb 7, 2025
e3402c9
fix a small bug on the old
Feb 7, 2025
ae4243d
Add a host mx gemm reference kernel (#1864)
geyyer Feb 7, 2025
f49de49
External CI: enable amd-develop branch trigger (#1859)
danielsu-amd Feb 7, 2025
9ba504b
merge with the develop support the fp8 with computev4
Feb 7, 2025
4106dfa
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 7, 2025
2003487
Merge branch 'develop' into ck_tile/gemm_compute_v4
Feb 7, 2025
2154151
Solve FMHA error
Feb 8, 2025
96b135f
clang format
Feb 8, 2025
df6042c
Fix the memory pipleine
Feb 8, 2025
884a2f7
Merge branch 'develop' into ck_tile/gemm_compute_v4
illsilin Feb 10, 2025
a9df418
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 10, 2025
ef2b53a
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 12, 2025
7dc420a
Solve merge conflict and add the gtest for compv4
Feb 12, 2025
4658f2f
Merge branch 'develop' of https://github.com/ROCm/composable_kernel i…
Feb 12, 2025
2672ead
Merge branch 'develop' into ck_tile/gemm_compute_v4
Feb 12, 2025
1160b99
sync with develop
Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/ck_tile/03_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options(tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
11 changes: 10 additions & 1 deletion example/ck_tile/03_gemm/gemm_basic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#define CK_TILE_PIPELINE_COMPUTE 1
Comment thread
ThomasNing marked this conversation as resolved.
Outdated
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V2 3
Comment thread
ThomasNing marked this conversation as resolved.
Outdated

#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
Expand All @@ -22,10 +23,17 @@
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
constexpr bool isDoubleSmemBuffer = false;
Comment thread
ThomasNing marked this conversation as resolved.
Outdated
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr bool isDoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr bool isDoubleSmemBuffer = true;
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
Expand Down Expand Up @@ -89,7 +97,8 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
Expand Down
17 changes: 14 additions & 3 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
Expand All @@ -106,9 +107,19 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));

// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
if (init_method == 0) {
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
} else if (init_method == 1) {
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else if (init_method == 2) {
ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
} else {
a_m_k.SetZero();
b_k_n.SetZero();
}

ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
Expand Down
46 changes: 38 additions & 8 deletions example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;

constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
Expand All @@ -42,6 +42,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;

#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;

constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#endif

constexpr bool kPadM = false;
Expand All @@ -63,9 +78,16 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;

using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using Traits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, isDoubleSmemBuffer, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
isDoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;

Expand Down Expand Up @@ -93,10 +115,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v,
tail_number_v>;

using GemmPipeline =
GEMM_PIPELINE<UniversalGemmProblem, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);

const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
Expand Down Expand Up @@ -196,6 +217,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
Comment thread
ThomasNing marked this conversation as resolved.
#endif
}
else
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
Expand All @@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
2 changes: 1 addition & 1 deletion include/ck_tile/ops/batched_transpose.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 2 additions & 0 deletions include/ck_tile/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
Expand Down
32 changes: 17 additions & 15 deletions include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;

static constexpr index_t KPack = WarpGemm::kKPerThread;

CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
Expand All @@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});

return a_block_dstr_encode;
}
Expand All @@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});

return b_block_dstr_encode;
}
Expand All @@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});

return c_block_dstr_encode;
}
Expand Down Expand Up @@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");

using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;

using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;

constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
Expand Down Expand Up @@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));

// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);

// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
Expand All @@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1
sequence<0, 0>>{};

constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
Expand Down
Loading