-
Notifications
You must be signed in to change notification settings - Fork 293
[CK_TILE] Fixing Type Conversions in PassThroughPack8 #2769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2dfbb22
cf35f78
1b115f3
022d658
729f631
0a3d51b
3b3a8e3
966792a
7d7fd19
80b9291
0281878
105a381
3b0e2d4
7c3c17c
bf810ef
94d7c9d
0db0bc1
3a89e68
5d561f8
da38672
53fc541
04061ab
e08b285
a786741
abab85e
00f3792
9521ec2
c5e3701
1f9d21d
c9f3e13
da546c4
9f07656
1ef29c8
98f005b
95d1b3e
472c6a7
6892004
ea464d7
4586d30
d917303
28c69dc
6a37228
55b5455
00f364c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,9 +7,26 @@ | |||||||||||||
| #include <cstdint> | ||||||||||||||
| #include <type_traits> | ||||||||||||||
|
|
||||||||||||||
| #define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1 | ||||||||||||||
| #define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0 | ||||||||||||||
| #define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0 | ||||||||||||||
|
|
||||||||||||||
| namespace ck_tile { | ||||||||||||||
| namespace element_wise { | ||||||||||||||
|
|
||||||||||||||
| // Generalized constexpr lookup table generator | ||||||||||||||
| template <typename T, std::size_t N, typename F, std::size_t... Is> | ||||||||||||||
| constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>) | ||||||||||||||
| { | ||||||||||||||
| return {func(Is)...}; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| template <typename T, std::size_t N, typename F> | ||||||||||||||
| constexpr std::array<T, N> make_lookup_table(F&& func) | ||||||||||||||
| { | ||||||||||||||
| return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{}); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Fast int4x4 to fp16x8_t data type conversion based on paper | ||||||||||||||
| * "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production" | ||||||||||||||
|
|
@@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale) | |||||||||||||
| */ | ||||||||||||||
| CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) | ||||||||||||||
| { | ||||||||||||||
| #if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16 | ||||||||||||||
| // This approach fails validation in GEMM tests. | ||||||||||||||
| uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); | ||||||||||||||
|
|
||||||||||||||
| static constexpr uint32_t fp32_base = 0x4B000000; | ||||||||||||||
|
|
@@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) | |||||||||||||
| __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); | ||||||||||||||
|
|
||||||||||||||
| return res; | ||||||||||||||
| #else | ||||||||||||||
| // Lookup table for bf16_t values corresponding to int4 values -8 to 7 | ||||||||||||||
| constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>( | ||||||||||||||
| [](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); }); | ||||||||||||||
|
|
||||||||||||||
| return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf], | ||||||||||||||
| bf16_lookup_table[(q >> 16) & 0xf], | ||||||||||||||
| bf16_lookup_table[(q >> 4) & 0xf], | ||||||||||||||
| bf16_lookup_table[(q >> 20) & 0xf]}; | ||||||||||||||
| #endif | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| #if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8 | ||||||||||||||
| /** | ||||||||||||||
| * @brief This function converts 8 packed 4-bit integers into 8 fp8 values. | ||||||||||||||
| * | ||||||||||||||
|
|
@@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) | |||||||||||||
|
|
||||||||||||||
| return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low); | ||||||||||||||
| } | ||||||||||||||
| #else | ||||||||||||||
| CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q) | ||||||||||||||
| { | ||||||||||||||
| // The approach below can be used once this compiler issue is resolved: | ||||||||||||||
| // "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported" | ||||||||||||||
| // Lookup table for fp8_t values corresponding to int4 values -8 to 7 | ||||||||||||||
| constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>( | ||||||||||||||
| [](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); }); | ||||||||||||||
|
|
||||||||||||||
| return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf], | ||||||||||||||
| fp8_lookup_table[(q >> 16) & 0xf], | ||||||||||||||
| fp8_lookup_table[(q >> 4) & 0xf], | ||||||||||||||
| fp8_lookup_table[(q >> 20) & 0xf]}; | ||||||||||||||
|
Comment on lines
+252
to
+254
|
||||||||||||||
| fp8_lookup_table[(q >> 16) & 0xf], | |
| fp8_lookup_table[(q >> 4) & 0xf], | |
| fp8_lookup_table[(q >> 20) & 0xf]}; | |
| fp8_lookup_table[(q >> 4) & 0xf], | |
| fp8_lookup_table[(q >> 8) & 0xf], | |
| fp8_lookup_table[(q >> 12) & 0xf]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The situation here is identical to the above.
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.
| bf8_lookup_table[(q >> 16) & 0xf], | |
| bf8_lookup_table[(q >> 4) & 0xf], | |
| bf8_lookup_table[(q >> 20) & 0xf]}; | |
| bf8_lookup_table[(q >> 4) & 0xf], | |
| bf8_lookup_table[(q >> 8) & 0xf], | |
| bf8_lookup_table[(q >> 12) & 0xf]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another identical instance of the use of the existing layout for pk_int4_t.
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit shift should be 16, not 8. The original code had >> 16 which correctly extracts the upper 16 bits for the high half of the bf16x8_t. Shifting by 8 will cause incorrect data extraction.
| y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8); | |
| y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functionality of the new i4_to_bhalf4 was modeled after the existing i4_to_half4 function. These make use of the same layout and data extraction for pk_int4_t, and therefore the new operator for bf16x8_t shifts by the same amount as the existing operator for fp16x8_t (see lines #355 to #359).
Validation for both functions was added in test_gemm_pipeline_universal_bf16.cpp and test_gemm_pipeline_universal_fp16.cpp.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,3 @@ | ||
| if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") | ||
| add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) | ||
| if(result EQUAL 0) | ||
| target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility) | ||
| endif() | ||
| endif() | ||
| endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, this is based on the existing layout and data extraction for
pk_int4_t, used byi4_to_half4. Lines #49 and #50 in the functioni4_to_half4correspond to the behavior here, where theLOis formed from offsets at 0 and 16, andHIfrom offsets at 4 and 20.