From 60bb298d05e4df50250ed2f22e6200cc910b80c0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Sun, 8 Feb 2026 02:34:15 +0000 Subject: [PATCH] pytest-dist round robin to gpus stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2241, branch: drisspg/stack/11 --- tests/cute/conftest.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/cute/conftest.py diff --git a/tests/cute/conftest.py b/tests/cute/conftest.py new file mode 100644 index 00000000000..6ee05e9a3a4 --- /dev/null +++ b/tests/cute/conftest.py @@ -0,0 +1,31 @@ +import os +import subprocess + + +def _get_gpu_ids(): + visible = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible: + return [g.strip() for g in visible.split(",")] + + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + return result.stdout.strip().splitlines() + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + return ["0"] + + +def pytest_configure(config): + worker_id = os.environ.get("PYTEST_XDIST_WORKER") + if not worker_id: + return + worker_num = int(worker_id.replace("gw", "")) + gpu_ids = _get_gpu_ids() + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)]