Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
% group_size
!= 0
):
# Gemma4 AutoRound row-parallel linears can produce TP shards that
# straddle a GPTQ group boundary. Fall back to a correctness-first
# path in that case instead of using Marlin/GPTQ kernels that
# assume group-aligned input shards.
return INCGPTQRowParallelTailLinearMethod(
weight_bits=weight_bits,
group_size=group_size,
Expand Down Expand Up @@ -733,6 +737,10 @@ def create_weights(

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.sym:
# The tail-shard fallback dequantizes weights on demand and handles
# the symmetric zero point via weight_type.bias in
# _get_dequantized_weight(), so the large packed qzeros tensor is
# replaced with a tiny placeholder after loading.
layer.qzeros = Parameter(
torch.tensor([8], dtype=torch.int8, device=layer.qweight.device),
requires_grad=False,
Expand Down Expand Up @@ -762,6 +770,7 @@ def _get_dequantized_weight(self, layer: torch.nn.Module) -> torch.Tensor:
scales = layer.scales.data.to(torch.float32)
dequant = qweight * scales.index_select(0, g_idx)
weight = dequant.t().contiguous()
# Cache the dequantized tail-shard weight after the first fallback use.
layer._inc_tail_dequant_weight = weight
return weight

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def add_mapping(
) -> None:
if handled_name is None:
handled_name = hf_name
# handled_name must match the HF state_dict key emitted by the
# installed Gemma4 transformers config/model classes. If upstream
# renames these tensors, update handled_name alongside the manual
# GGUF mapping.
if handled_name in normalized_state_names:
gguf_to_hf_name_map[gguf_name] = hf_name
handled_params.add(handled_name)
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/models/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def _dequantize_autoround_gptq_router_weight(
qzeros, weight_type, packed_dim=1
).to(torch.float32)
if sym:
# AutoRound's GPTQ-style symmetric checkpoints store router qzeros
# with a -1 offset relative to the effective zero point.
unpacked_qzeros = unpacked_qzeros + 1
row_groups = (
torch.arange(unpacked_qweight.shape[0], device=qweight.device) // group_size
Expand Down Expand Up @@ -1580,6 +1582,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
weight_loader(param, loaded_weight)
loaded_params.add(name)

if router_quant_params:
for router_name, quant_params in router_quant_params.items():
logger.warning(
"Skipping incomplete quantized router params for %s: %s",
router_name,
sorted(quant_params),
)

return loaded_params


Expand Down
3 changes: 2 additions & 1 deletion vllm/transformers_utils/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,12 @@ def cached_processor_from_config(
)
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is not None:
_maybe_patch_gemma4_gguf_tokenizer(
tokenizer = _maybe_patch_gemma4_gguf_tokenizer(
tokenizer,
model_config.model,
getattr(model_config.hf_config, "model_type", None),
)
processor.tokenizer = tokenizer
return processor


Expand Down
Loading