[torch.compile] Improve Cold Start for MoEs#32805
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a workaround to speed up the cold start time for MOE models when using torch.compile. The changes avoid hard-coding layer name strings into the compiled graph by storing them in the ForwardContext and retrieving them at runtime. This relies on the assumption that MoE layers are executed in a fixed order.
The implementation is sound, but I have one suggestion to improve the robustness and type safety of the new get_layer_from_name function. This will help in debugging if the execution order assumption is violated in the future and improves code maintainability.
vllm/forward_context.py
Outdated
| # There are longer-term solutions, like unwrapping the moe custom operator, | ||
| # and/or treating the string as a "symbolic input" to the graph that | ||
| # aren't ready yet. | ||
| remaining_moe_layers: list[str] |
There was a problem hiding this comment.
@ProExpertProg I could try to clean up how no_compile_layers works (it has both attention layers and MOE layers) by instead:
- making it a list
- popping from it instead of the new
remaining_moe_layerslist - deleting the "string" arguments to the moe_forward / moe_forward_shared operators (and maybe unified_attention if it has them too, I don't know)
but I'm not sure how long this change will really stick around. I expect the MOE refactor to expose the MOE internals and obviate this change.
There was a problem hiding this comment.
Yeah this isn't meant to be permanent. Do you mind adding a TODO with a link to #31985?
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
517304e to
a2cd62a
Compare
ProExpertProg
left a comment
There was a problem hiding this comment.
Looks good, thanks for doing this!
vllm/forward_context.py
Outdated
| # There are longer-term solutions, like unwrapping the moe custom operator, | ||
| # and/or treating the string as a "symbolic input" to the graph that | ||
| # aren't ready yet. | ||
| remaining_moe_layers: list[str] |
There was a problem hiding this comment.
Yeah this isn't meant to be permanent. Do you mind adding a TODO with a link to #31985?
a2cd62a to
f7d4588
Compare
f7d4588 to
1920eb3
Compare
Fixes vllm-project#29992 For torch.compile cold start times, we need to avoid hard-coding any strings into the graph. Right now, the vllm.moe_forward and vllm.moe_forward_shared custom operators hard-code strings into the graph. The workaround is to store a list of the strings that each of those custom ops needs, in reverse order, in the ForwardContext. The ForwardContext object is alive for the duration of the forward pass. When the custom op needs the string, pop the string from this list. This assumes that the custom operators will always be executed in order and that torch.compile will not try to reorder these operations with respect to each other. There are longer-term solutions, like unwrapping the moe custom operator, and/or treating the string as a "symbolic input" to the graph that aren't ready yet. This PR speeds up the torch.compile piece of gpt-oss-120b from 46s to 16s. Signed-off-by: Richard Zou <zou3519@gmail.com>
1920eb3 to
c94988d
Compare
|
I tried out the PR on my GB200 machine and got OOM. I don't get OOM in the main, so it seems this PR does cause the memory issue. |
|
Discussed with Woosuk online, it turns out that the model does also OOM on main so this PR is unrelated. |
Signed-off-by: Richard Zou <zou3519@gmail.com> Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
Signed-off-by: Richard Zou <zou3519@gmail.com> Signed-off-by: 陈建华 <1647430658@qq.com>
| no_compile_layers = vllm_config.compilation_config.static_forward_context | ||
| from vllm.model_executor.layers.fused_moe.layer import FusedMoE | ||
|
|
||
| remaining_moe_layers = [ | ||
| name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE) | ||
| ] | ||
| remaining_moe_layers.reverse() |
There was a problem hiding this comment.
do we have profiling to show how long it takes? doing it in every forward can be time-consuming.
maybe we can keep all_moe_layers and only increase a current_index inside forward context.
There was a problem hiding this comment.
I can change it to all_moe_layers and current_index. Should be able to cache all_moe_layers on a static_forward_context somewhere so that also isn't being recomputed on each forward pass.
This is a follow up to the comments on vllm-project#32805 . It contains the following two perf optimizations: - We don't need to recompute all of the MOE layer names on every forward pass. Instead we can get all of the layer names when the model is being initialized - Stop popping strings from a list. Instead, maintain a counter. Signed-off-by: Richard Zou <zou3519@gmail.com>
Signed-off-by: Richard Zou <zou3519@gmail.com>
Avoid hard-coding attention layer name strings into the compiled graph in unified_kv_cache_update. Each layer having a different name prevents Inductor from reusing piecewise graphs across layers, increasing cold start compilation time. Apply the same approach used for MOE layers (vllm-project#32805, vllm-project#33184): store the list of all KV cache update layer names at model init time and resolve them at runtime via a counter in ForwardContext. Fixes vllm-project#33267 Signed-off-by: Varun Chawla <varun_6april@hotmail.com>
Signed-off-by: Richard Zou <zou3519@gmail.com>
|
@zou3519 This PR caused garbled outputs when running the ERNIE-4.5-VL-28B-A3B-PT model. Since it contains two experts but only one expert is used during decode, this leads to misalignment and results in corrupted outputs. Could this be fixed? One possible approach I can think of is to add an input flag to FusedMoE to control this behavior, but that doesn’t seem very elegant. |
|
@CSWYF3634076 yes I believe this PR is the problem. Could you try to test -cc.fast_moe_cold_start=False as a workaround please? |
|
@CSWYF3634076 is that an in-tree model? If yes, can the model config init override this field to false? See |
|
@zou3519 @ProExpertProg Thank you for your answer. It worked. I solved the problem by setting false in Ernie4.5-VL directly. And there's no need for users to specify configurations. It can be fixed without being noticed. Could you please help review it #35587 |
Avoid hard-coding attention layer name strings into the compiled graph in unified_kv_cache_update. Each layer having a different name prevents Inductor from reusing piecewise graphs across layers, increasing cold start compilation time. Apply the same approach used for MOE layers (vllm-project#32805, vllm-project#33184): store the list of all KV cache update layer names at model init time and resolve them at runtime via a counter in ForwardContext. Fixes vllm-project#33267 Signed-off-by: Varun Chawla <varun_6april@hotmail.com>
Purpose
Fixes #29992
For torch.compile cold start times, we need to avoid hard-coding any strings into the graph. Right now, the vllm.moe_forward and vllm.moe_forward_shared custom operators hard-code strings into the graph.
The workaround is to store a list of the strings that each of those custom ops needs, in reverse order, in the ForwardContext. The ForwardContext object is alive for the duration of the forward pass. When the custom op needs the string, pop the string from this list.
This assumes that the custom operators will always be executed in order and that torch.compile will not try to reorder these operations with respect to each other.
There are longer-term solutions, like unwrapping the moe custom operator,
and/or treating the string as a "symbolic input" to the graph that aren't ready yet.
Test Plan & Test Result
This PR speeds up:
I also tested gpt-oss-120b locally with some inputs to sanity check that it was still correct.
Wait for CI