Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
register_auto_model)

from ..distributed import allgather
import os
from tensorrt_llm.mapping import Mapping

@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
Expand Down Expand Up @@ -166,10 +169,31 @@ def forward(self,
else:
hidden_states = hidden_states[-1].unsqueeze(0)

if not (self.model_config.mapping.enable_attention_dp):
# Add pre-lm gather logic
if (self.model_config.mapping.enable_attention_dp and
getattr(self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
# ADP + LM TP mode: perform All-Gather before LM_head
lm_tp_size = int(os.getenv('LM_TP_SIZE', 2))
assert self.model_config.mapping.tp_size % lm_tp_size == 0
lm_pp_size = self.model_config.mapping.pp_size * self.model_config.mapping.tp_size // lm_tp_size
mapping_lm_tp = Mapping(
world_size=lm_tp_size * lm_pp_size,
rank=self.model_config.mapping.rank,
gpus_per_node=self.model_config.mapping.gpus_per_node,
tp_size=lm_tp_size,
pp_size=lm_pp_size,
enable_attention_dp=self.model_config.mapping.enable_attention_dp,
enable_lm_tp_in_adp=self.model_config.mapping.enable_lm_tp_in_adp,
)
hidden_states = allgather(hidden_states, mapping_lm_tp, dim=0)

# Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled)
if (not self.model_config.mapping.enable_attention_dp) or (self.model_config.mapping.enable_attention_dp and
getattr(self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
lm_head.gather_output = False
logits = lm_head(hidden_states)
if not (self.model_config.mapping.enable_attention_dp):
logits = lm_head(hidden_states, is_mtp_head=True)
if (not self.model_config.mapping.enable_attention_dp) or (self.model_config.mapping.enable_attention_dp and
getattr(self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
lm_head.gather_output = True
return logits

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
self.pp_size = config.mapping.pp_size
self.has_custom_lm_head = False

if config.mapping.enable_attention_dp:
if config.mapping.enable_attention_dp and not getattr(config.mapping, 'enable_lm_tp_in_adp', False):
self.lm_head = LMHead(
vocab_size,
hidden_size,
Expand Down
34 changes: 32 additions & 2 deletions tensorrt_llm/_torch/modules/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Missing NVIDIA 2025 copyright header

Please prepend the standard NVIDIA copyright header (2025) to comply with repo guidelines.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/embedding.py around lines 1 to 1, the file is
missing the required NVIDIA 2025 copyright header; prepend the standard NVIDIA
2025 copyright header block at the very top of the file (above the existing
import math line), ensuring the header text, year (2025), and any required
license or ownership wording match the repo's standard header template exactly
and preserve a blank line after the header before the first import.

import os
from typing import Dict, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -35,6 +36,21 @@ def __init__(
local_in_features = embedding_dim
local_out_features = num_embeddings
mapping = mapping or Mapping()
if (mapping.enable_attention_dp and
getattr(mapping, 'enable_lm_tp_in_adp', False)):
lm_tp_size = int(os.getenv('LM_TP_SIZE', 2))
assert mapping.tp_size % lm_tp_size == 0, f"mapping.tp_size % lm_tp_size == 0, {mapping.tp_size} % {lm_tp_size} != 0"
lm_pp_size = mapping.pp_size * mapping.tp_size // lm_tp_size
mapping = Mapping(
world_size=lm_tp_size * lm_pp_size,
rank=mapping.rank,
gpus_per_node=mapping.gpus_per_node,
tp_size=lm_tp_size,
pp_size=lm_pp_size,
enable_attention_dp=mapping.enable_attention_dp,
enable_lm_tp_in_adp=mapping.enable_lm_tp_in_adp,
)

tp_size = mapping.tp_size

# Attention DP doesn't work with embedding parallelization.
Expand Down Expand Up @@ -83,9 +99,23 @@ def forward(
self,
input: torch.Tensor,
*,
all_reduce_params: Optional[AllReduceParams] = None
all_reduce_params: Optional[AllReduceParams] = None,
is_mtp_head: bool = False,
) -> torch.Tensor:
output = super().forward(input, all_reduce_params=all_reduce_params)
if is_mtp_head and (self.mapping.enable_attention_dp and
getattr(self.mapping, 'enable_lm_tp_in_adp', False)):
tp_rank = self.mapping.tp_rank
tp_size = self.mapping.tp_size
tensor_shape = self.weight.shape
width = tensor_shape[0]
slice_width = math.ceil(width / tp_size)
slice_start = tp_rank * slice_width
slice_end = min((tp_rank + 1) * slice_width, width)
slice_obj = [slice(None)] * len(tensor_shape)
slice_obj[0] = slice(slice_start, slice_end)
output = F.linear(input, self.weight[tuple(slice_obj)], None)
else:
output = super().forward(input, all_reduce_params=all_reduce_params)
if (self.tp_mode == TensorParallelMode.COLUMN and self.gather_output
and self.padding_size > 0):
output = output[..., :-self.padding_size]
Expand Down
68 changes: 64 additions & 4 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, List, Optional

import torch
import torch.nn.functional as F
from torch import nn

from ..attention_backend import AttentionMetadata
Expand All @@ -17,6 +18,8 @@
if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig

import os
from tensorrt_llm.mapping import Mapping

@dataclass(kw_only=True)
class SampleStateTensorsMTP(SampleStateTensors):
Expand Down Expand Up @@ -473,9 +476,23 @@ def forward(
for _, mtp_layer in enumerate(draft_model.mtp_layers):
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
token_count = hidden_states.view(-1,
hidden_states.shape[-1]).shape[0]
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
pad_len = all_rank_max_num_tokens - token_count
if pad_len > 0:
padded_hidden_states = F.pad(hidden_states.view(
-1, hidden_states.shape[-1]), (0, 0, 0, pad_len),
mode="constant",
value=0)
else:
padded_hidden_states = hidden_states.view(
-1, hidden_states.shape[-1])
logits = mtp_layer.shared_head(padded_hidden_states,
draft_model.lm_head,
attn_metadata).float()
new_draft_token = self.draft_sampler(logits)
new_draft_token = new_draft_token[:token_count]
next_draft_tokens.append(new_draft_token)
Comment on lines 494 to 496
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Signature mismatch: draft_sampler now requires iter but call omits it

This will raise a TypeError at runtime. Pass the loop index and update enumerate accordingly.

-        for _, mtp_layer in enumerate(draft_model.mtp_layers):
+        for i, mtp_layer in enumerate(draft_model.mtp_layers):
...
-            new_draft_token = self.draft_sampler(logits)
+            new_draft_token = self.draft_sampler(logits, i)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/mtp.py around lines 492-494, the call to
self.draft_sampler(logits) omits the required iter argument; update the
surrounding loop to use enumerate (e.g., for i, ... in enumerate(...)) and pass
the loop index into draft_sampler (self.draft_sampler(logits, i)), then keep the
rest of the logic (slicing to token_count and appending) unchanged.

# shift input_ids and hidden_states
input_ids = draft_inputs["input_ids"]
Expand Down Expand Up @@ -1041,12 +1058,13 @@ def prepare_drafter_inputs(
}

@torch.compile(options={"max-autotune": True})
def get_local_max_and_combined(self, logits):
def get_local_max_and_combined(self, logits, mapping_lm_tp=None):
local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True)
# Adjust indices based on TP rank and size
vocab_per_rank = logits.shape[-1]
mapping_lm_tp = mapping_lm_tp if mapping_lm_tp is not None else self.model_config.mapping
max_index_per_rank = local_argmax.type(
torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank)
torch.int32) + (mapping_lm_tp.tp_rank * vocab_per_rank)
# Use torch.stack and flatten instead of view+cat to avoid torch.compile issues
# Convert both to float32 to ensure consistent dtype
max_index_per_rank_float = max_index_per_rank.float()
Expand Down Expand Up @@ -1095,6 +1113,32 @@ def draft_sampler(
combined = self.get_local_max_and_combined(logits)
gathered = allgather(combined, self.model_config.mapping, dim=-1)
draft_tokens = self.get_draft_tokens_from_gathered(gathered)
elif (self.model_config is not None
and hasattr(self.model_config, 'mapping')
and self.model_config.mapping.tp_size > 1) and (
self.model_config.mapping.enable_attention_dp and getattr(
self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
# For ADP + LM TP mode, we need to find the global argmax across all TP ranks
# First, get local argmax and max values
lm_tp_size = int(os.getenv('LM_TP_SIZE', 2))
assert self.model_config.mapping.tp_size % lm_tp_size == 0
lm_pp_size = self.model_config.mapping.pp_size * self.model_config.mapping.tp_size // lm_tp_size
mapping_lm_tp = Mapping(
world_size=lm_tp_size * lm_pp_size,
rank=self.model_config.mapping.rank,
gpus_per_node=self.model_config.mapping.gpus_per_node,
tp_size=lm_tp_size,
pp_size=lm_pp_size,
enable_attention_dp=self.model_config.mapping.enable_attention_dp,
enable_lm_tp_in_adp=self.model_config.mapping.enable_lm_tp_in_adp,
)
combined = self.get_local_max_and_combined(logits, mapping_lm_tp)
gathered = allgather(combined, mapping_lm_tp, dim=-1)
batch_size = logits.shape[0]
local_batch_size = batch_size // mapping_lm_tp.tp_size
gathered = gathered.view(mapping_lm_tp.tp_size, local_batch_size, -1)
sliced_gathered = gathered[mapping_lm_tp.tp_rank]
draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
else:
# Simple argmax if no TP or no model config
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
Expand Down Expand Up @@ -1194,10 +1238,26 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
**inputs)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
hidden_states_gathered = hidden_states[gather_ids]
token_count = hidden_states_gathered.view(-1,
hidden_states_gathered.shape[-1]).shape[0]
max_num_requests = spec_metadata.max_num_requests
pad_len = max_num_requests - token_count
if pad_len > 0:
padded_hidden_states = F.pad(hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]), (0, 0, 0, pad_len),
mode="constant",
value=0)
elif pad_len == 0:
padded_hidden_states = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1])
else:
raise ValueError(f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported")
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head, attn_metadata,
padded_hidden_states, draft_model.lm_head, attn_metadata,
True)
new_draft_token = self.draft_sampler(logits)
new_draft_token = new_draft_token[:token_count]

hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class _ParallelConfig:
moe_ep_size: int = 1
cp_config: dict = field(default_factory=dict)
enable_attention_dp: bool = False
enable_lm_tp_in_adp: bool = False
auto_parallel: bool = False

_world_size: int = field(default=1, init=False)
Expand Down Expand Up @@ -288,6 +289,7 @@ def to_mapping(self) -> Mapping:
cp_size=self.cp_size,
cp_config=self.cp_config,
enable_attention_dp=self.enable_attention_dp,
enable_lm_tp_in_adp=self.enable_lm_tp_in_adp,
moe_cluster_size=self.moe_cluster_size,
moe_tp_size=self.moe_tp_size,
moe_ep_size=self.moe_ep_size,
Expand Down Expand Up @@ -1261,6 +1263,11 @@ class BaseLlmArgs(StrictBaseModel):
description="Enable attention data parallel.",
status="beta")

enable_lm_tp_in_adp: bool = Field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Superjomn FYI - is it ok to add another argument here? Any other suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK for the prototype stage. The mechanism, I think, is like this:

  1. If there is no existing XxConfig to hold the new knob, it is fine to add a dangling knob, but mark it as a prototype
  2. We can wait if there are more than two or three knobs in the same category, then we can consider grouping them into a Xxconfig, no rush to introduce a hierarchical config before we are sure the knobs need it.
  3. When the feature is somewhat stable, we can mark the xxx_config beta then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, in that case, for this knob it should be status="prototype"?

Copy link
Collaborator

@Superjomn Superjomn Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so, for a dangling knob, it should start from "prototype", as we may refactor it with some hierarchical Config later.

default=False,
description="Enable lm tp in attention dp.",
status="beta")

cp_config: Optional[dict] = Field(default_factory=dict,
description="Context parallel config.",
status="prototype")
Expand Down Expand Up @@ -1508,6 +1515,7 @@ def validate_parallel_config(self):
moe_tp_size=self.moe_tensor_parallel_size,
moe_ep_size=self.moe_expert_parallel_size,
enable_attention_dp=self.enable_attention_dp,
enable_lm_tp_in_adp=self.enable_lm_tp_in_adp,
cp_config=self.cp_config)
return self

Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def __init__(
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False,
enable_attention_dp=False):
enable_attention_dp=False,
enable_lm_tp_in_adp=False):
# set default values for non-moe cases
# or where only one MOE parallelism size is specified
if moe_cluster_size == -1:
Expand Down Expand Up @@ -224,6 +225,7 @@ def __init__(
self.auto_parallel = auto_parallel
self.world_size = world_size
self.enable_attention_dp = enable_attention_dp
self.enable_lm_tp_in_adp = enable_lm_tp_in_adp
self.rank = rank
self.gpus_per_node = gpus_per_node
self.pp_groups = []
Expand Down Expand Up @@ -510,4 +512,6 @@ def to_dict(self):
'attn_cp_size': self.attn_cp_size,
'cp_config': self.cp_config,
'auto_parallel': self.auto_parallel,
'enable_attention_dp': self.enable_attention_dp,
'enable_lm_tp_in_adp': self.enable_lm_tp_in_adp,
}
Loading