diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index b00944a18..e53565174 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -86,6 +86,7 @@ typedef __attribute__((__vector_size__(4 * sizeof(short)))) short bfloat16x4_vec; using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; +using float32x2 = __attribute__((__vector_size__(2 * sizeof(float)))) float; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; diff --git a/src/tl_templates/hip/gemm.h b/src/tl_templates/hip/gemm.h index 068d57a64..b24df9305 100644 --- a/src/tl_templates/hip/gemm.h +++ b/src/tl_templates/hip/gemm.h @@ -182,17 +182,25 @@ class GemmTensorOp { for (int i = 0; i < warp_rows; i++) { const auto l = warp_m * warp_row_tiles + i * micro_size_x; const auto r = ki * (kPack * micro_size_k); - for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { - if constexpr (TransposeA) { + if constexpr (TransposeA) { + for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { auto [row, col] = reverse_index_map_transposed(lane_id, local_id); A_local[i * kPack * local_size_a + local_id] = A_shared[make_swizzle_layout( r + row, l + col)]; - } else { + } + } else { + for (int local_id = 0; local_id < (kPack * local_size_a); local_id += (kPack * vec_size)) { auto [row, col] = reverse_index_map(lane_id, local_id); - A_local[i * kPack * local_size_a + local_id] = - A_shared[make_swizzle_layout( - l + row, r + col)]; + if constexpr (kPack == 1) { + *(float32x2*)(&A_local[i * kPack * local_size_a + local_id]) = + *(float32x2*)(&A_shared[make_swizzle_layout( + l + row, r + col)]); + } else { + *(float32x4*)(&A_local[i * kPack * local_size_a + local_id]) = + *(float32x4*)(&A_shared[make_swizzle_layout( + l + row, r + col)]); + } } } } @@ -200,13 +208,21 @@ class GemmTensorOp { for (int j = 0; j < warp_cols; j++) { const auto l = warp_n * warp_col_tiles + j * micro_size_y; const auto r = ki * (kPack * micro_size_k); - for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { - if constexpr (TransposeB) { + if constexpr (TransposeB) { + for (int local_id = 0; local_id < (kPack * local_size_b); local_id += (kPack * vec_size)) { auto [row, col] = reverse_index_map(lane_id, local_id); - B_local[j * kPack * local_size_b + local_id] = - B_shared[make_swizzle_layout( - l + row, r + col)]; - } else { + if constexpr (kPack == 1) { + *(float32x2*)(&B_local[j * kPack * local_size_b + local_id]) = + *(float32x2*)(&B_shared[make_swizzle_layout( + l + row, r + col)]); + } else { + *(float32x4*)(&B_local[j * kPack * local_size_b + local_id]) = + *(float32x4*)(&B_shared[make_swizzle_layout( + l + row, r + col)]); + } + } + } else { + for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { auto [row, col] = reverse_index_map_transposed(lane_id, local_id); B_local[j * kPack * local_size_b + local_id] = B_shared[make_swizzle_layout( @@ -257,13 +273,21 @@ class GemmTensorOp { for (int j = 0; j < warp_cols; j++) { const auto l = warp_n * warp_col_tiles + j * micro_size_y; const auto r = ki * kPack * micro_size_k; - for (int local_id = 0; local_id < kPack * local_size_b; local_id++) { - if constexpr (TransposeB) { + if constexpr (TransposeB) { + for (int local_id = 0; local_id < (kPack * local_size_b); local_id += (kPack * vec_size)) { auto [row, col] = reverse_index_map(lane_id, local_id); - B_local[j * kPack * local_size_b + local_id] = - B_shared[make_swizzle_layout( - l + row, r + col)]; - } else { + if constexpr (kPack == 1) { + *(float32x2*)(&B_local[j * kPack * local_size_b + local_id]) = + *(float32x2*)(&B_shared[make_swizzle_layout( + l + row, r + col)]); + } else { + *(float32x4*)(&B_local[j * kPack * local_size_b + local_id]) = + *(float32x4*)(&B_shared[make_swizzle_layout( + l + row, r + col)]); + } + } + } else { + for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) { auto [row, col] = reverse_index_map_transposed(lane_id, local_id); B_local[j * kPack * local_size_b + local_id] = B_shared[make_swizzle_layout( diff --git a/src/tl_templates/hip/hip_fp8.h b/src/tl_templates/hip/hip_fp8.h index 0000745b5..09bfe3e2e 100644 --- a/src/tl_templates/hip/hip_fp8.h +++ b/src/tl_templates/hip/hip_fp8.h @@ -127,3 +127,39 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, res.y = *reinterpret_cast(&b); return res; } + +__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, + fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, + fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3, + fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) { + signed char x0_char = *reinterpret_cast(&x0); + signed char x1_char = *reinterpret_cast(&x1); + signed char x2_char = *reinterpret_cast(&x2); + signed char x3_char = *reinterpret_cast(&x3); + signed char x4_char = *reinterpret_cast(&x4); + signed char x5_char = *reinterpret_cast(&x5); + signed char x6_char = *reinterpret_cast(&x6); + signed char x7_char = *reinterpret_cast(&x7); + signed char y0_char = *reinterpret_cast(&y0); + signed char y1_char = *reinterpret_cast(&y1); + signed char y2_char = *reinterpret_cast(&y2); + signed char y3_char = *reinterpret_cast(&y3); + signed char y4_char = *reinterpret_cast(&y4); + signed char y5_char = *reinterpret_cast(&y5); + signed char y6_char = *reinterpret_cast(&y6); + signed char y7_char = *reinterpret_cast(&y7); + int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char; + int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char; + int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char; + int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char; + fp8_e4_8_t res_x; + res_x.x = *reinterpret_cast(&a); + res_x.y = *reinterpret_cast(&b); + fp8_e4_8_t res_y; + res_y.x = *reinterpret_cast(&c); + res_y.y = *reinterpret_cast(&d); + fp8_e4_16_t res; + res.x = res_x; + res.y = res_y; + return res; +} \ No newline at end of file