Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False
NVCC_THREADS = os.getenv("NVCC_THREADS") or "4"

@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
Expand Down Expand Up @@ -186,8 +187,7 @@ def detect_hipify_v2():


def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args + ["--threads", NVCC_THREADS]


def rename_cpp_to_cu(cpp_files):
Expand Down Expand Up @@ -571,15 +571,23 @@ def __init__(self, *args, **kwargs) -> None:
if not os.environ.get("MAX_JOBS"):
import psutil

nvcc_threads = max(1, int(NVCC_THREADS))

# calculate the maximum allowed NUM_JOBS based on cores
max_num_jobs_cores = max(1, os.cpu_count() // 2)

# calculate the maximum allowed NUM_JOBS based on free memory
free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
# Assume worst-case peak observed memory usage of ~5GB per NVCC thread.
# Limit: peak_threads = max_jobs * nvcc_threads and peak_threads * 5GB <= free_memory.
max_num_jobs_memory = max(1, int(free_memory_gb / (5 * nvcc_threads)))

# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
print(
f"Auto set MAX_JOBS to `{max_jobs}`, NVCC_THREADS to `{nvcc_threads}`. "
"If you see memory pressure, please use a lower `MAX_JOBS=N` or `NVCC_THREADS=N` value."
)
os.environ["MAX_JOBS"] = str(max_jobs)

super().__init__(*args, **kwargs)
Expand Down