Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,10 @@ def __init__(self,
if self.strategy != AllReduceStrategy.UB:
if self.strategy == AllReduceStrategy.LOWPRECISION:
allocate_low_presicion_allreduce_workspace(self.mapping)
self.workspace = get_allreduce_workspace(self.mapping)
if self.strategy not in (AllReduceStrategy.UB,
AllReduceStrategy.NCCL,
AllReduceStrategy.NCCL_SYMMETRIC):
self.workspace = get_allreduce_workspace(self.mapping)

# Initialize MNNVL AllReduce if needed
if self.strategy in (AllReduceStrategy.AUTO,
Expand Down
49 changes: 31 additions & 18 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def __init__(
layer_idx: Optional[int] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
mapping_with_cp: Optional[Mapping] = None,
reduce_output: bool = True,
):
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
Expand All @@ -559,7 +560,8 @@ def __init__(
dtype=config.torch_dtype,
config=model_config,
aux_stream=aux_stream,
mapping_with_cp=mapping_with_cp)
mapping_with_cp=mapping_with_cp,
reduce_output=reduce_output)
self.kv_a_proj_with_mqa = DeepseekV3Linear(
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim +
Expand All @@ -579,6 +581,7 @@ def __init__(
model_config: ModelConfig[PretrainedConfig],
layer_idx: Optional[int] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
reduce_output: bool = True,
):
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
Expand All @@ -602,7 +605,8 @@ def __init__(
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
aux_stream=aux_stream)
aux_stream=aux_stream,
reduce_output=reduce_output)

self.indexer = self.mqa.indexer

Expand Down Expand Up @@ -892,8 +896,10 @@ def __init__(self,
overridden_tp_size=shared_tp_size,
reduce_output=False)

self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.allreduce = None
if not self.use_dp and self.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
Expand Down Expand Up @@ -1051,6 +1057,10 @@ def __init__(self,

self.mapping = model_config.mapping
mapping = self.mapping
self.enable_attention_dp = mapping.enable_attention_dp
self.mlp_tp_size = mapping.tp_size
self.is_p2p_supported = can_access_peer(mapping)

layer_idx_for_attention = layer_idx
if is_separate_draft_engine:
#KVCacheManager only support 1 layer for separate draft engine
Expand All @@ -1060,17 +1070,17 @@ def __init__(self,
self.self_attn = DeepseekV32Attention(
model_config,
layer_idx=layer_idx_for_attention,
aux_stream=aux_stream_dict[AuxStreamType.Attention])
aux_stream=aux_stream_dict[AuxStreamType.Attention],
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1)
else:
self.self_attn = DeepseekV3Attention(
model_config,
layer_idx=layer_idx_for_attention,
aux_stream=aux_stream_dict[AuxStreamType.Attention],
mapping_with_cp=mapping_with_cp)
self.enable_attention_dp = mapping.enable_attention_dp

self.mlp_tp_size = mapping.tp_size
self.is_p2p_supported = can_access_peer(mapping)
mapping_with_cp=mapping_with_cp,
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1)

self.fusion_config = EagerFusionConfig()
self.enable_fusion = os.environ.get(
Expand All @@ -1085,12 +1095,15 @@ def __init__(self,
quant_config.quant_algo
is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous"

has_tp = mapping.has_tp()
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype)
self.moe_allreduce = MoEAllReduce(self.mapping)
self.allreduce = None
self.moe_allreduce = None
if not self.enable_attention_dp and self.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype)
self.moe_allreduce = MoEAllReduce(self.mapping)

has_tp = mapping.has_tp()
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
Expand Down Expand Up @@ -1127,7 +1140,7 @@ def __init__(self,
dtype=config.torch_dtype,
config=model_config,
overridden_tp_size=self.mlp_tp_size,
reduce_output=True)
reduce_output=has_mlp_tp)

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down Expand Up @@ -1277,8 +1290,8 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):

# Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now
do_finalize = self.mapping.is_multi_node() or (
not (hidden_states.shape[0] <= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
not (self.fusion_config.POST_MOE_FUSION
and hidden_states.shape[0] <= self.moe_allreduce.max_token
and self.model_config.moe_backend == "TRTLLM"
and self.mlp.experts.has_nvfp4 and self.is_p2p_supported))

Expand Down
18 changes: 13 additions & 5 deletions tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self,
config: ModelConfig[GptOssConfig],
layer_idx: int = 0,
reduce_output: bool = True,
use_custom_cublas_mm: bool = False,
):
pretrained_config = config.pretrained_config
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
config=config,
q_scaling=1.0,
attention_chunk_size=None,
reduce_output=reduce_output,
use_custom_cublas_mm=use_custom_cublas_mm,
)

Expand Down Expand Up @@ -339,7 +341,10 @@ def __init__(
eps=pretrained_config.rms_norm_eps,
dtype=pretrained_config.torch_dtype)

self.attn = AttentionBlock(config, layer_idx, use_custom_cublas_mm)
self.attn = AttentionBlock(config,
layer_idx,
reduce_output=False,
use_custom_cublas_mm=use_custom_cublas_mm)

self.post_attention_layernorm = RMSNorm(
hidden_size=pretrained_config.hidden_size,
Expand All @@ -348,7 +353,7 @@ def __init__(

self.mlp = MLPBlock(config,
layer_idx,
reduce_results=not self.is_tp,
reduce_results=False,
use_custom_cublas_mm=use_custom_cublas_mm)

self.mapping = config.mapping
Expand All @@ -359,9 +364,12 @@ def __init__(
dtype=pretrained_config.torch_dtype)

# setup for tp
self.allreduce = AllReduce(mapping=config.mapping,
strategy=config.allreduce_strategy,
dtype=config.pretrained_config.torch_dtype)
self.allreduce = None
if self.is_tp:
self.allreduce = AllReduce(
mapping=config.mapping,
strategy=config.allreduce_strategy,
dtype=config.pretrained_config.torch_dtype)

def forward_normal(
self,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
attn_output_gate: bool = False,
use_gemma_rms_norm: bool = False,
disable_deep_gemm: bool = False,
reduce_output: bool = True,
):
config = model_config.pretrained_config
self.pretrained_config = config
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
attn_output_gate=self.attn_output_gate,
use_gemma_rms_norm=use_gemma_rms_norm,
disable_deep_gemm=disable_deep_gemm,
reduce_output=reduce_output,
)


Expand Down
30 changes: 19 additions & 11 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def __init__(
self.top_k = config.num_experts_per_tok
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.allreduce = None
if not self.enable_attention_dp and self.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)

self.gate = Qwen3Gate(
hidden_size=self.hidden_dim,
Expand Down Expand Up @@ -167,13 +169,14 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
self.mapping = model_config.mapping
self.enable_attention_dp = self.mapping.enable_attention_dp
self.self_attn = Qwen3Attention(
model_config,
layer_idx=layer_idx,
disable_deep_gemm=True,
)
self.mapping = model_config.mapping
self.enable_attention_dp = self.mapping.enable_attention_dp
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1)

self.mlp = Qwen3MoE(model_config, aux_stream_dict, layer_idx=layer_idx)

Expand All @@ -186,8 +189,10 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
dtype=config.torch_dtype)
self.layer_idx = layer_idx

self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.allreduce = None
if not self.enable_attention_dp and self.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.next_layer_layernorm: RMSNorm = None

self.is_p2p_supported = can_access_peer(model_config.mapping)
Expand All @@ -205,7 +210,9 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1
or self.enable_attention_dp)
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)
self.moe_allreduce = None
if not self.enable_attention_dp and self.mapping.tp_size > 1:
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)

def forward(
self,
Expand Down Expand Up @@ -248,8 +255,8 @@ def forward(

# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (
hidden_states.shape[0] <= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
self.fusion_config.POST_MOE_FUSION
and hidden_states.shape[0] <= self.moe_allreduce.max_token
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4 and self.is_p2p_supported)

Expand Down Expand Up @@ -327,7 +334,8 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]):
self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype)
dtype=config.pretrained_config.torch_dtype,
)
else:
self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
reduce_output=False,
use_custom_cublas_mm=getattr(model, 'use_custom_cublas_mm',
False),
)
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
disable_deep_gemm: bool = False,
attn_output_gate: Optional[bool] = None,
use_custom_cublas_mm: bool = False,
reduce_output: bool = True,
):
"""
Initialize the Attention module.
Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
reduce_output=reduce_output,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
Expand Down Expand Up @@ -687,6 +689,7 @@ def __init__(
config: Optional[ModelConfig] = None,
enable_unit_test: bool = False,
mapping_with_cp: Optional[Mapping] = None,
reduce_output: bool = True,
):
"""
Initialize the MLA module.
Expand Down Expand Up @@ -894,6 +897,7 @@ def __init__(
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
reduce_output=reduce_output,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False,
reduce_output: bool = True,
use_custom_cublas_mm: bool = False,
):
local_in_features = embedding_dim
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=tensor_parallel_mode,
gather_output=gather_output,
reduce_output=reduce_output,
use_custom_cublas_mm=use_custom_cublas_mm,
)

Expand Down Expand Up @@ -201,6 +203,7 @@ def __init__(
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False,
reduce_output: bool = True,
enable_torch_compile_for_embedding: Optional[bool] = False,
use_custom_cublas_mm: bool = False,
):
Expand All @@ -211,6 +214,7 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=tensor_parallel_mode,
gather_output=gather_output,
reduce_output=reduce_output,
use_custom_cublas_mm=use_custom_cublas_mm,
)

Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,12 @@ def __init__(
self.parallel_size = self.mapping.tp_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size

self.all_reduce = AllReduce(mapping=self.mapping,
strategy=model_config.allreduce_strategy,
dtype=self.dtype)
self.all_reduce = None
if not self.use_dp and self.mapping.tp_size > 1:
self.all_reduce = AllReduce(
mapping=self.mapping,
strategy=model_config.allreduce_strategy,
dtype=self.dtype)

# Initialize load balancer related attributes
if init_load_balancer:
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,15 +2052,15 @@ def __init__(
)
local_out_features = out_features // self.tp_size
else:
assert self.tp_mode is None, (
'unsupported tensor parallel mode: {self.tp_mode}')
assert self.tp_mode is None, f'unsupported tensor parallel mode: {self.tp_mode}'

self.in_features = local_in_features
self.out_features = local_out_features

self.all_reduce = AllReduce(mapping=self.mapping,
strategy=allreduce_strategy,
dtype=self.dtype) if reduce_output else None

self._weights_created = False
self.reduce_output = reduce_output
self.use_custom_cublas_mm = use_custom_cublas_mm
Expand Down
Loading