diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 442b20e41de5..60e10e53391a 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -26,6 +26,16 @@ #define __HIP__GFX9__ #endif +#if defined(__HIPCC__) && \ + (defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \ + defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__)) + #define __HIP__GFX1X__ +#endif + +#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) + #define __HIP__GFX12__ +#endif + #if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) #define __HIP__MI3XX__ #endif @@ -37,15 +47,31 @@ #endif int get_lds_size() { - static bool is_cached = false; - static int result; - if (is_cached == false) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = dprops->gcnArchName; - size_t substring = device_arch.find("gfx95"); - result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024); - is_cached = true; - } + static const int result = [] { + const auto* dprops = at::cuda::getCurrentDeviceProperties(); + const std::string device_arch = dprops->gcnArchName; + return device_arch.find("gfx95") == std::string::npos ? 64 * 1024 + : 160 * 1024; + }(); + return result; +} + +bool on_gfx1x() { + static const bool result = [] { + const auto* dprops = at::cuda::getCurrentDeviceProperties(); + const std::string device_arch = dprops->gcnArchName; + return device_arch.find("gfx11") != std::string::npos || + device_arch.find("gfx12") != std::string::npos; + }(); + return result; +} + +bool on_gfx12() { + static const bool result = [] { + const auto* dprops = at::cuda::getCurrentDeviceProperties(); + const std::string device_arch = dprops->gcnArchName; + return device_arch.find("gfx12") != std::string::npos; + }(); return result; } @@ -286,21 +312,35 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, return out_c; } -#define DOT2C(V0, V2, V3) \ - if constexpr (std::is_same_v) { \ - asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ - } else if constexpr (std::is_same_v) { \ - float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ - __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ - V0 += (s.x + s.y); \ - } +#if defined(__HIP__GFX9__) && !defined(__HIP__GFX1X__) + #define DOT2C(V0, V2, V3) \ + if constexpr (std::is_same_v) { \ + asm("v_dot2c_f32_f16 %0, %2, %3" \ + : "=v"(V0) \ + : "0"(V0), "v"(V2), "v"(V3)); \ + } else if constexpr (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ + } +#elif defined(__HIP__GFX1X__) + // gfx1x: v_dot2_f32_f16 (VOP3-P, dot10-insts, available on gfx11+gfx12) + #define DOT2C(V0, V2, V3) \ + if constexpr (std::is_same_v) { \ + asm("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(V0) : "v"(V2), "v"(V3)); \ + } else if constexpr (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ + } +#endif // To avoid LLVM silently upcasting to double __device__ inline unsigned int min__(uint32_t a, uint32_t b) { return min(a, b); } -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__) // This version targets cases where A[] fits LDS capacity template @@ -442,14 +482,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) 1); // row_shr2 sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf, 1); // row_shr1 + #if defined(__HIP__GFX9__) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf, 1); // ROW_BCAST15 sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf, 1); // ROW_BCAST31 + #else + sum[n][y] += __shfl_xor(sum[n][y], 16); + #endif } } - if (threadIdx.x == 63) { + if (threadIdx.x == (THRDS - 1)) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -469,9 +513,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } else { - #pragma unroll + #ifdef __HIP__GFX9__ + #pragma unroll for (int n = 0; n < N; n++) { - #pragma unroll + #pragma unroll for (int y = 0; y < YTILE; y++) { /*float accm1 = 0; for (int i=0; i<64; i++) @@ -498,7 +543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) sum4[n][y][0] = accm; } } - if (threadIdx.x == 63) { + if (threadIdx.x == (THRDS - 1)) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -513,11 +558,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + #endif // __HIP__GFX9__ (MFMA path) } m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else template __global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, @@ -528,9 +574,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +#endif -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__) // This version targets cases where A[] marginally exceeds LDS capacity template @@ -657,14 +703,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) 1); // row_shr2 sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf, 1); // row_shr1 + #if defined(__HIP__GFX9__) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf, 1); // ROW_BCAST15 sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf, 1); // ROW_BCAST31 + #else + sum[n][y] += __shfl_xor(sum[n][y], 16); + #endif } } - if (threadIdx.x == 63) { + if (threadIdx.x == (THRDS - 1)) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -686,9 +736,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } else { - #pragma unroll + #ifdef __HIP__GFX9__ + #pragma unroll for (int n = 0; n < N; n++) { - #pragma unroll + #pragma unroll for (int y = 0; y < YTILE; y++) { // float accm1 = 0; // for (int i=0; i<64; i++) @@ -713,7 +764,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) sum4[n][y][0] = accm; } } - if (threadIdx.x == 63) { + if (threadIdx.x == (THRDS - 1)) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -730,6 +781,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + #endif // __HIP__GFX9__ (MFMA path) } m += CuCount * _WvPrGrp * YTILE; @@ -746,7 +798,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else template __global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap, @@ -756,9 +808,9 @@ __global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +#endif -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__) // This version targets big A[] cases, where it is much larger than LDS capacity template @@ -1004,14 +1056,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) 1); // row_shr2 sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf, 1); // row_shr1 + #if defined(__HIP__GFX9__) sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf, 1); // ROW_BCAST15 sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf, 1); // ROW_BCAST31 + #else + sum[n][y] += __shfl_xor(sum[n][y], 16); + #endif } } - if (threadIdx.x == 63) { + if (threadIdx.x == (THRDS - 1)) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -1033,9 +1089,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } else { - #pragma unroll + #ifdef __HIP__GFX9__ + #pragma unroll for (int n = 0; n < N; n++) { - #pragma unroll + #pragma unroll for (int y = 0; y < YTILE; y++) { float accm = sum4[n][y][0]; accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf, @@ -1057,7 +1114,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) sum4[n][y][0] = accm; } } - if (threadIdx.x == 63) { + if (threadIdx.x == (THRDS - 1)) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -1074,6 +1131,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + #endif // __HIP__GFX9__ (MFMA path) } m += CuCount * _WvPrGrp * YTILE; @@ -1090,7 +1148,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else template __global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, @@ -1101,7 +1159,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +#endif // Find the min val of div2 that doesn't increase N/(div1*div2) int mindiv(int N, int div1, int div2) { @@ -1148,40 +1206,40 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size() / 2; -#define WVSPLITK(_YTILE, _UNRL, _N) \ +#define WVSPLITK_CFG(_THRDS, _WVPRGRP, _YTILE, _UNRL, _N) \ { \ - dim3 block(64, 16); \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ + dim3 block(_THRDS, _WVPRGRP); \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, _WVPRGRP); \ if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ - wvSplitK_hf_sml_ \ + wvSplitK_hf_sml_ \ <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ By_in, af4, bf4, biasf4, c, __wvPrGrp, \ CuCount); \ else if (Kbp_in * N_in <= max_lds_len * 1.2) \ - wvSplitK_hf_ \ + wvSplitK_hf_ \ <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ By_in, af4, bf4, biasf4, c, __wvPrGrp, \ CuCount); \ else \ - wvSplitK_hf_big_ \ + wvSplitK_hf_big_ \ <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ By_in, af4, bf4, biasf4, c, __wvPrGrp, \ CuCount); \ } -#define WVSPLIT_TILE(_sYT, __N) \ +#define WVSPLIT_TILE_CFG(_THRDS, _WVPRGRP, _sYT, __N) \ { \ bool fit_lds = (Kbp_in * N_in <= max_lds_len); \ if (_sYT <= 1) \ - WVSPLITK(1, 4, __N) \ + WVSPLITK_CFG(_THRDS, _WVPRGRP, 1, 4, __N) \ else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \ - WVSPLITK(2, 2, __N) \ + WVSPLITK_CFG(_THRDS, _WVPRGRP, 2, 2, __N) \ else if (_sYT <= 4 * 3) \ - WVSPLITK(3, 2, __N) \ + WVSPLITK_CFG(_THRDS, _WVPRGRP, 3, 2, __N) \ else if (__N == 4) \ - WVSPLITK(4, 1, __N) \ + WVSPLITK_CFG(_THRDS, _WVPRGRP, 4, 1, __N) \ else \ - WVSPLITK(4, 2, __N) \ + WVSPLITK_CFG(_THRDS, _WVPRGRP, 4, 2, __N) \ } AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { @@ -1198,18 +1256,31 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, // then cut the active waves to balance their distribution... int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4); + const bool use_wave32 = on_gfx1x(); switch (N_in) { case 1: - WVSPLIT_TILE(sYT, 1) + if (use_wave32) + WVSPLIT_TILE_CFG(32, 16, sYT, 1) + else + WVSPLIT_TILE_CFG(64, 16, sYT, 1) break; case 2: - WVSPLIT_TILE(sYT, 2) + if (use_wave32) + WVSPLIT_TILE_CFG(32, 16, sYT, 2) + else + WVSPLIT_TILE_CFG(64, 16, sYT, 2) break; case 3: - WVSPLIT_TILE(sYT, 3) + if (use_wave32) + WVSPLIT_TILE_CFG(32, 16, sYT, 3) + else + WVSPLIT_TILE_CFG(64, 16, sYT, 3) break; case 4: - WVSPLIT_TILE(sYT, 4) + if (use_wave32) + WVSPLIT_TILE_CFG(32, 16, sYT, 4) + else + WVSPLIT_TILE_CFG(64, 16, sYT, 4) break; default: throw std::runtime_error( @@ -1653,7 +1724,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #endif } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else template __global__ void wvSplitKrc_(const int actlN, const int K, const int Kap, @@ -1688,6 +1759,8 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, TORCH_CHECK(in_a.dtype() == torch::kFloat16 || in_a.dtype() == torch::kBFloat16); + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + auto out_c = torch::empty( {N_in, M_in}, torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device())); @@ -1696,7 +1769,6 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, dim3 grid(CuCount); - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // const int max_lds_len = get_lds_size() / 2; @@ -1773,7 +1845,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, return out_c; } -#if defined(__HIP__MI3XX__) // TODO: Add NAVI support +#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1817,12 +1889,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; - using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; float sA = *s_A; float sB = *s_B; while (m < M) { + #ifdef __HIP__GFX12__ + // gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8 + float sum[N][YTILE] = {}; + #else + // gfx9: MFMA accumulation scalar8 sum[N][YTILE] = {}; + #endif for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { bigType bigA[N][UNRL] = {}; bigType bigB[YTILE][UNRL]; @@ -1854,6 +1931,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t n = 0; n < N; n++) { + #ifdef __HIP__GFX12__ + // gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4) + for (int y = 0; y < YTILE; ++y) { + #pragma unroll + for (int i = 0; i < A_CHUNK / 4; i++) { + sum[n][y] = __builtin_amdgcn_dot4_f32_fp8_fp8( + bigA[n][k2].i[i], bigB[y][k2].i[i], sum[n][y]); + } + } + #else + // gfx9: MFMA path for (int i = 0; i < A_CHUNK; i += 8) { for (int y = 0; y < YTILE; ++y) { sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( @@ -1861,11 +1949,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) 0); } } + #endif } } } // Final reduction + #ifdef __HIP__GFX12__ + // gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:1 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + sum[n][y] += __shfl_xor(sum[n][y], 16); + } + } + #else + // gfx9 MFMA reduction for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { float accm0 = sum[n][y][0]; @@ -1880,8 +1990,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) sum[n][y][0] = accm0; } } + #endif - if (threadIdx.x == 0) { + const bool writeback_lane = + #ifdef __HIP__GFX12__ + threadIdx.x == (THRDS - 1); + #else + threadIdx.x == 0; + #endif + if (writeback_lane) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -1892,13 +2009,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - sum[n][y][0] *= sA * sB; + #ifdef __HIP__GFX12__ + float result = sum[n][y] * sA * sB; + #else + float result = sum[n][y][0] * sA * sB; + #endif if constexpr (std::is_same_v) { - sum[n][y][0] += __half2float(biases[n][y]); + result += __half2float(biases[n][y]); } else if constexpr (std::is_same_v) { - sum[n][y][0] += __bfloat162float(biases[n][y]); + result += __bfloat162float(biases[n][y]); } - C[m + y + n * M] = __float2s(sum[n][y][0]); + C[m + y + n * M] = __float2s(result); } } } @@ -1906,7 +2027,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support +#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, @@ -1918,9 +2039,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support +#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) -#if defined(__HIP__MI3XX__) // TODO: Add NAVI support +#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) template __global__ void __launch_bounds__(WvPrGrp* THRDS) @@ -1963,12 +2084,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; - using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; float sA = *s_A; float sB = *s_B; while (m < M) { + #ifdef __HIP__GFX12__ + // gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8 + float sum[N][YTILE] = {}; + #else + // gfx9: MFMA accumulation scalar8 sum[N][YTILE] = {}; + #endif for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { bigType bigA[N][UNRL] = {}; bigType bigB[YTILE][UNRL]; @@ -2002,6 +2128,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t n = 0; n < N; n++) { + #ifdef __HIP__GFX12__ + // gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4) + for (int y = 0; y < YTILE; ++y) { + #pragma unroll + for (int i = 0; i < A_CHUNK / 4; i++) { + sum[n][y] = __builtin_amdgcn_dot4_f32_fp8_fp8( + bigA[n][k2].i[i], bigB[y][k2].i[i], sum[n][y]); + } + } + #else + // gfx9: MFMA path for (int i = 0; i < A_CHUNK; i += 8) { for (int y = 0; y < YTILE; ++y) { sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( @@ -2009,11 +2146,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) 0); } } + #endif } } } // Final reduction + #ifdef __HIP__GFX12__ + // gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:1 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + sum[n][y] += __shfl_xor(sum[n][y], 16); + } + } + #else + // gfx9 MFMA reduction for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { float accm0 = sum[n][y][0]; @@ -2028,8 +2187,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) sum[n][y][0] = accm0; } } + #endif - if (threadIdx.x == 0) { + const bool writeback_lane = + #ifdef __HIP__GFX12__ + threadIdx.x == (THRDS - 1); + #else + threadIdx.x == 0; + #endif + if (writeback_lane) { scalar_t biases[N][YTILE] = {}; if (BIAS) for (int n = 0; n < N; n++) { @@ -2040,13 +2206,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - sum[n][y][0] *= sA * sB; + #ifdef __HIP__GFX12__ + float result = sum[n][y] * sA * sB; + #else + float result = sum[n][y][0] * sA * sB; + #endif if constexpr (std::is_same_v) { - sum[n][y][0] += __half2float(biases[n][y]); + result += __half2float(biases[n][y]); } else if constexpr (std::is_same_v) { - sum[n][y][0] += __bfloat162float(biases[n][y]); + result += __bfloat162float(biases[n][y]); } - C[m + y + n * M] = __float2s(sum[n][y][0]); + C[m + y + n * M] = __float2s(result); } } } @@ -2054,7 +2224,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support +#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__) template __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, @@ -2066,7 +2236,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support +#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__) void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a, const std::optional& in_bias, at::Tensor& out_c, @@ -2099,24 +2269,30 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size(); -#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ - int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \ - wvSplitKQ_hf_sml_ \ - <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ - By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ - } else { \ - int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \ - wvSplitKQ_hf_ \ - <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ - By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ - } \ +#define WVSPLITKQ_IMPL(_THRDS, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \ + { \ + dim3 block(_THRDS, _WvPrGrp); \ + if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \ + wvSplitKQ_hf_sml_<<>>( \ + K_in, Kap_in, Kbp_in, M_in, Bx_in, By_in, b_ptr, a_ptr, bias_ptr, \ + c_ptr, s_a, s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \ + wvSplitKQ_hf_ \ + <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ + By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } \ } +#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \ + if (on_gfx12()) \ + WVSPLITKQ_IMPL(32, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \ + else \ + WVSPLITKQ_IMPL(64, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] { using fptype = typename scalar::type; auto c_ptr = reinterpret_cast(out_c.data_ptr()); @@ -2136,10 +2312,10 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a, WVSPLITKQ(16, 2, 2, 2, 2, 2) break; case 3: - WVSPLITKQ(16, 2, 2, 2, 2, 3) + WVSPLITKQ(16, 2, 2, 1, 1, 3) break; case 4: - WVSPLITKQ(16, 2, 2, 2, 2, 4) + WVSPLITKQ(16, 2, 2, 1, 1, 4) break; default: throw std::runtime_error( diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 91b774c47464..d2123db2e8da 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -160,6 +160,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1 elif bias_mode == 2: BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1 + elif bias_mode == 3: + BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1 ref_out = torch.nn.functional.linear(A, B, BIAS) out = ops.wvSplitKrc(A, B, cu_count, BIAS) @@ -224,10 +226,9 @@ def test_rocm_wvsplitk_kernel( ref_out = torch.nn.functional.linear(A, B, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) - if xnorm: - assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8) - else: - assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2) + # Accumulation error in fp16 GEMM scales with sqrt(K) + atol = torch.finfo(dtype).eps * math.sqrt(k) + torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2) @pytest.mark.parametrize("xnorm", [False, True]) diff --git a/tests/model_executor/layers/test_rocm_unquantized_gemm.py b/tests/model_executor/layers/test_rocm_unquantized_gemm.py new file mode 100644 index 000000000000..c435a6e72422 --- /dev/null +++ b/tests/model_executor/layers/test_rocm_unquantized_gemm.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.platforms import current_platform + +if current_platform.is_cuda(): + pytest.skip( + "ROCm skinny GEMM tests are not supported on CUDA.", + allow_module_level=True, + ) + +from vllm.model_executor.layers import utils + + +def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch): + x = torch.randn(1, 64, dtype=torch.float16) + weight = torch.randn(128, 64, dtype=torch.float16) + + monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False) + monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False) + monkeypatch.setattr(utils, "get_cu_count", lambda: 120) + + wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) + monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock) + llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t()) + monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock) + + out = utils.rocm_unquantized_gemm_impl(x, weight, None) + ref = torch.nn.functional.linear(x, weight, None) + + wvsplitk_mock.assert_called_once() + llmm1_mock.assert_not_called() + assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) + + +def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch): + x = torch.randn(5, 64, dtype=torch.float16) + weight = torch.randn(128, 64, dtype=torch.float16) + + monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False) + monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False) + monkeypatch.setattr(utils, "get_cu_count", lambda: 120) + + wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) + monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock) + llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t()) + monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock) + + out = utils.rocm_unquantized_gemm_impl(x, weight, None) + ref = torch.nn.functional.linear(x, weight, None) + + wvsplitk_mock.assert_not_called() + llmm1_mock.assert_not_called() + assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) + + +def test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path(monkeypatch): + x = torch.randn(16, 1024, dtype=torch.float16) + weight = torch.randn(256, 1024, dtype=torch.float16) + + monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False) + monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: False) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False) + monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: True) + monkeypatch.setattr(utils, "get_cu_count", lambda: 120) + + wvsplitkrc_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) + monkeypatch.setattr(utils.ops, "wvSplitKrc", wvsplitkrc_mock) + wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) + monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock) + + out = utils.rocm_unquantized_gemm_impl(x, weight, None) + ref = torch.nn.functional.linear(x, weight, None) + + wvsplitkrc_mock.assert_called_once() + wvsplitk_mock.assert_not_called() + assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 757d1ecc5284..4918c83bdc39 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -122,7 +122,7 @@ def use_aiter_triton_gemm(n, m, k, dtype): def rocm_unquantized_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: - from vllm.platforms.rocm import on_gfx9, on_gfx950 + from vllm.platforms.rocm import on_gfx1x, on_gfx9, on_gfx950 n = x.numel() // x.size(-1) m = weight.shape[0] @@ -169,12 +169,12 @@ def rocm_unquantized_gemm_impl( use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM - and on_gfx9() + and (on_gfx9() or on_gfx1x()) and x.dtype in [torch.float16, torch.bfloat16] and k % 8 == 0 ) - if use_skinny is not True: + if not use_skinny: return torch.nn.functional.linear(x, weight, bias) x_view = x.reshape(-1, x.size(-1))