[MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine)#22051
[MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine)#22051froststeam wants to merge 1 commit intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the MUSA (Moore Threads GPU) hardware backend, specifically focusing on Flash Attention integration. It adds necessary dependencies, configuration parameters, and a new MUSA-specific attention module that wraps the mate library's flash attention functions. The implementation uses a thread-local context manager to automatically inject scheduler metadata into attention calls. Key changes include updates to the attention registry, the FlashAttentionBackend to handle MUSA-specific logic, and server argument adjustments for MUSA compatibility. Feedback highlights potential issues with global buffer safety in multi-GPU environments, metadata cache collisions due to non-unique keys, and the implications of ignoring cu_seqlens_k_new in the MUSA implementation.
yeahdongcn
left a comment
There was a problem hiding this comment.
I think it would be better to split this into two commits: one carrying over changes from the previous PR, and another fixing the regression in selecting FA kernels for different NVIDIA GPU architectures. This should make it easier for the SGLang core team to review.
3369ebb to
ba20eee
Compare
08c699e to
9cb257c
Compare
9cb257c to
0af5fe5
Compare
| flash_attn_varlen_func = self.flash_attn_varlen_func | ||
| flash_attn_with_kvcache = self.flash_attn_with_kvcache |
There was a problem hiding this comment.
The key updates to resolve the previous regression in FA3/FA4 kernel wiring on NVIDIA GPUs should be here, since upstream main now selects and stores the kernels during __init__.
There was a problem hiding this comment.
I'll go ahead and approve this PR to trigger CI. Could @Fridge003 and @Kangyan-Zhou please take a final look? Thanks!
|
/tag-and-rerun-ci |
Motivation
This PR fixes the Flash Attention backend support that was previously merged in PR #17985 but later reverted in PR #22002 due to a bug. The original commit 2373552 caused CI failures (see failed CI job).
Previously, the MUSA-adapted flash attention implementation had a bug in the
_forward_extend_implmethod. The code was missing a proper mechanism to select the correct kernel implementation based on thefa_impl_verparameter, causing it to always use the default FA3 implementation regardless of the specified version.Fix Applied
After rebasing to the latest main branch, the kernel selection logic has been refactored and moved to the
FlashAttentionBackend.__init__method. This ensures that the appropriate flash attention implementation is selected during initialization based on thefa_impl_verparameter.Moved kernel selection to
__init__: The logic to select the correct flash attention kernel (including MUSA-specific implementations) is now handled in theFlashAttentionBackend.__init__method, where two instance variables are initialized:self.flash_attn_with_kvcache: For cached attention operationsself.flash_attn_varlen_func: For variable-length attention operationsUpdated forward methods: Both
_forward_extend_impland_forward_decode_implnow use these instance variables instead of directly calling the default implementations, ensuring the correct kernel is used based on the initialized configuration.Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciRelated Links: