Skip to content

[Feature] support mtp3#33561

Open
mingMelody wants to merge 11 commits intovllm-project:mainfrom
mingMelody:feat/support-mtp3
Open

[Feature] support mtp3#33561
mingMelody wants to merge 11 commits intovllm-project:mainfrom
mingMelody:feat/support-mtp3

Conversation

@mingMelody
Copy link

@mingMelody mingMelody commented Feb 2, 2026

Purpose

To support MTP3, this PR introduces the MultiLayerEagleProposer class. The main challenges of MTP3 are illustrated in the figure below.

image

During multi-layer draft model execution, a naive just-roll strategy leads to incorrect KV cache states. In such scenarios, the affected tokens must be recomputed. However, direct recomputation depends on the hidden states produced in the previous iteration, which would require storing hidden states from (layer_num − 1) layers and handling a large number of corner cases.

To address this, this PR introduces an adjust_input function that performs input shifting before entering the MTP-layer inference. This proactively masks potential corner cases that could arise in future steps. As a result, the inference phase only needs to perform a single roll operation to proceed correctly.

This approach applies one-time handling at boundary conditions and only caches the target model’s hidden states, leading to a simpler overall design with minimal additional overhead.

Test Plan

Need to run with enable_multi_layers_mtp in speculative_config to enable MTP3.

A example of step3p5-flash with mtp3 as below:

vllm serve <MODEL_PATH_OR_HF_ID> \
  --served-model-name step3p5-flash \
  --tensor-parallel-size 8 \
  --enable-expert-parallel \
  --disable-cascade-attn \
  --reasoning-parser step3p5 \
  --enable-auto-tool-choice \
  --tool-call-parser step3p5 \
  --speculative_config '{"method": "step3p5_mtp", "num_speculative_tokens": 3, "enable_multi_layers_mtp": true}' \
  --trust-remote-code 

Test Result

acc tests

# ACC Test
python examples/offline_inference/spec_decode.py --num-prompts 80 --dataset-name hf --dataset-path philschmid/mt-bench --method mtp --model-dir <MODEL_PATH_OR_HF_ID> --num-spec-tokens 3 --tp 8 --enable-multi-layers-mtp

# Test Result
--------------------------------------------------
total_num_output_tokens: 20207
num_drafts: 7176
num_draft_tokens: 21528
num_accepted_tokens: 13042
mean acceptance length: 2.82
--------------------------------------------------
acceptance at token 0: 0.80
acceptance at token 1: 0.59
acceptance at token 2: 0.42

unit tests

# Run unit tests
pytest tests/v1/spec_decode/test_mtp3.py -v 

# Test Result
============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-8.3.5, pluggy-1.6.0 -- /home/i-zhangmingming/open_source/vllm/.venv/bin/python
cachedir: .pytest_cache
rootdir: /home/i-zhangmingming/open_source/vllm
configfile: pyproject.toml
plugins: anyio-4.12.1
collecting ... collected 12 items

tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[shift_0_at_sequence_end] PASSED [  8%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[batch2_short_seq_no_shift] PASSED [ 16%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[batch2_short_seq_shift_on_first] PASSED [ 25%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[short_seq_len_2_shift_0_cache_len_1] PASSED [ 33%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[short_seq_len_2_shift_1_cache_len_2] PASSED [ 41%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[shift_bounded_by_start_pos_zero] PASSED [ 50%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[shift_bounded_by_start_pos] PASSED [ 58%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[shift_2_bounded_by_remaining] PASSED [ 66%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[shift_3_full_cache_window] PASSED [ 75%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[batch2_shift_1_and_1] PASSED [ 83%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[batch4_mixed_shifts] PASSED [ 91%]
tests/v1/spec_decode/test_mtp3.py::test_adjust_input_layer3_cases[batch2_shift_0_and_2] PASSED [100%]
======================= 12 passed, 2 warnings in 11.45s ========================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Feb 2, 2026

Documentation preview: https://vllm--33561.org.readthedocs.build/en/33561/

@mergify mergify bot added documentation Improvements or additions to documentation speculative-decoding v1 labels Feb 2, 2026
@mergify
Copy link

mergify bot commented Feb 2, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mingMelody.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 2, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for MTP3, a multi-layer speculative decoding method, by adding the MultiLayerEagleProposer. The implementation includes custom Triton kernels for efficient input shifting and updates to the KV cache grouping logic. The changes are well-structured and include a comprehensive set of unit tests. I've identified one critical issue regarding the handling of 2D position tensors for M-RoPE, which would lead to a runtime error. A code suggestion has been provided to address this. Overall, this is a solid contribution that adds a powerful new feature.

Comment on lines +175 to +179
assert (
cached_prev_positions[:, i].shape
== draft_input_states.positions.shape
)
cached_prev_positions[:, i].copy_(draft_input_states.positions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's an issue with indexing cached_prev_positions when handling 2D position tensors (e.g., for M-RoPE). cached_prev_positions is a list of tensors, so cached_prev_positions[:, i] is invalid syntax and will cause a TypeError. The logic should iterate through the list to correctly copy the position data for each dimension.

                assert prev_positions.dim() == 2
                for j in range(prev_positions.shape[0]):
                    cached_prev_positions[j][i].copy_(draft_input_states.positions[j])

@github-actions
Copy link

github-actions bot commented Feb 2, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few high-level concerns with this implementation:

Correctness details

There is a fundamental train/inference mismatch for MTP modules, so it is hard to define what the ideal "correct" behaviour should be. There is some room for interpretation. That said, this PR seems to overlook a detail which is central to a lot of discussion around multi-head MTP implementations:

When recomputing with only as many cached hidden states as rejected tokens, the input shapes are consistent but the KV cache of the later MTP modules is not consistent. Your example figure correctly highlights that when rejecting tokens and naively rolling, the KV cache of the later MTP modules are corrupted with bad tokens and hidden states. However, when rolling based on the number of accepted tokens, the later modules' KV caches are populated based on the hidden states of the previous MTP iterations, and not the hidden states from the base model. Consider the case where all 3 tokens are accepted, and therefore no adjusting is necessary: the input hidden states to MTP0 are "7,8,9,10" with the corresponding hidden states. Correspondingly the inputs to MTP2 are "9,10,11,12". In this case the KV cache for MTP2 for tokens "7,8" are the hidden states from MTP1 for those tokens, not the new target model hidden states from the verification of those tokens. This creates an inconsistency for MTP1 and MTP2 where their context states are derived from a mix of target and MTP output hidden states.

It is debatable whether this is even worth considering as long as the token ids are correct, since MTP are typically trained based on the inputs of the previous modules as context anyways, but nevertheless it is an aspect of correctness worth considering. The 'corrected' solution in this case would be to prepend the cached tokens/hiddens always, ensuring that MTP2 always updates its KV cache with the target model's hidden states. This does mean that the input shapes would increase and no longer be the same as the shapes from the target model. In this case, reusing some logic from Parallel Drafting (which similarly has to insert tokens into the batch for specdec) would be useful.

Duplicated Code

The MultiLayerEagleProposer adds a lot of new code to maintain for EAGLE speculative decoding. I would hope that this can be implemented similar to Draft Models and Parallel Drafting where most of the EAGLE inference code is reused, for maintainability.

Caching state on the drafter

It is a major challenge to maintain consistent state in the GPU Model Runner. The input_batch class has a lot of utilities to ensure that under a dynamically changing batch, the state tensors remain consistent and input preparation is efficient. I am opposed to the style of decomposing the state into individual tensors and managing them in the Proposer class. I feel this will lead to a lot of subtle bugs when the batch is reordered in unexpected ways, and/or a lot of overhead needing to rebuild the batch from scratch every iteration. There are a lot of clone/copy/insert operations that look like they would cause a slowdown for large batch sizes. Have you measured any overheads of this approach compared to EAGLE-style MTP, across a range of batch sizes? Does the evaluation and drafting accuracy remain stable across batch sizes?

Underdocumented kernels

This PR introduces several custom triton kernels for preparing the metadata. Being less readable than straight pytorch code, they should be documented with comments explaining their purpose and input/outputs

I am happy to discuss in more detail offline in the vLLM slack if desired. Feel free to reach out directly or in #feat-spec-decode.

vllm_config.speculative_config is not None
and vllm_config.speculative_config.enable_multi_layers_mtp
):
for i in range(0, len(layers), group_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right way to handle this? Can any vLLM KV-cache-manager experts weigh in here?

@mingMelody
Copy link
Author

I have a few high-level concerns with this implementation:

Correctness details

There is a fundamental train/inference mismatch for MTP modules, so it is hard to define what the ideal "correct" behaviour should be. There is some room for interpretation. That said, this PR seems to overlook a detail which is central to a lot of discussion around multi-head MTP implementations:

When recomputing with only as many cached hidden states as rejected tokens, the input shapes are consistent but the KV cache of the later MTP modules is not consistent. Your example figure correctly highlights that when rejecting tokens and naively rolling, the KV cache of the later MTP modules are corrupted with bad tokens and hidden states. However, when rolling based on the number of accepted tokens, the later modules' KV caches are populated based on the hidden states of the previous MTP iterations, and not the hidden states from the base model. Consider the case where all 3 tokens are accepted, and therefore no adjusting is necessary: the input hidden states to MTP0 are "7,8,9,10" with the corresponding hidden states. Correspondingly the inputs to MTP2 are "9,10,11,12". In this case the KV cache for MTP2 for tokens "7,8" are the hidden states from MTP1 for those tokens, not the new target model hidden states from the verification of those tokens. This creates an inconsistency for MTP1 and MTP2 where their context states are derived from a mix of target and MTP output hidden states.

It is debatable whether this is even worth considering as long as the token ids are correct, since MTP are typically trained based on the inputs of the previous modules as context anyways, but nevertheless it is an aspect of correctness worth considering. The 'corrected' solution in this case would be to prepend the cached tokens/hiddens always, ensuring that MTP2 always updates its KV cache with the target model's hidden states. This does mean that the input shapes would increase and no longer be the same as the shapes from the target model. In this case, reusing some logic from Parallel Drafting (which similarly has to insert tokens into the batch for specdec) would be useful.

Duplicated Code

The MultiLayerEagleProposer adds a lot of new code to maintain for EAGLE speculative decoding. I would hope that this can be implemented similar to Draft Models and Parallel Drafting where most of the EAGLE inference code is reused, for maintainability.

Caching state on the drafter

It is a major challenge to maintain consistent state in the GPU Model Runner. The input_batch class has a lot of utilities to ensure that under a dynamically changing batch, the state tensors remain consistent and input preparation is efficient. I am opposed to the style of decomposing the state into individual tensors and managing them in the Proposer class. I feel this will lead to a lot of subtle bugs when the batch is reordered in unexpected ways, and/or a lot of overhead needing to rebuild the batch from scratch every iteration. There are a lot of clone/copy/insert operations that look like they would cause a slowdown for large batch sizes. Have you measured any overheads of this approach compared to EAGLE-style MTP, across a range of batch sizes? Does the evaluation and drafting accuracy remain stable across batch sizes?

Underdocumented kernels

This PR introduces several custom triton kernels for preparing the metadata. Being less readable than straight pytorch code, they should be documented with comments explaining their purpose and input/outputs

I am happy to discuss in more detail offline in the vLLM slack if desired. Feel free to reach out directly or in #feat-spec-decode.

Thanks for your detailed feedback. I’ll follow up by addressing these issues.

For Correctness details

For adjust_input function, it's main logic can be seen in the figure below.

image

In adjust_input, we update positions, hidden_states, token_ids, slot_mappings, and related fields to ensure that all uncertain positions are recomputed in every scenario, without altering the actual input length.

Consequently, the positions of mtp0, mtp1, and mtp2 are recomputed, enabling mtp1 and mtp2 to reuse the correct hidden_states from earlier layers.

For Caching state on the drafter

I agree with your point here. I will try to move the cache-related logic into req_states, so that the drafter no longer needs to maintain this state explicitly.

Others

After addressing the caching state on the drafter, I will come back to further tidy up the code and add more documentation, in order to resolve the issues around duplicated code and underdocumented kernels.

@benchislett
Copy link
Collaborator

@mingMelody

all uncertain positions are recomputed in every scenario, without altering the actual input length

I do not think this is possible as I have explained. Suppose all tokens are accepted, then we still want to recompute the states for them so that MTP2 sees the target model's hidden states for the (accepted) draft tokens. In this case, our MTP batch needs to include both the accepted tokens and the new drafted tokens, so that MTP2 can update the position for the tokens that were drafted by MTP0 and MTP1 in the previous pass.

This is not illustrated in your figure because it only shows the all-rejected case, and not the all-accepted case. Please re-review my original response for more context. The issue is quite subtle.

@mingMelody
Copy link
Author

Consider the case where all 3 tokens are accepted, and therefore no adjusting is necessary: the input hidden states to MTP0 are "7,8,9,10" with the corresponding hidden states. Correspondingly the inputs to MTP2 are "9,10,11,12". In this case the KV cache for MTP2 for tokens "7,8" are the hidden states from MTP1 for those tokens, not the new target model hidden states from the verification of those tokens. This creates an inconsistency for MTP1 and MTP2 where their context states are derived from a mix of target and MTP output hidden states.

I think this concern does not apply in the case where all speculative tokens are accepted.

More concretely, although tokens “7,8” in the KV cache of MTP2 are populated using hidden states produced by earlier MTP stages, these hidden states ultimately originate from target-model–verified tokens “4,5”. When all token ids along this trajectory are accepted by the target model, the entire generation trajectory is implicitly accepted as well. Consequently, the downstream hidden states derived along this trajectory (through MTP0 and MTP1) remain consistent and deterministic, and no incorrect execution path is introduced. As a result, the KV cache entries constructed along this trajectory are also correct and can be safely reused.

This is indeed a very subtle issue.

@benchislett
Copy link
Collaborator

@mingMelody I see what you mean.

I think this is a consequence of a particular design decision around multi MTP to match the training style instead of the EAGLE inference style. In the existing MTP implementation in vLLM, we do EAGLE-style inference where the target model's hidden states are shared context for all draft positions, and we only use the hidden states from the later modules for the draft tokens. In such a case, with multi-mtp, we would need to refresh the hidden states for consistency as I have described. However for training-style multi MTP, all hidden states seen by MTP1 are outputs from MTP0. Is that correct? It seems that way from the code, but your figures do not distinguish between which model generated which hidden states.

If this is the intended implementation, I think it should be fine. Feel free to ping me when you have addressed the other issues.

Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
Signed-off-by: makubes <2416013822@qq.com>
@mergify
Copy link

mergify bot commented Feb 10, 2026

Hi @mingMelody, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: makubes <2416013822@qq.com>
@mingMelody
Copy link
Author

mingMelody commented Feb 10, 2026

The test results of MTP3 are shown below. (Base on H200)

# serve
vllm serve <MODEL_PATH_OR_HF_ID> \
  --served-model-name step3p5-flash \
  --tensor-parallel-size 8 \
  --enable-expert-parallel \
  --disable-cascade-attn \
  --reasoning-parser step3p5 \
  --enable-auto-tool-choice \
  --tool-call-parser step3p5 \
  --speculative_config '{"method": "step3p5_mtp", "num_speculative_tokens": 3, "enable_multi_layers_mtp": true}' \
  --trust-remote-code

# bench test
vllm bench serve --backend vllm --model <MODEL_PATH_OR_HF_ID> --served-model-name step3p5-flash --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 500 --max-concurrency $BATCH_SIZE

Decode Tps / request = Output token throughput (tok/s) / Maximum request concurrency

Batch mtp3 Decode Tps / request mtp3 Avg Acc Rate mtp1 draft 1 Decode Tps / request mtp1 draft 1 Avg Acc Rate mtp1 draft 3 Decode Tps / request mtp1 draft 3 Avg Acc Rate
1 263.54 53.97 192.54 76.88 232.78 42.67
8 165.43 54.03 132.54 77.21 144.33 42.83
16 121.43 54.13 106.19 77.42 104.68 42.96
32 95.30 54.11 76.86 77.36 82.22 42.99
48 84.75 54.19 66.15 77.56 72.88 42.53
64 78.66 54.26 60.92 77.21 67.73 42.76
96 66.96 53.98 53.35 76.83 58.37 42.56
128 39.85 54.32 49.92 77.26 35.89 42.70

@benchislett Previously mentioned issues have now been addressed. Feedback and suggestions would be very welcome.

@MaoJianwei
Copy link
Contributor

when will it be merged? :)

@javilima01
Copy link

Hi guys, is there any updated on this? Will it be supported for Step 3.5 Flash?

@mingMelody
Copy link
Author

Hi guys, is there any updated on this? Will it be supported for Step 3.5 Flash?

Yes, it is already supported for Step 3.5 Flash.

@javilima01
Copy link

Okay, when will it be merged ?

@benchislett
Copy link
Collaborator

Apologies for the delays. I will review this week.

It will take a while to merge, as there is a lot of code so it will require a lot of effort to review and eventually maintain

@javilima01
Copy link

Glad to hear it!
Much thanks for all of the effort!!!

@csy0225
Copy link
Contributor

csy0225 commented Mar 12, 2026

@benchislett How's the review going?

@benchislett
Copy link
Collaborator

Sorry guys, I have been focusing on DFlash lately and have been spread pretty thin. Hopefully I can finish my review soon.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my first round of feedback. Overall, the code currently feels very opaque and a bit bloated. It has been challenging to parse some of the segments, and while I don't doubt the correctness the complexity feels excessive given the scope.

self.num_speculative_tokens = self.speculative_config.num_speculative_tokens

self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
self.layer_num = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a better name, or a comment explaining what "layer_num" means. Intuitively, I would assume it means "number of layers in each MTP/EAGLE module", but that seems incorrect.

common_attn_metadata.seq_lens.sub_(shift)

# NOTE: ignore cpu data to avoid device sync
# common_attn_metadata.seq_lens_cpu.copy_(common_attn_metadata.seq_lens,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dead code?

sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment not needed if this function raises an error anyways.

tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise Exception(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generic exception is not recommended.

I suggest you ensure that when using multi layer eagle the padded drafter batch mode be automatically enabled as required (example). If so, you can change this to an assert

self.model(**model_kwargs)


def _multi_layer_eagle_shift_and_cache(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a detailed comment here. It is unclear what the responsibility of this function is

)
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)

_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems very involved. Can you explain why such a complicated implementation is necessary?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given all the complexity here, a direct torch.compile'd implementation may be preferable

if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
draft_token_ids_list = []
for spec_step_idx in range(self.layer_num):
Copy link
Collaborator

@benchislett benchislett Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid implementing this feature in such a way that affects the readability of the codebase more broadly. The core flow of the EAGLE pathway should be preserved as much as possible.

target_positions,
target_hidden_states,
common_attn_metadata,
) = self.adjust_input(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this function? Why can't the multi layer component just specialize set_inputs_first_pass?

pooling_states: PoolingStates | None = None

# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't feel great about having to cache all this state on the drafter. This feels like the wrong way to handle it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the meantime, can you add some documentation here about what these tensors represent, and how they are intended to be used for multi-layer EAGLE statefulness across iterations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants