Skip to content

Added qwen3 vision language moe support for speculative decoding#32048

Merged
benchislett merged 35 commits intovllm-project:mainfrom
neuralmagic:qwen3-vl-moe-spec-update
Jan 21, 2026
Merged

Added qwen3 vision language moe support for speculative decoding#32048
benchislett merged 35 commits intovllm-project:mainfrom
neuralmagic:qwen3-vl-moe-spec-update

Conversation

@shanjiaz
Copy link
Contributor

@shanjiaz shanjiaz commented Jan 9, 2026

Purpose

To support qwen3 vision language moe speculator models.

Test Plan

Test command:

python examples/offline_inference/spec_decode.py   --method "eagle3"   --model-dir "Qwen/Qwen3-VL-30B-A3B-Instruct"   --eagle-dir "shanjiaz/qwen3-vl-5k-epoch9"   --dataset_name "hf" --dataset_path "philschmid/mt-bench" --num-spec-tokens 3

Test Result

Model is producing reasonable result. (It's a test model only trained on 5k examples so it's not very good.)

total_num_output_tokens: 213421
num_drafts: 154015
num_draft_tokens: 462045
num_accepted_tokens: 57143
mean acceptance length: 1.37
--------------------------------------------------
acceptance at token 0: 0.29
acceptance at token 1: 0.07
acceptance at token 2: 0.01

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added qwen Related to Qwen models speculative-decoding v1 labels Jan 9, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3 Vision-Language MoE models in speculative decoding. The changes correctly handle M-RoPE position adjustments for text-only draft models and add the necessary model-specific configurations. However, I found a critical bug in qwen3_vl_moe.py that would cause a TypeError at runtime due to an unhandled None value for the residual connection in the first layer. I've provided a code suggestion to fix this issue.

shanjiaz and others added 6 commits January 9, 2026 21:32
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
self.uses_mrope = self.vllm_config.model_config.uses_mrope
# Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine to use as should support both multi-modal and text only draft models, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

# Convert M-RoPE positions to 1D if draft model is text-only
if not self.uses_mrope and target_positions.dim() == 2:
# For text inputs, all M-RoPE dimensions are identical
target_positions = target_positions[0]
Copy link
Contributor

@dsikka dsikka Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the first check be sufficient? Wont it always be 1D if no mrope? We can also just assert > 1 if the first condition is true?

Copy link
Contributor Author

@shanjiaz shanjiaz Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might need the second condition, since we only want to convert the 2D target_positions to 1D when draft does not use mrope but target does? If we only check self.uses_mrope, for non vl models that have 1D target_positions anyways this might error out? I can change this to something more specific like if not self.uses_mrope and self.vllm_config.model_config.uses_mrope?

@shanjiaz shanjiaz marked this pull request as ready for review January 14, 2026 17:36
Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few quick questions

"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
id="qwen3-eagle3-speculator",
),
pytest.param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test run on nightly and/or part of the per-PR CI suite? I wonder if there might be issues adding more and more models to this test, or if it's not run frequently enough to drain resources.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this test is run as part of the per-PR CI suite. I can revert this change for now and open up a follow-up PR to move these models to an optional test suite similar to the setup of large lm-eval tests. Let me know if that's more appropriate!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will ask for opinions among committers and report back

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

30B is too large of a model size unfortunately. Should be good to merge as-is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can use a smaller model for testing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work

shanjiaz and others added 2 commits January 19, 2026 11:50
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2026-01-20T20:20:47Z] #58 651.6 error: Request failed after 3 retries

Unrelated CI failure, try merge from main

@benchislett benchislett merged commit 7ab80a8 into vllm-project:main Jan 21, 2026
53 checks passed
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…m-project#32048)

Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
monajafi-amd pushed a commit to monajafi-amd/vllm that referenced this pull request Jan 23, 2026
…m-project#32048)

Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 27, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
…m-project#32048)

Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…m-project#32048)

Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants