Skip to content

Commit 2dacd57

Browse files
authored
[platform] Move get_cu_count to utils (#27005)
Signed-off-by: wangxiyuan <[email protected]>
1 parent d75ad04 commit 2dacd57

File tree

6 files changed

+28
-18
lines changed

6 files changed

+28
-18
lines changed

tests/kernels/quantization/test_rocm_skinny_gemms.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import vllm._custom_ops as ops
99
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
1010
from vllm.platforms import current_platform
11+
from vllm.utils.platform_utils import get_cu_count
1112

1213
DTYPES = [torch.bfloat16, torch.float16]
1314
# Specific (N, K, M) combinations for targeted testing
@@ -85,7 +86,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
8586
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
8687
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
8788
torch.manual_seed(seed)
88-
cu_count = current_platform.get_cu_count()
89+
cu_count = get_cu_count()
8990

9091
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
9192
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
@@ -102,7 +103,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
102103
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
103104
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
104105
torch.manual_seed(seed)
105-
cu_count = current_platform.get_cu_count()
106+
cu_count = get_cu_count()
106107

107108
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
108109
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@@ -121,7 +122,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
121122
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
122123
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
123124
torch.manual_seed(seed)
124-
cu_count = current_platform.get_cu_count()
125+
cu_count = get_cu_count()
125126

126127
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
127128
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@@ -153,7 +154,14 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
153154
ref_out = torch._scaled_mm(
154155
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
155156
)
156-
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
157+
out = ops.wvSplitKQ(
158+
B,
159+
A,
160+
dtype,
161+
scale_a,
162+
scale_b,
163+
get_cu_count(),
164+
)
157165

158166
assert torch.allclose(out, ref_out, rtol=0.01)
159167

@@ -180,7 +188,13 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
180188
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
181189
)
182190
out = ops.wvSplitKQ(
183-
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
191+
B,
192+
A,
193+
dtype,
194+
scale_a,
195+
scale_b,
196+
get_cu_count(),
197+
BIAS,
184198
)
185199

186200
assert torch.allclose(out, ref_out, rtol=0.01)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
1414
from vllm.platforms import current_platform
1515
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
16+
from vllm.utils.platform_utils import get_cu_count
1617
from vllm.utils.torch_utils import direct_register_custom_op
1718

1819
# Input scaling factors are no longer optional in _scaled_mm starting
@@ -200,7 +201,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
200201
out_dtype,
201202
scale_a,
202203
scale_b,
203-
current_platform.get_cu_count(),
204+
get_cu_count(),
204205
bias,
205206
)
206207
else:

vllm/model_executor/layers/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm._aiter_ops import rocm_aiter_ops
1212
from vllm.logger import init_logger
1313
from vllm.platforms import CpuArchEnum, current_platform
14+
from vllm.utils.platform_utils import get_cu_count
1415
from vllm.utils.torch_utils import direct_register_custom_op
1516

1617
logger = init_logger(__name__)
@@ -151,7 +152,7 @@ def rocm_unquantized_gemm_impl(
151152

152153
x_view = x.reshape(-1, x.size(-1))
153154
if m > 8 and 0 < n <= 4:
154-
cu_count = current_platform.get_cu_count()
155+
cu_count = get_cu_count()
155156
out = ops.wvSplitK(weight, x_view, cu_count, bias)
156157
return out.reshape(*x.shape[:-1], weight.shape[0])
157158
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:

vllm/platforms/interface.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,6 @@ def get_global_graph_pool(self) -> Any:
545545
cls._global_graph_pool = self.graph_pool_handle()
546546
return cls._global_graph_pool
547547

548-
@classmethod
549-
def get_cu_count(cls, device_id: int = 0) -> int:
550-
"""
551-
Returns the total number of compute units (CU) on single GPU.
552-
"""
553-
raise NotImplementedError
554-
555548
@classmethod
556549
def get_static_graph_wrapper_cls(cls) -> str:
557550
"""

vllm/platforms/rocm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,6 @@ def use_custom_allreduce(cls) -> bool:
423423
def opaque_attention_op(cls) -> bool:
424424
return True
425425

426-
@classmethod
427-
def get_cu_count(cls, device_id: int = 0) -> int:
428-
return torch.cuda.get_device_properties(device_id).multi_processor_count
429-
430426
@classmethod
431427
def is_navi(cls) -> bool:
432428
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName

vllm/utils/platform_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def xpu_is_initialized() -> bool:
2424
return torch.xpu.is_initialized()
2525

2626

27+
def get_cu_count(cls, device_id: int = 0) -> int:
28+
"""Returns the total number of compute units (CU) on single GPU."""
29+
return torch.cuda.get_device_properties(device_id).multi_processor_count
30+
31+
2732
def cuda_get_device_properties(
2833
device, names: Sequence[str], init_cuda=False
2934
) -> tuple[Any, ...]:

0 commit comments

Comments
 (0)