Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion flash_attn/cute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,33 @@
_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1"


def _is_cuda_12() -> bool:
"""Check if the CUDA toolkit version is 12.x.

2CTA forward non-causal has a codegen regression on CUDA 12 that causes
~18% slowdown compared to 1CTA. This is fixed in CUDA 13.x.
"""
try:
import torch

cuda_version = torch.version.cuda
if cuda_version is not None:
major = cuda_version.split(".")[0]
return int(major) == 12
except Exception:
pass
return False


_fa_disable_2cta_cuda12: bool = _is_cuda_12()


def _get_use_clc_scheduler_default() -> bool:
return _fa_clc_enabled


def _get_disable_2cta_default() -> bool:
return _fa_disable_2cta_enabled
return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12


def _compute_base_hash(func: Callable) -> str:
Expand Down