Skip to content
118 changes: 27 additions & 91 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,30 @@ def zp_valid(zp: Optional[torch.Tensor]):
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.parametrize(
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_compressed_tensors_w8a8_logprobs(
hf_runner,
vllm_runner,
example_prompts,
model_path,
max_tokens,
num_logprobs,
use_aiter,
monkeypatch,
):

if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

if use_aiter:
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
pytest.skip(
f"Skip model {model_path} as it is not support by aiter.")
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

dtype = "bfloat16"

# skip language translation prompt for the static per tensor asym model
Expand Down Expand Up @@ -204,13 +215,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
),
],
)
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
@pytest.mark.parametrize(
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_compressed_tensors_w8a8_dynamic_per_token(
vllm_runner,
model_args,
use_aiter,
monkeypatch,
):
model_path, strategy = model_args

if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

if use_aiter:
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
pytest.skip(
f"Skip model {model_path} as it is not support by aiter.")
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model_path, dtype=torch.float16) as llm:

def check_model(model):
Expand Down Expand Up @@ -307,9 +332,7 @@ def check_model(model):

if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz
if current_platform.is_rocm()
else torch.float8_e4m3fn)
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0

Expand Down Expand Up @@ -613,90 +636,3 @@ def check_model(model):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output


@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
],
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="This tests is skipped on non-ROCm platform.")
def test_compressed_tensors_w8a8_logprobs_rocm_aiter(
hf_runner,
vllm_runner,
example_prompts,
model_path,
max_tokens,
num_logprobs,
monkeypatch,
):
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

dtype = "bfloat16"

# skip language translation prompt for the static per tensor asym model
if (model_path ==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
): # noqa: E501
example_prompts = example_prompts[0:-1]

with hf_runner(model_path, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize(
"model_args",
[
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
"channel",
),
],
)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="This tests is skipped on non-ROCm platform.")
def test_compressed_tensors_w8a8_dynamic_per_token_rocm_aiter(
vllm_runner,
model_args,
monkeypatch,
):

# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8

llm.apply_model(check_model)

output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
39 changes: 5 additions & 34 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def register_fake(fn):
from torch.library import impl_abstract as register_fake


def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_LINEAR \
and envs.VLLM_ROCM_USE_AITER


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
Expand Down Expand Up @@ -547,34 +541,11 @@ def cutlass_scaled_mm(a: torch.Tensor,
n = b.shape[1]

if current_platform.is_rocm():
if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled():
per_tensor_scale_a = (scale_a.numel() == 1)
per_tensor_scale_b = (scale_b.numel() == 1)
per_channel_tensor_scale_a = (scale_a.numel() == m)
per_channel_tensor_scale_b = (scale_b.numel() == n)

# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports
# per-tensor-per-tensor a8w8 scaled GEMM and
# per-channel-per-channel a8w8 scacled GEMM
assert (
(per_tensor_scale_a and per_tensor_scale_b) or
(per_channel_tensor_scale_a and per_channel_tensor_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-channel-per-channel GEMM through AITER"
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" +
" ATIER block scaled GEMM yet.")

from aiter import gemm_a8w8_CK
return gemm_a8w8_CK(a, b.t(), scale_a, scale_b, bias).to(out_dtype)
else:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

out = torch.empty((m, n), dtype=out_dtype, device=a.device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,10 @@ def triton_scaled_mm(input: torch.Tensor,
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b

assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == (1, 1) or scale_a.shape == (M, 1)
assert scale_b.shape == (1, 1) or scale_b.shape == (N, 1)
# assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
# [M, 1])
# assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
# [N, 1])
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
[N, 1])
assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig


def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should just be inlined to the sole callsite (unless I'm missing another use)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.

return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_LINEAR \
and envs.VLLM_ROCM_USE_AITER


class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):

@classmethod
Expand All @@ -20,25 +27,20 @@ def get_min_capability(cls) -> int:
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
if current_platform.is_cpu() or not current_platform.is_rocm():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be simpler?

Suggested change
if current_platform.is_cpu() or not current_platform.is_rocm():
if not current_platform.is_rocm():

Copy link
Collaborator Author

@tjtanaa tjtanaa Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I have removed the check for CPU.

return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"currently supported on CPU.")
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is only " +
"currently supported on ROCm.")
# try import aiter
"currently supported on CPU and non-ROCm platform.")

try:
pass
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"installed supported on ROCm.")
if not ops.is_rocm_aiter_gemm_w8a8_scaled_mm_enabled():
if not is_rocm_aiter_gemm_w8a8_scaled_mm_enabled():
return (False, "AiterScaledMMLinearKernel is disabled. " +
"Enable by setting `VLLM_ROCM_USE_AITER=1`.")

Expand All @@ -55,4 +57,58 @@ def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-channel-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, ("AiterScaledMMLinearKernel only supports"
" symmetric quantization.")
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
i_s,
i_zp,
symmetric=symmetric)

assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
" symmetric quantization.")
out_dtype = x.dtype

assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == w_q.shape[
1] and bias.dtype == out_dtype

m = x_q.shape[0] # a
n = w_q.shape[1] # b

per_tensor_scale_a = (x_s.numel() == 1)
per_tensor_scale_b = (w_s.numel() == 1)
per_channel_tensor_scale_a = (x_s.numel() == m)
per_channel_tensor_scale_b = (w_s.numel() == n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this more accurate here?

Suggested change
per_channel_tensor_scale_a = (x_s.numel() == m)
per_channel_tensor_scale_b = (w_s.numel() == n)
per_token_scale_a = (x_s.numel() == m)
per_channel_scale_b = (w_s.numel() == n)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right. I have made the amendments. Thank you so much.


# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports
# per-tensor-per-tensor a8w8 scaled GEMM and
# per-channel-per-channel a8w8 scacled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b) or
(per_channel_tensor_scale_a and per_channel_tensor_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-channel-per-channel GEMM through AITER"
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" +
" ATIER block scaled GEMM yet.")

from aiter import gemm_a8w8_CK
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious for future work: does this kernel support fp8?

Also, can you add a comment why w_q needs to be transposed here? I assume because it's using the Cutlass prepare weights which are transposed so here we restore it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm/aiter does not support FP8 at this moment.
I have added the comment.