Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX/distributed] Small clean ups on warning emissions in jax.distributed. #24853

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def auto_detect_unset_distributed_params(cls,
initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:

if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)

# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if cluster_detection_method is not None:
Expand Down
39 changes: 26 additions & 13 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
logger = logging.getLogger(__name__)


_CHECK_PROXY_ENVS = config.bool_flag(
name="jax_check_proxy_envs",
default=True,
help="Checks proxy vars in user envs and emit warnings.",
)


class State:
process_id: int = 0
num_processes: int = 1
Expand Down Expand Up @@ -55,16 +62,18 @@ def initialize(self,
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
local_device_ids = list(map(int, env_ids.split(",")))

(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address,
num_processes,
process_id,
local_device_ids,
cluster_detection_method,
initialization_timeout,
)
)
_any_is_none = lambda *args: any(arg is None for arg in args)
if _any_is_none(coordinator_address, num_processes, process_id, local_device_ids):
(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address,
num_processes,
process_id,
local_device_ids,
cluster_detection_method,
initialization_timeout,
)
)

if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
Expand Down Expand Up @@ -92,8 +101,10 @@ def initialize(self,

self.process_id = process_id

# Emit a warning about PROXY variables if they are in the user's env:
proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]
proxy_vars = []
if _CHECK_PROXY_ENVS.value:
proxy_vars = [key for key in os.environ.keys()
if '_proxy' in key.lower()]

if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
Expand Down Expand Up @@ -179,7 +190,9 @@ def initialize(coordinator_address: str | None = None,
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.

Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
``num_processes``, ``process_id``, and ``local_device_ids`` arguments
to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster
envs auto detection will be skipped.

Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
Expand Down