From 713797da88002dc670414d64a43ce154cbb74c5b Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Tue, 14 Apr 2026 10:47:01 +0800 Subject: [PATCH 1/3] [Feat] Upgrade MATE to match CUDA/ROCm behavior on FA Signed-off-by: Xiaodong Ye --- vllm_omni/diffusion/attention/backends/flash_attn.py | 12 +----------- vllm_omni/diffusion/attention/backends/utils/fa.py | 6 +++--- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py index b6ab3a57ad5..d38ea4f6eaa 100644 --- a/vllm_omni/diffusion/attention/backends/flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -96,7 +96,7 @@ def forward_cuda( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - """CUDA/ROCm flash attention implementation.""" + """CUDA/ROCm/MUSA flash attention implementation.""" from vllm_omni.diffusion.attention.backends.utils.fa import ( HAS_FLASH_ATTN, flash_attn_func, @@ -209,13 +209,3 @@ def forward_npu( layout="BNSD", ) return output - - def forward_musa( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AttentionMetadata = None, - ) -> torch.Tensor: - # XXX (MUSA): MUSA uses the same implementation as XPU (mate only provides flash_attn_varlen_func) - return self.forward_xpu(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/attention/backends/utils/fa.py b/vllm_omni/diffusion/attention/backends/utils/fa.py index fe6051f8ba7..18886871a22 100644 --- a/vllm_omni/diffusion/attention/backends/utils/fa.py +++ b/vllm_omni/diffusion/attention/backends/utils/fa.py @@ -41,7 +41,7 @@ pass elif current_omni_platform.is_musa(): try: - from mate import flash_attn_varlen_func # noqa: F401 + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 except (ImportError, ModuleNotFoundError): pass else: @@ -82,8 +82,8 @@ @lru_cache(maxsize=1) def is_mate_available() -> bool: - """Check if MATE (MUSA Flash Attention) is available.""" - return current_omni_platform.is_musa() and flash_attn_varlen_func is not None + """Check if MATE (MUSA AI Tensor Engine) is available.""" + return current_omni_platform.is_musa() and flash_attn_func is not None or flash_attn_varlen_func is not None def _index_first_axis(tensor, indices): From cc1473635a5c4754d513b33ae718a17e5f3561b0 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Tue, 14 Apr 2026 11:02:45 +0800 Subject: [PATCH 2/3] Bump versions Signed-off-by: Xiaodong Ye --- requirements/musa.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements/musa.txt b/requirements/musa.txt index 112f3260465..88a89a15ecf 100644 --- a/requirements/musa.txt +++ b/requirements/musa.txt @@ -1,4 +1,6 @@ -r common.txt # MUSA platform dependencies -torchada>=0.1.46 +torchada>=0.1.49 onnxruntime>=1.23.2 +mate>=0.2.0 +flash_attn_3>=0.1.4 From 2843edcfc92a3e028522c655e77d0d7508a521ee Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Tue, 21 Apr 2026 11:41:30 +0800 Subject: [PATCH 3/3] Bump torchada version to 0.1.50 Signed-off-by: Xiaodong Ye --- requirements/musa.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/musa.txt b/requirements/musa.txt index 88a89a15ecf..c100c70cf05 100644 --- a/requirements/musa.txt +++ b/requirements/musa.txt @@ -1,6 +1,6 @@ -r common.txt # MUSA platform dependencies -torchada>=0.1.49 +torchada>=0.1.50 onnxruntime>=1.23.2 mate>=0.2.0 flash_attn_3>=0.1.4