Skip LM head during FlashInfer autotune dummy run#23796
Closed
Kangyan-Zhou wants to merge 2 commits into
Closed
Conversation
The autotune cache only needs attention/MoE/GEMM kernel timings, but _dummy_run currently goes all the way through LogitsProcessor, where the [batch * dp_size, vocab] tensor-parallel all-gather buffer can OOM under DP attention on tight memory budgets (e.g. GLM-5.1-FP8 TP8+DP8 with --mem-fraction-static=0.9, which has been failing every B200 nightly run since the test was added). Add a module-level autotune_dummy_run_mode() context manager (mirroring cuda_graph_runner's model_capture_mode pattern), wrap the autotune dummy run with it, and have LogitsProcessor.forward short-circuit to LogitsProcessorOutput(next_token_logits=None) when the flag is set. The return value is discarded by the autotune call site, so the stub output is safe. Mirrors vLLM's split between _dummy_run (no lm_head, used by autotune) and _dummy_sampler_run (lm_head, profile-only). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment-analyzer review flagged three near-duplicate comment blocks explaining the same OOM mechanism. Keep the canonical explanation on _in_autotune_dummy_run, shrink the call-site comment to a one-liner that notes the dispatch ordering, and drop the redundant model_runner comment (the context-manager name is self-descriptive). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
kpham-sgl
reviewed
May 1, 2026
kpham-sgl
left a comment
Collaborator
There was a problem hiding this comment.
I think this is correct but cannot reproduce locally :'(. Move to another branch to trigger nightly CI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
LogitsProcessor.forwardshort-circuit during the FlashInfer autotune dummy run so the LM head + tensor-parallel all-gather are skipped._dummy_runreturns hidden states without callingcompute_logits.The bug
test/registered/8-gpu-models/test_glm_51_fp8.py::TestGlm51Fp8::test_glm51_fp8(TP8+DP8 variant) has been failing every B200 nightly run since the test was added on 2026-04-09 (PR #22399). Identical signature in every run:Latest reproducer (failing run, job 72975439421 on 2026-04-25).
Why TP8+DP8 specifically OOMs
The autotune dummy run uses
batch_size = req_to_token_pool.size. With DP attention enabled (--tp 8 --dp 8 --enable-dp-attention),_get_logitsfirst calls_gather_dp_attn_hidden_states, which multiplies the token count bydp_size. Then the LM head + all-gather produces a[batch × dp_size, vocab]buffer.max_running_requestsGLM-5.1-FP8 sits at the unfortunate intersection:
flashinfer_trtllmMoE backend (autotune fires), DP attention enabled, large auto-resolvedmax_running_requests=2048, and aggressive--mem-fraction-static=0.9leaving only ~3 GiB free after weights+KV. Comparable tests liketest_qwen35.pyuse--mem-fraction-static=0.8, which gives ~17 GiB extra activation headroom and absorbs the 4.7 GiB buffer.What vLLM does
vllm/v1/worker/gpu_model_runner.py:5615-5619:_dummy_runreturns hidden states directly.compute_logits(lm_head) only runs in a separate_dummy_sampler_run, called fromprofile_run— never fromflashinfer_autotune. So vLLM never allocates a[*, vocab]tensor during autotune. Even in production, line 4086-4087 gathershidden_states[logits_indices]beforecompute_logits, so the lm_head only sees[num_reqs, hidden](one row per sequence), never[num_tokens, hidden].What this PR changes
python/sglang/srt/layers/logits_processor.py: add module-level_in_autotune_dummy_runflag and a@contextmanager autotune_dummy_run_mode()(mirrorscuda_graph_runner.is_capture_mode/model_capture_mode). At the top ofLogitsProcessor.forward, returnLogitsProcessorOutput(next_token_logits=None)when the flag is set. The short-circuit sits before the MIS / DLLM / common dispatch, so all three LM-head paths are covered.python/sglang/srt/model_executor/model_runner.py: wrap the_dummy_runcall in_flashinfer_autotunewithautotune_dummy_run_mode(). The autotune call site discards the return value (run_once()is called without consuming its result), so the stub output is safe._dummy_runhas only one caller (_flashinfer_autotune), so the bypass cannot leak into cuda graph capture, profiling, or production forward.Test plan
_flashinfer_autotune) after this fix.Notes
EAGLEspeculative decoding hardcodesmax_running_requests=48inserver_args.py:3370, then_resolve_max_num_reqsinmodel_runner_kv_cache_mixin.py:707divides bydp_size. Withdp_size=8this becomes 6 per worker. Whether 48 is meant as system-wide or per-worker is ambiguous in the warning text — out of scope for this PR but worth a follow-up.