Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion xtuner/v1/ops/flash_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ def get_flash_attn_varlen() -> FlashAttnVarlenProtocol:

if os.environ.get("XTUNER_USE_FA3", "0") == "1":
try:
from flash_attn_interface import flash_attn_3_cuda
# flash_attn_interface renamed its v3 entrypoint in newer releases:
# - old: flash_attn_3_cuda
# - new: flash_attn_3_gpu
# We only import it here as an availability check; the actual impl lives in `.gpu`.
try:
from flash_attn_interface import flash_attn_3_gpu # noqa: F401
except ImportError:
from flash_attn_interface import flash_attn_3_cuda # noqa: F401
except ImportError as e:
raise ImportError(f"Import FlashAttention 3 failed {e}, Please install it manually.")
from .gpu import gpu_flash_varlen_attn_v3 as flash_attn_varlen_func
Expand Down
9 changes: 8 additions & 1 deletion xtuner/v1/ops/flash_attn/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@


try:
from flash_attn_interface import flash_attn_3_cuda
# flash_attn_interface renamed its v3 entrypoint in newer releases:
# - old: flash_attn_3_cuda
# - new: flash_attn_3_gpu
# Keep the rest of this file stable by aliasing whichever exists to `flash_attn_3_cuda`.
try:
from flash_attn_interface import flash_attn_3_gpu as flash_attn_3_cuda
except ImportError:
from flash_attn_interface import flash_attn_3_cuda as flash_attn_3_cuda
from flash_attn_interface import maybe_contiguous as maybe_contiguous_v3

@torch.library.custom_op("flash_attn::_flash_attn_varlen_forward_v3", mutates_args=(), device_types="cuda")
Expand Down
Loading