Skip to content

Commit dad1051

Browse files
committed
remove all_rank_max_num_tokens
Signed-off-by: qgai <[email protected]>
1 parent 1c3108f commit dad1051

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

100644100755
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
394394
p.data.copy_(module_weights[n][:])
395395

396396
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
397-
) and is_sm_100f() and hasattr(
398-
module, "weight_scale"):
397+
) and is_sm_100f() and hasattr(module, "weight_scale"):
399398
weight, weight_scale = resmooth_to_fp8_e8m0(
400399
module.weight, module.weight_scale)
401400
transfromed_scale = transform_sf_into_required_layout(
@@ -787,8 +786,9 @@ def __init__(self,
787786
for key in [EventType.Main, EventType.MoeShared]
788787
}
789788

790-
def _compute_shared_expert_tp_size(self, intermediate_size: int,
791-
block_size: int) -> int:
789+
def _compute_shared_expert_tp_size(
790+
self, intermediate_size: int,
791+
block_size: int) -> tuple[int, float | None]:
792792
"""
793793
In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size.
794794
For example, when the intermediate_size is 2048 and block scaling size is 128,
@@ -800,7 +800,9 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
800800
it's 128. For NVFP4, it's 16.
801801
802802
Returns:
803-
int: The computed tp_size.
803+
tuple[int, float | None]: A tuple containing (shared_tp_size, shared_output_scale).
804+
- shared_tp_size: The computed TP size.
805+
- shared_output_scale: The output scale factor, or None if not needed.
804806
"""
805807

806808
assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size."

tensorrt_llm/_torch/models/modeling_speculative.py

100644100755
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,6 @@ def forward(
393393
hidden_states: torch.Tensor,
394394
attn_metadata: AttentionMetadata,
395395
all_rank_num_tokens: Optional[List[int]] = None,
396-
all_rank_max_num_tokens: Optional[int] = None,
397396
**kwargs,
398397
) -> Tuple[torch.Tensor, torch.Tensor]:
399398
hidden_states = self.layers(
@@ -403,7 +402,6 @@ def forward(
403402
embed_tokens=self.embed_tokens,
404403
attn_metadata=attn_metadata,
405404
all_rank_num_tokens=all_rank_num_tokens,
406-
all_rank_max_num_tokens=all_rank_max_num_tokens,
407405
)
408406

409407
return hidden_states
@@ -458,7 +456,6 @@ def forward(self,
458456
hidden_states=hidden_states,
459457
attn_metadata=attn_metadata,
460458
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
461-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
462459
**kwargs)
463460
return self.logits_processor.forward(
464461
output,

0 commit comments

Comments
 (0)