Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2dfbb22
Change the return type of run_gemm_combinations in the basic tests
SamiAario-AMD Aug 22, 2025
cf35f78
Change the return type of run_gemm_combinations in the universal tests
SamiAario-AMD Aug 22, 2025
1b115f3
Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4
SamiAario-AMD Aug 22, 2025
022d658
Add universal GEMM test for fp8 x pk_i4
SamiAario-AMD Aug 22, 2025
729f631
Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4.
SamiAario-AMD Aug 22, 2025
0a3d51b
Add missing GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_til…
SamiAario-AMD Aug 25, 2025
3b3a8e3
Add missing GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_ti…
SamiAario-AMD Aug 25, 2025
966792a
No need for utility in test_ck_tile_elementwise_1d
SamiAario-AMD Aug 28, 2025
7d7fd19
Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8
SamiAario-AMD Aug 28, 2025
80b9291
Avoid union-based type punning in float_to_bf16_truc_raw to make it c…
SamiAario-AMD Sep 1, 2025
0281878
For consistency also make float_to_bf16_truc_nan_raw constexpr compli…
SamiAario-AMD Sep 1, 2025
105a381
Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF…
SamiAario-AMD Sep 2, 2025
3b0e2d4
Convert from float to bf16 during compilation rather than using magic…
SamiAario-AMD Sep 1, 2025
7c3c17c
Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8
SamiAario-AMD Aug 28, 2025
bf810ef
Comment out the basic test for fp16 x pk_i4 as it does not pass
SamiAario-AMD Sep 2, 2025
94d7c9d
Add missing GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_til…
SamiAario-AMD Sep 5, 2025
0db0bc1
Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8
SamiAario-AMD Sep 5, 2025
3a89e68
Add basic and universal GEMM tests for bf8 x pk_i4
SamiAario-AMD Sep 5, 2025
5d561f8
Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it wor…
SamiAario-AMD Sep 10, 2025
da38672
Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it wor…
SamiAario-AMD Sep 10, 2025
53fc541
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 11, 2025
04061ab
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 12, 2025
e08b285
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 16, 2025
a786741
Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary…
SamiAario-AMD Sep 10, 2025
abab85e
Use explicit macros for enabling and disabling the the constexpr look…
SamiAario-AMD Sep 16, 2025
00f3792
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 17, 2025
9521ec2
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 18, 2025
c5e3701
Fix two failing tests
SamiAario-AMD Sep 18, 2025
1f9d21d
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 18, 2025
c9f3e13
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 18, 2025
da546c4
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 19, 2025
9f07656
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 19, 2025
1ef29c8
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 19, 2025
98f005b
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 24, 2025
95d1b3e
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 24, 2025
472c6a7
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 24, 2025
6892004
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 24, 2025
ea464d7
Avoid union-based type punning in float_to_bf16_rtn_raw to make it co…
SamiAario-AMD Sep 24, 2025
4586d30
Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16…
SamiAario-AMD Sep 24, 2025
d917303
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 25, 2025
28c69dc
On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t
SamiAario-AMD Sep 26, 2025
6a37228
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 26, 2025
55b5455
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 26, 2025
00f364c
Merge branch 'develop' into LWPCK-3548
SamiAario-AMD Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions include/ck_tile/core/numeric/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,8 @@ using bf16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
uint32_t bits = bit_cast<uint32_t>(f);
if(~bits & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
Expand All @@ -140,9 +136,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
else if(bits & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
Expand All @@ -152,9 +148,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
bits |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
return uint16_t(bits >> 16);
}

CK_TILE_HOST
Expand Down Expand Up @@ -225,24 +221,16 @@ uint16_t float_to_bf16_rta_asm(float f)
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
}

// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16);
}

template <bf16_rounding_mode rounding>
Expand Down Expand Up @@ -287,7 +275,7 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
#if CK_TILE_USE_LLVM_BUILTIN_BF16
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,26 @@
#include <cstdint>
#include <type_traits>

#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0

namespace ck_tile {
namespace element_wise {

// Generalized constexpr lookup table generator
template <typename T, std::size_t N, typename F, std::size_t... Is>
constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
{
return {func(Is)...};
}

template <typename T, std::size_t N, typename F>
constexpr std::array<T, N> make_lookup_table(F&& func)
{
return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
}

/**
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
Expand Down Expand Up @@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
*/
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
// This approach fails validation in GEMM tests.
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);

static constexpr uint32_t fp32_base = 0x4B000000;
Expand All @@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));

