diff --git a/xtuner/v1/ops/flash_attn/__init__.py b/xtuner/v1/ops/flash_attn/__init__.py index ad9712de7..226810283 100644 --- a/xtuner/v1/ops/flash_attn/__init__.py +++ b/xtuner/v1/ops/flash_attn/__init__.py @@ -19,7 +19,14 @@ def get_flash_attn_varlen() -> FlashAttnVarlenProtocol: if os.environ.get("XTUNER_USE_FA3", "0") == "1": try: - from flash_attn_interface import flash_attn_3_cuda + # flash_attn_interface renamed its v3 entrypoint in newer releases: + # - old: flash_attn_3_cuda + # - new: flash_attn_3_gpu + # We only import it here as an availability check; the actual impl lives in `.gpu`. + try: + from flash_attn_interface import flash_attn_3_gpu # noqa: F401 + except ImportError: + from flash_attn_interface import flash_attn_3_cuda # noqa: F401 except ImportError as e: raise ImportError(f"Import FlashAttention 3 failed {e}, Please install it manually.") from .gpu import gpu_flash_varlen_attn_v3 as flash_attn_varlen_func diff --git a/xtuner/v1/ops/flash_attn/gpu.py b/xtuner/v1/ops/flash_attn/gpu.py index 6681d73c1..beb1f796b 100644 --- a/xtuner/v1/ops/flash_attn/gpu.py +++ b/xtuner/v1/ops/flash_attn/gpu.py @@ -5,7 +5,14 @@ try: - from flash_attn_interface import flash_attn_3_cuda + # flash_attn_interface renamed its v3 entrypoint in newer releases: + # - old: flash_attn_3_cuda + # - new: flash_attn_3_gpu + # Keep the rest of this file stable by aliasing whichever exists to `flash_attn_3_cuda`. + try: + from flash_attn_interface import flash_attn_3_gpu as flash_attn_3_cuda + except ImportError: + from flash_attn_interface import flash_attn_3_cuda as flash_attn_3_cuda from flash_attn_interface import maybe_contiguous as maybe_contiguous_v3 @torch.library.custom_op("flash_attn::_flash_attn_varlen_forward_v3", mutates_args=(), device_types="cuda")