[compile] Allow strings in custom ops without regressing compilation times#38123
[compile] Allow strings in custom ops without regressing compilation times#38123zou3519 wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a ModuleName opaque type to improve torch.compile behavior by hoisting layer names as graph inputs, preventing per-layer recompilation for custom operations. This change involves updating various torch.ops.vllm calls across attention, KV cache, and Mamba mixer modules to use _encode_layer_name and _resolve_layer_name for consistent handling of layer names. A critical issue was identified where the kv_cache_dummy_dep variable could be undefined, leading to an UnboundLocalError.
35626a5 to
99e2154
Compare
| qkvz_output_size: int, | ||
| ba_output_size: int, | ||
| layer_name: str, | ||
| layer_name: _layer_name_type, |
There was a problem hiding this comment.
Nit: can we call this LayerName or LayerNameType?
99e2154 to
14a1342
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
807e2e7 to
594eb86
Compare
…times This is a follow-up to vllm-project#35475 to extend the fix to all custom operators, not just the MOE custom ops. Previously, string inputs to custom ops would regress compilation times. The problem goes: - a transformer model (e.g. llama3-70b) has 80 identical layers - we capture a full graph and the split the graph on the attention operations. - this produces 81 subgraphs: the middle 79 graphs are all identical (aside from graph inputs - parameters and buffers) - vLLM-compile produces 1 compiled artifact for all of the middle 79 subgraphs. - If a custom operator with a layer_name string appears in the graph, then this causes the middle 79 subgraphs to now be unique, so vLLM-compile ends up producing 79 compiled artifacts for them. In PyTorch 2.11, we have added a special class (the OpaqueObject type). The idea is that instead of passing strings to custom operators, we can pass a special LayerName object to the custom operator. This signifies to the compiler that it should "lift" the LayerName object to being a graph input and not bake the value directly into the graph. More notes: - the LayerName object used to be called ModuleName. I renamed it here, LayerName seemed more appropriate. - VLLM_USE_LAYERNAME=0 turns this feature off. This option is here just in case something breaks. I'll probably remove it in the next month. Signed-off-by: Richard Zou <zou3519@gmail.com>
594eb86 to
04fc04b
Compare
This is a follow-up to #35475 to extend the fix to all custom operators, not just the MOE custom ops.
Previously, string inputs to custom ops would regress compilation times. The problem goes:
In PyTorch 2.11, we have added a special class (the OpaqueObject type). The idea is that instead of passing strings to custom operators, we can pass a special LayerName object to the custom operator. This signifies to the compiler that it should "lift" the LayerName object to being a graph input and not bake the value directly into the graph.
More things: