Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions tests/kernels/quantization/test_rocm_skinny_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
from vllm.utils.platform_utils import get_cu_count

DTYPES = [torch.bfloat16, torch.float16]
# Specific (N, K, M) combinations for targeted testing
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()

A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
Expand All @@ -102,7 +103,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()

xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
Expand All @@ -121,7 +122,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()

xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
Expand Down Expand Up @@ -153,7 +154,14 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
out = ops.wvSplitKQ(
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
)

assert torch.allclose(out, ref_out, rtol=0.01)

Expand All @@ -180,7 +188,13 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
BIAS,
)

assert torch.allclose(out, ref_out, rtol=0.01)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op

# Input scaling factors are no longer optional in _scaled_mm starting
Expand Down Expand Up @@ -200,7 +201,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
out_dtype,
scale_a,
scale_b,
current_platform.get_cu_count(),
get_cu_count(),
bias,
)
else:
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op

logger = init_logger(__name__)
Expand Down Expand Up @@ -151,7 +152,7 @@ def rocm_unquantized_gemm_impl(

x_view = x.reshape(-1, x.size(-1))
if m > 8 and 0 < n <= 4:
cu_count = current_platform.get_cu_count()
cu_count = get_cu_count()
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
Expand Down
7 changes: 0 additions & 7 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,13 +542,6 @@ def get_global_graph_pool(self) -> Any:
cls._global_graph_pool = self.graph_pool_handle()
return cls._global_graph_pool

@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
"""
Returns the total number of compute units (CU) on single GPU.
"""
raise NotImplementedError

@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
"""
Expand Down
4 changes: 0 additions & 4 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,6 @@ def use_custom_allreduce(cls) -> bool:
def opaque_attention_op(cls) -> bool:
return True

@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count

@classmethod
def is_navi(cls) -> bool:
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
Expand Down
5 changes: 5 additions & 0 deletions vllm/utils/platform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def xpu_is_initialized() -> bool:
return torch.xpu.is_initialized()


def get_cu_count(cls, device_id: int = 0) -> int:
"""Returns the total number of compute units (CU) on single GPU."""
return torch.cuda.get_device_properties(device_id).multi_processor_count


def cuda_get_device_properties(
device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]:
Expand Down