-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[FEAT][ROCm]: Support AITER MLA #15893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT][ROCm]: Support AITER MLA #15893
Conversation
Co-authored-by: qli88 <[email protected]> Signed-off-by: vllmellm <[email protected]>
Co-authored-by: qli88 <[email protected]> Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: ArthurAMD [email protected] Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
… if/else statements Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
| "cpu": [], | ||
| } | ||
|
|
||
| DEVICE_NON_MLA_BACKENDS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lets just call this DEVICE_REGULAR_ATTN_BACKENDS instead of MLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson This has been addressed. Thanks.
| self.block_tables.extend([] * cuda_graph_pad_size) | ||
| num_decode_tokens = batch_size - self.num_prefill_tokens | ||
| self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) | ||
| self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why relocate these lines? Also can you please explain to me why we now need self.__class__.BLOCK_TABLE_EXTENDER
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.class.BLOCK_TABLE_EXTENDER this is a static class variable since common had this hardcoded as "[]" in the line below:
self.block_tables.extend([] * cuda_graph_pad_size)
cuz in AiterMLAMetadataBuilder for capturing graph we need "[[]]" instead of "[]", by eliminating the hardcoded extender into class variable allows the subclass to implement itsown extender value or just inherit from parent.
to review this file is better to open the entire file, as the github interface is not representative enough what has been changed.
overall as explained in the PR descript for the summary of the changes to accommodate AITER MLA implementation and reduce the code duplication in the subclass some refactoring has been made in certain function to allow more flexibility in subclasses.
Implementation
ROCM_AITER_MLA is introduced as an additional attention backend type for ROCm platform.
To support this backend the modules below are implemented vllm/attention/backends/rocm_aiter_mla.py
AiterMLABackendinherits fromMLACommonBackend.AiterMLAMetadatainherits fromMLACommonMetadata: note that from this class theadvance_stepfunction utilizesadvance_step_flashinferfunction from VLLM cutom ops.AiterMLAMetadataBuilderinherits fromMLACommonMetadataBuilder.AiterMLAStateinherits fromMLACommonState.AiterMLAImplclass inherits fromCommonMLAImpl:
Important notes for this class:flash_attn_varlen_func(FA function) used in this class is AITER FA implementation (flash_attn_varlen_funcfrom AITER package)._forward_decodefunction in this class usesmla_decode_fwdkernel from AITER package.
The MLACommon module has been refactored to reduce code duplication in its subclasses. This was achieved by separating the attention output computation into two dedicated functions named as _get_fwd_prefill_attn_output and _get_prefill_ctx_attn_output that are used in _compute_prefill_context and _forward_prefill function respectively.
Another refactoring is placed in advance_step function by separating out the pre assertion checks before calling an advance_step method to allow advance_step function to be overridden without code duplication in its subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson after resolving merge conflict for this file. the only changes in common.py are as below:
-
invoking
ops.advance_step_flashattnin a separate function_ops_advance_stepthat can be overridden by subclass that is used inadvance_stepfunction. -
use of "static" class variable as
BLOCK_TABLE_EXTENDER: list[list[int]] = []that is used to updateself.block_tablesin graph mode which eliminates the hardcoded "[]" self.block_tables.extend([] * cuda_graph_pad_size) to allow flexibility for the subclasses to override this update based on the class variable.
5e6ed9a to
6e48433
Compare
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the contribution
vllm/platforms/rocm.py
Outdated
| else: | ||
| raise ValueError( | ||
| f" The selected backend, {selected_backend.name}," | ||
| "does not support block size {block_size}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "does not support block size {block_size}.") | |
| f"does not support block size {block_size}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing out this. Have added the the suggestion.
vllm/platforms/rocm.py
Outdated
| else: | ||
| raise ValueError( | ||
| f" The selected backend, {selected_backend.name}," | ||
| "does not support block size {block_size}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "does not support block size {block_size}." | |
| f"does not support block size {block_size}." |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
… handle wrong backend selection when MLA is requested. Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]> Co-authored-by: qli88 <[email protected]> Signed-off-by: Frieda (Jingying) Huang <[email protected]>
Signed-off-by: vllmellm <[email protected]> Co-authored-by: qli88 <[email protected]>
Signed-off-by: vllmellm <[email protected]> Co-authored-by: qli88 <[email protected]>
Signed-off-by: vllmellm <[email protected]> Co-authored-by: qli88 <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
Signed-off-by: vllmellm <[email protected]> Co-authored-by: qli88 <[email protected]> Signed-off-by: Mu Huai <[email protected]>
Description
This PR integrates the AITER ops to improve the MLA functionality from AITER flash_attn_varlen_func and AITER mla_decode_fwd into vLLM, and will allow any up-coming optimizations in AITER kernel to be directly used and evaluated within the vLLM framework.
Implementation
ROCM_AITER_MLAis introduced as an additional attention backend type for ROCm platform.To support this backend the modules below are implemented
vllm/attention/backends/rocm_aiter_mla.pyAiterMLABackendinherits fromMLACommonBackend.AiterMLAMetadatainherits fromMLACommonMetadata: note that from this class theadvance_stepfunction utilizesadvance_step_flashinferfunction from VLLM cutom ops.AiterMLAMetadataBuilderinherits fromMLACommonMetadataBuilder.AiterMLAStateinherits fromMLACommonState.AiterMLAImplclass inherits fromCommonMLAImpl:Important notes for this class:
flash_attn_varlen_func(FA function) used in this class is AITER FA implementation (flash_attn_varlen_funcfrom AITER package)._forward_decodefunction in this class usesmla_decode_fwdkernel from AITER package.The MLACommon module has been refactored to reduce code duplication in its subclasses for
advance_stepfunction by invoking ops attentionops.advance_step_flashattnin a separate function_ops_advance_stepthat can be overridden by subclass.To enable the backed the environment variable
VLLM_ATTN_BACKENDcan be set toROCM_AITER_MLA.In case that the backend is not specified the
rocm.pyinvllm/platformsverifies whetherVLLM_ROCM_USE_AITERandVLLM_ROCM_USE_AITER_MLAare both enabled or not to utilize this backend. Otherwise the selected backend isTRITON_MLA.Important Notes:
block_size=1and the variablemax_model_len=32768has to be set.Testing
In order to ensure correct attention backend is selected.
MLA backend env backends has been added into the test cases in
tests/kernels/test_attention_selector.pyPerformance
Benchmark Serving Results Comparison
Lm Eval Results
Envrionment Setting
Updates in Dockerfile.rocm_base
Added AITER Package:
Additional Notes installing AITER
When setting up AITER, it is crucial to use the command git clone --recursive. This is because the package depends on a third-party package (Composable Kernel).
For building and installing the AITER Python package, you must use the PREBUILD_KERNELS=1 flag along with the command python3 setup.py develop. This ensures that all kernels in the AITER package are built successfully.
The following branches were used as references for this integration:
https://github.com/ROCm/vllm/tree/dsv3_dev
https://github.com/ROCm/vllm/tree/aiter_integration_final
https://github.com/ROCm/vllm/tree/deepseek_v3_dev