diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu index e72cc7b6815..82d28dfa2db 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu @@ -613,22 +613,8 @@ void run(Data& data, void* stream) TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, data.mNumExpertGroups); - if (data.mNumExpertGroups > 1) - { - TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, - "Routing kernel expects #experts groups %d to be <= #warps %d", data.mNumExpertGroups, MaxNumGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, - "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, - data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize, - "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", - data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); - } - else - { - TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= #warps %d", - data.mTopK, topk::MaxNumTopK); - } + // Note: Routing-specific constraints (experts per group, topK limits) are checked later + // only when routing is actually needed (data.mPtrTopKIds == nullptr) TLLM_CHECK_WITH_INFO( data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); int const numBlocks = data.mNumTokens; @@ -663,6 +649,25 @@ void run(Data& data, void* stream) int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; if (data.mPtrTopKIds == nullptr) { + // Routing needs to be executed - validate routing kernel constraints + if (data.mNumExpertGroups > 1) + { + TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, + "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, + "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, + data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize, + "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts " + "per group", + WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups); + } + else + { + TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= max topk %d", + data.mTopK, topk::MaxNumTopK); + } + int const numThreadsMain = data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; LAUNCH_ROUTING_DEEPSEEK(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index e9274e408da..82239d59c86 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -1044,13 +1044,25 @@ def _create_tensor_like(self, origin_tensor: torch.Tensor, dtype = origin_tensor.dtype device = origin_tensor.device shapes = [] - for d in dims: + for i, d in enumerate(dims): if isinstance(d, StaticDim): + assert d.val == origin_tensor.shape[i] shapes.append(d.val) else: # TODO: how to make sure the created Tensor has the min/max info assert isinstance(d, DynamicDim) shapes.append(d.opt) + + if len(dims) == 2 and isinstance(dims[0], DynamicDim) and isinstance( + dims[1], StaticDim) and (dtype == torch.int32 + or dtype == torch.int64): + # We should be carefully about int values, since they might be index like topk_index. + # We want to keep them legal, so just repeating input tensor. + repeat_times = (shapes[0] + origin_tensor.shape[0] - + 1) // origin_tensor.shape[0] + dup_tensor = origin_tensor.repeat(repeat_times, 1)[:shapes[0]] + return dup_tensor + # TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE # One solution is to manituplate the tensor content to make it more like the real data # during the tuning process. This can by controlled in the preparation phase by the runner. diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 11d6f86670b..27de4068bbc 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import lru_cache from typing import List, Optional, Tuple, Union @@ -13,6 +13,115 @@ OptimizationProfile, TunableRunner, TuningConfig) +def prepare_dummy_topk_and_hook( + topk_weights: Optional[torch.Tensor], + topk_ids: Optional[torch.Tensor], + hidden_states: torch.Tensor, + routing_logits: Optional[torch.Tensor], + base_tuning_config: TuningConfig, + top_k: int, + num_experts: int, + local_num_experts: int, + hidden_states_index: int = 2, +) -> Tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor, TuningConfig]: + """ + Prepare dummy topk tensors and input pre-hook for AutoTuner profiling. + + This function handles attention DP scenarios where topk_weights/topk_ids are pre-computed. + It creates dummy tensors to prevent the routing kernel from being called during profiling, + and provides a hook to dynamically adjust tensor shapes when AutoTuner tries different + token counts. + + Args: + topk_weights: Pre-computed topk weights (None for normal routing scenario) + topk_ids: Pre-computed topk ids (None for normal routing scenario) + hidden_states: Hidden states tensor (used for shape and device) + routing_logits: Routing logits (None if not provided) + base_tuning_config: Base tuning config to add hook to + top_k: Number of top experts to select + num_experts: Total number of experts + local_num_experts: Number of local experts + hidden_states_index: Index of hidden_states in input_tensors list (default: 2) + + Returns: + Tuple of (routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook) + """ + # Determine if we need dummy topk tensors (attention DP scenario) + need_dummy_topk = (topk_weights is not None or topk_ids is not None) + + # Create dummy topk tensors for attention DP scenario + if need_dummy_topk: + # Attention DP: topk is pre-computed, no routing needed + dummy_topk_weights = torch.randn(hidden_states.shape[0], + top_k, + dtype=torch.bfloat16, + device=hidden_states.device) + rand_num_for_argsort = torch.rand(hidden_states.shape[0], + local_num_experts, + device=hidden_states.device) + dummy_topk_ids = rand_num_for_argsort.argsort(dim=1)[:, :top_k].to( + torch.int32) + topk_weights_for_tuner = dummy_topk_weights + topk_ids_for_tuner = dummy_topk_ids + # Don't pass routing_logits to avoid C++ warning about all three being provided + routing_logits_for_tuner = None + else: + # Normal routing: need routing_logits, topk will be computed by kernel + topk_weights_for_tuner = topk_weights + topk_ids_for_tuner = topk_ids + # Create dummy routing_logits if None (needed for shape constraints) + if routing_logits is None: + routing_logits_for_tuner = torch.randn(hidden_states.shape[0], + num_experts, + dtype=torch.bfloat16, + device=hidden_states.device) + else: + routing_logits_for_tuner = routing_logits + + # Define hook to recreate dummy tensors when shape changes during profiling + def recreate_dummy_topk_if_needed( + inputs: List[torch.Tensor]) -> List[torch.Tensor]: + """Recreate dummy topk tensors if token count changed during profiling.""" + current_num_tokens = inputs[hidden_states_index].shape[0] + + # Only recreate if we originally created dummies + if need_dummy_topk: + # Check if shape changed + if inputs[-1] is not None and inputs[-1].shape[ + 0] != current_num_tokens: + # Recreate with new shape + inputs[-2] = torch.randn( + current_num_tokens, + top_k, + dtype=torch.bfloat16, + device=inputs[hidden_states_index].device) + rand_num_for_argsort = torch.rand( + current_num_tokens, + local_num_experts, + device=inputs[hidden_states_index].device) + inputs[-1] = rand_num_for_argsort.argsort(dim=1)[:, :top_k].to( + torch.int32) + + # Note: routing_logits is None in attention DP, no need to adjust + else: + # Normal routing scenario: adjust routing_logits if it was originally None + if routing_logits is None and inputs[0] is not None and inputs[ + 0].shape[0] != current_num_tokens: + inputs[0] = torch.randn( + current_num_tokens, + num_experts, + dtype=torch.bfloat16, + device=inputs[hidden_states_index].device) + + return inputs + + # Add inputs_pre_hook to handle shape changes during profiling + tuning_config_with_hook = replace( + base_tuning_config, inputs_pre_hook=recreate_dummy_topk_if_needed) + + return routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook + + def calculate_tile_tokens_dim( num_tokens: int, num_experts: int, @@ -203,14 +312,23 @@ def _constrain_fp4_linear_layout(shapes: Tuple[torch.Size]) -> int: ROUTER_LOGITS_IDX = 0 CONSTRAINED_RL_DIM = 0 + TOPK_WEIGHTS_IDX = 11 + TOPK_IDS_IDX = 12 constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, CONSTRAINED_RL_DIM, _constrain_to_num_tokens) + constraint_topk_weights = ConstraintSpec(TOPK_WEIGHTS_IDX, + CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) + constraint_topk_ids = ConstraintSpec(TOPK_IDS_IDX, CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) constraint_specs_tuple = ( constraint_hidden_states_scale, constraint_routing_logits, + constraint_topk_weights, + constraint_topk_ids, ) return constraint_specs_tuple @@ -268,15 +386,21 @@ def fp4_block_scale_moe_runner( do_finalize, ) - # Use dummy routing logits for autotuner - if routing_logits is None: - routing_logits_for_tuner = torch.randn(hidden_states.shape[0], - num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) - else: - routing_logits_for_tuner = routing_logits + # Prepare dummy topk tensors and hook for AutoTuner profiling + routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ + prepare_dummy_topk_and_hook( + topk_weights=topk_weights, + topk_ids=topk_ids, + hidden_states=hidden_states, + routing_logits=routing_logits, + base_tuning_config=FP4BlockScaleMoERunner.get_tuning_config(), + top_k=top_k, + num_experts=num_experts, + local_num_experts=local_num_experts, + hidden_states_index=2, + ) + # Build input_tensors_for_tuner input_tensors_for_tuner = [ routing_logits_for_tuner, routing_bias, @@ -289,18 +413,35 @@ def fp4_block_scale_moe_runner( output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, + topk_weights_for_tuner, # Dummy if need_dummy_topk, else actual value + topk_ids_for_tuner, # Dummy if need_dummy_topk, else actual value ] kernel_runner, best_tactic = tuner.choose_one( "trtllm::fp4_block_scale_moe_runner", [kernel_runner], - FP4BlockScaleMoERunner.get_tuning_config(), + tuning_config_with_hook, input_tensors_for_tuner, ) - input_tensors = input_tensors_for_tuner + [topk_weights, topk_ids] - input_tensors[ - 0] = routing_logits # replace dummy routing logits with actual routing logits + # Final execution: use ACTUAL parameters (not dummies) + # topk_weights/topk_ids can be None (routing scenario) or real values (attention DP) + input_tensors = [ + routing_logits, # Actual value (can be None for attention DP) + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + topk_weights, # Actual value (None for routing, real tensor for attention DP) + topk_ids, # Actual value (None for routing, real tensor for attention DP) + ] + return kernel_runner(input_tensors, tactic=[-1, -1] if best_tactic == -1 else best_tactic) @@ -516,14 +657,23 @@ def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: ROUTER_LOGITS_IDX = 0 CONSTRAINED_RL_DIM = 0 + TOPK_WEIGHTS_IDX = 8 + TOPK_IDS_IDX = 9 constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, CONSTRAINED_RL_DIM, _constrain_to_num_tokens) + constraint_topk_weights = ConstraintSpec(TOPK_WEIGHTS_IDX, + CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) + constraint_topk_ids = ConstraintSpec(TOPK_IDS_IDX, CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) constraint_specs_tuple = ( constraint_hidden_states_scale, constraint_routing_logits, + constraint_topk_weights, + constraint_topk_ids, ) return constraint_specs_tuple @@ -578,14 +728,19 @@ def fp8_block_scale_moe_runner( ) ] - # Use dummy routing logits for autotuner - if routing_logits is None: - routing_logits_for_tuner = torch.randn(hidden_states.shape[0], - num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) - else: - routing_logits_for_tuner = routing_logits + # Prepare dummy topk tensors and hook for AutoTuner profiling + routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ + prepare_dummy_topk_and_hook( + topk_weights=topk_weights, + topk_ids=topk_ids, + hidden_states=hidden_states, + routing_logits=routing_logits, + base_tuning_config=FP8BlockScaleMoERunner.get_tuning_config(), + top_k=top_k, + num_experts=num_experts, + local_num_experts=local_num_experts, + hidden_states_index=2, + ) input_tensors_for_tuner = [ routing_logits_for_tuner, @@ -596,18 +751,22 @@ def fp8_block_scale_moe_runner( gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, + topk_weights_for_tuner, + topk_ids_for_tuner, ] kernel_runner, best_tactic = tuner.choose_one( "trtllm::fp8_block_scale_moe_runner", kernel_runners, - FP8BlockScaleMoERunner.get_tuning_config(), + tuning_config_with_hook, input_tensors_for_tuner, ) - input_tensors = input_tensors_for_tuner + [topk_weights, topk_ids] + input_tensors = input_tensors_for_tuner input_tensors[ 0] = routing_logits # replace dummy routing logits with actual routing logits + input_tensors[-2] = topk_weights # replace dummy topk_weights with actual + input_tensors[-1] = topk_ids # replace dummy topk_ids with actual return kernel_runner(input_tensors, tactic=[-1, -1] if best_tactic == -1 else best_tactic) @@ -802,13 +961,21 @@ def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: ROUTER_LOGITS_IDX = 0 CONSTRAINED_RL_DIM = 0 + TOPK_WEIGHTS_IDX = 13 + TOPK_IDS_IDX = 14 constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, CONSTRAINED_RL_DIM, _constrain_routing_logits) + constraint_topk_weights = ConstraintSpec(TOPK_WEIGHTS_IDX, + CONSTRAINED_RL_DIM, + _constrain_routing_logits) + constraint_topk_ids = ConstraintSpec(TOPK_IDS_IDX, CONSTRAINED_RL_DIM, + _constrain_routing_logits) constraint_specs_tuple = (constraint_hidden_states_scale, - constraint_routing_logits) + constraint_routing_logits, + constraint_topk_weights, constraint_topk_ids) return constraint_specs_tuple @@ -871,14 +1038,19 @@ def mxe4m3_mxe2m1_block_scale_moe_runner( act_type, ) - # Use dummy routing logits for autotuner - if routing_logits is None: - routing_logits_for_tuner = torch.randn(hidden_states.shape[0], - num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) - else: - routing_logits_for_tuner = routing_logits + # Prepare dummy topk tensors and hook for AutoTuner profiling + routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ + prepare_dummy_topk_and_hook( + topk_weights=topk_weights, + topk_ids=topk_ids, + hidden_states=hidden_states, + routing_logits=routing_logits, + base_tuning_config=MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(), + top_k=top_k, + num_experts=num_experts, + local_num_experts=local_num_experts, + hidden_states_index=2, + ) input_tensors_for_tuner = [ routing_logits_for_tuner, @@ -894,18 +1066,22 @@ def mxe4m3_mxe2m1_block_scale_moe_runner( gemm2_weights, gemm2_weights_scale, gemm2_bias, + topk_weights_for_tuner, + topk_ids_for_tuner, ] kernel_runner, best_tactic = tuner.choose_one( "trtllm::mxe4m3_mxe2m1_block_scale_moe_runner", [kernel_runner], - MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(), + tuning_config_with_hook, input_tensors_for_tuner, ) - input_tensors = input_tensors_for_tuner + [topk_weights, topk_ids] + input_tensors = input_tensors_for_tuner input_tensors[ 0] = routing_logits # replace dummy routing logits with actual routing logits + input_tensors[-2] = topk_weights # replace dummy topk_weights with actual + input_tensors[-1] = topk_ids # replace dummy topk_ids with actual return kernel_runner(input_tensors, tactic=[-1, -1] if best_tactic == -1 else best_tactic, output=output) @@ -1054,12 +1230,20 @@ def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: ROUTER_LOGITS_IDX = 0 CONSTRAINED_RL_DIM = 0 + TOPK_WEIGHTS_IDX = 15 + TOPK_IDS_IDX = 16 constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, CONSTRAINED_RL_DIM, _constrain_routing_logits) + constraint_topk_weights = ConstraintSpec(TOPK_WEIGHTS_IDX, + CONSTRAINED_RL_DIM, + _constrain_routing_logits) + constraint_topk_ids = ConstraintSpec(TOPK_IDS_IDX, CONSTRAINED_RL_DIM, + _constrain_routing_logits) - constraint_specs_tuple = (constraint_routing_logits, ) + constraint_specs_tuple = (constraint_routing_logits, + constraint_topk_weights, constraint_topk_ids) return constraint_specs_tuple @@ -1121,14 +1305,19 @@ def e4m3_mxe2m1_block_scale_moe_runner( act_type, ) - # Use dummy routing logits for autotuner - if routing_logits is None: - routing_logits_for_tuner = torch.randn(hidden_states.shape[0], - num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) - else: - routing_logits_for_tuner = routing_logits + # Prepare dummy topk tensors and hook for AutoTuner profiling + routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ + prepare_dummy_topk_and_hook( + topk_weights=topk_weights, + topk_ids=topk_ids, + hidden_states=hidden_states, + routing_logits=routing_logits, + base_tuning_config=E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(), + top_k=top_k, + num_experts=num_experts, + local_num_experts=local_num_experts, + hidden_states_index=2, + ) input_tensors_for_tuner = [ routing_logits_for_tuner, @@ -1146,19 +1335,23 @@ def e4m3_mxe2m1_block_scale_moe_runner( output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, + topk_weights_for_tuner, + topk_ids_for_tuner, ] kernel_runner, best_tactic = tuner.choose_one( "trtllm::e4m3_mxe2m1_block_scale_moe_runner", [kernel_runner], - E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(), + tuning_config_with_hook, input_tensors_for_tuner, ) - # Add topk tensors for final execution - input_tensors = input_tensors_for_tuner + [topk_weights, topk_ids] + # Replace dummy tensors with actual ones for final execution + input_tensors = input_tensors_for_tuner input_tensors[ 0] = routing_logits # replace dummy routing logits with actual routing logits + input_tensors[-2] = topk_weights # replace dummy topk_weights with actual + input_tensors[-1] = topk_ids # replace dummy topk_ids with actual return kernel_runner(input_tensors, tactic=[-1, -1] if best_tactic == -1 else best_tactic) @@ -1301,12 +1494,20 @@ def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: ROUTER_LOGITS_IDX = 0 CONSTRAINED_DIM = 0 + TOPK_WEIGHTS_IDX = 12 + TOPK_IDS_IDX = 13 constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, CONSTRAINED_DIM, _constrain_routing_logits) + constraint_topk_weights = ConstraintSpec(TOPK_WEIGHTS_IDX, + CONSTRAINED_DIM, + _constrain_routing_logits) + constraint_topk_ids = ConstraintSpec(TOPK_IDS_IDX, CONSTRAINED_DIM, + _constrain_routing_logits) - constraint_specs_tuple = (constraint_routing_logits, ) + constraint_specs_tuple = (constraint_routing_logits, + constraint_topk_weights, constraint_topk_ids) return constraint_specs_tuple @@ -1365,14 +1566,19 @@ def bf16_mxe2m1_block_scale_moe_runner( act_type, ) - # Use dummy routing logits for autotuner - if routing_logits is None: - routing_logits_for_tuner = torch.randn(hidden_states.shape[0], - num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) - else: - routing_logits_for_tuner = routing_logits + # Prepare dummy topk tensors and hook for AutoTuner profiling + routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ + prepare_dummy_topk_and_hook( + topk_weights=topk_weights, + topk_ids=topk_ids, + hidden_states=hidden_states, + routing_logits=routing_logits, + base_tuning_config=Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(), + top_k=top_k, + num_experts=num_experts, + local_num_experts=local_num_experts, + hidden_states_index=2, + ) input_tensors_for_tuner = [ routing_logits_for_tuner, @@ -1387,20 +1593,24 @@ def bf16_mxe2m1_block_scale_moe_runner( gemm2_weights, gemm2_weights_scale, gemm2_bias, + topk_weights_for_tuner, + topk_ids_for_tuner, ] # Choose best tactic using autotuner kernel_runner, best_tactic = tuner.choose_one( "trtllm::bf16_mxe2m1_block_scale_moe_runner", [kernel_runner], - Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(), + tuning_config_with_hook, input_tensors_for_tuner, ) - # Add topk tensors for final execution - input_tensors = input_tensors_for_tuner + [topk_weights, topk_ids] + # Replace dummy tensors with actual ones for final execution + input_tensors = input_tensors_for_tuner input_tensors[ 0] = routing_logits # replace dummy routing logits with actual routing logits + input_tensors[-2] = topk_weights # replace dummy topk_weights with actual + input_tensors[-1] = topk_ids # replace dummy topk_ids with actual return kernel_runner(input_tensors, tactic=[-1, -1] if best_tactic == -1 else best_tactic) @@ -1548,12 +1758,20 @@ def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: ROUTER_LOGITS_IDX = 0 CONSTRAINED_RL_DIM = 0 + TOPK_WEIGHTS_IDX = 10 + TOPK_IDS_IDX = 11 constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, CONSTRAINED_RL_DIM, _constrain_to_num_tokens) + constraint_topk_weights = ConstraintSpec(TOPK_WEIGHTS_IDX, + CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) + constraint_topk_ids = ConstraintSpec(TOPK_IDS_IDX, CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) - constraint_specs_tuple = (constraint_routing_logits, ) + constraint_specs_tuple = (constraint_routing_logits, + constraint_topk_weights, constraint_topk_ids) return constraint_specs_tuple @@ -1612,14 +1830,19 @@ def fp8_fp4_block_scale_moe_runner( act_type, ) - # Use dummy routing logits for autotuner - if routing_logits is None: - routing_logits_for_tuner = torch.randn(hidden_states.shape[0], - num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) - else: - routing_logits_for_tuner = routing_logits + # Prepare dummy topk tensors and hook for AutoTuner profiling + routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ + prepare_dummy_topk_and_hook( + topk_weights=topk_weights, + topk_ids=topk_ids, + hidden_states=hidden_states, + routing_logits=routing_logits, + base_tuning_config=FP8FP4BlockScaleMoERunner.get_tuning_config(), + top_k=top_k, + num_experts=num_experts, + local_num_experts=local_num_experts, + hidden_states_index=2, + ) input_tensors_for_tuner = [ routing_logits_for_tuner, @@ -1632,18 +1855,23 @@ def fp8_fp4_block_scale_moe_runner( output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, + topk_weights_for_tuner, + topk_ids_for_tuner, ] kernel_runner, best_tactic = tuner.choose_one( "trtllm::fp8_fp4_block_scale_moe_runner", [kernel_runner], - FP8FP4BlockScaleMoERunner.get_tuning_config(), + tuning_config_with_hook, input_tensors_for_tuner, ) - input_tensors = input_tensors_for_tuner + [topk_weights, topk_ids] - # replace dummy routing logits with actual routing logits - input_tensors[0] = routing_logits + # Replace dummy tensors with actual ones for final execution + input_tensors = input_tensors_for_tuner + input_tensors[ + 0] = routing_logits # replace dummy routing logits with actual routing logits + input_tensors[-2] = topk_weights # replace dummy topk_weights with actual + input_tensors[-1] = topk_ids # replace dummy topk_ids with actual return kernel_runner(input_tensors, tactic=[-1, -1] if best_tactic == -1 else best_tactic) diff --git a/tests/unittest/_torch/thop/parallel/test_moe.py b/tests/unittest/_torch/thop/parallel/test_moe.py index 3daa157ecdf..2b57234bb8b 100644 --- a/tests/unittest/_torch/thop/parallel/test_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_moe.py @@ -1178,6 +1178,36 @@ def test_no_autotune_fp8_fp4(self, num_tokens, hidden_size, use_autotune=False, use_topk_as_input=use_topk_as_input) + @pytest.mark.parametrize("num_tokens", [1, 256, 1024]) + @pytest.mark.parametrize("hidden_size", [1024]) + @pytest.mark.parametrize("intermediate_size", [1024, 768]) + @pytest.mark.parametrize( + "routing_info", + [ + pytest.param( + { + "num_experts": 288, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3 + }, + id="RoutingDSv3"), + ], + ) + def test_online_eplb288_topk_input(self, num_tokens, hidden_size, + intermediate_size, routing_info): + # although we don't need to run router with num_expert 288, but we do need MoE run with num_slots 288 for EPLB with redundant experts. + self.run_moe_fp4_test(num_tokens, + hidden_size, + intermediate_size, + routing_info, + use_autotune=True, + use_topk_as_input=True) + def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, intermediate_size: int, routing_info: dict, use_autotune: bool, use_topk_as_input: bool) -> None: @@ -1209,9 +1239,9 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, assert num_experts % 4 == 0 if use_topk_as_input: - if routing_method_type != RoutingMethodType.DeepSeekV3 or num_tokens != 150 or use_autotune: + if routing_method_type != RoutingMethodType.DeepSeekV3: pytest.skip( - "use_topk_as_input is tested only with routing_method_type=DeepSeekV3 and num_tokens=150 and use_autotune=False" + "use_topk_as_input is tested only with routing_method_type=DeepSeekV3" ) if are_groups_valid(top_k_groups, n_groups): @@ -1469,9 +1499,9 @@ def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int, assert top_k < (top_k_groups * num_experts / n_groups) if use_topk_as_input: - if routing_method_type != RoutingMethodType.DeepSeekV3 or num_tokens != 150 or use_autotune: + if routing_method_type != RoutingMethodType.DeepSeekV3: pytest.skip( - "use_topk_as_input is tested only with routing_method_type=DeepSeekV3 and num_tokens=150 and use_autotune=False" + "use_topk_as_input is tested only with routing_method_type=DeepSeekV3" ) if routing_method_type == RoutingMethodType.DeepSeekV3: