-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[FEAT][ROCm]: Support AITER MLA #15893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vllm-bot
merged 27 commits into
vllm-project:main
from
EmbeddedLLM:aiter-mla-integration
Apr 22, 2025
Merged
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
f782c66
add AITER MLA implementation in attention backend
vllmellm 42d5c62
remove unused arguments in aiter mla decode fwd kernel
vllmellm 565a3fd
add unittest for AITER MLA backend in attention selector
vllmellm 645f400
add unittest for MLA attention backend selector
vllmellm 22c8726
code cleaning
vllmellm 5dc1348
update AITER version
vllmellm 12f8023
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm da8c69f
add ck flash attn in prefill mla computation
vllmellm 1ea5718
further code cleaning
vllmellm 681d777
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm 9ada055
fix mypy typing errors
vllmellm 1ceb3b9
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm 20a3f07
fix mypy error on Iterable typing error
vllmellm 194a42a
remove padding for v tensor in AITER MLA which improves performance
vllmellm a9a02d5
upgrade aiter package version
vllmellm 02a4fb3
only support AITER FA in AITER MLA backend to avoid latency caused by…
vllmellm 95213e2
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm 6e48433
add missing data types of arguments in aiter_mla_decode_fwd
vllmellm 8c2ed72
NIT
vllmellm c95cb02
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm 25d88d5
support block-size 1 for ROCM AITER MLA
vllmellm f38c4a9
fix mypy error
vllmellm 0027497
preserve the lines
vllmellm 78007d0
return back calling the tiron fa function to its original format
vllmellm cb4e861
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm 54817a1
fix fstring in error message
vllmellm 8fd039e
Update MLA attention backend selector for rocm attention selector and…
vllmellm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why relocate these lines? Also can you please explain to me why we now need
self.__class__.BLOCK_TABLE_EXTENDERThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.class.BLOCK_TABLE_EXTENDER this is a static class variable since common had this hardcoded as "[]" in the line below:
self.block_tables.extend([] * cuda_graph_pad_size)
cuz in AiterMLAMetadataBuilder for capturing graph we need "[[]]" instead of "[]", by eliminating the hardcoded extender into class variable allows the subclass to implement itsown extender value or just inherit from parent.
to review this file is better to open the entire file, as the github interface is not representative enough what has been changed.
overall as explained in the PR descript for the summary of the changes to accommodate AITER MLA implementation and reduce the code duplication in the subclass some refactoring has been made in certain function to allow more flexibility in subclasses.
Implementation
ROCM_AITER_MLAis introduced as an additional attention backend type for ROCm platform.To support this backend the modules below are implemented
vllm/attention/backends/rocm_aiter_mla.pyAiterMLABackendinherits fromMLACommonBackend.AiterMLAMetadatainherits fromMLACommonMetadata: note that from this class theadvance_stepfunction utilizesadvance_step_flashinferfunction from VLLM cutom ops.AiterMLAMetadataBuilderinherits fromMLACommonMetadataBuilder.AiterMLAStateinherits fromMLACommonState.AiterMLAImplclass inherits fromCommonMLAImpl:Important notes for this class:
flash_attn_varlen_func(FA function) used in this class is AITER FA implementation (flash_attn_varlen_funcfrom AITER package)._forward_decodefunction in this class usesmla_decode_fwdkernel from AITER package.The MLACommon module has been refactored to reduce code duplication in its subclasses. This was achieved by separating the attention output computation into two dedicated functions named as
_get_fwd_prefill_attn_outputand_get_prefill_ctx_attn_outputthat are used in_compute_prefill_contextand_forward_prefillfunction respectively.Another refactoring is placed in
advance_stepfunction by separating out the pre assertion checks before calling an advance_step method to allowadvance_stepfunction to be overridden without code duplication in its subclasses.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson after resolving merge conflict for this file. the only changes in
common.pyare as below:invoking
ops.advance_step_flashattnin a separate function_ops_advance_stepthat can be overridden by subclass that is used inadvance_stepfunction.use of "static" class variable as
BLOCK_TABLE_EXTENDER: list[list[int]] = []that is used to updateself.block_tablesin graph mode which eliminates the hardcoded "[]" self.block_tables.extend([] * cuda_graph_pad_size) to allow flexibility for the subclasses to override this update based on the class variable.