-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTLLM-6741] [feat] enable LM tp for MTP, under attention dp case #7128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| from typing import TYPE_CHECKING, List, Optional | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import nn | ||
|
|
||
| from ..attention_backend import AttentionMetadata | ||
|
|
@@ -17,6 +18,8 @@ | |
| if TYPE_CHECKING: | ||
| from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig | ||
|
|
||
| import os | ||
| from tensorrt_llm.mapping import Mapping | ||
|
|
||
| @dataclass(kw_only=True) | ||
| class SampleStateTensorsMTP(SampleStateTensors): | ||
|
|
@@ -473,9 +476,23 @@ def forward( | |
| for _, mtp_layer in enumerate(draft_model.mtp_layers): | ||
| hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, | ||
| **draft_inputs) | ||
| logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head, | ||
| token_count = hidden_states.view(-1, | ||
| hidden_states.shape[-1]).shape[0] | ||
| all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens | ||
| pad_len = all_rank_max_num_tokens - token_count | ||
| if pad_len > 0: | ||
| padded_hidden_states = F.pad(hidden_states.view( | ||
| -1, hidden_states.shape[-1]), (0, 0, 0, pad_len), | ||
| mode="constant", | ||
| value=0) | ||
| else: | ||
| padded_hidden_states = hidden_states.view( | ||
| -1, hidden_states.shape[-1]) | ||
| logits = mtp_layer.shared_head(padded_hidden_states, | ||
| draft_model.lm_head, | ||
| attn_metadata).float() | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| new_draft_token = self.draft_sampler(logits) | ||
| new_draft_token = new_draft_token[:token_count] | ||
| next_draft_tokens.append(new_draft_token) | ||
|
Comment on lines
494
to
496
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Signature mismatch: draft_sampler now requires iter but call omits it This will raise a TypeError at runtime. Pass the loop index and update enumerate accordingly. - for _, mtp_layer in enumerate(draft_model.mtp_layers):
+ for i, mtp_layer in enumerate(draft_model.mtp_layers):
...
- new_draft_token = self.draft_sampler(logits)
+ new_draft_token = self.draft_sampler(logits, i)🤖 Prompt for AI Agents
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # shift input_ids and hidden_states | ||
| input_ids = draft_inputs["input_ids"] | ||
|
|
@@ -1041,12 +1058,13 @@ def prepare_drafter_inputs( | |
| } | ||
|
|
||
| @torch.compile(options={"max-autotune": True}) | ||
| def get_local_max_and_combined(self, logits): | ||
| def get_local_max_and_combined(self, logits, mapping_lm_tp=None): | ||
| local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True) | ||
| # Adjust indices based on TP rank and size | ||
| vocab_per_rank = logits.shape[-1] | ||
| mapping_lm_tp = mapping_lm_tp if mapping_lm_tp is not None else self.model_config.mapping | ||
| max_index_per_rank = local_argmax.type( | ||
| torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank) | ||
| torch.int32) + (mapping_lm_tp.tp_rank * vocab_per_rank) | ||
| # Use torch.stack and flatten instead of view+cat to avoid torch.compile issues | ||
| # Convert both to float32 to ensure consistent dtype | ||
| max_index_per_rank_float = max_index_per_rank.float() | ||
|
|
@@ -1095,6 +1113,32 @@ def draft_sampler( | |
| combined = self.get_local_max_and_combined(logits) | ||
| gathered = allgather(combined, self.model_config.mapping, dim=-1) | ||
| draft_tokens = self.get_draft_tokens_from_gathered(gathered) | ||
| elif (self.model_config is not None | ||
| and hasattr(self.model_config, 'mapping') | ||
| and self.model_config.mapping.tp_size > 1) and ( | ||
| self.model_config.mapping.enable_attention_dp and getattr( | ||
| self.model_config.mapping, 'enable_lm_tp_in_adp', False)): | ||
| # For ADP + LM TP mode, we need to find the global argmax across all TP ranks | ||
| # First, get local argmax and max values | ||
| lm_tp_size = int(os.getenv('LM_TP_SIZE', 2)) | ||
| assert self.model_config.mapping.tp_size % lm_tp_size == 0 | ||
| lm_pp_size = self.model_config.mapping.pp_size * self.model_config.mapping.tp_size // lm_tp_size | ||
| mapping_lm_tp = Mapping( | ||
| world_size=lm_tp_size * lm_pp_size, | ||
| rank=self.model_config.mapping.rank, | ||
| gpus_per_node=self.model_config.mapping.gpus_per_node, | ||
| tp_size=lm_tp_size, | ||
| pp_size=lm_pp_size, | ||
| enable_attention_dp=self.model_config.mapping.enable_attention_dp, | ||
| enable_lm_tp_in_adp=self.model_config.mapping.enable_lm_tp_in_adp, | ||
| ) | ||
| combined = self.get_local_max_and_combined(logits, mapping_lm_tp) | ||
| gathered = allgather(combined, mapping_lm_tp, dim=-1) | ||
| batch_size = logits.shape[0] | ||
| local_batch_size = batch_size // mapping_lm_tp.tp_size | ||
| gathered = gathered.view(mapping_lm_tp.tp_size, local_batch_size, -1) | ||
| sliced_gathered = gathered[mapping_lm_tp.tp_rank] | ||
| draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered) | ||
| else: | ||
| # Simple argmax if no TP or no model config | ||
| draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32) | ||
|
|
@@ -1194,10 +1238,26 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): | |
| **inputs) | ||
| # All of the seq_len are 1, use batch_indices_cuda as gather_ids | ||
| gather_ids = spec_metadata.batch_indices_cuda[:batch_size] | ||
| hidden_states_gathered = hidden_states[gather_ids] | ||
| token_count = hidden_states_gathered.view(-1, | ||
| hidden_states_gathered.shape[-1]).shape[0] | ||
| max_num_requests = spec_metadata.max_num_requests | ||
| pad_len = max_num_requests - token_count | ||
| if pad_len > 0: | ||
| padded_hidden_states = F.pad(hidden_states_gathered.view( | ||
| -1, hidden_states_gathered.shape[-1]), (0, 0, 0, pad_len), | ||
| mode="constant", | ||
| value=0) | ||
| elif pad_len == 0: | ||
| padded_hidden_states = hidden_states_gathered.view( | ||
| -1, hidden_states_gathered.shape[-1]) | ||
| else: | ||
| raise ValueError(f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported") | ||
| logits = draft_model.mtp_layers[0].shared_head( | ||
| hidden_states[gather_ids], draft_model.lm_head, attn_metadata, | ||
| padded_hidden_states, draft_model.lm_head, attn_metadata, | ||
| True) | ||
| new_draft_token = self.draft_sampler(logits) | ||
| new_draft_token = new_draft_token[:token_count] | ||
|
|
||
| hidden_states, position_ids = self.update_draft_tokens( | ||
| next_draft_tokens, new_draft_token, hidden_states, gather_ids, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -225,6 +225,7 @@ class _ParallelConfig: | |
| moe_ep_size: int = 1 | ||
| cp_config: dict = field(default_factory=dict) | ||
| enable_attention_dp: bool = False | ||
| enable_lm_tp_in_adp: bool = False | ||
| auto_parallel: bool = False | ||
|
|
||
| _world_size: int = field(default=1, init=False) | ||
|
|
@@ -288,6 +289,7 @@ def to_mapping(self) -> Mapping: | |
| cp_size=self.cp_size, | ||
| cp_config=self.cp_config, | ||
| enable_attention_dp=self.enable_attention_dp, | ||
| enable_lm_tp_in_adp=self.enable_lm_tp_in_adp, | ||
| moe_cluster_size=self.moe_cluster_size, | ||
| moe_tp_size=self.moe_tp_size, | ||
| moe_ep_size=self.moe_ep_size, | ||
|
|
@@ -1261,6 +1263,11 @@ class BaseLlmArgs(StrictBaseModel): | |
| description="Enable attention data parallel.", | ||
| status="beta") | ||
|
|
||
| enable_lm_tp_in_adp: bool = Field( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Superjomn FYI - is it ok to add another argument here? Any other suggestions?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is OK for the prototype stage. The mechanism, I think, is like this:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, in that case, for this knob it should be
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so, for a dangling knob, it should start from "prototype", as we may refactor it with some hierarchical Config later. |
||
| default=False, | ||
| description="Enable lm tp in attention dp.", | ||
| status="beta") | ||
|
|
||
| cp_config: Optional[dict] = Field(default_factory=dict, | ||
| description="Context parallel config.", | ||
| status="prototype") | ||
|
|
@@ -1508,6 +1515,7 @@ def validate_parallel_config(self): | |
| moe_tp_size=self.moe_tensor_parallel_size, | ||
| moe_ep_size=self.moe_expert_parallel_size, | ||
| enable_attention_dp=self.enable_attention_dp, | ||
| enable_lm_tp_in_adp=self.enable_lm_tp_in_adp, | ||
| cp_config=self.cp_config) | ||
| return self | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Missing NVIDIA 2025 copyright header
Please prepend the standard NVIDIA copyright header (2025) to comply with repo guidelines.
🤖 Prompt for AI Agents