Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Jul 16, 2025

Purpose

FIX #20974 #20986

The issue was that the GPTQ caller was setting zero_points=False, however since it always allocates its qzeros param that still gets loaded into the machete kernel. We just need to filter out that case to fix.

Test Plan

Manual testing of the issue on H100

Test Result

Before

lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-30B-A3B-GPTQ-Int4 --tasks gsm8k --num_fewshot 5 --batch_size auto
...
from user code:
   File "/home/mgoin/code/vllm/vllm/model_executor/models/qwen3_moe.py", line 369, in forward
    hidden_states, residual = layer(positions, hidden_states, residual)
  File "/home/mgoin/code/vllm/vllm/model_executor/models/qwen3_moe.py", line 305, in forward
    hidden_states = self.self_attn(
  File "/home/mgoin/code/vllm/vllm/model_executor/models/qwen3_moe.py", line 223, in forward
    qkv, _ = self.qkv_proj(hidden_states)
  File "/home/mgoin/code/vllm/vllm/model_executor/layers/linear.py", line 510, in forward
    output_parallel = self.quant_method.apply(self, input_, bias)
  File "/home/mgoin/code/vllm/vllm/model_executor/layers/quantization/gptq_marlin.py", line 372, in apply
    return self.kernel.apply_weights(layer, x, bias)
  File "/home/mgoin/code/vllm/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py", line 129, in apply_weights
    output = ops.machete_mm(a=x_2d,
  File "/home/mgoin/code/vllm/vllm/_custom_ops.py", line 1099, in machete_mm
    return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

After

lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-30B-A3B-GPTQ-Int4 --tasks gsm8k --num_fewshot 5 --batch_size auto
...
vllm (pretrained=Qwen/Qwen3-30B-A3B-GPTQ-Int4), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7915|±  |0.0112|
|     |       |strict-match    |     5|exact_match|↑  |0.8901|±  |0.0086|

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mgoin mgoin added the bug Something isn't working label Jul 16, 2025
@mgoin mgoin added this to the v0.10.0 milestone Jul 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a crash for GPTQ models using the Machete kernel when an unexpected zero-point tensor is present. My review includes a suggestion to enhance robustness by adding a check for cases where zero points are expected but missing, which could otherwise lead to silent correctness issues.

@mgoin mgoin changed the title [Bugfix] Fix Machete zero point issue for GPTQ models [Bugfix] Fix Machete zero point issue for GPTQ models on SM90 Jul 16, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 16, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix

@vllm-bot vllm-bot merged commit 28a6d54 into vllm-project:main Jul 17, 2025
76 of 80 checks passed
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Bug/Regression in handling GPTQ-Int4 models

3 participants