-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[V1] Support Deepseek MTP #18435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[V1] Support Deepseek MTP #18435
Changes from 10 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
70e688c
[V1] MTP Support Prototype
ruisearch42 4cf308b
fix MTP tp
YaoJiayi c3a32cd
resolve conflicts
YaoJiayi e2f065c
fix bugs
YaoJiayi fce4d58
fix pp
YaoJiayi fade742
fix format
YaoJiayi 0e6ed11
fix format
YaoJiayi a79b116
revert changes
YaoJiayi ff6a1b3
fix format
YaoJiayi ca0d63d
fix unit test
YaoJiayi 3699a98
address minor comments
YaoJiayi c505dd7
add cudagraph compatibility
YaoJiayi 8b8c8ba
unify eagle and mtp
YaoJiayi 5bc04ee
Merge branch 'main' into localdev/v1-mtp
YaoJiayi 5d02965
fix minor bug and remove mtp_proposer
YaoJiayi 56e08d7
add blank line
YaoJiayi a76facb
fix format
YaoJiayi 1271a57
fix minor issues
YaoJiayi 485332a
fix warning
YaoJiayi e02f26c
fix unit test
YaoJiayi a9c890c
format fix
YaoJiayi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from vllm.attention.layer import Attention | ||
| from vllm.config import (VllmConfig, get_layers_from_vllm_config, | ||
| set_current_vllm_config) | ||
| from vllm.forward_context import set_forward_context | ||
| from vllm.model_executor.model_loader import get_model_loader | ||
| from vllm.model_executor.model_loader.utils import ( | ||
| process_weights_after_loading, set_default_torch_dtype) | ||
| from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP | ||
| from vllm.v1.attention.backends.utils import CommonAttentionMetadata | ||
| from vllm.v1.sample.metadata import SamplingMetadata | ||
| from vllm.v1.spec_decode.utils import prepare_input_kernel | ||
|
|
||
|
|
||
| # FIXME(woosuk): The logic here is duplicated with the main sampling code. | ||
| # We should refactor this to reuse the same sampling implementation. | ||
| def compute_probs_and_sample_next_token( | ||
YaoJiayi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| logits: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if sampling_metadata.all_greedy: | ||
| # For greedy requests, draft_probs is not used in rejection sampling. | ||
| # Therefore, we can just return the logits. | ||
| probs = logits | ||
| next_token_ids = logits.argmax(dim=-1) | ||
| return next_token_ids, probs | ||
|
|
||
| is_greedy = sampling_metadata.temperature == -1 | ||
| temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) | ||
| logits.div_(temperature.view(-1, 1)) | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
|
|
||
| # NOTE(woosuk): Currently, we ignore most of the sampling parameters in | ||
| # generating the draft tokens. We only use the temperature. While this | ||
| # could degrade the acceptance rate, it does not affect the distribution | ||
| # of the generated tokens after rejection sampling. | ||
|
|
||
| # TODO(woosuk): Consider seeds. | ||
| q = torch.empty_like(probs) | ||
| q.exponential_() | ||
| next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) | ||
| if not sampling_metadata.all_random: | ||
| greedy_token_ids = probs.argmax(dim=-1) | ||
| next_token_ids = torch.where( | ||
| is_greedy, | ||
| greedy_token_ids, | ||
| next_token_ids, | ||
| ) | ||
| return next_token_ids, probs | ||
|
|
||
|
|
||
| class MtpProposer: | ||
YaoJiayi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self, | ||
| vllm_config: VllmConfig, | ||
| runner, | ||
| ): | ||
| self.vllm_config = vllm_config | ||
YaoJiayi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.num_speculative_tokens = ( | ||
| vllm_config.speculative_config.num_speculative_tokens) | ||
| self.block_size = vllm_config.cache_config.block_size | ||
| self.runner = runner | ||
|
|
||
| @staticmethod | ||
| def prepare_inputs( | ||
| # [batch_size + 1] | ||
| cu_target_query_lens: torch.Tensor, | ||
| # [batch_size] | ||
| num_rejected_tokens: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| # cu_target_query_lens: [0, a, a + b, a + b + c] | ||
| # num_rejected_tokens: [n1, n2, n3] | ||
| # num_tokens_per_req: [a - n1, b - n2, c - n3] | ||
| # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] | ||
| # token_indices: [0, 1, ..., a - n1 - 1, | ||
| # a, a + 1, ..., a + b - n2 - 1, | ||
| # a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||
|
|
||
| # [0, a, a + b, a + b + c] -> [a, b, c] | ||
| query_len_per_req = (cu_target_query_lens[1:] - | ||
| cu_target_query_lens[:-1]) | ||
| # [a, b, c] -> [a - n1, b - n2, c - n3] | ||
| num_tokens_per_req = query_len_per_req - num_rejected_tokens | ||
|
|
||
| cu_num_tokens = torch.empty_like(cu_target_query_lens) | ||
| torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) | ||
| cu_num_tokens[0] = 0 | ||
|
|
||
| # FIXME(woosuk): Avoid synchronization. | ||
| num_tokens = cu_num_tokens[-1].item() | ||
| token_indices = torch.empty( | ||
| num_tokens, | ||
| dtype=torch.int32, | ||
| device=cu_num_tokens.device, | ||
| ) | ||
|
|
||
| batch_size = num_rejected_tokens.shape[0] | ||
| BLOCK_SIZE = 1024 | ||
| prepare_input_kernel[(batch_size, )]( | ||
| token_indices, | ||
| cu_target_query_lens, | ||
| cu_num_tokens, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
| return cu_num_tokens, token_indices | ||
|
|
||
| def propose( | ||
| self, | ||
| # [num_tokens] | ||
| target_token_ids: torch.Tensor, | ||
| # [num_tokens] | ||
| target_positions: torch.Tensor, | ||
| # [num_tokens, hidden_size] | ||
| target_hidden_states: torch.Tensor, | ||
| # [num_tokens] | ||
| target_slot_mapping: torch.Tensor, | ||
| # [batch_size] | ||
| next_token_ids: torch.Tensor, | ||
| # [batch_size + 1] starting with 0 | ||
| cu_num_tokens: torch.Tensor, | ||
| # [batch_size, max_num_blocks_per_req] | ||
| block_table: Optional[torch.Tensor], | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| num_tokens = target_token_ids.shape[0] | ||
| batch_size = next_token_ids.shape[0] | ||
| last_token_indices = cu_num_tokens[1:] - 1 | ||
|
|
||
| input_ids = torch.empty_like(target_token_ids) | ||
| # Shift the input ids by one token. | ||
| # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] | ||
| input_ids[:-1] = target_token_ids[1:] | ||
| # Replace the last token with the next token. | ||
| # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] | ||
| input_ids[last_token_indices] = next_token_ids | ||
|
|
||
| query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] | ||
| max_query_len = query_lens.max().item() | ||
|
|
||
| seq_lens = (target_positions[last_token_indices] + 1) | ||
|
|
||
| common_attn_metadata = CommonAttentionMetadata( | ||
| query_start_loc=cu_num_tokens, seq_lens=seq_lens) | ||
|
|
||
| # FIXME: reorder_batch() needs to be called before build() | ||
| # because fields of attn_metadata_builder needs to be updated. | ||
| # However, currently reorder_batch() takes input_batch and | ||
| # scheduler_output as arguments, we should probably refactor | ||
| # the method to use new data structures which are independent | ||
| # from input_batch and scheduler_output. | ||
| # self.runner.attn_metadata_builder.reorder_batch( | ||
| # input_batch=self.runner.input_batch, | ||
| # scheduler_output=self.runner.scheduler_output, | ||
| # ) | ||
|
|
||
| # FIXME: need to consider multiple kv_cache_groups | ||
| attn_metadata = self.runner.attn_metadata_builders[0].build( | ||
| num_reqs=batch_size, | ||
| num_actual_tokens=num_tokens, | ||
| max_query_len=max_query_len, | ||
| common_prefix_len=0, | ||
| common_attn_metadata=common_attn_metadata, | ||
| ) | ||
|
|
||
| with set_forward_context(attn_metadata, self.vllm_config): | ||
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=target_positions, | ||
| previous_hidden_states=target_hidden_states, | ||
| ) | ||
| sample_hidden_states = hidden_states[last_token_indices] | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
| draft_token_ids = logits.argmax(dim=-1) | ||
|
|
||
| assert self.num_speculative_tokens == 1 | ||
| # [batch_size, 1] | ||
| return draft_token_ids.view(-1, 1) | ||
|
|
||
| def load_model(self, target_model: nn.Module) -> None: | ||
YaoJiayi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| loader = get_model_loader(self.vllm_config.load_config) | ||
| target_attn_layer_names = set( | ||
| get_layers_from_vllm_config(self.vllm_config, Attention).keys()) | ||
|
|
||
| draft_model_config = \ | ||
| self.vllm_config.speculative_config.draft_model_config | ||
| # FIXME(lily): This does not handle with distributed inference. | ||
| target_device = self.vllm_config.device_config.device | ||
| # We need to set the vllm_config here to register attention | ||
| # layers in the forward context. | ||
| with set_default_torch_dtype( | ||
| draft_model_config.dtype), set_current_vllm_config( | ||
| self.vllm_config): | ||
|
|
||
| with target_device: | ||
| self.model = DeepSeekMTP(vllm_config=self.vllm_config) | ||
|
|
||
| draft_attn_layer_names = (get_layers_from_vllm_config( | ||
| self.vllm_config, Attention).keys() - target_attn_layer_names) | ||
| assert len(draft_attn_layer_names) == 1 | ||
| self.attn_layer_name = next(iter(draft_attn_layer_names)) | ||
| self.model.load_weights( | ||
| loader.get_all_weights( | ||
| self.vllm_config.speculative_config.draft_model_config, | ||
| self.model)) | ||
|
|
||
| process_weights_after_loading( | ||
| self.model, | ||
| self.vllm_config.speculative_config.draft_model_config, | ||
| target_device) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.