diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 976874e6f4ad..15ebcc776ad7 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1902,7 +1902,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) float sB = *s_B; while (m < M) { - floatx16 sum[N][YTILE] = {}; + scalar8 sum[N][YTILE] = {}; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { bigType bigA[N][UNRL] = {}; bigType bigB[YTILE][UNRL]; @@ -1936,7 +1936,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t n = 0; n < N; n++) { for (int i = 0; i < A_CHUNK; i += 8) { for (int y = 0; y < YTILE; ++y) { - sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, 0); } @@ -1949,31 +1949,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { float accm0 = sum[n][y][0]; - float accm16 = sum[n][y][8]; accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf, 1); // row_shl1 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1); accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf, 1); // row_shl2 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1); accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf, 1); // row_shl3 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf, - 1); // row_shl8 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf, - 1); // row_shl9 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf, - 1); // row_shl10 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf, - 1); // row_shl11 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1); - accm0 += __shfl(accm0, 36); - accm16 += __shfl(accm16, 52); - sum[n][y][0] = accm0 + __shfl(accm16, 16); + accm0 += __shfl_down(accm0, 20); + accm0 += __shfl_down(accm0, 40); + sum[n][y][0] = accm0; } } @@ -2064,7 +2048,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) float sB = *s_B; while (m < M) { - floatx16 sum[N][YTILE] = {}; + scalar8 sum[N][YTILE] = {}; for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { bigType bigA[N][UNRL] = {}; bigType bigB[YTILE][UNRL]; @@ -2100,7 +2084,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t n = 0; n < N; n++) { for (int i = 0; i < A_CHUNK; i += 8) { for (int y = 0; y < YTILE; ++y) { - sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, 0); } @@ -2113,31 +2097,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { float accm0 = sum[n][y][0]; - float accm16 = sum[n][y][8]; accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf, 1); // row_shl1 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1); accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf, 1); // row_shl2 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1); accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf, 1); // row_shl3 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf, - 1); // row_shl8 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf, - 1); // row_shl9 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf, - 1); // row_shl10 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1); - accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf, - 1); // row_shl11 - accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1); - accm0 += __shfl(accm0, 36); - accm16 += __shfl(accm16, 52); - sum[n][y][0] = accm0 + __shfl(accm16, 16); + accm0 += __shfl_down(accm0, 20); + accm0 += __shfl_down(accm0, 40); + sum[n][y][0] = accm0; } } @@ -2242,16 +2210,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a, : nullptr; switch (N_in) { case 1: - WVSPLITKQ(12, 2, 2, 2, 2, 1) + WVSPLITKQ(16, 2, 2, 2, 2, 1) break; case 2: - WVSPLITKQ(12, 2, 2, 2, 2, 2) + WVSPLITKQ(16, 2, 2, 2, 2, 2) break; case 3: - WVSPLITKQ(8, 2, 2, 1, 1, 3) + WVSPLITKQ(16, 2, 2, 2, 2, 3) break; case 4: - WVSPLITKQ(4, 2, 2, 1, 1, 4) + WVSPLITKQ(16, 2, 2, 2, 2, 4) break; default: throw std::runtime_error(