[Bugfix][V1][MoE] Warm up WNA16 MoE Triton kernels#42193
Conversation
Co-authored-by: OpenAI Codex <codex@openai.com> Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a warmup mechanism for WNA16 MoE Triton kernels to ensure they are reliably exercised during initialization. It adds a new fused_moe_warmup.py module that calculates appropriate M values for dummy runs and integrates this into the kernel_warmup process. A review comment pointed out that the expert mapping logic should account for ROCm-specific binary masks to avoid including non-local experts during warmup.
Co-authored-by: OpenAI Codex <codex@openai.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Gemini <gemini-code-assist@users.noreply.github.com> Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com>
|
@ZJY0516 @qiching @tdoublep @vadiklyutiy Hi again, me from #42165. Same area but different kernel this time. WNA16 fused MoE ( I added Would appreciate your eyes on this since you know the warmup/monitor context from #40137. Thanks. |
What this fixes
In V1, WNA16 fused MoE warmup has a gap. Startup dummy run uses small batch, so
should_moe_wna16_use_cudapicks CUDA path and Triton kernel never gets compiled. When first real request comes with larger token count,fused_moe_kernel_gptq_awqcompiles on the fly during inference.What I changed
First,
fused_moe_kernel_gptq_awqnow hasdo_not_specializeforEMandnum_valid_tokens. Otherwise every different token count can trigger new compilation even after warmup.Second part is actual warmup. New module
fused_moe_warmup.pyscans model forFusedMoElayers that useMoeWNA16Method, figures out which M values will hit Triton path based on the CUDA/Triton dispatch threshold, and callsquant_method.apply()with synthetic inputs.One thing I had to be careful about: WNA16 dispatches two GEMMs with different
top_k. Gate/up uses model'stop_kbut down projection usestop_k=1, so dispatch threshold is different for each. Both need separate warmup values.For expert parallelism, only local expert IDs from
expert_mapare used. Layers with same weight shape and quant config get deduped so we don't repeat same compilation.No full model forward pass. Just direct kernel level warmup.
Checked open PRs, didn't find existing one for this.
Test Plan
Test Result
pytest: 6 passed, 16 warnings.
ruff-format, ruff-check, mypy: all passed.
Pre-commit hooks passed on commit.
Essential Elements of an Effective PR Description Checklist
AI assistance: Codex, Claude, Gemini.