Skip to content

[torch.compile] Improve Cold Start for MoEs#32805

Merged
zou3519 merged 1 commit intovllm-project:mainfrom
zou3519:fix_moe_cold_start
Jan 22, 2026
Merged

[torch.compile] Improve Cold Start for MoEs#32805
zou3519 merged 1 commit intovllm-project:mainfrom
zou3519:fix_moe_cold_start

Conversation

@zou3519
Copy link
Collaborator

@zou3519 zou3519 commented Jan 21, 2026

Purpose

Fixes #29992

For torch.compile cold start times, we need to avoid hard-coding any strings into the graph. Right now, the vllm.moe_forward and vllm.moe_forward_shared custom operators hard-code strings into the graph.

The workaround is to store a list of the strings that each of those custom ops needs, in reverse order, in the ForwardContext. The ForwardContext object is alive for the duration of the forward pass. When the custom op needs the string, pop the string from this list.

This assumes that the custom operators will always be executed in order and that torch.compile will not try to reorder these operations with respect to each other.

There are longer-term solutions, like unwrapping the moe custom operator,
and/or treating the string as a "symbolic input" to the graph that aren't ready yet.

Test Plan & Test Result

This PR speeds up:

  • the torch.compile piece of gpt-oss-120b from 46s to 16s.
  • the torch.compile piece of GLM-4.7-FP8 from 197.62s to 46s.

I also tested gpt-oss-120b locally with some inputs to sanity check that it was still correct.

Wait for CI

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a workaround to speed up the cold start time for MOE models when using torch.compile. The changes avoid hard-coding layer name strings into the compiled graph by storing them in the ForwardContext and retrieving them at runtime. This relies on the assumption that MoE layers are executed in a fixed order.

The implementation is sound, but I have one suggestion to improve the robustness and type safety of the new get_layer_from_name function. This will help in debugging if the execution order assumption is violated in the future and improves code maintainability.

# There are longer-term solutions, like unwrapping the moe custom operator,
# and/or treating the string as a "symbolic input" to the graph that
# aren't ready yet.
remaining_moe_layers: list[str]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg I could try to clean up how no_compile_layers works (it has both attention layers and MOE layers) by instead:

  • making it a list
  • popping from it instead of the new remaining_moe_layers list
  • deleting the "string" arguments to the moe_forward / moe_forward_shared operators (and maybe unified_attention if it has them too, I don't know)

but I'm not sure how long this change will really stick around. I expect the MOE refactor to expose the MOE internals and obviate this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this isn't meant to be permanent. Do you mind adding a TODO with a link to #31985?

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@zou3519 zou3519 force-pushed the fix_moe_cold_start branch 2 times, most recently from 517304e to a2cd62a Compare January 21, 2026 19:49
@zou3519 zou3519 requested a review from youkaichao January 21, 2026 19:52
@robertgshaw2-redhat robertgshaw2-redhat changed the title [torch.compile] Speed up cold start time for MOE models [torch.compile] Improve Cold Start for MoEs Jan 21, 2026
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for doing this!

# There are longer-term solutions, like unwrapping the moe custom operator,
# and/or treating the string as a "symbolic input" to the graph that
# aren't ready yet.
remaining_moe_layers: list[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this isn't meant to be permanent. Do you mind adding a TODO with a link to #31985?

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 21, 2026
@zou3519 zou3519 force-pushed the fix_moe_cold_start branch from a2cd62a to f7d4588 Compare January 21, 2026 20:26
Fixes vllm-project#29992

For torch.compile cold start times, we need to avoid hard-coding
any strings into the graph. Right now, the vllm.moe_forward
and vllm.moe_forward_shared custom operators hard-code strings into
the graph.

The workaround is to store a list of the strings that each of those
custom ops needs, in reverse order, in the ForwardContext.
The ForwardContext object is alive for the duration of the forward pass.
When the custom op needs the string, pop the string from this list.

This assumes that the custom operators will always be executed in
order and that torch.compile will not try to reorder these
operations with respect to each other.

There are longer-term solutions, like unwrapping the moe custom
operator,
and/or treating the string as a "symbolic input" to the graph that
aren't ready yet.

This PR speeds up the torch.compile piece of gpt-oss-120b from 46s to
16s.

Signed-off-by: Richard Zou <zou3519@gmail.com>
@zou3519 zou3519 force-pushed the fix_moe_cold_start branch from 1920eb3 to c94988d Compare January 22, 2026 01:07
@WoosukKwon
Copy link
Collaborator

I tried out the PR on my GB200 machine and got OOM. I don't get OOM in the main, so it seems this PR does cause the memory issue.

@zou3519
Copy link
Collaborator Author

zou3519 commented Jan 22, 2026

Discussed with Woosuk online, it turns out that the model does also OOM on main so this PR is unrelated.

@zou3519 zou3519 merged commit 654a71f into vllm-project:main Jan 22, 2026
52 checks passed
monajafi-amd pushed a commit to monajafi-amd/vllm that referenced this pull request Jan 23, 2026
Signed-off-by: Richard Zou <zou3519@gmail.com>
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
cwazai pushed a commit to cwazai/vllm that referenced this pull request Jan 25, 2026
Signed-off-by: Richard Zou <zou3519@gmail.com>
Signed-off-by: 陈建华 <1647430658@qq.com>
Comment on lines +272 to +278
no_compile_layers = vllm_config.compilation_config.static_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE

remaining_moe_layers = [
name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE)
]
remaining_moe_layers.reverse()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have profiling to show how long it takes? doing it in every forward can be time-consuming.

maybe we can keep all_moe_layers and only increase a current_index inside forward context.

Copy link
Collaborator Author

@zou3519 zou3519 Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it to all_moe_layers and current_index. Should be able to cache all_moe_layers on a static_forward_context somewhere so that also isn't being recomputed on each forward pass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated at #33184, please take a look

zou3519 added a commit to zou3519/vllm that referenced this pull request Jan 27, 2026
This is a follow up to the comments on
vllm-project#32805 .

It contains the following two perf optimizations:
- We don't need to recompute all of the MOE layer names on every forward
  pass. Instead we can get all of the layer names when the model is
  being initialized
- Stop popping strings from a list. Instead, maintain a counter.

Signed-off-by: Richard Zou <zou3519@gmail.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
Signed-off-by: Richard Zou <zou3519@gmail.com>
veeceey added a commit to veeceey/vllm that referenced this pull request Feb 7, 2026
Avoid hard-coding attention layer name strings into the compiled graph
in unified_kv_cache_update. Each layer having a different name prevents
Inductor from reusing piecewise graphs across layers, increasing cold
start compilation time.

Apply the same approach used for MOE layers (vllm-project#32805, vllm-project#33184): store the
list of all KV cache update layer names at model init time and resolve
them at runtime via a counter in ForwardContext.

Fixes vllm-project#33267

Signed-off-by: Varun Chawla <varun_6april@hotmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Richard Zou <zou3519@gmail.com>
@CSWYF3634076
Copy link
Contributor

@zou3519 This PR caused garbled outputs when running the ERNIE-4.5-VL-28B-A3B-PT model. Since it contains two experts but only one expert is used during decode, this leads to misalignment and results in corrupted outputs. Could this be fixed?

One possible approach I can think of is to add an input flag to FusedMoE to control this behavior, but that doesn’t seem very elegant.

@zou3519
Copy link
Collaborator Author

zou3519 commented Feb 27, 2026

@CSWYF3634076 yes I believe this PR is the problem. Could you try to test -cc.fast_moe_cold_start=False as a workaround please?

@ProExpertProg
Copy link
Collaborator

@CSWYF3634076 is that an in-tree model? If yes, can the model config init override this field to false? See vllm/model_executor/models/config.py

@CSWYF3634076
Copy link
Contributor

@zou3519 @ProExpertProg Thank you for your answer. It worked. I solved the problem by setting false in Ernie4.5-VL directly. And there's no need for users to specify configurations. It can be fixed without being noticed. Could you please help review it #35587

SongyouZhong pushed a commit to SongyouZhong/vllm that referenced this pull request Mar 6, 2026
Avoid hard-coding attention layer name strings into the compiled graph
in unified_kv_cache_update. Each layer having a different name prevents
Inductor from reusing piecewise graphs across layers, increasing cold
start compilation time.

Apply the same approach used for MOE layers (vllm-project#32805, vllm-project#33184): store the
list of all KV cache update layer names at model init time and resolve
them at runtime via a counter in ForwardContext.

Fixes vllm-project#33267

Signed-off-by: Varun Chawla <varun_6april@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: vLLM cold start on MOE models not optimal

5 participants