Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/quantization/test_cpu_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
MODELS = [
"TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ",
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx
"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4", # without g_idx
]
DTYPE = ["bfloat16"]

Expand Down
1 change: 0 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,6 @@ def _verify_quantization(self) -> None:
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
"cpu_gptq",
"cpu_awq",
]
quantization_methods = [
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"inc",
"mxfp4",
"petit_nvfp4",
"cpu_gptq",
"cpu_awq",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
Expand Down Expand Up @@ -109,7 +108,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from .cpu_wna16 import CPUAWQConfig, CPUGPTQConfig
from .cpu_wna16 import CPUAWQConfig
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
Expand Down Expand Up @@ -162,7 +161,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_gptq": CPUGPTQConfig,
"cpu_awq": CPUAWQConfig,
}
# Update the `method_to_config` with customized quantization methods.
Expand Down
326 changes: 0 additions & 326 deletions vllm/model_executor/layers/quantization/cpu_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
pack_cols,
Expand All @@ -34,335 +28,15 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of

logger = init_logger(__name__)


class CPUGPTQConfig(QuantizationConfig):
"""Config class for CPU GPTQ quant"""

def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, int | bool]],
full_config: dict[str, Any],
modules_in_block_to_quantize: list[str] | None = None,
) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False

# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
assert weight_bits == 4
self.dynamic = dynamic
self.weight_bits = weight_bits
self.is_sym = is_sym
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []

def __repr__(self) -> str:
return (
f"CPUWNA16Config("
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)

@classmethod
def get_name(cls) -> QuantizationMethods:
return "cpu_gptq"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
return -1

@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]

@classmethod
def from_config(cls, config: dict[str, Any]) -> "CPUGPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
group_size = cls.get_from_keys(config, ["group_size"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None
)
return cls(
weight_bits,
group_size,
desc_act,
is_sym,
lm_head_quantized,
dynamic,
config,
modules_in_block_to_quantize,
)

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "gptq"):
return cls.get_name()
return None

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
return get_linear_quant_method(self, layer, prefix, CPUGPTQLinearMethod) # type: ignore

def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize
)

def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item
for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return

unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name, revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get("dtype", None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)


class CPUGPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ on CPU.

Args:
quant_config: The CPUWNA16 quantization config.
"""

def __init__(self, quant_config: CPUGPTQConfig) -> None:
self.quant_config = quant_config
assert self.quant_config.is_sym, "GPTQ asym quant is not supported on CPU"

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
assert output_size_per_partition * self.quant_config.weight_bits % 32 == 0
assert output_size_per_partition % 32 == 0
assert input_size_per_partition % 32 == 0

is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")

# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size

# Determine sharding
if marlin_repeat_scales_on_all_ranks(
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each rank in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size

# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)

# Activation order
g_idx = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(
g_idx,
{"ignore_warning": True},
)

qzeros_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader": weight_loader,
}
weight_scale_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}

if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)

else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)

layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
packed_weight = layer.qweight.data
bits = self.quant_config.weight_bits
pack_factor = int(self.quant_config.pack_factor)
p_w_k, p_w_n = packed_weight.size()
input_size = p_w_k * pack_factor
output_size = p_w_n
isa_hint = _get_isa_hint(layer.scales.dtype)
layer.isa_hint = isa_hint

layer.qzeros = None
if not self.quant_config.desc_act:
layer.g_idx = None

# convert input dim packed to output dim packed
weight = unpack_cols(packed_weight, bits, p_w_k, p_w_n * pack_factor).view(
p_w_k, p_w_n, pack_factor
)
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
weight = pack_cols(weight, bits, input_size, output_size)
# make 16 output channel as a block and transpose to the make
# the block contigous
weight = (
weight.view(input_size, -1, 16 // pack_factor)
.permute(1, 0, 2)
.reshape(-1, input_size * 16 // pack_factor)
.contiguous()
)
layer.qweight.data = weight

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
x = cpu_gemm_wna16(
input=x,
q_weight=layer.qweight,
scales=layer.scales,
zeros=layer.qzeros,
g_idx=layer.g_idx,
bias=bias,
pack_factor=8,
isa_hint=layer.isa_hint,
)
return x


class CPUAWQConfig(QuantizationConfig):
"""Config class for CPU AWQ"""

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")

if not current_platform.is_cuda():
if not (current_platform.is_cuda() or current_platform.is_cpu()):
return False

if quant_method != "gptq":
Expand Down
Loading