-
-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE #35180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e0013f6
9624497
ca9ec3f
cfa149e
83ff4bd
8a2c1ff
e512759
b5f2bb8
1d9ab8e
1139789
11177e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: | |
| ) | ||
|
|
||
|
|
||
| def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: | ||
| """Enable if rotary embedding custom op is active and | ||
| use_inductor_graph_partition is enabled. | ||
| """ | ||
| from vllm._aiter_ops import rocm_aiter_ops | ||
|
|
||
| return ( | ||
| rocm_aiter_ops.is_enabled() | ||
| and cfg.compilation_config.is_custom_op_enabled("rotary_embedding") | ||
| and cfg.compilation_config.use_inductor_graph_partition | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So are you guys using inductor graph partition on rocm by default? Otherwise we should also return true here I'd dynamo partition and kv cache op not in splitting ops
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Somehow GH ate my original PR comment that explained this) This PR is necessary but not sufficient to actually enable this fusion by default. We also need:
https://github.com/vllm-project/vllm/blob/main/vllm/config/compilation.py#1001 is called in https://github.com/vllm-project/vllm/blob/main/vllm/config/vllm.py#L961 after the defaults are set in https://github.com/vllm-project/vllm/blob/main/vllm/config/vllm.py#L807. So if inductor partition is not enabled, we would return true for this, then append kv cache to splitting ops, which would silently break the fusion.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Links are broken but I know what you mean - but if |
||
| ) | ||
|
|
||
|
|
||
| def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: | ||
| """Enable if using AITER RMSNorm and AITER Triton GEMMs | ||
| and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion.""" | ||
| from vllm._aiter_ops import rocm_aiter_ops | ||
|
|
||
| return ( | ||
| envs.VLLM_ROCM_USE_AITER | ||
| and envs.VLLM_ROCM_USE_AITER_RMSNORM | ||
| and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM | ||
| rocm_aiter_ops.is_rmsnorm_enabled() | ||
| and not rocm_aiter_ops.is_triton_gemm_enabled() | ||
| and cfg.model_config is not None | ||
| and cfg.model_config.get_hidden_size() == 2880 | ||
| ) | ||
|
|
@@ -149,6 +162,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: | |
| "enable_sp": False, | ||
| "fuse_gemm_comms": False, | ||
| "fuse_act_padding": False, | ||
| "fuse_rope_kvcache": False, | ||
| }, | ||
| "cudagraph_mode": CUDAGraphMode.NONE, | ||
| "use_inductor_graph_partition": False, | ||
|
|
@@ -167,6 +181,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: | |
| "enable_sp": False, | ||
| "fuse_gemm_comms": False, | ||
| "fuse_act_padding": enable_norm_pad_fusion, | ||
| "fuse_rope_kvcache": enable_rope_kvcache_fusion, | ||
| }, | ||
| "cudagraph_mode": CUDAGraphMode.PIECEWISE, | ||
| "use_inductor_graph_partition": False, | ||
|
|
@@ -185,6 +200,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: | |
| "enable_sp": IS_DENSE, | ||
| "fuse_gemm_comms": IS_DENSE, | ||
| "fuse_act_padding": enable_norm_pad_fusion, | ||
| "fuse_rope_kvcache": enable_rope_kvcache_fusion, | ||
| }, | ||
| "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, | ||
| "use_inductor_graph_partition": False, | ||
|
|
@@ -203,6 +219,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: | |
| "enable_sp": IS_DENSE, | ||
| "fuse_gemm_comms": IS_DENSE, | ||
| "fuse_act_padding": enable_norm_pad_fusion, | ||
| "fuse_rope_kvcache": enable_rope_kvcache_fusion, | ||
| }, | ||
| "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, | ||
| "use_inductor_graph_partition": False, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.