diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 8cec07601094..a8bb1b8a287f 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -30,7 +30,8 @@ class SlurmCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: - return _JOBID_PARAM in os.environ + return all(var in os.environ for var in + (_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID)) @classmethod def get_coordinator_address(cls, timeout_secs: int | None) -> str: