Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements/musa.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
-r common.txt
# MUSA platform dependencies
torchada>=0.1.46
torchada>=0.1.50
onnxruntime>=1.23.2
mate>=0.2.0
flash_attn_3>=0.1.4
12 changes: 1 addition & 11 deletions vllm_omni/diffusion/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward_cuda(
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
"""CUDA/ROCm flash attention implementation."""
"""CUDA/ROCm/MUSA flash attention implementation."""
from vllm_omni.diffusion.attention.backends.utils.fa import (
HAS_FLASH_ATTN,
flash_attn_func,
Expand Down Expand Up @@ -209,13 +209,3 @@ def forward_npu(
layout="BNSD",
)
return output

def forward_musa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
# XXX (MUSA): MUSA uses the same implementation as XPU (mate only provides flash_attn_varlen_func)
return self.forward_xpu(query, key, value, attn_metadata)
6 changes: 3 additions & 3 deletions vllm_omni/diffusion/attention/backends/utils/fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
pass
elif current_omni_platform.is_musa():
try:
from mate import flash_attn_varlen_func # noqa: F401
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
except (ImportError, ModuleNotFoundError):
pass
else:
Expand Down Expand Up @@ -82,8 +82,8 @@

@lru_cache(maxsize=1)
def is_mate_available() -> bool:
"""Check if MATE (MUSA Flash Attention) is available."""
return current_omni_platform.is_musa() and flash_attn_varlen_func is not None
"""Check if MATE (MUSA AI Tensor Engine) is available."""
return current_omni_platform.is_musa() and flash_attn_func is not None or flash_attn_varlen_func is not None


def _index_first_axis(tensor, indices):
Expand Down
Loading