From 3359c4dd9fc5f1849da06e9021b13fc93a477670 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 31 Oct 2025 20:11:07 -0400 Subject: [PATCH 1/3] restore autotuning functionality Signed-off-by: Varun Sundar Rabindranath --- tests/quantization/test_blackwell_moe.py | 16 ++--------- .../layers/fused_moe/trtllm_moe.py | 2 +- .../layers/quantization/mxfp4.py | 4 +-- vllm/model_executor/warmup/kernel_warmup.py | 27 +------------------ 4 files changed, 6 insertions(+), 43 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 3cae6f46147b..8dd4551ff4b9 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -172,21 +172,9 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT) -def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") - monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") - can_initialize( - "openai/gpt-oss-20b", - extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], - hf_overrides=HF_OVERRIDE_TEXT, - ) - - -def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") - monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") +def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch): can_initialize( "openai/gpt-oss-20b", - extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--enforce-eager"], ) diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index e305483eb17d..c35c18be4bbb 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -127,7 +127,7 @@ def apply( "routing_method_type": 1, "do_finalize": True, "output": output, - "tune_max_num_tokens": self.max_capture_size, + "tune_max_num_tokens": max(self.max_capture_size, 1), } from flashinfer import trtllm_fp4_block_scale_routed_moe diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 597ee1b6bafe..bf34ec0f3899 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1047,7 +1047,7 @@ def apply( None, 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize - tune_max_num_tokens=self.max_capture_size, + tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output elif ( @@ -1122,7 +1122,7 @@ def apply( tp_rank=self.moe.tp_rank, ep_size=self.moe.ep_size, ep_rank=self.moe.ep_rank, - tune_max_num_tokens=self.max_capture_size, + tune_max_num_tokens=max(self.max_capture_size, 1), **extra_kwargs, ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index ffa3bc8f021e..28792338f036 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -11,7 +11,6 @@ import torch import vllm.envs as envs -from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform @@ -25,26 +24,6 @@ logger = init_logger(__name__) -def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool: - """ - Record known issues with vllm + flashinfer autotune here. Return True if - and only if flashinfer autotune will run through without issues. - """ - is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or ( - vllm_config.parallel_config.tensor_parallel_size > 1 - ) - is_fi_mxfp4_backend = ( - envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS - ) or ( - current_platform.is_cuda() and current_platform.is_device_capability(100) - ) # on >=sm100, default mxfp4 backend is flashinfer - is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE - - return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager) - - def kernel_warmup(worker: "Worker"): # Deep GEMM warmup do_deep_gemm_warmup = ( @@ -58,11 +37,7 @@ def kernel_warmup(worker: "Worker"): deep_gemm_warmup(model, max_tokens) # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs - if ( - has_flashinfer() - and current_platform.has_device_capability(90) - and flashinfer_autotune_supported(worker.vllm_config) - ): + if has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) # FlashInfer attention warmup From de805928bbc947d31ad35bc52d1ac36fac16d1a4 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 1 Nov 2025 12:22:27 -0400 Subject: [PATCH 2/3] fix dp Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/trtllm_moe.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index c35c18be4bbb..6baa1cbded46 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -132,5 +132,11 @@ def apply( from flashinfer import trtllm_fp4_block_scale_routed_moe - trtllm_fp4_block_scale_routed_moe(**kwargs) + from vllm.utils.flashinfer import autotune + + with autotune(False): + # Skipping flashinfer auto-tune for this function as the autotuner + # throws a "Cannot pack tensors on meta" error. + trtllm_fp4_block_scale_routed_moe(**kwargs) + return output From 0fec4d3c71152ab3d344b7d9c1298ff14db73167 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 1 Nov 2025 13:08:26 -0400 Subject: [PATCH 3/3] Add fi issue Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/trtllm_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 6baa1cbded46..132d35e65aba 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -135,8 +135,9 @@ def apply( from vllm.utils.flashinfer import autotune with autotune(False): - # Skipping flashinfer auto-tune for this function as the autotuner - # throws a "Cannot pack tensors on meta" error. + # Enable autotune when, + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is + # resolved. trtllm_fp4_block_scale_routed_moe(**kwargs) return output