Skip to content

[v0.13.0][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD#5786

Merged
wangxiyuan merged 1 commit intovllm-project:releases/v0.13.0from
dragondream-chen:v0.13.0/bugfix_d_mtp
Jan 13, 2026
Merged

[v0.13.0][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD#5786
wangxiyuan merged 1 commit intovllm-project:releases/v0.13.0from
dragondream-chen:v0.13.0/bugfix_d_mtp

Conversation

@dragondream-chen
Copy link
Copy Markdown
Collaborator

@dragondream-chen dragondream-chen commented Jan 12, 2026

What this PR does / why we need it?

BUG Problem
When using prefill-decode disaggregation + MTP + full graph +asynchronous scheduling, the KV cache pulled by decode nodes from prefill decodes does not include spec tokens. As a result, the total_num_scheduled_tokens obtained by decode nodes from the scheduler lacks spec tokens. When determining whether to enqueue the full graph on decode nodes, the condition for uniform_decode scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len is not met, leading to the current instance not being enqueued into the full graph.

The above situation leads to both full graph and eagle mode instances coexisting in the decode instances. Due to the synchronization wait of MoeDispatch, the decode instances in full graph are significantly slowed down by the instance in eagle mode.

Solution
The scenario is PD separation + MTP + Full Graph + asynchronous scheduling.
On the decode nodes, the spec tokens of the request with KV cache from P need be padded. Then, the padded spec tokens will be rejected by sampling. This operation ensures that the uniform_decode condition is satisfied when determining whether decode nodes are included in the full graph, thereby guaranteeing that all decode instances are present in the full graph and avoiding synchronous waiting for MoeDispatch.

same with #5472

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the MTP speculative decoding implementation with prefill-decode disaggregation, where decode nodes would fall back to eager execution, causing performance degradation. The fix involves padding speculative tokens on decode nodes to ensure they are correctly enqueued into the full graph.

The changes in the scheduler (RecomputeScheduler) appear correct and effectively implement the core of the solution by preparing the requests with placeholder speculative tokens.

However, I've identified a critical issue in the GPUModelRunner patch. While it correctly updates prev_num_draft_len for new requests, it fails to update input_batch.num_tokens. This inconsistency will cause the is_uniform_decode check to fail, preventing the model from entering full graph mode, which defeats the purpose of this bugfix. I have provided a detailed comment and a code suggestion to resolve this.

Comment on lines +277 to +282
if self.is_kv_consumer and self.speculative_config and \
self.speculative_config.method == "mtp" and self.use_async_scheduling:
req_state = self.requests[request.req_id]
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
request.req_id, [])
req_state.prev_num_draft_len = len(spec_token_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This patch correctly sets prev_num_draft_len for new and resumed requests, which is necessary for the rejection sampling logic in the next step. However, it misses a critical part of the logic for handling speculative tokens.

For the uniform_decode condition to be met and to enable full graph mode, self.input_batch.num_tokens must be consistent across all requests in the batch. The current implementation updates num_tokens for running requests to include speculative tokens, but it does not do so for new or resumed requests (those in reqs_to_add). This will cause self.input_batch.is_uniform_decode to be False, preventing the use of the full graph, which is the main goal of this pull request.

To fix this, the logic for handling speculative tokens (updating token_ids_cpu, num_tokens, and spec_token_ids in the input_batch) that exists for running requests should also be applied here for new/resumed requests.

        if self.is_kv_consumer and self.speculative_config and \
            self.speculative_config.method == "mtp" and self.use_async_scheduling:
            req_state = self.requests[request.req_id]
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
                request.req_id, [])
            num_spec_tokens = len(spec_token_ids)
            req_state.prev_num_draft_len = num_spec_tokens

            if num_spec_tokens > 0:
                req_index = self.input_batch.req_id_to_index.get(request.req_id)
                # This should not be None as we just added it.
                if req_index is not None:
                    # Update input_batch with spec tokens for new/resumed requests.
                    # This is crucial for is_uniform_decode to be true.
                    start_index = self.input_batch.num_tokens_no_spec[req_index]
                    end_token_index = start_index + num_spec_tokens
                    self.input_batch.token_ids_cpu[
                        req_index, start_index:end_token_index] = spec_token_ids
                    self.input_batch.num_tokens[req_index] += num_spec_tokens

                    # This part is also in the logic for running requests.
                    self.input_batch.spec_token_ids[req_index].clear()
                    self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)

@wangxiyuan
Copy link
Copy Markdown
Collaborator

please rebase to enable CI test. Thanks.

@wangxiyuan wangxiyuan changed the title [Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD for v0.13.0 [v0.13.0][Bugfix] Support ALL D-Nodes in fullgraph when running MTP in PD Jan 12, 2026
…-decode disaggregation

Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Jan 12, 2026
@wangxiyuan wangxiyuan merged commit 895d32c into vllm-project:releases/v0.13.0 Jan 13, 2026
20 checks passed
wangxiyuan pushed a commit that referenced this pull request Jan 30, 2026
### What this PR does / why we need it?
This PR extends #5786 to eagle3 spec decode when used with
pd-disaggregation + async-scheduling.

Signed-off-by: Angazenn <supperccell@163.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…n PD (vllm-project#5786)

### What this PR does / why we need it?
**BUG Problem**
When using prefill-decode disaggregation + MTP + full graph
+asynchronous scheduling, the KV cache pulled by decode nodes from
prefill decodes does not include spec tokens. As a result, the
total_num_scheduled_tokens obtained by decode nodes from the scheduler
lacks spec tokens. When determining whether to enqueue the full graph on
decode nodes, the condition for uniform_decode `
scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs
* max_query_len` is not met, leading to the current instance not being
enqueued into the full graph.

The above situation leads to both full graph and eagle mode instances
coexisting in the decode instances. Due to the synchronization wait of
MoeDispatch, the decode instances in full graph are significantly slowed
down by the instance in eagle mode.

**Solution**
The scenario is PD separation + MTP + Full Graph + asynchronous
scheduling.
On the decode nodes, the spec tokens of the request with KV cache from P
need be padded. Then, the padded spec tokens will be rejected by
sampling. This operation ensures that the uniform_decode condition is
satisfied when determining whether decode nodes are included in the full
graph, thereby guaranteeing that all decode instances are present in the
full graph and avoiding synchronous waiting for MoeDispatch.

same with vllm-project#5472

Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…6443)

### What this PR does / why we need it?
This PR extends vllm-project#5786 to eagle3 spec decode when used with
pd-disaggregation + async-scheduling.

Signed-off-by: Angazenn <supperccell@163.com>
SkychenLee pushed a commit to SkychenLee/vllm-ascend that referenced this pull request Jan 31, 2026
…6443)

### What this PR does / why we need it?
This PR extends vllm-project#5786 to eagle3 spec decode when used with
pd-disaggregation + async-scheduling.

Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: l00832868 <litianchen2@huawei.com>
wangxiyuan pushed a commit that referenced this pull request Feb 11, 2026
…fect (#6615)

### What this PR does / why we need it?
This pr corrects the patch from
[#5786](#5786),
otherwise it might not take effect when tp_size > 1.

Related changes in this patch has been merged in vLLM
[#31944](vllm-project/vllm#31944).

### Does this PR introduce _any_ user-facing change?
No.

Signed-off-by: Angazenn <supperccell@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants