Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tl_templates/hip/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ typedef
__attribute__((__vector_size__(4 * sizeof(short)))) short bfloat16x4_vec;

using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x2 = __attribute__((__vector_size__(2 * sizeof(float)))) float;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;

Expand Down
60 changes: 42 additions & 18 deletions src/tl_templates/hip/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,31 +182,47 @@ class GemmTensorOp {
for (int i = 0; i < warp_rows; i++) {
const auto l = warp_m * warp_row_tiles + i * micro_size_x;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
if constexpr (TransposeA) {
if constexpr (TransposeA) {
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
A_local[i * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
r + row, l + col)];
} else {
}
} else {
for (int local_id = 0; local_id < (kPack * local_size_a); local_id += (kPack * vec_size)) {
auto [row, col] = reverse_index_map(lane_id, local_id);
A_local[i * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)];
if constexpr (kPack == 1) {
*(float32x2*)(&A_local[i * kPack * local_size_a + local_id]) =
*(float32x2*)(&A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)]);
} else {
*(float32x4*)(&A_local[i * kPack * local_size_a + local_id]) =
*(float32x4*)(&A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)]);
}
}
}
}
// Fetch B into register
for (int j = 0; j < warp_cols; j++) {
const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
if constexpr (TransposeB) {
if constexpr (TransposeB) {
for (int local_id = 0; local_id < (kPack * local_size_b); local_id += (kPack * vec_size)) {
auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} else {
if constexpr (kPack == 1) {
*(float32x2*)(&B_local[j * kPack * local_size_b + local_id]) =
*(float32x2*)(&B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]);
} else {
*(float32x4*)(&B_local[j * kPack * local_size_b + local_id]) =
*(float32x4*)(&B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]);
}
}
} else {
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
Expand Down Expand Up @@ -257,13 +273,21 @@ class GemmTensorOp {
for (int j = 0; j < warp_cols; j++) {
const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * kPack * micro_size_k;
for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
if constexpr (TransposeB) {
if constexpr (TransposeB) {
for (int local_id = 0; local_id < (kPack * local_size_b); local_id += (kPack * vec_size)) {
auto [row, col] = reverse_index_map(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} else {
if constexpr (kPack == 1) {
*(float32x2*)(&B_local[j * kPack * local_size_b + local_id]) =
*(float32x2*)(&B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]);
} else {
*(float32x4*)(&B_local[j * kPack * local_size_b + local_id]) =
*(float32x4*)(&B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)]);
}
}
} else {
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
B_local[j * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
Expand Down
36 changes: 36 additions & 0 deletions src/tl_templates/hip/hip_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,39 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res;
}

__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7,
fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) {
signed char x0_char = *reinterpret_cast<signed char *>(&x0);
signed char x1_char = *reinterpret_cast<signed char *>(&x1);
signed char x2_char = *reinterpret_cast<signed char *>(&x2);
signed char x3_char = *reinterpret_cast<signed char *>(&x3);
signed char x4_char = *reinterpret_cast<signed char *>(&x4);
signed char x5_char = *reinterpret_cast<signed char *>(&x5);
signed char x6_char = *reinterpret_cast<signed char *>(&x6);
signed char x7_char = *reinterpret_cast<signed char *>(&x7);
signed char y0_char = *reinterpret_cast<signed char *>(&y0);
signed char y1_char = *reinterpret_cast<signed char *>(&y1);
signed char y2_char = *reinterpret_cast<signed char *>(&y2);
signed char y3_char = *reinterpret_cast<signed char *>(&y3);
signed char y4_char = *reinterpret_cast<signed char *>(&y4);
signed char y5_char = *reinterpret_cast<signed char *>(&y5);
signed char y6_char = *reinterpret_cast<signed char *>(&y6);
signed char y7_char = *reinterpret_cast<signed char *>(&y7);
int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char;
int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char;
int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char;
int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char;
fp8_e4_8_t res_x;
res_x.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
res_x.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
fp8_e4_8_t res_y;
res_y.x = *reinterpret_cast<fp8_e4_4_t *>(&c);
res_y.y = *reinterpret_cast<fp8_e4_4_t *>(&d);
fp8_e4_16_t res;
res.x = res_x;
res.y = res_y;
return res;
}