Skip to content
135 changes: 132 additions & 3 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
sparse_cutlass_supported)
from vllm.platforms import current_platform

ROCM_AITER_SUPPORTED_INT8_MODEL = [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
]

# TritonScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]


@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
Expand Down Expand Up @@ -57,6 +71,11 @@ def use_v0_only(monkeypatch):
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a bit confusing. What models are and aren't supported by aiter vs Triton?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER only supports per-channel-per-channel INT8 gemm and per-tensor-per-tensor INT8 GEMM. It does not support mix precision MM and mix quantization scheme.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add that as a comment?

if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

with vllm_runner(model_path, enforce_eager=True) as llm:

def check_model(model):
Expand Down Expand Up @@ -131,6 +150,11 @@ def test_compressed_tensors_w8a8_logprobs(
max_tokens,
num_logprobs,
):

if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

dtype = "bfloat16"

# skip language translation prompt for the static per tensor asym model
Expand All @@ -154,6 +178,9 @@ def test_compressed_tensors_w8a8_logprobs(
name_1="vllm",
)

if current_platform.is_rocm():
torch.cuda.synchronize()


def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
Expand All @@ -179,6 +206,11 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
)
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args

if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

with vllm_runner(model_path, dtype=torch.float16) as llm:

def check_model(model):
Expand Down Expand Up @@ -207,6 +239,8 @@ def check_model(model):
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
],
)
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="The tests are skipped on non-CUDA platform.")
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
Expand All @@ -231,6 +265,8 @@ def check_model(model):
assert output


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm:
Expand Down Expand Up @@ -271,7 +307,9 @@ def check_model(model):

if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz
if current_platform.is_rocm()
else torch.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use current_platform.fp8_dtype()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0

Expand All @@ -281,6 +319,8 @@ def check_model(model):
assert output


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
Expand Down Expand Up @@ -309,7 +349,8 @@ def _test_2of4_quant_models(qkv_proj,


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
not current_platform.is_cuda()
or not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -356,7 +397,8 @@ def check_model(model):


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
not current_platform.is_cuda()
or not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -571,3 +613,90 @@ def check_model(model):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output


@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
],
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="This tests is skipped on non-ROCm platform.")
def test_compressed_tensors_w8a8_logprobs_rocm_aiter(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be folkded into the existing tests, by adding a boolean use_aiter parameter in the tests? And we can do [False] if <platform ...> else [False, True]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

hf_runner,
vllm_runner,
example_prompts,
model_path,
max_tokens,
num_logprobs,
monkeypatch,
):
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

dtype = "bfloat16"

# skip language translation prompt for the static per tensor asym model
if (model_path ==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
): # noqa: E501
example_prompts = example_prompts[0:-1]

with hf_runner(model_path, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize(
"model_args",
[
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
"channel",
),
],
)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="This tests is skipped on non-ROCm platform.")
def test_compressed_tensors_w8a8_dynamic_per_token_rocm_aiter(
vllm_runner,
model_args,
monkeypatch,
):

# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8

llm.apply_model(check_model)

output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
39 changes: 34 additions & 5 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def register_fake(fn):
from torch.library import impl_abstract as register_fake


def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_LINEAR \
and envs.VLLM_ROCM_USE_AITER


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
Expand Down Expand Up @@ -529,11 +535,34 @@ def cutlass_scaled_mm(a: torch.Tensor,
n = b.shape[1]

if current_platform.is_rocm():
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't live inside cutlass_scaled_mm, has nothing to do with cutlass. This code should just live inside AiterScaledMMLinearKernel.apply.

I know the Triton kernel is here but it shouldn't be either, I'm currently refactoring that.

per_tensor_scale_a = (scale_a.numel() == 1)
per_tensor_scale_b = (scale_b.numel() == 1)
per_channel_tensor_scale_a = (scale_a.numel() == m)
per_channel_tensor_scale_b = (scale_b.numel() == n)

# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports
# per-tensor-per-tensor a8w8 scaled GEMM and
# per-channel-per-channel a8w8 scacled GEMM
assert (
(per_tensor_scale_a and per_tensor_scale_b) or
(per_channel_tensor_scale_a and per_channel_tensor_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-channel-per-channel GEMM through AITER"
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" +
" ATIER block scaled GEMM yet.")

from aiter import gemm_a8w8_CK
return gemm_a8w8_CK(a, b.t(), scale_a, scale_b, bias).to(out_dtype)
else:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

out = torch.empty((m, n), dtype=out_dtype, device=a.device)

Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
Expand Down Expand Up @@ -511,6 +512,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),

# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
"VLLM_ROCM_USE_AITER_LINEAR":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in
("true", "1")),

# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ def triton_scaled_mm(input: torch.Tensor,
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b

assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
[N, 1])
assert scale_a.shape == (1, 1) or scale_a.shape == (M, 1)
assert scale_b.shape == (1, 1) or scale_b.shape == (N, 1)
# assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
# [M, 1])
# assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
# [N, 1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up comments

assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
from typing import Dict, List, Optional, Type

from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Expand All @@ -17,7 +19,7 @@
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch

from vllm import _custom_ops as ops
from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig


class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):

@classmethod
def get_min_capability(cls) -> int:
return 90

@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"currently supported on CPU.")
if not current_platform.is_rocm():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a single check

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is only " +
"currently supported on ROCm.")
# try import aiter
try:
pass
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"installed supported on ROCm.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you forgot to import aiter here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed this is missing the import

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. It seems ruff remove the import aiter. I have annotated this line. Ruff will not changed it into pass.

if not ops.is_rocm_aiter_gemm_w8a8_scaled_mm_enabled():
return (False, "AiterScaledMMLinearKernel is disabled. " +
"Enable by setting `VLLM_ROCM_USE_AITER=1`.")

if not c.input_symmetric:
return (False,
"AiterScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)