Conversation
|
Documentation preview: https://vllm--33561.org.readthedocs.build/en/33561/ |
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces support for MTP3, a multi-layer speculative decoding method, by adding the MultiLayerEagleProposer. The implementation includes custom Triton kernels for efficient input shifting and updates to the KV cache grouping logic. The changes are well-structured and include a comprehensive set of unit tests. I've identified one critical issue regarding the handling of 2D position tensors for M-RoPE, which would lead to a runtime error. A code suggestion has been provided to address this. Overall, this is a solid contribution that adds a powerful new feature.
| assert ( | ||
| cached_prev_positions[:, i].shape | ||
| == draft_input_states.positions.shape | ||
| ) | ||
| cached_prev_positions[:, i].copy_(draft_input_states.positions) |
There was a problem hiding this comment.
There's an issue with indexing cached_prev_positions when handling 2D position tensors (e.g., for M-RoPE). cached_prev_positions is a list of tensors, so cached_prev_positions[:, i] is invalid syntax and will cause a TypeError. The logic should iterate through the list to correctly copy the position data for each dimension.
assert prev_positions.dim() == 2
for j in range(prev_positions.shape[0]):
cached_prev_positions[j][i].copy_(draft_input_states.positions[j])|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
I have a few high-level concerns with this implementation:
Correctness details
There is a fundamental train/inference mismatch for MTP modules, so it is hard to define what the ideal "correct" behaviour should be. There is some room for interpretation. That said, this PR seems to overlook a detail which is central to a lot of discussion around multi-head MTP implementations:
When recomputing with only as many cached hidden states as rejected tokens, the input shapes are consistent but the KV cache of the later MTP modules is not consistent. Your example figure correctly highlights that when rejecting tokens and naively rolling, the KV cache of the later MTP modules are corrupted with bad tokens and hidden states. However, when rolling based on the number of accepted tokens, the later modules' KV caches are populated based on the hidden states of the previous MTP iterations, and not the hidden states from the base model. Consider the case where all 3 tokens are accepted, and therefore no adjusting is necessary: the input hidden states to MTP0 are "7,8,9,10" with the corresponding hidden states. Correspondingly the inputs to MTP2 are "9,10,11,12". In this case the KV cache for MTP2 for tokens "7,8" are the hidden states from MTP1 for those tokens, not the new target model hidden states from the verification of those tokens. This creates an inconsistency for MTP1 and MTP2 where their context states are derived from a mix of target and MTP output hidden states.
It is debatable whether this is even worth considering as long as the token ids are correct, since MTP are typically trained based on the inputs of the previous modules as context anyways, but nevertheless it is an aspect of correctness worth considering. The 'corrected' solution in this case would be to prepend the cached tokens/hiddens always, ensuring that MTP2 always updates its KV cache with the target model's hidden states. This does mean that the input shapes would increase and no longer be the same as the shapes from the target model. In this case, reusing some logic from Parallel Drafting (which similarly has to insert tokens into the batch for specdec) would be useful.
Duplicated Code
The MultiLayerEagleProposer adds a lot of new code to maintain for EAGLE speculative decoding. I would hope that this can be implemented similar to Draft Models and Parallel Drafting where most of the EAGLE inference code is reused, for maintainability.
Caching state on the drafter
It is a major challenge to maintain consistent state in the GPU Model Runner. The input_batch class has a lot of utilities to ensure that under a dynamically changing batch, the state tensors remain consistent and input preparation is efficient. I am opposed to the style of decomposing the state into individual tensors and managing them in the Proposer class. I feel this will lead to a lot of subtle bugs when the batch is reordered in unexpected ways, and/or a lot of overhead needing to rebuild the batch from scratch every iteration. There are a lot of clone/copy/insert operations that look like they would cause a slowdown for large batch sizes. Have you measured any overheads of this approach compared to EAGLE-style MTP, across a range of batch sizes? Does the evaluation and drafting accuracy remain stable across batch sizes?
Underdocumented kernels
This PR introduces several custom triton kernels for preparing the metadata. Being less readable than straight pytorch code, they should be documented with comments explaining their purpose and input/outputs
I am happy to discuss in more detail offline in the vLLM slack if desired. Feel free to reach out directly or in #feat-spec-decode.
| vllm_config.speculative_config is not None | ||
| and vllm_config.speculative_config.enable_multi_layers_mtp | ||
| ): | ||
| for i in range(0, len(layers), group_size): |
There was a problem hiding this comment.
Is this the right way to handle this? Can any vLLM KV-cache-manager experts weigh in here?
Thanks for your detailed feedback. I’ll follow up by addressing these issues. For Correctness detailsFor
In Consequently, the positions of mtp0, mtp1, and mtp2 are recomputed, enabling mtp1 and mtp2 to reuse the correct For Caching state on the drafterI agree with your point here. I will try to move the cache-related logic into req_states, so that the drafter no longer needs to maintain this state explicitly. OthersAfter addressing the caching state on the drafter, I will come back to further tidy up the code and add more documentation, in order to resolve the issues around duplicated code and underdocumented kernels. |
I do not think this is possible as I have explained. Suppose all tokens are accepted, then we still want to recompute the states for them so that MTP2 sees the target model's hidden states for the (accepted) draft tokens. In this case, our MTP batch needs to include both the accepted tokens and the new drafted tokens, so that MTP2 can update the position for the tokens that were drafted by MTP0 and MTP1 in the previous pass. This is not illustrated in your figure because it only shows the all-rejected case, and not the all-accepted case. Please re-review my original response for more context. The issue is quite subtle. |
I think this concern does not apply in the case where all speculative tokens are accepted. More concretely, although tokens “7,8” in the KV cache of MTP2 are populated using hidden states produced by earlier MTP stages, these hidden states ultimately originate from target-model–verified tokens “4,5”. When all token ids along this trajectory are accepted by the target model, the entire generation trajectory is implicitly accepted as well. Consequently, the downstream hidden states derived along this trajectory (through MTP0 and MTP1) remain consistent and deterministic, and no incorrect execution path is introduced. As a result, the KV cache entries constructed along this trajectory are also correct and can be safely reused. This is indeed a very subtle issue. |
|
@mingMelody I see what you mean. I think this is a consequence of a particular design decision around multi MTP to match the training style instead of the EAGLE inference style. In the existing MTP implementation in vLLM, we do EAGLE-style inference where the target model's hidden states are shared context for all draft positions, and we only use the hidden states from the later modules for the draft tokens. In such a case, with multi-mtp, we would need to refresh the hidden states for consistency as I have described. However for training-style multi MTP, all hidden states seen by MTP1 are outputs from MTP0. Is that correct? It seems that way from the code, but your figures do not distinguish between which model generated which hidden states. If this is the intended implementation, I think it should be fine. Feel free to ping me when you have addressed the other issues. |
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
0e1e695 to
a7ed9a4
Compare
Signed-off-by: makubes <2416013822@qq.com>
|
Hi @mingMelody, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
The test results of MTP3 are shown below. (Base on H200) Decode Tps / request = Output token throughput (tok/s) / Maximum request concurrency
@benchislett Previously mentioned issues have now been addressed. Feedback and suggestions would be very welcome. |
|
when will it be merged? :) |
|
Hi guys, is there any updated on this? Will it be supported for Step 3.5 Flash? |
Yes, it is already supported for Step 3.5 Flash. |
|
Okay, when will it be merged ? |
|
Apologies for the delays. I will review this week. It will take a while to merge, as there is a lot of code so it will require a lot of effort to review and eventually maintain |
|
Glad to hear it! |
|
@benchislett How's the review going? |
|
Sorry guys, I have been focusing on DFlash lately and have been spread pretty thin. Hopefully I can finish my review soon. |
benchislett
left a comment
There was a problem hiding this comment.
This is my first round of feedback. Overall, the code currently feels very opaque and a bit bloated. It has been challenging to parse some of the segments, and while I don't doubt the correctness the complexity feels excessive given the scope.
| self.num_speculative_tokens = self.speculative_config.num_speculative_tokens | ||
|
|
||
| self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp | ||
| self.layer_num = 1 |
There was a problem hiding this comment.
This needs a better name, or a comment explaining what "layer_num" means. Intuitively, I would assume it means "number of layers in each MTP/EAGLE module", but that seems incorrect.
| common_attn_metadata.seq_lens.sub_(shift) | ||
|
|
||
| # NOTE: ignore cpu data to avoid device sync | ||
| # common_attn_metadata.seq_lens_cpu.copy_(common_attn_metadata.seq_lens, |
| sampled_token_ids: list[list[int]], | ||
| num_draft_tokens: list[int], | ||
| ) -> tuple[CommonAttentionMetadata, torch.Tensor]: | ||
| """ |
There was a problem hiding this comment.
Doc comment not needed if this function raises an error anyways.
| tokens (and newly sampled tokens). It also returns the token indices | ||
| of the tokens that should be fed to the speculator. | ||
| """ | ||
| raise Exception( |
There was a problem hiding this comment.
Generic exception is not recommended.
I suggest you ensure that when using multi layer eagle the padded drafter batch mode be automatically enabled as required (example). If so, you can change this to an assert
| self.model(**model_kwargs) | ||
|
|
||
|
|
||
| def _multi_layer_eagle_shift_and_cache( |
There was a problem hiding this comment.
Please add a detailed comment here. It is unclear what the responsibility of this function is
| ) | ||
| num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS) | ||
|
|
||
| _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( |
There was a problem hiding this comment.
This logic seems very involved. Can you explain why such a complicated implementation is necessary?
There was a problem hiding this comment.
Given all the complexity here, a direct torch.compile'd implementation may be preferable
| if self.supports_mm_inputs: | ||
| mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) | ||
| draft_token_ids_list = [] | ||
| for spec_step_idx in range(self.layer_num): |
There was a problem hiding this comment.
Please avoid implementing this feature in such a way that affects the readability of the codebase more broadly. The core flow of the EAGLE pathway should be preserved as much as possible.
| target_positions, | ||
| target_hidden_states, | ||
| common_attn_metadata, | ||
| ) = self.adjust_input( |
There was a problem hiding this comment.
What is the purpose of this function? Why can't the multi layer component just specialize set_inputs_first_pass?
| pooling_states: PoolingStates | None = None | ||
|
|
||
| # for multi layer eagle proposer | ||
| cached_len: torch.Tensor | None = None |
There was a problem hiding this comment.
I still don't feel great about having to cache all this state on the drafter. This feels like the wrong way to handle it
There was a problem hiding this comment.
In the meantime, can you add some documentation here about what these tensors represent, and how they are intended to be used for multi-layer EAGLE statefulness across iterations?

Purpose
To support MTP3, this PR introduces the MultiLayerEagleProposer class. The main challenges of MTP3 are illustrated in the figure below.
During multi-layer draft model execution, a naive just-roll strategy leads to incorrect KV cache states. In such scenarios, the affected tokens must be recomputed. However, direct recomputation depends on the hidden states produced in the previous iteration, which would require storing hidden states from (layer_num − 1) layers and handling a large number of corner cases.
To address this, this PR introduces an
adjust_inputfunction that performs input shifting before entering the MTP-layer inference. This proactively masks potential corner cases that could arise in future steps. As a result, the inference phase only needs to perform a single roll operation to proceed correctly.This approach applies one-time handling at boundary conditions and only caches the target model’s hidden states, leading to a simpler overall design with minimal additional overhead.
Test Plan
Need to run with
enable_multi_layers_mtpinspeculative_configto enable MTP3.A example of
step3p5-flashwith mtp3 as below:Test Result
acc tests
unit tests
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.