Skip to content

[Bug] Fix AttributeError: 'ColumnParallelLinear' object has no attribute weight_scale_inv#30823

Merged
vllm-bot merged 1 commit intomainfrom
wentao-fix-attribute-error
Dec 17, 2025
Merged

[Bug] Fix AttributeError: 'ColumnParallelLinear' object has no attribute weight_scale_inv#30823
vllm-bot merged 1 commit intomainfrom
wentao-fix-attribute-error

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Dec 16, 2025

Purpose

Fix AttributeError: 'ColumnParallelLinear' object has no attribute weight_scale_inv

export MODEL="RedHatAI/Kimi-K2-Thinking-FP8-Block"

vllm serve $MODEL -tp 8 --port 9256 --enable-expert-parallel --enforce_eager --trust_remote_code --gpu_memory_utilization 0.94 --max_model_len 4096

(Worker_TP0_EP0 pid=1089072) INFO 12-16 23:53:39 [default_loader.py:308] Loading weights took 36.23 seconds
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751] WorkerProc failed to start.
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751] Traceback (most recent call last):
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/v1/executor/multiproc_executor.py", line 722, in worker_main
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     worker = WorkerProc(*args, **kwargs)
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/v1/executor/multiproc_executor.py", line 562, in __init__
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     self.worker.load_model()
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/v1/worker/gpu_worker.py", line 289, in load_model
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/v1/worker/gpu_model_runner.py", line 3588, in load_model
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     self.model = model_loader.load_model(
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]                  ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/model_executor/model_loader/base_loader.py", line 56, in load_model
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     process_weights_after_loading(model, model_config, target_device)
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/model_executor/model_loader/utils.py", line 108, in process_weights_after_loading
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     quant_method.process_weights_after_loading(module)
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 896, in process_weights_after_loading
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     layer.scheme.process_weights_after_loading(layer)
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 176, in process_weights_after_loading
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     maybe_post_process_fp8_weight_block(layer)
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/vllm-source/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 1442, in maybe_post_process_fp8_weight_block
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     ws=layer.weight_scale_inv.data,
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]        ^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1964, in __getattr__
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751]     raise AttributeError(
(Worker_TP0_EP0 pid=1089072) ERROR 12-16 23:53:40 [multiproc_executor.py:751] AttributeError: 'ColumnParallelLinear' object has no attribute 'weight_scale_inv'. Did you mean: 'weight_scale'?

Now it can run succesfully

…eight_scale_inv'. Did you mean: 'weight_scale'

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves an AttributeError that occurs during the loading of certain FP8 models. The issue stemmed from hardcoding the weight_scale_inv attribute, which is not always present. The fix introduces a dynamic check for weight_scale_inv and falls back to weight_scale if the former is not found. This approach is robust and correctly handles different naming conventions for weight scales. The change is well-implemented and directly addresses the bug. I approve this pull request.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 17, 2025
@vllm-bot vllm-bot merged commit f284d7b into main Dec 17, 2025
57 of 59 checks passed
@vllm-bot vllm-bot deleted the wentao-fix-attribute-error branch December 17, 2025 10:00
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Dec 17, 2025
…ute `weight_scale_inv` (vllm-project#30823)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…ute `weight_scale_inv` (vllm-project#30823)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…ute `weight_scale_inv` (vllm-project#30823)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
npanpaliya pushed a commit to odh-on-pz/vllm-cpu that referenced this pull request Feb 16, 2026
…ute `weight_scale_inv` (vllm-project/vllm#30823)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
npanpaliya pushed a commit to odh-on-pz/vllm-cpu that referenced this pull request Feb 16, 2026
- [Misc] Implement `TokenizerLike.convert_tokens_to_ids`
(vllm-project/vllm#31796)
  [INFERENG-4151](https://issues.redhat.com/browse/INFERENG-4151)
- [Bug] Revert torch warning fix (vllm-project/vllm#31585)
  [INFERENG-4152](https://issues.redhat.com/browse/INFERENG-4152)
- [Bug] Fix AttributeError: `ColumnParallelLinear` object has no
attribute `weight_scale_inv` (vllm-project/vllm#30823)
  [INFERENG-4153](https://issues.redhat.com/browse/INFERENG-4153)
- Avoid `opencv-python-headless==4.13.0.90`, it's broken. See
opencv/opencv-python#1183
- [Bugfix] Handle mistral tokenizer in get_hf_processor
(vllm-project/vllm#31817)
  [INFERENG-4151](https://issues.redhat.com/browse/INFERENG-4151)
- [Bugfix] Fix Whisper/encoder-decoder GPU memory leak
vllm-project/vllm#32789
- [Model] Handle `trust_remote_code` for transformers backend
(vllm-project/vllm#32194) (fixes
GHSA-2pc9-4j83-qjmr)
- [Bugfix] CUDA: fix segfault by bumping numba to `numba==0.63.1`
([AIPCC-9384](https://issues.redhat.com/browse/AIPCC-9384))
- [Bugfix] pin `mistral_common==1.8.5` to avoid crash with Voxtral
([INFERENG-4154](https://issues.redhat.com/browse/INFERENG-4154))
- [Bugfix] fix tokenizer loading for mistral models
(vllm-project/vllm#33175)
  [INFERENG-4151](https://issues.redhat.com/browse/INFERENG-4151)
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…ute `weight_scale_inv` (vllm-project#30823)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants