Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,14 +852,9 @@ def _profile_runners(
# Handle None tensors for optional inputs
shapes = self._get_input_sizes(input_tensors)
logger.warning_once(
f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.",
f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Error: {e}",
key=(custom_op, "warning_autotuning_profile_failure"),
)
(logger.info_once
if self._log_level_to_info else logger.debug_once)(
f"[Autotuner] Exception captured: {e}",
key=(custom_op, "debug_autotuning_exception"),
)

# Record the failed profiling combinations
self.stats.failed_profiling_count[custom_op].add(
Expand Down
156 changes: 90 additions & 66 deletions tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import torch

from tensorrt_llm._torch.modules.fused_moe.routing import (
ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType)
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
Expand All @@ -18,10 +20,13 @@ def prepare_dummy_topk_and_hook(
topk_ids: Optional[torch.Tensor],
hidden_states: torch.Tensor,
routing_logits: Optional[torch.Tensor],
routing_method_type: int,
base_tuning_config: TuningConfig,
top_k: int,
num_experts: int,
local_num_experts: int,
n_group: Optional[int],
topk_group: Optional[int],
routed_scaling_factor: Optional[float],
hidden_states_index: int = 2,
) -> Tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor, TuningConfig]:
"""
Expand All @@ -32,6 +37,11 @@ def prepare_dummy_topk_and_hook(
and provides a hook to dynamically adjust tensor shapes when AutoTuner tries different
token counts.

NOTE: whether or not MoE accepts routing_logits or topk_id/topk_weights, ALWAYS start with dummy
routing_logits then calculate the dummy topk_id/topk_weights according to model routing_method.
This has found to more closely mirror the actual expert distribution and thus result in better
e2e performance.

Args:
topk_weights: Pre-computed topk weights (None for normal routing scenario)
topk_ids: Pre-computed topk ids (None for normal routing scenario)
Expand All @@ -40,78 +50,87 @@ def prepare_dummy_topk_and_hook(
base_tuning_config: Base tuning config to add hook to
top_k: Number of top experts to select
num_experts: Total number of experts
local_num_experts: Number of local experts
hidden_states_index: Index of hidden_states in input_tensors list (default: 2)

Returns:
Tuple of (routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook)
"""

# NOTE: This prevents auto-tuning related code from being executed in actual runs
tuner = AutoTuner.get()
if not tuner.is_tuning_mode:
return routing_logits, topk_weights, topk_ids, base_tuning_config

if routing_logits is None:
routing_logits_for_tuner = torch.randn(hidden_states.shape[0],
num_experts,
dtype=torch.bfloat16,
device=hidden_states.device)
else:
routing_logits_for_tuner = routing_logits

# Determine if we need dummy topk tensors (attention DP scenario)
need_dummy_topk = (topk_weights is not None or topk_ids is not None)

# Get routing method
routing_cls_kwargs = {}
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_cls_kwargs.update({
'n_group':
n_group,
'topk_group':
topk_group,
'routed_scaling_factor':
routed_scaling_factor,
'is_fused':
False, # fuse_routing_kernel
'callable_e_score_correction_bias':
lambda: torch.randn(
num_experts, dtype=torch.bfloat16, device=hidden_states.device)
})
routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type](
top_k=top_k, **routing_cls_kwargs)

# Create dummy topk tensors for attention DP scenario
if need_dummy_topk:
# Attention DP: topk is pre-computed, no routing needed
dummy_topk_weights = torch.randn(hidden_states.shape[0],
top_k,
dtype=torch.bfloat16,
device=hidden_states.device)
rand_num_for_argsort = torch.rand(hidden_states.shape[0],
local_num_experts,
device=hidden_states.device)
dummy_topk_ids = rand_num_for_argsort.argsort(dim=1)[:, :top_k].to(
torch.int32)
topk_weights_for_tuner = dummy_topk_weights
topk_ids_for_tuner = dummy_topk_ids
topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
routing_logits_for_tuner)
topk_weights_for_tuner = topk_weights_for_tuner.to(torch.bfloat16)
# Don't pass routing_logits to avoid C++ warning about all three being provided
routing_logits_for_tuner = None
else:
# Normal routing: need routing_logits, topk will be computed by kernel
topk_weights_for_tuner = topk_weights
topk_ids_for_tuner = topk_ids
# Create dummy routing_logits if None (needed for shape constraints)
if routing_logits is None:
routing_logits_for_tuner = torch.randn(hidden_states.shape[0],
num_experts,
dtype=torch.bfloat16,
device=hidden_states.device)
else:
routing_logits_for_tuner = routing_logits
assert topk_weights_for_tuner is None
assert topk_ids_for_tuner is None

# Define hook to recreate dummy tensors when shape changes during profiling
def recreate_dummy_topk_if_needed(
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
"""Recreate dummy topk tensors if token count changed during profiling."""
current_num_tokens = inputs[hidden_states_index].shape[0]
# Recreate routing logits if token count changed
if inputs[0] is None or inputs[0].shape[0] != current_num_tokens:
routing_logits_for_tuner = torch.randn(
current_num_tokens,
num_experts,
dtype=torch.bfloat16,
device=inputs[hidden_states_index].device)

# Only recreate if we originally created dummies
if need_dummy_topk:
# Check if shape changed
if inputs[-1] is not None and inputs[-1].shape[
0] != current_num_tokens:
# Recreate with new shape
inputs[-2] = torch.randn(
current_num_tokens,
top_k,
dtype=torch.bfloat16,
device=inputs[hidden_states_index].device)
rand_num_for_argsort = torch.rand(
current_num_tokens,
local_num_experts,
device=inputs[hidden_states_index].device)
inputs[-1] = rand_num_for_argsort.argsort(dim=1)[:, :top_k].to(
torch.int32)

topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
routing_logits_for_tuner)
inputs[-1] = topk_ids_for_tuner
inputs[-2] = topk_weights_for_tuner.to(torch.bfloat16)
# Note: routing_logits is None in attention DP, no need to adjust
else:
# Normal routing scenario: adjust routing_logits if it was originally None
if routing_logits is None and inputs[0] is not None and inputs[
0].shape[0] != current_num_tokens:
inputs[0] = torch.randn(
current_num_tokens,
num_experts,
dtype=torch.bfloat16,
device=inputs[hidden_states_index].device)
assert inputs[0] is None

return inputs

Expand Down Expand Up @@ -375,14 +394,17 @@ def fp4_block_scale_moe_runner(
# Prepare dummy topk tensors and hook for AutoTuner profiling
routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \
prepare_dummy_topk_and_hook(
routing_method_type=routing_method_type,
topk_weights=topk_weights,
topk_ids=topk_ids,
hidden_states=hidden_states,
routing_logits=routing_logits,
base_tuning_config=FP4BlockScaleMoERunner.get_tuning_config(),
top_k=top_k,
num_experts=num_experts,
local_num_experts=local_num_experts,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
hidden_states_index=2,
)

Expand Down Expand Up @@ -410,24 +432,11 @@ def fp4_block_scale_moe_runner(
input_tensors_for_tuner,
)

# Final execution: use ACTUAL parameters (not dummies)
# topk_weights/topk_ids can be None (routing scenario) or real values (attention DP)
input_tensors = [
routing_logits, # Actual value (can be None for attention DP)
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
topk_weights, # Actual value (None for routing, real tensor for attention DP)
topk_ids, # Actual value (None for routing, real tensor for attention DP)
]

input_tensors = input_tensors_for_tuner
input_tensors[
0] = routing_logits # replace dummy routing logits with actual routing logits
input_tensors[-2] = topk_weights # replace dummy topk_weights with actual
input_tensors[-1] = topk_ids # replace dummy topk_ids with actual
return kernel_runner(input_tensors,
tactic=[-1, -1] if best_tactic == -1 else best_tactic)

Expand Down Expand Up @@ -704,14 +713,17 @@ def fp8_block_scale_moe_runner(
# Prepare dummy topk tensors and hook for AutoTuner profiling
routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \
prepare_dummy_topk_and_hook(
routing_method_type=routing_method_type,
topk_weights=topk_weights,
topk_ids=topk_ids,
hidden_states=hidden_states,
routing_logits=routing_logits,
base_tuning_config=FP8BlockScaleMoERunner.get_tuning_config(),
top_k=top_k,
num_experts=num_experts,
local_num_experts=local_num_experts,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
hidden_states_index=2,
)

Expand Down Expand Up @@ -1011,14 +1023,17 @@ def mxe4m3_mxe2m1_block_scale_moe_runner(
# Prepare dummy topk tensors and hook for AutoTuner profiling
routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \
prepare_dummy_topk_and_hook(
routing_method_type=routing_method_type,
topk_weights=topk_weights,
topk_ids=topk_ids,
hidden_states=hidden_states,
routing_logits=routing_logits,
base_tuning_config=MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(),
top_k=top_k,
num_experts=num_experts,
local_num_experts=local_num_experts,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
hidden_states_index=2,
)

Expand Down Expand Up @@ -1280,14 +1295,17 @@ def e4m3_mxe2m1_block_scale_moe_runner(
# Prepare dummy topk tensors and hook for AutoTuner profiling
routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \
prepare_dummy_topk_and_hook(
routing_method_type=routing_method_type,
topk_weights=topk_weights,
topk_ids=topk_ids,
hidden_states=hidden_states,
routing_logits=routing_logits,
base_tuning_config=E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(),
top_k=top_k,
num_experts=num_experts,
local_num_experts=local_num_experts,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
hidden_states_index=2,
)

Expand Down Expand Up @@ -1543,14 +1561,17 @@ def bf16_mxe2m1_block_scale_moe_runner(
# Prepare dummy topk tensors and hook for AutoTuner profiling
routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \
prepare_dummy_topk_and_hook(
routing_method_type=routing_method_type,
topk_weights=topk_weights,
topk_ids=topk_ids,
hidden_states=hidden_states,
routing_logits=routing_logits,
base_tuning_config=Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(),
top_k=top_k,
num_experts=num_experts,
local_num_experts=local_num_experts,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
hidden_states_index=2,
)

Expand Down Expand Up @@ -1794,14 +1815,17 @@ def fp8_fp4_block_scale_moe_runner(
# Prepare dummy topk tensors and hook for AutoTuner profiling
routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \
prepare_dummy_topk_and_hook(
routing_method_type=routing_method_type,
topk_weights=topk_weights,
topk_ids=topk_ids,
hidden_states=hidden_states,
routing_logits=routing_logits,
base_tuning_config=FP8FP4BlockScaleMoERunner.get_tuning_config(),
top_k=top_k,
num_experts=num_experts,
local_num_experts=local_num_experts,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
hidden_states_index=2,
)

Expand Down
Loading