Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set an upper limit on CPU threads in distributed training #18677

Merged
merged 26 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e1d7c41
set num threads
awaelchli Sep 29, 2023
0977447
update
awaelchli Sep 29, 2023
0b040c4
todo
awaelchli Sep 29, 2023
f431d61
repro script
awaelchli Sep 29, 2023
ed0f68b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
5be3583
update
awaelchli Sep 29, 2023
b40c351
Merge remote-tracking branch 'origin/feature/set-num-threads' into fe…
awaelchli Sep 29, 2023
c4905d9
update repro
awaelchli Oct 2, 2023
5959ece
Merge branch 'master' into feature/set-num-threads
awaelchli Oct 3, 2023
837671e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2023
9fd9046
Merge remote-tracking branch 'origin/feature/set-num-threads' into fe…
awaelchli Oct 3, 2023
3ceb644
update
awaelchli Oct 3, 2023
aaba3f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2023
4ff08bc
if needed
awaelchli Oct 3, 2023
321d243
expected env var set
awaelchli Oct 3, 2023
066c109
Merge branch 'master' into feature/set-num-threads
awaelchli Oct 3, 2023
c3f2b69
test
awaelchli Oct 4, 2023
1d2ad3a
update test
awaelchli Oct 4, 2023
0d486c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
775c3b0
changelog
awaelchli Oct 4, 2023
6f883a1
Merge branch 'master' into feature/set-num-threads
awaelchli Oct 4, 2023
f05fbad
update tests
awaelchli Oct 4, 2023
05fc41d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2023
485ef20
Merge branch 'master' into feature/set-num-threads
awaelchli Oct 4, 2023
66d752b
unblock
awaelchli Oct 4, 2023
9b405c7
Merge branch 'master' into feature/set-num-threads
awaelchli Oct 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:class", "torch.utils.data.DistributedSampler"),
("py:class", "torch_xla.distributed.parallel_loader.MpDeviceLoader"),
("py:func", "torch_xla.distributed.xla_multiprocessing.spawn"),
("py:class", "torch._dynamo.OptimizedModule"),
("py:mod", "tqdm"),
("py:meth", "training_step"),
("py:meth", "transfer_batch_to_device"),
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled launching via `torchrun` in a SLURM environment; the `TorchElasticEnvironment` now gets chosen over the `SLURMEnvironment` if both are detected ([#18618](https://github.com/Lightning-AI/lightning/pull/18618))


- If not set by the user, Lightning will set `OMP_NUM_THREADS` to `num_cpus / num_processes` when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks ([#18677](https://github.com/Lightning-AI/lightning/pull/18677))



### Deprecated

- Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.distributed import _suggested_max_num_threads

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,7 +178,7 @@ def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
torchrun_args.extend(script_args)

# set a good default number of threads for OMP to avoid warnings being emitted to the user
os.environ.setdefault("OMP_NUM_THREADS", str(max(1, (os.cpu_count() or 1) // num_processes)))
os.environ.setdefault("OMP_NUM_THREADS", str(_suggested_max_num_threads()))
torchrun.main(torchrun_args)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightning.fabric.accelerators.cpu import CPUAccelerator
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.distributed import _set_num_threads_if_needed
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states

Expand Down Expand Up @@ -129,10 +130,11 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()

if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
args, kwargs = _disable_module_memory_sharing((args, kwargs))

_set_num_threads_if_needed(num_processes=self._strategy.num_processes)

os.environ["LOCAL_RANK"] = str(process_idx)
results = function(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.utilities.distributed import _set_num_threads_if_needed
from lightning.fabric.utilities.rank_zero import rank_prefixed_message

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,6 +99,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
_launch_process_observer(self.procs)

_set_num_threads_if_needed(num_processes=self.num_processes)
return function(*args, **kwargs)

def _call_children_scripts(self) -> None:
Expand Down
14 changes: 14 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler

from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp

Expand Down Expand Up @@ -359,3 +360,16 @@ def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any)
def __iter__(self) -> Iterator:
self.dataset.reset()
return (self.dataset[index] for index in super().__iter__())


def _suggested_max_num_threads(num_processes: int = 1) -> int:
if num_processes < 1:
raise ValueError(f"`num_processes` should be >= 1, got {num_processes}.")
return max(1, _num_cpus_available() // num_processes)


def _set_num_threads_if_needed(num_processes: int = 1) -> None:
if "OMP_NUM_THREADS" not in os.environ:
num_threads = _suggested_max_num_threads(num_processes)
torch.set_num_threads(num_threads)
os.environ["OMP_NUM_THREADS"] = str(num_threads)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled launching via `torchrun` in a SLURM environment; the `TorchElasticEnvironment` now gets chosen over the `SLURMEnvironment` if both are detected ([#18618](https://github.com/Lightning-AI/lightning/pull/18618))


- If not set by the user, Lightning will set `OMP_NUM_THREADS` to `num_cpus / num_processes` when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks ([#18677](https://github.com/Lightning-AI/lightning/pull/18677))


### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_disable_module_memory_sharing,
)
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.distributed import _set_num_threads_if_needed
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import CPUAccelerator
Expand Down Expand Up @@ -154,6 +155,8 @@ def _wrapping_function(
if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
args, kwargs = _disable_module_memory_sharing((args, kwargs))

_set_num_threads_if_needed(num_processes=self._strategy.num_processes)

os.environ["LOCAL_RANK"] = str(process_idx)
results = function(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_hydra_subprocess_cmd,
_launch_process_observer,
)
from lightning.fabric.utilities.distributed import _set_num_threads_if_needed
from lightning.pytorch.strategies.launchers.launcher import _Launcher
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM

Expand Down Expand Up @@ -96,6 +97,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
_launch_process_observer(self.procs)

_set_num_threads_if_needed(num_processes=self.num_processes)
return function(*args, **kwargs)

def kill(self, signum: _SIGNUM) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def restore_env_variables():
"POPLAR_ENGINE_OPTIONS", # set by IPUStrategy
"CUDA_MODULE_LOADING", # leaked since PyTorch 1.13
"CRC32C_SW_MODE", # set by tensorboardX
"OMP_NUM_THREADS", # set by our launchers
# set by XLA FSDP on XRT
"XRT_TORCH_DIST_ROOT",
"XRT_MESH_SERVICE_ADDRESS",
Expand Down
32 changes: 31 additions & 1 deletion tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
import os
from functools import partial
from pathlib import Path
from unittest import mock

import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp, is_shared_filesystem
from lightning.fabric.utilities.distributed import (
_gather_all_tensors,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
is_shared_filesystem,
)

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -158,3 +165,26 @@ def _test_is_shared_filesystem(strategy, tmp_path, monkeypatch):

# Remote path is considered shared
assert is_shared_filesystem(strategy, path="s3://my-bucket/data")


@pytest.mark.parametrize("invalid", [-1, 0])
def test_suggested_max_num_threads(invalid):
with pytest.raises(ValueError, match="should be >= 1"):
_suggested_max_num_threads(invalid)


@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("lightning.fabric.utilities.distributed.torch.set_num_threads")
@mock.patch("lightning.fabric.utilities.distributed._num_cpus_available", return_value=4)
@pytest.mark.parametrize(("num_processes", "expected"), [(1, 4), (2, 2), (3, 1), (4, 1), (8, 1)])
def test_set_num_threads_if_needed(_, set_num_threads_mock, num_processes, expected):
assert "OMP_NUM_THREADS" not in os.environ
_set_num_threads_if_needed(num_processes)
set_num_threads_mock.assert_called_with(expected)
assert os.environ["OMP_NUM_THREADS"] == str(expected)

# if env variable is already set, no change
set_num_threads_mock.reset_mock()
_set_num_threads_if_needed(1)
set_num_threads_mock.assert_not_called()
assert os.environ["OMP_NUM_THREADS"] == str(expected)
1 change: 1 addition & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def restore_env_variables():
"KMP_DUPLICATE_LIB_OK", # leaked since PyTorch 1.13
"CRC32C_SW_MODE", # leaked by tensorboardX
"TRITON_CACHE_DIR", # leaked by torch.compile
"OMP_NUM_THREADS", # set by our launchers
# leaked by XLA
"ALLOW_MULTIPLE_LIBTPU_LOAD",
"GRPC_VERBOSITY",
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def local_rank(self):
def node_rank(self):
return 0

ddp_strategy = DDPStrategy(cluster_environment=MyClusterEnvironment())
ddp_strategy = DDPStrategy(cluster_environment=MyClusterEnvironment(), parallel_devices=[torch.device("cpu")])
assert ddp_strategy.launcher is None
ddp_strategy._configure_launcher()
assert isinstance(ddp_strategy.launcher, _SubprocessScriptLauncher)
Expand Down
Loading