diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index ecd94cacc659..d34714c97f6c 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -29,6 +29,10 @@ #define __HIP__MI3XX__ #endif +#if defined(__HIPCC__) && defined(__GFX11__) + #define __HIP__GFX11__ +#endif + #if defined(__gfx950__) #define LDS_SIZE 160 * 1024 #else @@ -48,6 +52,15 @@ int get_lds_size() { return result; } +bool is_gfx11() { + static const bool result = [] { + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + return device_arch.find("gfx11") != std::string::npos; + }(); + return result; +} + #if defined(NDEBUG) #undef NDEBUG #include @@ -285,21 +298,41 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, return out_c; } -#define DOT2C(V0, V2, V3) \ - if constexpr (std::is_same_v) { \ - asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ - } else if constexpr (std::is_same_v) { \ - float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ - __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ - V0 += (s.x + s.y); \ +// GFX11 (RDNA3) renamed v_dot2c_f32_f16 to v_dot2acc_f32_f16 +#if defined(__HIP__GFX11__) + #define DOT2C_FP16_INSN "v_dot2acc_f32_f16" +#else + #define DOT2C_FP16_INSN "v_dot2c_f32_f16" +#endif + +#define DOT2C(V0, V2, V3) \ + if constexpr (std::is_same_v) { \ + asm(DOT2C_FP16_INSN " %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ + } else if constexpr (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ } +// GFX11 (RDNA, wave32) butterfly reduction: sum all 32 lanes within one +// wavefront. Every lane gets the result. +#if defined(__HIP__GFX11__) + #define REDUCE_SUM_WAVE32(val) \ + do { \ + val += __shfl_xor(val, 1); \ + val += __shfl_xor(val, 2); \ + val += __shfl_xor(val, 4); \ + val += __shfl_xor(val, 8); \ + val += __shfl_xor(val, 16); \ + } while (0) +#endif + // To avoid LLVM silently upcasting to double __device__ inline unsigned int min__(uint32_t a, uint32_t b) { return min(a, b); } -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) || defined(__HIP__GFX11__) // This version targets cases where A[] fits LDS capacity template @@ -471,6 +504,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Final reduction step using shuffle //---------------------------------------------------- if constexpr (!use_mfma) { + #if defined(__HIP__GFX11__) + // Wave32: butterfly reduce within the single wavefront per row + for (int n = 0; n < N; n++) + for (int y = 0; y < YTILE; y++) REDUCE_SUM_WAVE32(sum[n][y]); + + // Lane 0 has the complete sum; write the result + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + #else // GFX9 wave64 path for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " @@ -509,6 +564,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + #endif // defined(__HIP__GFX11__) } else { #pragma unroll for (int n = 0; n < N; n++) { @@ -560,7 +616,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) && !defined(__HIP__GFX11__) template __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, @@ -570,9 +626,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) || defined(__HIP__GFX11__) -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) || defined(__HIP__GFX11__) // This version targets cases where A[] marginally exceeds LDS capacity template @@ -771,6 +827,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Final reduction step using shuffle //---------------------------------------------------- if constexpr (!use_mfma) { + #if defined(__HIP__GFX11__) + // Wave32: butterfly reduce within the single wavefront per row + for (int n = 0; n < N; n++) + for (int y = 0; y < YTILE; y++) REDUCE_SUM_WAVE32(sum[n][y]); + + // Lane 0 has the complete sum; write the result + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } + #else // GFX9 wave64 path for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " @@ -811,6 +891,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + #endif // defined(__HIP__GFX11__) } else { #pragma unroll for (int n = 0; n < N; n++) { @@ -877,7 +958,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) && !defined(__HIP__GFX11__) template __global__ void wvSplitK_hf_(const int K, const int M, const int Bx, @@ -887,9 +968,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) || defined(__HIP__GFX11__) -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) || defined(__HIP__GFX11__) // This version targets big A[] cases, where it is much larger than LDS capacity template @@ -1139,6 +1220,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Final reduction step using shuffle //---------------------------------------------------- if constexpr (!use_mfma) { + #if defined(__HIP__GFX11__) + // Wave32: butterfly reduce within the single wavefront per row + for (int n = 0; n < N; n++) + for (int y = 0; y < YTILE; y++) REDUCE_SUM_WAVE32(sum[n][y]); + + // Lane 0 has the complete sum; write the result + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } + #else // GFX9 wave64 path for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " @@ -1179,6 +1284,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + #endif // defined(__HIP__GFX11__) } else { #pragma unroll for (int n = 0; n < N; n++) { @@ -1241,7 +1347,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) && !defined(__HIP__GFX11__) template __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, @@ -1251,7 +1357,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +#endif // defined(__HIP__GFX9__) || defined(__HIP__GFX11__) // Find the min val of div2 that doesn't increase N/(div1*div2) int mindiv(int N, int div1, int div2) { @@ -1296,24 +1402,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size() / 2; -#define WVSPLITK(_YTILE, _UNRL, _N) \ +#define WVSPLITK_LAUNCH(_THRDS, _YTILE, _UNRL, _N) \ { \ - dim3 block(64, 16); \ + dim3 block(_THRDS, 16); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ - wvSplitK_hf_sml_ \ + wvSplitK_hf_sml_ \ <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ biasf4, c, __wvPrGrp, CuCount); \ else if (K_in * N_in <= max_lds_len * 1.2) \ - wvSplitK_hf_ \ + wvSplitK_hf_ \ <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ biasf4, c, __wvPrGrp, CuCount); \ else \ - wvSplitK_hf_big_ \ + wvSplitK_hf_big_ \ <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ biasf4, c, __wvPrGrp, CuCount); \ } +#define WVSPLITK(_YTILE, _UNRL, _N) \ + if (is_gfx11()) \ + WVSPLITK_LAUNCH(32, _YTILE, _UNRL, _N) \ + else \ + WVSPLITK_LAUNCH(64, _YTILE, _UNRL, _N) + #define WVSPLIT_TILE(_sYT, __N) \ { \ bool fit_lds = (K_in * N_in <= max_lds_len); \ diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 566cb0239fe0..271a9866d109 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -194,7 +194,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): ref_out = torch.nn.functional.linear(A, B) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) - assert torch.allclose(out, ref_out, rtol=0.01) + # Accumulation error in fp16 GEMM scales with sqrt(K) + atol = torch.finfo(dtype).eps * math.sqrt(k) + assert torch.allclose(out, ref_out, rtol=0.01, atol=atol) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a6a5ef106f3b..103d07d53649 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -143,7 +143,7 @@ def use_aiter_triton_gemm(n, m, k, dtype): def rocm_unquantized_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: - from vllm.platforms.rocm import on_gfx9, on_gfx950 + from vllm.platforms.rocm import on_gfx9, on_gfx11, on_gfx950 n = x.numel() // x.size(-1) m = weight.shape[0] @@ -188,7 +188,7 @@ def rocm_unquantized_gemm_impl( use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM - and on_gfx9() + and (on_gfx9() or on_gfx11()) and x.dtype in [torch.float16, torch.bfloat16] and k % 8 == 0 and x.is_contiguous() diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 2545e46209a4..2f56a249edb5 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -139,6 +139,12 @@ def on_gfx9() -> bool: return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) +@cache +def on_gfx11() -> bool: + GPU_ARCH = _get_gcn_arch_via_amdsmi() + return "gfx11" in GPU_ARCH + + @cache def on_gfx942() -> bool: GPU_ARCH = _get_gcn_arch_via_amdsmi()