Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 14 additions & 46 deletions csrc/rocm/skinny_gemms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1902,7 +1902,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;

while (m < M) {
floatx16 sum[N][YTILE] = {};
scalar8 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
Expand Down Expand Up @@ -1936,7 +1936,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
Expand All @@ -1949,31 +1949,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
Comment on lines 1954 to 1956
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comments for the __builtin_amdgcn_mov_dpp operations are inconsistent. While sum[n][y][1] uses 0x101 (row_shl1), sum[n][y][2] uses 0x102 and sum[n][y][3] uses 0x103. The comment // row_shl1 is repeated for all three, which is misleading. Please update the comments to accurately reflect the shift values (e.g., // row_shl2, // row_shl3).

                                          1);  // row_shl1
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
                                          1);  // row_shl2
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
                                          1);  // row_shl3

1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
accm0 += __shfl_down(accm0, 20);
accm0 += __shfl_down(accm0, 40);
sum[n][y][0] = accm0;
}
}

Expand Down Expand Up @@ -2064,7 +2048,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float sB = *s_B;

while (m < M) {
floatx16 sum[N][YTILE] = {};
scalar8 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
Expand Down Expand Up @@ -2100,7 +2084,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
0);
}
Expand All @@ -2113,31 +2097,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
accm0 += __shfl_down(accm0, 20);
accm0 += __shfl_down(accm0, 40);
sum[n][y][0] = accm0;
}
}

Expand Down Expand Up @@ -2242,16 +2210,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
: nullptr;
switch (N_in) {
case 1:
WVSPLITKQ(12, 2, 2, 2, 2, 1)
WVSPLITKQ(16, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(12, 2, 2, 2, 2, 2)
WVSPLITKQ(16, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(8, 2, 2, 1, 1, 3)
WVSPLITKQ(16, 2, 2, 2, 2, 3)
break;
case 4:
WVSPLITKQ(4, 2, 2, 1, 1, 4)
WVSPLITKQ(16, 2, 2, 2, 2, 4)
break;
default:
throw std::runtime_error(
Expand Down