Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0c4cf86
[CK_TILE] Add GetName functions for Gemm Kernels
aledudek Jan 2, 2025
ff5115b
[CK_TILE] Add GetName for grouped gemm
aledudek Jan 3, 2025
8adaf41
[CK_TILE] Add GetName for gemm - review changes
aledudek Jan 8, 2025
67ab389
Merge branch 'develop' into gemm_getname
aledudek Jan 8, 2025
67e305e
Merge branch 'develop' into getname_gemm
aledudek Jan 9, 2025
c32aa3f
Merge branch 'develop' into getname_gemm
aledudek Jan 14, 2025
08aad5e
[CK_TILE] Print also gemm problem pipeline and shape
aledudek Jan 14, 2025
56e6326
[CK_TILE] Print also GemmPipelineScheduler
aledudek Jan 22, 2025
8866825
Merge branch 'develop' into getname_gemm
aledudek Jan 22, 2025
615fd8f
[CK_TILE] GetName - fixed Scheduler <<operator visibility
aledudek Jan 28, 2025
6e7ac86
Merge branch 'develop' into getname_gemm
aledudek Jan 28, 2025
9258ed7
[CK_TILE] GetName info adjustments
aledudek Jan 30, 2025
5522165
Merge branch 'develop' into getname_gemm
aledudek Jan 30, 2025
75a3989
[CK_TILE] GetName post-merge fix
aledudek Jan 30, 2025
3416bef
[CK_TILE] GetName - add general concat function
aledudek Jan 31, 2025
afa81a7
[CK_TILE] GetName - small adjustments, format change
aledudek Feb 4, 2025
98b3516
Merge branch 'develop' into getname_gemm
aledudek Feb 4, 2025
d9c6a64
post merge develop fix
aledudek Feb 4, 2025
711b2dd
Remove commented code
aledudek Feb 4, 2025
800cf89
Merge from internal (#1857)
illsilin Feb 4, 2025
2bef550
restore cron trigger (#1863)
illsilin Feb 5, 2025
5bb041b
add vectorloads on non-k dim for memory pipelines (#1856)
jakpiase Feb 6, 2025
feb656d
Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm …
kylasa Feb 6, 2025
932071a
Extract prec_str and add separator to concat
aledudek Feb 6, 2025
510abd9
Merge branch 'develop' into getname_gemm
aledudek Feb 6, 2025
0f063f3
GetName add
aledudek Feb 6, 2025
b5d201d
CK Tile - small fix to hotloop scheduler & KPack value. (#1867)
aosewski Feb 7, 2025
ce51797
Merge branch 'develop' into getname_gemm
aledudek Feb 7, 2025
8976171
Merge branch 'develop' into getname_gemm
aledudek Feb 7, 2025
f4607cf
Merge branch 'getname_gemm' of https://github.com/ROCm/composable_ker…
aledudek Feb 7, 2025
fa3afaf
Resolve merge issues
aledudek Feb 7, 2025
7857894
Merge branch 'develop' into getname_gemm
aledudek Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&

if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/03_gemm/gemm_basic.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ int run_gemm_example_with_layouts(int argc,
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
Expand Down Expand Up @@ -229,7 +229,7 @@ int run_gemm_example_with_layouts(int argc,
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}

return pass;
Expand Down
7 changes: 5 additions & 2 deletions example/ck_tile/16_batched_gemm/batched_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre

if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;

std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,

if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}

return pass;
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
Expand Down
122 changes: 122 additions & 0 deletions include/ck_tile/host/concat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"

namespace ck_tile {

template <typename T>
struct IsCharArray : std::false_type
{
};

template <std::size_t N>
struct IsCharArray<char[N]> : std::true_type
{
};

template <std::size_t N>
struct IsCharArray<const char[N]> : std::true_type
{
};

template <std::size_t N>
struct IsCharArray<char (&)[N]> : std::true_type
{
};

template <std::size_t N>
struct IsCharArray<const char (&)[N]> : std::true_type
{
};

template <typename... Ts>
inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v<Ts, std::string_view> ||
IsCharArray<Ts>::value ||
std::is_same_v<Ts, char>)&&...);

template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
oss.str("");

(oss << ... << xs);
return oss.str();
}

template <std::size_t N>
[[nodiscard]] constexpr inline std::size_t getSize(char (&)[N]) noexcept
{
return N;
}

template <std::size_t N>
[[nodiscard]] constexpr inline std::size_t getSize(const char (&)[N]) noexcept
{
return N;
}

[[nodiscard]] constexpr inline std::size_t getSize(const char* s) noexcept
{
const char* end = s;
while(*end++ != 0) {}
return end - s - 1;
}

[[nodiscard]] constexpr inline std::size_t getSize(const char&) noexcept { return 1; }

[[nodiscard]] inline std::size_t getSize(const std::string& s) noexcept { return s.size(); }

[[nodiscard]] constexpr inline std::size_t getSize(const std::string_view& s) noexcept
{
return s.size();
}

template <typename... Ts>
auto concatInto(std::string& result, const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
{
const std::size_t space = (1 + ... + getSize(xs));
result.reserve(result.size() + space);
((result += xs), ...);
}

template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
{
std::string result;
concatInto(result, xs...);
return result;
}

// Function for types convertible to std::string_view
template <typename Sep, typename First, typename... Rest>
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
-> std::enable_if_t<AllConvertibleToStringView<First, Rest...>, std::string>
{
std::string result;
result += first;
((result += sep, result += rest), ...);
return result;
}

// Function for other types
template <typename Sep, typename First, typename... Rest>
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
-> std::enable_if_t<!AllConvertibleToStringView<First, Rest...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
oss.str("");
oss << first;
((oss << sep << rest), ...);
return oss.str();
}

} // namespace ck_tile
1 change: 1 addition & 0 deletions include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/batched_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
34 changes: 34 additions & 0 deletions include/ck_tile/ops/common/utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <iostream>
#include <string>

#include "ck_tile/core.hpp"

namespace ck_tile {
Comment thread
aledudek marked this conversation as resolved.

// clang-format off
template <typename T> struct typeToStr;
template <> struct typeToStr<float> { static constexpr const char * name = "fp32"; };
template <> struct typeToStr<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct typeToStr<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on

template <typename ADataType_, typename BDataType_>
std::string gemm_prec_str()
{
std::string base_str = std::string(typeToStr<ADataType_>::name);
if(!std::is_same_v<ADataType_, BDataType_>)
{
base_str += "_" + std::string(typeToStr<BDataType_>::name);
}
return base_str;
}

} // namespace ck_tile
1 change: 1 addition & 0 deletions include/ck_tile/ops/elementwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/flatmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/fmha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/fused_moe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
1 change: 1 addition & 0 deletions include/ck_tile/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
16 changes: 15 additions & 1 deletion include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"

namespace ck_tile {

Expand Down Expand Up @@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;

[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;

return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}

struct BatchedGemmKernelArgs : GemmKernelArgs
{
index_t batch_stride_A;
Expand Down
8 changes: 8 additions & 0 deletions include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"

namespace ck_tile {

Expand Down Expand Up @@ -75,6 +76,13 @@ struct GemmKernel
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();

[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}

CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
Expand Down
12 changes: 12 additions & 0 deletions include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
};

[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;

return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}

__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{
Expand Down
Loading