From 9c5783dcca267fda8e29a6d2fc60e74bf0559e0b Mon Sep 17 00:00:00 2001 From: Julien Lin Date: Wed, 5 Nov 2025 06:47:11 +0000 Subject: [PATCH] use maximum number of batched tokens to autotune Signed-off-by: Julien Lin --- vllm/model_executor/layers/fused_moe/trtllm_moe.py | 6 +++--- vllm/model_executor/layers/quantization/mxfp4.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 132d35e65aba..b641c44829b0 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -21,14 +21,14 @@ def __init__( gemm1_alpha, gemm1_beta, gemm1_clamp_limit, - max_capture_size, + tune_max_num_tokens, ): super().__init__(quant_config) self.moe = moe self.gemm1_alpha = gemm1_alpha self.gemm1_beta = gemm1_beta self.gemm1_clamp_limit = gemm1_clamp_limit - self.max_capture_size = max_capture_size + self.tune_max_num_tokens = tune_max_num_tokens @property def activation_formats( @@ -127,7 +127,7 @@ def apply( "routing_method_type": 1, "do_finalize": True, "output": output, - "tune_max_num_tokens": max(self.max_capture_size, 1), + "tune_max_num_tokens": self.tune_max_num_tokens, } from flashinfer import trtllm_fp4_block_scale_routed_moe diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index b95d1a6b3a1f..b033b9b44de6 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -217,8 +217,10 @@ def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN - self.max_capture_size = ( - get_current_vllm_config().compilation_config.max_cudagraph_capture_size + # Be conservative and tune for the most extreme inbalance for MoE, + # i.e., one expert receives all the tokens. + self.tune_max_num_tokens = ( + get_current_vllm_config().scheduler_config.max_num_batched_tokens ) assert self.mxfp4_backend != Mxfp4Backend.NONE, ( @@ -842,7 +844,7 @@ def select_gemm_impl( "gemm1_beta": layer.gemm1_beta, "gemm1_clamp_limit": layer.gemm1_clamp_limit, # TODO(bnell): part of quant_config - "max_capture_size": self.max_capture_size, + "tune_max_num_tokens": self.tune_max_num_tokens, } return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) elif self.mxfp4_backend == Mxfp4Backend.MARLIN: @@ -978,7 +980,7 @@ def apply( None, 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize - tune_max_num_tokens=max(self.max_capture_size, 1), + tune_max_num_tokens=self.tune_max_num_tokens, )[0] return trtllm_gen_output elif ( @@ -1053,7 +1055,7 @@ def apply( tp_rank=self.moe.tp_rank, ep_size=self.moe.ep_size, ep_rank=self.moe.ep_rank, - tune_max_num_tokens=max(self.max_capture_size, 1), + tune_max_num_tokens=self.tune_max_num_tokens, **extra_kwargs, )