Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions tests/quantization/test_gptq_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from types import SimpleNamespace

import torch

import vllm.model_executor.layers.fused_moe as fused_moe_mod
import vllm.model_executor.layers.linear as linear_mod
import vllm.model_executor.layers.quantization.gptq_marlin as gptq_marlin_mod
import vllm.model_executor.parameter as parameter_mod
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.quantization.inc import INCConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from vllm.model_executor.model_loader.weight_utils import default_weight_loader


class _DummyKernel:
def __init__(self, *args, **kwargs):
pass

def process_weights_after_loading(self, layer):
return None

def apply_weights(self, layer, x, bias):
out = torch.zeros((*x.shape[:-1], layer.output_size), dtype=x.dtype)
if bias is not None:
out = out + bias
return out


def test_gptq_marlin_create_weights_uses_ceil_groups_for_row_parallel(
monkeypatch,
):
monkeypatch.setattr(
gptq_marlin_mod, "verify_marlin_supported", lambda **kwargs: None
)
monkeypatch.setattr(
gptq_marlin_mod, "choose_mp_linear_kernel", lambda _: _DummyKernel
)
monkeypatch.setattr(parameter_mod, "get_tensor_model_parallel_rank", lambda: 0)
monkeypatch.setattr(parameter_mod, "get_tensor_model_parallel_world_size", lambda: 1)

config = GPTQMarlinConfig(
weight_bits=4,
group_size=128,
desc_act=False,
is_sym=True,
lm_head_quantized=False,
dynamic={},
full_config={},
)
method = GPTQMarlinLinearMethod(config)
layer = torch.nn.Module()

method.create_weights(
layer=layer,
input_size_per_partition=2112,
output_partition_sizes=[2816],
input_size=4224,
output_size=2816,
params_dtype=torch.float16,
weight_loader=default_weight_loader,
)

assert layer.scales.shape == (17, 2816)
assert layer.qzeros.shape == (17, 352)


def test_inc_row_parallel_overlap_group_load_uses_global_group_offsets(
monkeypatch,
):
monkeypatch.setattr(
gptq_marlin_mod, "verify_marlin_supported", lambda **kwargs: None
)
monkeypatch.setattr(
gptq_marlin_mod, "choose_mp_linear_kernel", lambda _: _DummyKernel
)
monkeypatch.setattr(linear_mod, "get_tensor_model_parallel_rank", lambda: 1)
monkeypatch.setattr(linear_mod, "get_tensor_model_parallel_world_size", lambda: 2)
monkeypatch.setattr(parameter_mod, "get_tensor_model_parallel_rank", lambda: 1)
monkeypatch.setattr(parameter_mod, "get_tensor_model_parallel_world_size", lambda: 2)

quant_config = INCConfig(
weight_bits=4,
group_size=128,
sym=True,
packing_format="auto_round:auto_gptq",
backend="marlin",
)
layer = RowParallelLinear(
2112,
2816,
bias=False,
quant_config=quant_config,
prefix="model.layers.0.mlp.down_proj",
)

scales = torch.arange(17 * 2816, dtype=layer.scales.dtype).reshape(17, 2816)
qzeros = torch.arange(17 * 352, dtype=layer.qzeros.dtype).reshape(17, 352)

layer.weight_loader_v2(layer.scales, scales)
layer.weight_loader_v2(layer.qzeros, qzeros)

assert torch.equal(layer.scales.data, scales[8:17])
assert torch.equal(layer.qzeros.data, qzeros[8:17])


def test_gate_linear_quantized_forward_does_not_require_unquantized_weight(
monkeypatch,
):
monkeypatch.setattr(
gptq_marlin_mod, "verify_marlin_supported", lambda **kwargs: None
)
monkeypatch.setattr(
gptq_marlin_mod, "choose_mp_linear_kernel", lambda _: _DummyKernel
)
monkeypatch.setattr(linear_mod, "get_tensor_model_parallel_rank", lambda: 0)
monkeypatch.setattr(linear_mod, "get_tensor_model_parallel_world_size", lambda: 1)
monkeypatch.setattr(parameter_mod, "get_tensor_model_parallel_rank", lambda: 0)
monkeypatch.setattr(parameter_mod, "get_tensor_model_parallel_world_size", lambda: 1)

quant_config = INCConfig(
weight_bits=4,
group_size=128,
sym=True,
packing_format="auto_round:auto_gptq",
backend="marlin",
)
gate = GateLinear(
2816,
128,
bias=False,
out_dtype=torch.float32,
quant_config=quant_config,
prefix="model.layers.0.router.proj",
)

