-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[FEAT] [ROCm] Add AITER int8 scaled gemm kernel #15433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
492c6db
86f994a
895d6ba
a5a25a3
6e2832d
caf94ee
a26b31c
4d231f4
ab52481
9d81390
9754921
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,20 @@ | |
| sparse_cutlass_supported) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| ROCM_AITER_SUPPORTED_INT8_MODEL = [ | ||
| "neuralmagic/Llama-3.2-1B-quantized.w8a8", | ||
| "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" | ||
| ] | ||
|
|
||
| # TritonScaledMMLinearKernel only supports symmetric quantization. | ||
| ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ | ||
| "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", | ||
| "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", | ||
| "neuralmagic/Llama-3.2-1B-quantized.w8a8", | ||
| "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", | ||
| "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", | ||
| ] | ||
|
|
||
|
|
||
| @pytest.fixture(scope="function", autouse=True) | ||
| def use_v0_only(monkeypatch): | ||
|
|
@@ -57,6 +71,11 @@ def use_v0_only(monkeypatch): | |
| ) | ||
| def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): | ||
| model_path, strategy, quant_type, shape_0, is_symmetric = model_args | ||
|
|
||
| if current_platform.is_rocm( | ||
| ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: | ||
| pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") | ||
|
|
||
| with vllm_runner(model_path, enforce_eager=True) as llm: | ||
|
|
||
| def check_model(model): | ||
|
|
@@ -131,6 +150,11 @@ def test_compressed_tensors_w8a8_logprobs( | |
| max_tokens, | ||
| num_logprobs, | ||
| ): | ||
|
|
||
| if current_platform.is_rocm( | ||
| ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: | ||
| pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") | ||
|
|
||
| dtype = "bfloat16" | ||
|
|
||
| # skip language translation prompt for the static per tensor asym model | ||
|
|
@@ -154,6 +178,9 @@ def test_compressed_tensors_w8a8_logprobs( | |
| name_1="vllm", | ||
| ) | ||
|
|
||
| if current_platform.is_rocm(): | ||
| torch.cuda.synchronize() | ||
|
|
||
|
|
||
| def test_compressed_tensors_no_enforce_eager(vllm_runner): | ||
| model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" | ||
|
|
@@ -179,6 +206,11 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): | |
| ) | ||
| def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): | ||
| model_path, strategy = model_args | ||
|
|
||
| if current_platform.is_rocm( | ||
| ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: | ||
| pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") | ||
|
|
||
| with vllm_runner(model_path, dtype=torch.float16) as llm: | ||
|
|
||
| def check_model(model): | ||
|
|
@@ -207,6 +239,8 @@ def check_model(model): | |
| ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), | ||
| ], | ||
| ) | ||
| @pytest.mark.skipif(not current_platform.is_cuda(), | ||
| reason="The tests are skipped on non-CUDA platform.") | ||
| def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): | ||
| model, strategy, group, pack_factor = wNa16_args | ||
| with vllm_runner(model) as llm: | ||
|
|
@@ -231,6 +265,8 @@ def check_model(model): | |
| assert output | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not current_platform.is_cuda(), | ||
| reason="This test is skipped on non-CUDA platform.") | ||
| def test_compressed_tensors_w4a16_marlin24(vllm_runner): | ||
| model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" | ||
| with vllm_runner(model_path) as llm: | ||
|
|
@@ -271,7 +307,9 @@ def check_model(model): | |
|
|
||
| if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8): | ||
| assert len(qkv_proj.input_scale.shape) == 0 | ||
| assert qkv_proj.weight.dtype is torch.float8_e4m3fn | ||
| assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz | ||
| if current_platform.is_rocm() | ||
| else torch.float8_e4m3fn) | ||
|
||
| assert qkv_proj.weight_scale.dtype is torch.float32 | ||
| assert len(qkv_proj.weight_scale.shape) == 0 | ||
|
|
||
|
|
@@ -281,6 +319,8 @@ def check_model(model): | |
| assert output | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not current_platform.is_cuda(), | ||
| reason="This test is skipped on non-CUDA platform.") | ||
| def test_compressed_tensors_kv_cache(vllm_runner): | ||
| model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" | ||
| with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: | ||
|
|
@@ -309,7 +349,8 @@ def _test_2of4_quant_models(qkv_proj, | |
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.has_device_capability(90), | ||
| not current_platform.is_cuda() | ||
| or not current_platform.has_device_capability(90), | ||
| reason="Sparse FP8 is not yet supported on this GPU type.", | ||
| ) | ||
| @pytest.mark.parametrize( | ||
|
|
@@ -356,7 +397,8 @@ def check_model(model): | |
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.has_device_capability(90), | ||
| not current_platform.is_cuda() | ||
| or not current_platform.has_device_capability(90), | ||
| reason="Sparse FP8 is not yet supported on this GPU type.", | ||
| ) | ||
| @pytest.mark.parametrize( | ||
|
|
@@ -571,3 +613,90 @@ def check_model(model): | |
| output = llm.generate_greedy("Hello my name is", max_tokens=20) | ||
| print(output) | ||
| assert output | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "model_path", | ||
| [ | ||
| "neuralmagic/Llama-3.2-1B-quantized.w8a8", | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("max_tokens", [32]) | ||
| @pytest.mark.parametrize("num_logprobs", [10]) | ||
| @pytest.mark.skipif(not current_platform.is_rocm(), | ||
| reason="This tests is skipped on non-ROCm platform.") | ||
| def test_compressed_tensors_w8a8_logprobs_rocm_aiter( | ||
|
||
| hf_runner, | ||
| vllm_runner, | ||
| example_prompts, | ||
| model_path, | ||
| max_tokens, | ||
| num_logprobs, | ||
| monkeypatch, | ||
| ): | ||
| # this will enable VLLM_ROCM_USE_AITER_LINEAR | ||
| monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") | ||
|
|
||
| dtype = "bfloat16" | ||
|
|
||
| # skip language translation prompt for the static per tensor asym model | ||
| if (model_path == | ||
| "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" | ||
| ): # noqa: E501 | ||
| example_prompts = example_prompts[0:-1] | ||
|
|
||
| with hf_runner(model_path, dtype=dtype) as hf_model: | ||
| hf_outputs = hf_model.generate_greedy_logprobs_limit( | ||
| example_prompts, max_tokens, num_logprobs) | ||
|
|
||
| with vllm_runner(model_path, dtype=dtype) as vllm_model: | ||
| vllm_outputs = vllm_model.generate_greedy_logprobs( | ||
| example_prompts, max_tokens, num_logprobs) | ||
|
|
||
| check_logprobs_close( | ||
| outputs_0_lst=hf_outputs, | ||
| outputs_1_lst=vllm_outputs, | ||
| name_0="hf", | ||
| name_1="vllm", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "model_args", | ||
| [ | ||
| ( | ||
| "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", | ||
| "channel", | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.skipif(not current_platform.is_rocm(), | ||
| reason="This tests is skipped on non-ROCm platform.") | ||
| def test_compressed_tensors_w8a8_dynamic_per_token_rocm_aiter( | ||
| vllm_runner, | ||
| model_args, | ||
| monkeypatch, | ||
| ): | ||
|
|
||
| # this will enable VLLM_ROCM_USE_AITER_LINEAR | ||
| monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") | ||
|
|
||
| model_path, strategy = model_args | ||
| with vllm_runner(model_path, dtype=torch.float16) as llm: | ||
|
|
||
| def check_model(model): | ||
| layer = model.model.layers[0] | ||
|
|
||
| qkv_proj = layer.self_attn.qkv_proj | ||
|
|
||
| assert isinstance(qkv_proj.quant_method, | ||
| CompressedTensorsLinearMethod) | ||
| assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) | ||
| assert not qkv_proj.scheme.is_static_input_scheme | ||
| assert qkv_proj.scheme.strategy == strategy | ||
| assert qkv_proj.weight.dtype is torch.int8 | ||
|
|
||
| llm.apply_model(check_model) | ||
|
|
||
| output = llm.generate_greedy(["Hello my name is"], max_tokens=20) | ||
| assert output | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,12 @@ def register_fake(fn): | |
| from torch.library import impl_abstract as register_fake | ||
|
|
||
|
|
||
| def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: | ||
| return current_platform.is_rocm() \ | ||
| and envs.VLLM_ROCM_USE_AITER_LINEAR \ | ||
| and envs.VLLM_ROCM_USE_AITER | ||
|
|
||
|
|
||
| # page attention ops | ||
| def paged_attention_v1( | ||
| out: torch.Tensor, | ||
|
|
@@ -529,11 +535,34 @@ def cutlass_scaled_mm(a: torch.Tensor, | |
| n = b.shape[1] | ||
|
|
||
| if current_platform.is_rocm(): | ||
| triton_scaled_mm_module = importlib.import_module( | ||
| "vllm.model_executor.layers.quantization.compressed_tensors." | ||
| "triton_scaled_mm") | ||
| triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm | ||
| return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) | ||
| if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): | ||
|
||
| per_tensor_scale_a = (scale_a.numel() == 1) | ||
| per_tensor_scale_b = (scale_b.numel() == 1) | ||
| per_channel_tensor_scale_a = (scale_a.numel() == m) | ||
| per_channel_tensor_scale_b = (scale_b.numel() == n) | ||
|
|
||
| # @TODO: | ||
| # Maybe broadcast the per-tensor-scale into per-channel-scale | ||
| # if one of the scale is a per-channel-scale. | ||
| # For now, it only supports | ||
| # per-tensor-per-tensor a8w8 scaled GEMM and | ||
| # per-channel-per-channel a8w8 scacled GEMM | ||
| assert ( | ||
| (per_tensor_scale_a and per_tensor_scale_b) or | ||
| (per_channel_tensor_scale_a and per_channel_tensor_scale_b)), ( | ||
| "Currently only support per-tensor-per-tensor GEMM " + | ||
| " and per-channel-per-channel GEMM through AITER" | ||
| " w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + | ||
| " ATIER block scaled GEMM yet.") | ||
|
|
||
| from aiter import gemm_a8w8_CK | ||
| return gemm_a8w8_CK(a, b.t(), scale_a, scale_b, bias).to(out_dtype) | ||
| else: | ||
| triton_scaled_mm_module = importlib.import_module( | ||
| "vllm.model_executor.layers.quantization.compressed_tensors." | ||
| "triton_scaled_mm") | ||
| triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm | ||
| return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) | ||
|
|
||
| out = torch.empty((m, n), dtype=out_dtype, device=a.device) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,10 +143,12 @@ def triton_scaled_mm(input: torch.Tensor, | |
| scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b | ||
|
|
||
| assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() | ||
| assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( | ||
| [M, 1]) | ||
| assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( | ||
| [N, 1]) | ||
| assert scale_a.shape == (1, 1) or scale_a.shape == (M, 1) | ||
| assert scale_b.shape == (1, 1) or scale_b.shape == (N, 1) | ||
| # assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( | ||
| # [M, 1]) | ||
| # assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( | ||
| # [N, 1]) | ||
|
||
| assert out_dtype.is_floating_point | ||
| assert bias is None or bias.is_floating_point() | ||
| assert is_weak_contiguous(input) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from .cutlass import CutlassScaledMMLinearKernel | ||
| from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig | ||
|
|
||
|
|
||
| class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| return 90 | ||
|
|
||
| @classmethod | ||
| def can_implement( | ||
| cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: | ||
| if current_platform.is_cpu(): | ||
| return ( | ||
| False, | ||
| "AiterScaledMMLinearKernel requires `aiter` which is not " + | ||
| "currently supported on CPU.") | ||
| if not current_platform.is_rocm(): | ||
|
||
| return ( | ||
| False, | ||
| "AiterScaledMMLinearKernel requires `aiter` which is only " + | ||
| "currently supported on ROCm.") | ||
| # try import aiter | ||
| try: | ||
| pass | ||
| except Exception: | ||
| return ( | ||
| False, | ||
| "AiterScaledMMLinearKernel requires `aiter` which is not " + | ||
| "installed supported on ROCm.") | ||
|
||
| if not ops.is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): | ||
| return (False, "AiterScaledMMLinearKernel is disabled. " + | ||
| "Enable by setting `VLLM_ROCM_USE_AITER=1`.") | ||
|
|
||
| if not c.input_symmetric: | ||
| return (False, | ||
| "AiterScaledMMLinearKernel only supports symmetric " + | ||
| "quantization.") | ||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| super().process_weights_after_loading(layer) | ||
|
|
||
| def apply_weights(self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| return super().apply_weights(layer, x, bias) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is a bit confusing. What models are and aren't supported by aiter vs Triton?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AITER only supports per-channel-per-channel INT8 gemm and per-tensor-per-tensor INT8 GEMM. It does not support mix precision MM and mix quantization scheme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add that as a comment?