diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index adb966c4b1c0..751b86787c7b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor, m = weight.shape[0] cu_count = current_platform.get_cu_count() - if m > 8 and 0 < n < 4: + if m > 8 and 0 < n <= 4: out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 43911aab5fd2..bb271513fb17 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +@cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])