Skip to content

[Bugfix][V1][MoE] Warm up WNA16 MoE Triton kernels#42193

Open
lesj0610 wants to merge 2 commits into
vllm-project:mainfrom
lesj0610:lesj/wna16-moe-jit-warmup-20260510
Open

[Bugfix][V1][MoE] Warm up WNA16 MoE Triton kernels#42193
lesj0610 wants to merge 2 commits into
vllm-project:mainfrom
lesj0610:lesj/wna16-moe-jit-warmup-20260510

Conversation

@lesj0610
Copy link
Copy Markdown
Contributor

@lesj0610 lesj0610 commented May 10, 2026

What this fixes

In V1, WNA16 fused MoE warmup has a gap. Startup dummy run uses small batch, so should_moe_wna16_use_cuda picks CUDA path and Triton kernel never gets compiled. When first real request comes with larger token count, fused_moe_kernel_gptq_awq compiles on the fly during inference.

What I changed

First, fused_moe_kernel_gptq_awq now has do_not_specialize for EM and num_valid_tokens. Otherwise every different token count can trigger new compilation even after warmup.

Second part is actual warmup. New module fused_moe_warmup.py scans model for FusedMoE layers that use MoeWNA16Method, figures out which M values will hit Triton path based on the CUDA/Triton dispatch threshold, and calls quant_method.apply() with synthetic inputs.

One thing I had to be careful about: WNA16 dispatches two GEMMs with different top_k. Gate/up uses model's top_k but down projection uses top_k=1, so dispatch threshold is different for each. Both need separate warmup values.

For expert parallelism, only local expert IDs from expert_map are used. Layers with same weight shape and quant config get deduped so we don't repeat same compilation.

No full model forward pass. Just direct kernel level warmup.

Checked open PRs, didn't find existing one for this.

Test Plan

.venv/bin/python -m pytest tests/model_executor/test_fused_moe_warmup.py -v

pre-commit run ruff-format --files \
  vllm/model_executor/layers/fused_moe/fused_moe.py \
  vllm/model_executor/warmup/kernel_warmup.py \
  vllm/model_executor/warmup/fused_moe_warmup.py \
  tests/model_executor/test_fused_moe_warmup.py

pre-commit run ruff-check --files \
  vllm/model_executor/layers/fused_moe/fused_moe.py \
  vllm/model_executor/warmup/kernel_warmup.py \
  vllm/model_executor/warmup/fused_moe_warmup.py \
  tests/model_executor/test_fused_moe_warmup.py

pre-commit run mypy-3.10 --files \
  vllm/model_executor/layers/fused_moe/fused_moe.py \
  vllm/model_executor/warmup/kernel_warmup.py \
  vllm/model_executor/warmup/fused_moe_warmup.py \
  tests/model_executor/test_fused_moe_warmup.py \
  --hook-stage manual

Test Result

pytest: 6 passed, 16 warnings.
ruff-format, ruff-check, mypy: all passed.
Pre-commit hooks passed on commit.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR
  • The test plan
  • The test results
  • AI assistance disclosed
  • Assisted-by trailers in commit

AI assistance: Codex, Claude, Gemini.

Co-authored-by: OpenAI Codex <codex@openai.com>

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com>
@mergify mergify Bot added the bug Something isn't working label May 10, 2026
@lesj0610 lesj0610 marked this pull request as ready for review May 10, 2026 01:19
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a warmup mechanism for WNA16 MoE Triton kernels to ensure they are reliably exercised during initialization. It adds a new fused_moe_warmup.py module that calculates appropriate M values for dummy runs and integrates this into the kernel_warmup process. A review comment pointed out that the expert mapping logic should account for ROCm-specific binary masks to avoid including non-local experts during warmup.

Comment thread vllm/model_executor/warmup/fused_moe_warmup.py Outdated
Co-authored-by: OpenAI Codex <codex@openai.com>

Co-authored-by: Claude <noreply@anthropic.com>

Co-authored-by: Gemini <gemini-code-assist@users.noreply.github.com>

Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com>
@lesj0610
Copy link
Copy Markdown
Contributor Author

@ZJY0516 @qiching @tdoublep @vadiklyutiy Hi again, me from #42165.

Same area but different kernel this time. WNA16 fused MoE (fused_moe_kernel_gptq_awq) also has JIT compile problem in V1. Startup warmup batch is small enough that dispatch goes CUDA path, so Triton kernel is never compiled before JIT monitor starts. Then first bigger request compiles it during serving.

I added do_not_specialize for EM and num_valid_tokens to stop recompilation on different token counts, and a small kernel-level warmup that exercises the Triton path directly. Not a full model forward, just the MoE kernels.

Would appreciate your eyes on this since you know the warmup/monitor context from #40137. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant