Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 133 additions & 21 deletions csrc/rocm/skinny_gemms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#define __HIP__MI3XX__
#endif

#if defined(__HIPCC__) && defined(__GFX11__)
#define __HIP__GFX11__
#endif

#if defined(__gfx950__)
#define LDS_SIZE 160 * 1024
#else
Expand All @@ -48,6 +52,15 @@ int get_lds_size() {
return result;
}

bool is_gfx11() {
static const bool result = [] {
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
return device_arch.find("gfx11") != std::string::npos;
}();
return result;
}

#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
Expand Down Expand Up @@ -285,21 +298,41 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
return out_c;
}

#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
// GFX11 (RDNA3) renamed v_dot2c_f32_f16 to v_dot2acc_f32_f16
#if defined(__HIP__GFX11__)
#define DOT2C_FP16_INSN "v_dot2acc_f32_f16"
#else
#define DOT2C_FP16_INSN "v_dot2c_f32_f16"
#endif

#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm(DOT2C_FP16_INSN " %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
}

// GFX11 (RDNA, wave32) butterfly reduction: sum all 32 lanes within one
// wavefront. Every lane gets the result.
#if defined(__HIP__GFX11__)
#define REDUCE_SUM_WAVE32(val) \
do { \
val += __shfl_xor(val, 1); \
val += __shfl_xor(val, 2); \
val += __shfl_xor(val, 4); \
val += __shfl_xor(val, 8); \
val += __shfl_xor(val, 16); \
} while (0)
#endif

// To avoid LLVM silently upcasting to double
__device__ inline unsigned int min__(uint32_t a, uint32_t b) {
return min(a, b);
}

#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) || defined(__HIP__GFX11__)
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
Expand Down Expand Up @@ -471,6 +504,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
#if defined(__HIP__GFX11__)
// Wave32: butterfly reduce within the single wavefront per row
for (int n = 0; n < N; n++)
for (int y = 0; y < YTILE; y++) REDUCE_SUM_WAVE32(sum[n][y]);

// Lane 0 has the complete sum; write the result
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
#else // GFX9 wave64 path
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
Expand Down Expand Up @@ -509,6 +564,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#endif // defined(__HIP__GFX11__)
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
Expand Down Expand Up @@ -560,7 +616,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) && !defined(__HIP__GFX11__)
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
Expand All @@ -570,9 +626,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) || defined(__HIP__GFX11__)

#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) || defined(__HIP__GFX11__)
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
Expand Down Expand Up @@ -771,6 +827,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
#if defined(__HIP__GFX11__)
// Wave32: butterfly reduce within the single wavefront per row
for (int n = 0; n < N; n++)
for (int y = 0; y < YTILE; y++) REDUCE_SUM_WAVE32(sum[n][y]);

// Lane 0 has the complete sum; write the result
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
}
#else // GFX9 wave64 path
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
Expand Down Expand Up @@ -811,6 +891,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#endif // defined(__HIP__GFX11__)
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
Expand Down Expand Up @@ -877,7 +958,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}

#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) && !defined(__HIP__GFX11__)
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
Expand All @@ -887,9 +968,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) || defined(__HIP__GFX11__)

#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) || defined(__HIP__GFX11__)
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
Expand Down Expand Up @@ -1139,6 +1220,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
#if defined(__HIP__GFX11__)
// Wave32: butterfly reduce within the single wavefront per row
for (int n = 0; n < N; n++)
for (int y = 0; y < YTILE; y++) REDUCE_SUM_WAVE32(sum[n][y]);

// Lane 0 has the complete sum; write the result
if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
}
#else // GFX9 wave64 path
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
Expand Down Expand Up @@ -1179,6 +1284,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#endif // defined(__HIP__GFX11__)
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
Expand Down Expand Up @@ -1241,7 +1347,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) && !defined(__HIP__GFX11__)
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
Expand All @@ -1251,7 +1357,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) || defined(__HIP__GFX11__)

// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) {
Expand Down Expand Up @@ -1296,24 +1402,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;

#define WVSPLITK(_YTILE, _UNRL, _N) \
#define WVSPLITK_LAUNCH(_THRDS, _YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
dim3 block(_THRDS, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_sml_<fptype, _THRDS, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_<fptype, _THRDS, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_big_<fptype, _THRDS, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
}

#define WVSPLITK(_YTILE, _UNRL, _N) \
if (is_gfx11()) \
WVSPLITK_LAUNCH(32, _YTILE, _UNRL, _N) \
else \
WVSPLITK_LAUNCH(64, _YTILE, _UNRL, _N)

#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/quantization/test_rocm_skinny_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)

assert torch.allclose(out, ref_out, rtol=0.01)
# Accumulation error in fp16 GEMM scales with sqrt(K)
atol = torch.finfo(dtype).eps * math.sqrt(k)
assert torch.allclose(out, ref_out, rtol=0.01, atol=atol)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def use_aiter_triton_gemm(n, m, k, dtype):
def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950
from vllm.platforms.rocm import on_gfx9, on_gfx11, on_gfx950

n = x.numel() // x.size(-1)
m = weight.shape[0]
Expand Down Expand Up @@ -188,7 +188,7 @@ def rocm_unquantized_gemm_impl(

use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
and (on_gfx9() or on_gfx11())
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
and x.is_contiguous()
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def on_gfx9() -> bool:
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])


@cache
def on_gfx11() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return "gfx11" in GPU_ARCH


@cache
def on_gfx942() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
Expand Down