diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index d279ffe45d6d..0114a309b987 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, output_shape: List) -> torch.Tensor: - if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[ - 0] == 1 and qinput.shape[1] % 16 == 0: + from vllm.platforms.rocm import on_mi250_mi300 + if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300( + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) else: @@ -371,7 +372,7 @@ def apply( return w8a8_scaled_mm_func(qinput=qinput, weight=weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 686d031f7b72..adb966c4b1c0 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def rocm_unquantized_gemm(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + from vllm.platforms.rocm import on_mi250_mi300 k = weight.shape[1] - use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \ + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) @@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor, m = weight.shape[0] cu_count = current_platform.get_cu_count() - if m > 8 and n < 4: + if m > 8 and 0 < n < 4: out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: - out = ops.LLMM1(weight, x_view, out, 4) + out = ops.LLMM1(weight, x_view, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 1eb0c8c2ef4e..d5eaeec1ae24 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,6 +12,7 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -40,7 +41,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return F.linear(x, layer.weight, bias) + return dispatch_unquantized_gemm()(x, layer.weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 944879b94ecd..3d5e90dc32a8 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +def on_mi250_mi300() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + + @cache def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, gqa_ratio: int, max_seq_len: int, sliding_window: int) -> bool: - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName - ON_NAVI = "gfx1" in GPU_ARCH - ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) - - # rocm custom page attention not support on navi (gfx1*) + # rocm custom page attention not support on gfx1* # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. - return (ON_MI250_MI300 and not ON_NAVI - and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) + return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32)