diff --git a/requirements/musa.txt b/requirements/musa.txt index 112f3260465..c100c70cf05 100644 --- a/requirements/musa.txt +++ b/requirements/musa.txt @@ -1,4 +1,6 @@ -r common.txt # MUSA platform dependencies -torchada>=0.1.46 +torchada>=0.1.50 onnxruntime>=1.23.2 +mate>=0.2.0 +flash_attn_3>=0.1.4 diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py index b6ab3a57ad5..d38ea4f6eaa 100644 --- a/vllm_omni/diffusion/attention/backends/flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -96,7 +96,7 @@ def forward_cuda( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - """CUDA/ROCm flash attention implementation.""" + """CUDA/ROCm/MUSA flash attention implementation.""" from vllm_omni.diffusion.attention.backends.utils.fa import ( HAS_FLASH_ATTN, flash_attn_func, @@ -209,13 +209,3 @@ def forward_npu( layout="BNSD", ) return output - - def forward_musa( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AttentionMetadata = None, - ) -> torch.Tensor: - # XXX (MUSA): MUSA uses the same implementation as XPU (mate only provides flash_attn_varlen_func) - return self.forward_xpu(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/attention/backends/utils/fa.py b/vllm_omni/diffusion/attention/backends/utils/fa.py index fe6051f8ba7..18886871a22 100644 --- a/vllm_omni/diffusion/attention/backends/utils/fa.py +++ b/vllm_omni/diffusion/attention/backends/utils/fa.py @@ -41,7 +41,7 @@ pass elif current_omni_platform.is_musa(): try: - from mate import flash_attn_varlen_func # noqa: F401 + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 except (ImportError, ModuleNotFoundError): pass else: @@ -82,8 +82,8 @@ @lru_cache(maxsize=1) def is_mate_available() -> bool: - """Check if MATE (MUSA Flash Attention) is available.""" - return current_omni_platform.is_musa() and flash_attn_varlen_func is not None + """Check if MATE (MUSA AI Tensor Engine) is available.""" + return current_omni_platform.is_musa() and flash_attn_func is not None or flash_attn_varlen_func is not None def _index_first_axis(tensor, indices):