Skip to content

Commit

Permalink
Add numa binding to init from env helper (#396)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #396

Differential Revision: D46036186

fbshipit-source-id: 1b38c0b5d14fcde4d944b3a55c9eb3862ea32b05
  • Loading branch information
ananthsub authored and facebook-github-bot committed May 19, 2023
1 parent e05c060 commit 0952d37
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torchtnt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def init_from_env(
pg_backend: T.Optional[str] = None,
pg_timeout: timedelta = default_pg_timeout,
float32_matmul_precision: str = "high",
bind_numa: bool = True,
) -> torch.device:
"""Utility function that initializes the device and process group, if applicable.
Expand All @@ -57,6 +58,7 @@ def init_from_env(
pg_timeout (timedelta, optional): Timeout for operations executed against the process
group. Default value equals 30 minutes
float32_matmul_precision (str, optional): The setting for torch's precision of matrix multiplications.
bind_numa (bool, optional): Whether to bind CPU sockets to GPUs.
Returns:
The current device.
Expand Down Expand Up @@ -86,4 +88,17 @@ def init_from_env(
)
torch.distributed.init_process_group(backend=pg_backend, timeout=pg_timeout)
maybe_enable_tf32(float32_matmul_precision)
if bind_numa and device.type == "cuda":
init_numa()

return device


def init_numa() -> None:
import numa

local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
num_sockets = numa.get_max_node() + 1
socket_id = torch.cuda.current_device() // (max(local_world_size // num_sockets, 1))
node_mask = {socket_id}
numa.bind(node_mask)

0 comments on commit 0952d37

Please sign in to comment.