return res;
#else
// Lookup table for bf16_t values corresponding to int4 values -8 to 7
constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
[](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });

return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
bf16_lookup_table[(q >> 16) & 0xf],
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 20) & 0xf]};
Comment on lines +174 to +176
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.

Suggested change
bf16_lookup_table[(q >> 16) & 0xf],
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 20) & 0xf]};
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 8) & 0xf],
bf16_lookup_table[(q >> 12) & 0xf]};

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

@SamiAario-AMD SamiAario-AMD Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is based on the existing layout and data extraction for pk_int4_t, used by i4_to_half4. Lines #49 and #50 in the function i4_to_half4 correspond to the behavior here, where the LO is formed from offsets at 0 and 16, and HI from offsets at 4 and 20.

#endif
}

#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
/**
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
*
Expand Down Expand Up @@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)

return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for fp8_t values corresponding to int4 values -8 to 7
constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
[](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });

return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
fp8_lookup_table[(q >> 16) & 0xf],
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 20) & 0xf]};
Comment on lines +252 to +254
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.

Suggested change
fp8_lookup_table[(q >> 16) & 0xf],
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 20) & 0xf]};
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 8) & 0xf],
fp8_lookup_table[(q >> 12) & 0xf]};

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The situation here is identical to the above.

}
#endif

CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
Expand All @@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
return res;
}

#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
/**
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
*
Expand Down Expand Up @@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)

return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for bf8_t values corresponding to int4 values -8 to 7
constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
[](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });

return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
bf8_lookup_table[(q >> 16) & 0xf],
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 20) & 0xf]};
Comment on lines +344 to +346
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.

Suggested change
bf8_lookup_table[(q >> 16) & 0xf],
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 20) & 0xf]};
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 8) & 0xf],
bf8_lookup_table[(q >> 12) & 0xf]};

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another identical instance of the use of the existing layout for pk_int4_t.

}
#endif

struct PassThroughPack8
{
Expand All @@ -300,17 +361,27 @@ struct PassThroughPack8
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bit shift should be 16, not 8. The original code had >> 16 which correctly extracts the upper 16 bits for the high half of the bf16x8_t. Shifting by 8 will cause incorrect data extraction.

Suggested change
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functionality of the new i4_to_bhalf4 was modeled after the existing i4_to_half4 function. These make use of the same layout and data extraction for pk_int4_t, and therefore the new operator for bf16x8_t shifts by the same amount as the existing operator for fp16x8_t (see lines #355 to #359).

Validation for both functions was added in test_gemm_pipeline_universal_bf16.cpp and test_gemm_pipeline_universal_fp16.cpp.

}

CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_fp8x4(bit_cast<int>(x));
y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
#endif
}

CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_bf8x4(bit_cast<int>(x));
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
Expand Down
5 changes: 1 addition & 4 deletions test/ck_tile/elementwise/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)
endif()
endif()
endif()
9 changes: 8 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::bf16_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
10 changes: 9 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,12 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
11 changes: 10 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,13 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::half_t>() && is_success;
#if 0
is_success =
run_gemm_combinations<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
#endif
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
10 changes: 9 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,12 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
4 changes: 2 additions & 2 deletions test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ bool run_gemm_test(int argc, char* argv[])
}

template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_combinations()
bool run_gemm_combinations()
{
// Define possible values for each parameter
std::vector<std::string> m_values = {"128", "1024"};
Expand Down Expand Up @@ -304,5 +304,5 @@ int run_gemm_combinations()
}
}
}
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
return is_success;
}
27 changes: 27 additions & 0 deletions test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,15 @@ struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
using CDataType = ck_tile::bf16_t;
};

template <>
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};

template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
Expand All @@ -281,6 +290,15 @@ struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};

template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};

template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
Expand All @@ -290,6 +308,15 @@ struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};

template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};

template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{
Expand Down
9 changes: 8 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,11 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::bf16_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
10 changes: 9 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
9 changes: 8 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,11 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
10 changes: 9 additions & 1 deletion test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"

int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}
Loading
Loading