assert not hasattr(gate, "weight")
assert hasattr(gate, "qweight")

x = torch.randn(3, 2816, dtype=torch.float16)
output, output_bias = gate(x)

assert output_bias is None
assert output.shape == (3, 128)
assert output.dtype == torch.float32


def test_moe_wna16_apply_passes_layer_activation(monkeypatch):
captured: dict[str, object] = {}

def fake_fused_experts(x, w1, w2, **kwargs):
captured["activation"] = kwargs["activation"]
captured["quant_config"] = kwargs["quant_config"]
return torch.zeros_like(x)

monkeypatch.setattr(fused_moe_mod, "fused_experts", fake_fused_experts)

method = MoeWNA16Method(
SimpleNamespace(),
SimpleNamespace(disable_inplace=False),
)
method.moe_quant_config = object()

layer = SimpleNamespace(
activation=MoEActivation.GELU,
w13_qweight=torch.empty(2, 4, 4, dtype=torch.uint8),
w2_qweight=torch.empty(2, 4, 4, dtype=torch.uint8),
apply_router_weight_on_input=False,
global_num_experts=2,
expert_map=None,
)

x = torch.randn(3, 8, dtype=torch.float16)
topk_weights = torch.ones(3, 1, dtype=torch.float32)
topk_ids = torch.zeros(3, 1, dtype=torch.int32)

output = method.apply(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=None,
)

assert captured["activation"] == MoEActivation.GELU
assert captured["quant_config"] is method.moe_quant_config
assert output.shape == x.shape
18 changes: 13 additions & 5 deletions vllm/model_executor/layers/fused_moe/router/gate_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.platforms import current_platform


Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(
out_dtype: torch.dtype | None = None,
params_dtype: torch.dtype | None = None,
force_fp32_compute: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
is_hopper_or_blackwell = current_platform.is_device_capability(
Expand All @@ -52,10 +54,12 @@ def __init__(
output_size,
bias=bias,
params_dtype=params_dtype,
quant_config=None,
quant_config=quant_config,
prefix=prefix,
)
self.out_dtype = out_dtype
weight = getattr(self, "weight", None)
weight_dtype = None if weight is None else weight.dtype

# DSV3 specialized kernel eligibility (SM90+, exact dims)
self.allow_specialized_router_gemm = can_use_specialized_kernels
Expand All @@ -68,7 +72,7 @@ def __init__(
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
and self.weight.dtype == torch.bfloat16
and weight_dtype == torch.bfloat16
and self.out_dtype == torch.float32
)

Expand All @@ -87,7 +91,10 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None:
and self.allow_specialized_router_gemm
and out_dtype == torch.float32
):
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
weight = getattr(self, "weight", None)
self.allow_cublas_router_gemm = (
weight is not None and weight.dtype == torch.bfloat16
)

def forward(
self, x: torch.Tensor
Expand All @@ -109,8 +116,9 @@ def forward(
return output, None

# Tier 3: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
weight = getattr(self, "weight", None)
if self.out_dtype is not None and weight is not None and x.dtype != weight.dtype:
x = x.to(weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import cdiv

logger = init_logger(__name__)

Expand Down Expand Up @@ -395,12 +396,12 @@ def create_weights(
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
scales_and_zp_size = cdiv(input_size, group_size)
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_size = cdiv(input_size_per_partition, group_size)

# Quantized weights
qweight = PackedvLLMParameter(
Expand Down Expand Up @@ -463,6 +464,12 @@ def create_weights(
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
row_group_attrs = {
"row_group_size": group_size,
"row_input_size_per_partition": input_size_per_partition,
}
set_weight_attrs(scales, row_group_attrs)
set_weight_attrs(qzeros, row_group_attrs)

layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
Expand Down
5 changes: 1 addition & 4 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,14 @@ def apply(
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)

return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
config.num_experts,
bias=False,
out_dtype=torch.float32,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)

Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,13 @@ def input_dim(self):

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, self.tp_rank * shard_size, shard_size
)
group_size = getattr(self, "row_group_size", None)
input_partition_size = getattr(self, "row_input_size_per_partition", None)
if group_size is not None and input_partition_size is not None:
start_idx = (self.tp_rank * input_partition_size) // group_size
else:
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.input_dim, start_idx, shard_size)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
Expand Down
Loading