Skip to content

[Bugfix] Add autotuning guard to all unprotected FlashInfer MoE kernels#37091

Open
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-36999
Open

[Bugfix] Add autotuning guard to all unprotected FlashInfer MoE kernels#37091
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-36999

Conversation

@haosdent
Copy link
Contributor

@haosdent haosdent commented Mar 15, 2026

Purpose

Fixes #36999 - CPU weight offloading produces garbage output when the FlashInfer autotuner is enabled on Blackwell GPUs.

Root Cause

During FlashInfer autotuning, kernel_warmup.py:flashinfer_autotune() sets _is_fi_autotuning = True and runs a model dummy pass. This triggers MoE kernel calls. Certain FlashInfer MoE kernels are incompatible with FlashInfer's autotuning mechanism (upstream FlashInfer bug flashinfer-ai/flashinfer#2023). The incompatible kernel call corrupts CUDA state, producing garbage output or crashes during subsequent inference.

The modular kernel implementations (TrtLlmNvFp4ExpertsModular in #32564, TrtLlmFp8ExpertsModular in #36307) already have the _is_fi_autotuning guard. However, the monolithic counterparts and other FlashInfer MoE paths were missing protection.

For single-GPU Kimi K2.5-NVFP4 (the original reporter's config), the kernel oracle selects TrtLlmNvFp4ExpertsMonolithic (monolithic preferred over modular when no EP/EPLB), which calls flashinfer.fused_moe.trtllm_fp4_block_scale_moe() during autotuning without protection.

Fix

Wrap all unprotected FlashInfer MoE kernel calls with with autotune(False):, following the existing pattern in trtllm_moe.py:178. This tells FlashInfer not to autotune these specific kernel calls while still allowing them to execute normally — avoiding shape/dtype mismatches that occurred with the previous dummy-return approach.

File Kernel wrapped
experts/trtllm_nvfp4_moe.py trtllm_fp4_block_scale_moe
experts/trtllm_fp8_moe.py trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe
flashinfer_trtllm_moe.py flashinfer_trtllm_bf16_moe
flashinfer_cutlass_moe.py flashinfer_cutlass_fused_moe
flashinfer_cutedsl_moe.py flashinfer_cutedsl_moe_masked
quantization/utils/flashinfer_mxint4_moe.py trtllm_mxint4_block_scale_moe

Test Plan

  • E2E reproduction on Blackwell GPU: Verified on NVIDIA GB10 (SM121) with nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4 (FLASHINFER_CUTLASS backend)
  • Verify existing MoE tests pass: pytest tests/kernels/moe/ -v -s
  • Verify linting: pre-commit run --all-files

Test Result

E2E Verification (NVIDIA GB10, SM121)

WITHOUT the fix (guard removed from flashinfer_cutlass_moe.py:FlashInferExperts.apply()):

  • Served nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4 (NVFP4 MoE, FLASHINFER_CUTLASS backend)
  • FlashInfer autotuner called the CUTLASS MoE kernel during warmup
  • Result: CUDA error: an illegal instruction was encountered — server crashed
  • Stack trace shows the crash in dispatchMoeGemmSelectTileShapeTmaWarpSpecialized during FlashInfer autotuning

WITH the fix (autotune(False) wrapping the kernel call):

  • Same model, same configuration
  • Server starts successfully, kernel runs normally with autotuning disabled
  • Result: Correct output"The capital of France is Paris." for the prompt "What is the capital of France?"

@mergify mergify bot added nvidia bug Something isn't working labels Mar 15, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds autotuning guards to several FlashInfer MoE kernels to prevent state corruption during the autotuning dummy pass. The fix is applied consistently across all identified unprotected kernel paths. My main feedback is regarding the code duplication introduced by adding the same guard logic in six different files. I've suggested refactoring this logic into a centralized helper function to improve maintainability and adhere to the DRY principle. This would make the codebase more robust to future changes in the autotuning mechanism.

Comment on lines +411 to +416
# trtllm_fp8 monolithic kernels do not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return torch.zeros_like(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This autotuning guard logic is duplicated in 6 different files in this pull request. To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, this logic could be centralized into a helper function.

For example, you could add a helper in vllm/utils/flashinfer.py:

def skip_if_autotuning(output_tensor_for_shape=None):
    """If autotuning, returns (True, dummy_output). Otherwise (False, None)."""
    if _is_fi_autotuning:
        if output_tensor_for_shape is None:
            return True, None
        return True, torch.zeros_like(output_tensor_for_shape)
    return False, None

Then, this apply method could be simplified to:

import vllm.utils.flashinfer as fi_utils

should_skip, retval = fi_utils.skip_if_autotuning(hidden_states)
if should_skip:
    return retval

This would make the code cleaner and easier to manage if the autotuning check logic changes in the future. This suggestion applies to all files changed in this PR.

Comment on lines +274 to +279
# trtllm_fp4_block_scale_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return torch.zeros_like(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to other files in this PR, this autotuning guard introduces code duplication. To enhance maintainability, this logic could be refactored into a shared helper function within vllm.utils.flashinfer. Centralizing the check for _is_fi_autotuning and the creation of a dummy return value would make the codebase more robust to future changes in the autotuning mechanism.

Comment on lines +143 to +148
# flashinfer CuteDSL MoE does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This autotuning guard is repeated across multiple locations. To avoid this duplication, consider creating a single helper function in vllm.utils.flashinfer. This function could handle the check and return logic. For this specific case, since the method returns None, a centralized helper could be designed to handle this gracefully (e.g., by being called with an argument indicating no return value is needed). This would centralize the logic and improve code quality.

Comment on lines +257 to +262
# flashinfer cutlass MoE does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to guard against autotuning is duplicated here and in other files. A refactoring to a common helper function in vllm.utils.flashinfer would be beneficial for long-term maintenance. This would ensure that any future modifications to the autotuning guard only need to be made in one place. Since this apply method has no return value, the helper could be designed to handle this case gracefully.

Comment on lines +97 to +102
# flashinfer bf16 monolithic MoE does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return torch.zeros_like(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change introduces the same autotuning guard logic seen in other files in this PR. To follow the DRY principle, it would be better to abstract this logic into a reusable helper function located in vllm.utils.flashinfer. This would consolidate the autotuning check and make the overall implementation cleaner and more maintainable.

Comment on lines +194 to +199
# flashinfer mxint4 monolithic MoE does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return torch.zeros_like(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of this autotuning guard results in duplicated code across several files. I recommend refactoring this logic into a centralized helper function in vllm.utils.flashinfer. This function would encapsulate the check for _is_fi_autotuning and the logic for returning a correctly shaped zero tensor (in this case, based on the x tensor). This would improve code maintainability.

@haosdent haosdent changed the title [WIP][Bugfix] Add autotuning guard to all unprotected FlashInfer MoE kernels [Bugfix] Add autotuning guard to all unprotected FlashInfer MoE kernels Mar 15, 2026
@haosdent haosdent marked this pull request as ready for review March 15, 2026 10:23
@wzhao18
Copy link
Contributor

wzhao18 commented Mar 16, 2026

Hi @haosdent, thanks for the fix. I got the following error when trying out the fix on the offloading usecase. Could you take a look?

 [core.py:1088]   File "/vllm/vllm/v1/engine/core.py", line 1078, in run_engine_core
 [core.py:1088]     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
 [core.py:1088]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1088]     return func(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/v1/engine/core.py", line 830, in __init__
 [core.py:1088]     super().__init__(
 [core.py:1088]   File "/vllm/vllm/v1/engine/core.py", line 120, in __init__
 [core.py:1088]     kv_cache_config = self._initialize_kv_caches(vllm_config)
 [core.py:1088]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1088]     return func(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/v1/engine/core.py", line 276, in _initialize_kv_caches
 [core.py:1088]     self.model_executor.initialize_from_config(kv_cache_configs)
 [core.py:1088]   File "/vllm/vllm/v1/executor/abstract.py", line 118, in initialize_from_config
 [core.py:1088]     compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
 [core.py:1088]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/v1/executor/uniproc_executor.py", line 78, in collective_rpc
 [core.py:1088]     result = run_method(self.driver_worker, method, args, kwargs)
 [core.py:1088]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/v1/serial_utils.py", line 459, in run_method
 [core.py:1088]     return func(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1088]     return func(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/v1/worker/gpu_worker.py", line 601, in compile_or_warm_up_model
 [core.py:1088]     kernel_warmup(self)
 [core.py:1088]   File "/vllm/vllm/model_executor/warmup/kernel_warmup.py", line 46, in kernel_warmup
 [core.py:1088]     flashinfer_autotune(worker.model_runner)
 [core.py:1088]   File "/vllm/vllm/model_executor/warmup/kernel_warmup.py", line 103, in flashinfer_autotune
 [core.py:1088]     runner._dummy_run(
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
 [core.py:1088]     return func(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/v1/worker/gpu_model_runner.py", line 5228, in _dummy_run
 [core.py:1088]     outputs = self.model(
 [core.py:1088]               ^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/compilation/cuda_graph.py", line 241, in __call__
 [core.py:1088]     return self.runnable(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
 [core.py:1088]     return self._call_impl(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
 [core.py:1088]     return forward_call(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/model_executor/models/kimi_k25.py", line 477, in forward
 [core.py:1088]     hidden_states = self.language_model(
 [core.py:1088]                     ^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
 [core.py:1088]     return self._call_impl(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
 [core.py:1088]     return forward_call(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/model_executor/models/deepseek_v2.py", line 1386, in forward
 [core.py:1088]     hidden_states = self.model(
 [core.py:1088]                     ^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/compilation/decorators.py", line 452, in __call__
 [core.py:1088]     return self.aot_compiled_fn(self, *args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/aot_compile.py", line 240, in __call__
 [core.py:1088]     return self.fn(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/model_executor/models/deepseek_v2.py", line 1189, in forward
 [core.py:1088]     def forward(
 [core.py:1088]   File "/vllm/vllm/compilation/caching.py", line 206, in __call__
 [core.py:1088]     return self.optimized_call(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped
 [core.py:1088]     return self._wrapped_call(self, *args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 461, in __call__
 [core.py:1088]     raise e
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 447, in __call__
 [core.py:1088]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
 [core.py:1088]     return self._call_impl(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
 [core.py:1088]     return forward_call(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "<eval_with_key>.194", line 581, in forward
 [core.py:1088]     submod_8 = self.submod_8(getitem_8, s59, l_self_modules_layers_modules_1_modules_self_attn_modules_mla_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_1_modules_post_attention_layernorm_parameters_weight_, getitem_9, synthetic_local_tmp_0_, l_self_modules_layers_modules_2_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_fused_qkv_a_proj_parameters_weight_, l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_q_a_layernorm_parameters_weight_, l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_q_b_proj_parameters_weight_, l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_kv_a_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_mla_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_positions_);  getitem_8 = l_self_modules_layers_modules_1_modules_self_attn_modules_mla_attn_modules_o_proj_parameters_weight_ = l_self_modules_layers_modules_1_modules_post_attention_layernorm_parameters_weight_ = getitem_9 = synthetic_local_tmp_0_ = l_self_modules_layers_modules_2_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_fused_qkv_a_proj_parameters_weight_ = l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_q_a_layernorm_parameters_weight_ = l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_q_b_proj_parameters_weight_ = l_self_modules_layers_modules_2_modules_self_attn_modules_mla_attn_modules_kv_a_layernorm_parameters_weight_ = None
 [core.py:1088]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/compilation/cuda_graph.py", line 241, in __call__
 [core.py:1088]     return self.runnable(*args, **kwargs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/compilation/piecewise_backend.py", line 367, in __call__
 [core.py:1088]     return range_entry.runnable(*args)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/vllm/compilation/compiler_interface.py", line 445, in compiled_graph_wrapper
 [core.py:1088]     graph_output = inductor_compiled_graph(*args)
 [core.py:1088]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/standalone_compile.py", line 122, in __call__
 [core.py:1088]     return self._compiled_fn(*args)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/standalone_compile.py", line 236, in <lambda>
 [core.py:1088]     return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
 [core.py:1088]                                                ^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 583, in runtime_wrapper
 [core.py:1088]     all_outs = call_func_at_runtime_with_args(
 [core.py:1088]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
 [core.py:1088]     out = normalize_as_list(f(args))
 [core.py:1088]                             ^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 786, in wrapper
 [core.py:1088]     return compiled_fn(runtime_args)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 671, in __call__
 [core.py:1088]     return self.current_callable(inputs)
 [core.py:1088]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/utils.py", line 3398, in run
 [core.py:1088]     out = model(new_inputs)
 [core.py:1088]           ^^^^^^^^^^^^^^^^^
 [core.py:1088]   File "/tmp/torchinductor_python_user/vw/cvw4j2pwhxyvma72xw4dkrxkwqd6ss4yxcft6cwl3irlhtky2bvp.py", line 1244, in call
 [core.py:1088]     assert_size_stride(buf8, (s59, 7168), (7168, 1), 'torch.ops.vllm.moe_forward_shared.default')
 [core.py:1088] AssertionError: expected size 8192==8192, stride 3584==7168 at dim=0; expected size 3584==7168, stride 1==1 at dim=1
 [core.py:1088] Error in op: torch.ops.vllm.moe_forward_shared.default
 [core.py:1088] This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
 [core.py:1088] Use torch.library.opcheck to test your custom op.
 [core.py:1088] See https://pytorch.org/docs/stable/library.html#torch.library.opcheck

@wzhao18
Copy link
Contributor

wzhao18 commented Mar 16, 2026

Also, I wonder how do we know if a kernel from flashinfer is incompatible with auto-tuning? It seems even for the same kernel, it only causes problems in some setup. Do we have a reliable way to know if things are working or not?

@haosdent
Copy link
Contributor Author

Hi @wzhao18, thanks a lot for your test! I just fixed the issue you reported. Can you try the latest version and then test again?

@wzhao18
Copy link
Contributor

wzhao18 commented Mar 17, 2026

@haosdent I tried wrapping the trtllm nvfp4 moe with the following and it works.

from vllm.utils.flashinfer import autotune
        with autotune(False):

Would this be cleaner? This is used in trtllm_moe.py

@haosdent
Copy link
Contributor Author

@wzhao18 yes, many thanks for your feedback, let me update later

@mergify
Copy link

mergify bot commented Mar 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @haosdent.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@haosdent
Copy link
Contributor Author

Have updated, thanks @wzhao18 , the new way looks much better

@mergify mergify bot removed the needs-rebase label Mar 18, 2026
Use `with autotune(False):` to disable FlashInfer autotuning for MoE
kernels that are incompatible with it (upstream flashinfer#2023). This
follows the existing pattern in trtllm_moe.py and avoids shape/dtype
mismatches from dummy return values.

Kernels wrapped:
- TrtLlmNvFp4ExpertsMonolithic (trtllm_fp4_block_scale_moe)
- TrtLlmFp8ExpertsMonolithic (trtllm_fp8_block_scale_moe,
  trtllm_fp8_per_tensor_scale_moe)
- flashinfer_fused_moe_bf16 (flashinfer_trtllm_bf16_moe)
- FlashInferExperts (flashinfer_cutlass_fused_moe)
- FlashInferCuteDSLExperts (flashinfer_cutedsl_moe_masked)
- flashinfer_trtllm_mxint4_moe (trtllm_mxint4_block_scale_moe)

Signed-off-by: haosdent <haosdent@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[Bug]: CPU offloading produces gibberish output with flashinfer autotuner

2 participants