Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/quantization/test_auto_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,19 @@
"""

import pytest
import torch

from vllm.model_executor.layers.quantization.inc import (
INCGPTQRowParallelTailLinearMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
)
from vllm.model_executor.models.gemma4 import (
_dequantize_autoround_gptq_router_weight,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

MODELS = [
"OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq
Expand All @@ -30,3 +41,91 @@ def test_auto_round(vllm_runner, model):
output = llm.generate_greedy(["The capital of France is"], max_tokens=8)
assert output
print(f"{output[0][1]}")


def test_autoround_gptq_router_weight_dequantizes_symmetric_zero_point():
qweight_unpacked = (torch.arange(64, dtype=torch.int32).reshape(8, 8) % 8) + 8
qzeros_unpacked = torch.full((2, 8), 7, dtype=torch.int32)
scales = torch.stack(
(
torch.linspace(0.5, 1.2, 8),
torch.linspace(1.5, 2.2, 8),
)
)

qweight = pack_quantized_values_into_int32(
qweight_unpacked, scalar_types.uint4b8, packed_dim=0
)
qzeros = pack_quantized_values_into_int32(
qzeros_unpacked, scalar_types.uint4b8, packed_dim=1
)

weight = _dequantize_autoround_gptq_router_weight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
num_bits=4,
group_size=4,
sym=True,
params_dtype=torch.float16,
)

expected_qzeros = qzeros_unpacked + 1
row_groups = torch.arange(qweight_unpacked.shape[0]) // 4
expected = (
(qweight_unpacked - expected_qzeros[row_groups]) * scales[row_groups]
).t()
torch.testing.assert_close(weight, expected.to(torch.float16))


def test_inc_gptq_row_parallel_tail_fallback_uses_global_group_indices(monkeypatch):
import vllm.model_executor.layers.quantization.inc as inc
import vllm.model_executor.parameter as parameter

monkeypatch.setattr(inc, "get_tensor_model_parallel_rank", lambda: 1)
monkeypatch.setattr(parameter, "get_tensor_model_parallel_rank", lambda: 1)
monkeypatch.setattr(parameter, "get_tensor_model_parallel_world_size", lambda: 2)

method = INCGPTQRowParallelTailLinearMethod(
weight_bits=4,
group_size=16,
sym=True,
)
layer = torch.nn.Module()
layer.input_size_per_partition = 24
method.create_weights(
layer,
input_size_per_partition=24,
output_partition_sizes=[8],
input_size=48,
output_size=8,
params_dtype=torch.float32,
)

assert layer.g_idx.tolist() == [1] * 8 + [2] * 16

qweight_unpacked = torch.full((24, 8), 9, dtype=torch.int32)
layer.qweight.data.copy_(
pack_quantized_values_into_int32(
qweight_unpacked, scalar_types.uint4b8, packed_dim=0
)
)
layer.scales.data.copy_(
torch.tensor(
[
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
[4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
],
dtype=torch.float32,
)
)
method.process_weights_after_loading(layer)

x = torch.ones(1, 24, dtype=torch.float16)
output = method.apply(layer, x)

# qweight 9 minus uint4 symmetric bias 8 gives dequant value 1.
expected = 8 * layer.scales.data[1] + 16 * layer.scales.data[2]
expected = expected.unsqueeze(0)
torch.testing.assert_close(output, expected.to(torch.float16))
157 changes: 157 additions & 0 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import regex as re
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
Expand All @@ -18,9 +20,17 @@
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
unpack_quantized_values_into_int32,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
Expand Down Expand Up @@ -341,6 +351,22 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
group_size,
sym,
)
if (
isinstance(layer, LinearBase)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The INCGPTQRowParallelTailLinearMethod fallback is specifically designed for row-parallel layers where tensor parallelism can split a quantization group across ranks. Its create_weights implementation (lines 728-731) calculates shard_offset by multiplying the rank by the partition width, which is only correct for RowParallelLinear. For ColumnParallelLinear (and its variants like QKVParallelLinear), the input dimension is not sharded, so shard_offset should be 0 for all ranks. Using this fallback for non-row-parallel layers will result in incorrect group indices (g_idx) and corrupted dequantized weights on ranks > 0. You should restrict this fallback to layers that are actually sharded along the input dimension.

Suggested change
isinstance(layer, LinearBase)
isinstance(layer, LinearBase) and getattr(layer, "input_is_parallel", False)

and group_size > 0
and getattr(layer, "input_size_per_partition", layer.input_size)
% group_size
!= 0
):
# Gemma4 AutoRound row-parallel linears can produce TP shards that
# straddle a GPTQ group boundary. Fall back to a correctness-first
# path in that case instead of using Marlin/GPTQ kernels that
# assume group-aligned input shards.
return INCGPTQRowParallelTailLinearMethod(
weight_bits=weight_bits,
group_size=group_size,
sym=sym,
)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
Expand All @@ -353,6 +379,10 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
elif isinstance(layer, LinearBase):
use_marlin = use_marlin and check_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
Expand Down Expand Up @@ -625,3 +655,130 @@ def apply(
None, # g_idx not needed: desc_act is always False for INC models
)
return out.reshape(out_shape)


class INCGPTQRowParallelTailLinearMethod(LinearMethodBase):
"""Fallback for row-parallel GPTQ-family linears with group-tail shards."""

def __init__(self, weight_bits: int, group_size: int, sym: bool):
self.weight_bits = weight_bits
self.group_size = group_size
self.sym = sym
self.pack_factor = 32 // weight_bits
self.weight_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}[weight_bits]

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
full_num_groups = (input_size + self.group_size - 1) // self.group_size

qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.pack_factor,
weight_loader=weight_loader,
)
scales = ChannelQuantScaleParameter(
data=torch.empty(
full_num_groups,
output_size_per_partition,
dtype=params_dtype,
),
output_dim=1,
weight_loader=weight_loader,
)
qzeros = PackedColumnParameter(
data=torch.empty(
full_num_groups,
output_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
output_dim=1,
packed_dim=1,
packed_factor=self.pack_factor,
weight_loader=weight_loader,
)

layer.register_parameter("qweight", qweight)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)

shard_width = getattr(
layer, "input_size_per_partition", input_size_per_partition
)
shard_offset = get_tensor_model_parallel_rank() * shard_width
g_idx = (
torch.arange(input_size_per_partition, dtype=torch.int32) + shard_offset
) // self.group_size
layer.register_parameter("g_idx", Parameter(g_idx, requires_grad=False))
layer._inc_tail_dequant_weight = None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.sym:
# The tail-shard fallback dequantizes weights on demand and handles
# the symmetric zero point via weight_type.bias in
# _get_dequantized_weight(), so the large packed qzeros tensor is
# replaced with a tiny placeholder after loading.
layer.qzeros = Parameter(
torch.tensor([8], dtype=torch.int8, device=layer.qweight.device),
requires_grad=False,
)
else:
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)

def _get_dequantized_weight(self, layer: torch.nn.Module) -> torch.Tensor:
cached = layer._inc_tail_dequant_weight
if cached is not None:
return cached

if not self.sym:
raise NotImplementedError(
"INCGPTQRowParallelTailLinearMethod currently supports only "
"symmetric checkpoints."
)

qweight = unpack_quantized_values_into_int32(
layer.qweight.data, self.weight_type, packed_dim=0
).to(torch.float32)
qweight = qweight - float(self.weight_type.bias)

g_idx = layer.g_idx.data.to(torch.long)
scales = layer.scales.data.to(torch.float32)
dequant = qweight * scales.index_select(0, g_idx)
weight = dequant.t().contiguous()
# Cache the dequantized tail-shard weight after the first fallback use.
layer._inc_tail_dequant_weight = weight
return weight

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[1],)
x_2d = x.reshape(-1, x.shape[-1]).to(torch.float32)
bias_2d = bias.to(torch.float32) if bias is not None else None
output = F.linear(x_2d, self._get_dequantized_weight(layer), bias_2d)
return output.to(x.dtype).reshape(out_shape)
Loading
Loading