Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/quantization/test_quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
QuarkW8A8Fp8,
QuarkW8A8Int8,
)
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkW8A8Int8MoEMethod,
)
from vllm.platforms import current_platform

from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
Expand Down Expand Up @@ -126,6 +129,34 @@ def check_model(model):
assert output


@pytest.mark.parametrize("tp", [1])
def test_quark_int8_w8a8_moe(vllm_runner, tp):
"""Test W8A8 INT8 MoE quantization with a tiny Qwen3 MoE model."""
model_path = "nameistoken/tiny-qwen3-moe-w8a8-int8-quark"
with vllm_runner(
model_path,
enforce_eager=True,
tensor_parallel_size=tp,
gpu_memory_utilization=0.1,
) as llm:

def check_model(model):
layer = model.model.layers[0]
# MoE experts should use QuarkW8A8Int8MoEMethod
moe = layer.mlp.experts
assert isinstance(moe.quant_method, QuarkW8A8Int8MoEMethod), (
f"Expected QuarkW8A8Int8MoEMethod, got {type(moe.quant_method)}"
)
# Non-MoE linear layers should use QuarkW8A8Int8
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

llm.apply_model(check_model)

output = llm.generate_greedy("Hello", max_tokens=4)
assert output


def test_quark_fp8_parity(vllm_runner):
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,15 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None:
assert per_act_token, "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
if per_act_token:
A, A_scale = per_token_quant_int8(A)
elif A_scale is not None:
# Static per-tensor: use the optimized CUDA kernel
A, A_scale, _ = ops.scaled_int8_quant(A, scale=A_scale)
elif A_scale is None:
# Dynamic per-tensor: compute scale then quantize via kernel
A_scale = torch.clamp(A.abs().max() / 127.0, min=1e-10)
A, A_scale, _ = ops.scaled_int8_quant(A, scale=A_scale)
else:
assert not per_act_token
assert len(block_shape) == 2
Expand Down
38 changes: 38 additions & 0 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,37 @@ def _is_w4a8_mxfp4_fp8(

return is_weight_mxfp4 and is_input_fp8

def _is_dynamic_per_token_w8a8(
self,
weight_quant: dict[str, Any] | None,
input_quant: dict[str, Any] | None,
) -> bool:
"""Detect W8A8 INT8 with per-tensor or per-channel
weights and dynamic per-token input."""
if weight_quant is None or input_quant is None:
return False

is_int8_dtype = (
weight_quant.get("dtype") == "int8" and input_quant.get("dtype") == "int8"
)

is_valid_weight_scheme = weight_quant.get("qscheme") in [
"per_tensor",
"per_channel",
]
is_per_token_input = input_quant.get("qscheme") == "per_channel"

is_dynamic_input = input_quant.get("is_dynamic") is True
is_weight_symmetric = weight_quant.get("symmetric") is True

return (
is_int8_dtype
and is_valid_weight_scheme
and is_per_token_input
and is_dynamic_input
and is_weight_symmetric
)

def _is_w_ocp_mx_a_x(
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
) -> bool:
Expand Down Expand Up @@ -556,6 +587,13 @@ def _get_scheme_from_config(
)
if is_w4a8_supported:
return QuarkW4A8_MXFP4_FP8(weight_config, input_config)
elif self._is_dynamic_per_token_w8a8(weight_config, input_config):
weight_qscheme = cast(str, weight_config.get("qscheme"))
return QuarkW8A8Int8(
qscheme=weight_qscheme,
is_static_input_scheme=False,
input_symmetric=input_config.get("symmetric"),
)
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
Expand Down
Loading
Loading