diff --git a/torchtnt/utils/env.py b/torchtnt/utils/env.py index 49dda0826e..984ab0a06a 100644 --- a/torchtnt/utils/env.py +++ b/torchtnt/utils/env.py @@ -38,6 +38,7 @@ def init_from_env( pg_backend: T.Optional[str] = None, pg_timeout: timedelta = default_pg_timeout, float32_matmul_precision: str = "high", + bind_numa: bool = True, ) -> torch.device: """Utility function that initializes the device and process group, if applicable. @@ -57,6 +58,7 @@ def init_from_env( pg_timeout (timedelta, optional): Timeout for operations executed against the process group. Default value equals 30 minutes float32_matmul_precision (str, optional): The setting for torch's precision of matrix multiplications. + bind_numa (bool, optional): Whether to bind CPU sockets to GPUs. Returns: The current device. @@ -86,4 +88,17 @@ def init_from_env( ) torch.distributed.init_process_group(backend=pg_backend, timeout=pg_timeout) maybe_enable_tf32(float32_matmul_precision) + if bind_numa and device.type == "cuda": + init_numa() + return device + + +def init_numa() -> None: + import numa + + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + num_sockets = numa.get_max_node() + 1 + socket_id = torch.cuda.current_device() // (max(local_world_size // num_sockets, 1)) + node_mask = {socket_id} + numa.bind(node_mask)