Skip to content

Commit

Permalink
[GPU] Fix gemm_tiled_opt kernel bug for tile_n_size 32 (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#23776)

### Details:
- Fixed crash and accuracy issue for n_tile_size 32 + transposed input
for static shape
- Fixed gemm_tiled_opt test to apply more various combinations & added
more TCs

### Tickets:
 - 137358
  • Loading branch information
yeonbok authored Apr 1, 2024
1 parent 91d922b commit 51906fe
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,12 @@ KERNEL(gemm_tiled_opt)(
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read[subtile_k_id], simd_local_id)),
b_tile[subtile_k_id * SIMD_WIDTH + simd_local_id], c_tile[dot_id]);
#else // TILE_K > SIMD_WIDTH
#if IS_DYNAMIC && B_VEC_SIZE > 1
#if TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
#if B_VEC_SIZE > 1 && TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
MAKE_VECTOR_TYPE(INPUT1_TYPE, B_VEC_SIZE) b_tile_tmp;
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
b_tile_tmp[b_elem] = b_tile[b_elem][simd_local_id];
}
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_local_id)), b_tile_tmp, c_tile[dot_id]);
#else
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_local_id)), b_tile[simd_local_id], c_tile[dot_id]);
#endif
#else
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_local_id)), b_tile[simd_local_id], c_tile[dot_id]);
#endif
Expand Down Expand Up @@ -464,7 +460,15 @@ KERNEL(gemm_tiled_opt)(
// Tile C calculation for TN, TT cases
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
unroll_for (uint simd_local_id = 0; simd_local_id < SIMD_WIDTH; simd_local_id++) {
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_tile[dot_id], simd_local_id)), b_tile[simd_local_id], c_tile[dot_id]);
#if B_VEC_SIZE > 1 && TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
MAKE_VECTOR_TYPE(INPUT1_TYPE, B_VEC_SIZE) b_tile_tmp;
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
b_tile_tmp[b_elem] = b_tile[b_elem][simd_local_id];
}
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_tile[dot_id], simd_local_id)), b_tile_tmp, c_tile[dot_id]);
#else
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_tile[dot_id], simd_local_id)), b_tile[simd_local_id], c_tile[dot_id]);
#endif
}
} // Tile C calculation for TN, TT cases end
#endif // !TRANSPOSE_INPUT0
Expand Down
Loading

0 comments on commit 51906fe

Please sign in to comment.