Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient gradient accumulation in LightningLite #14966

Merged
merged 33 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ccf59c7
define and implement interfaces
awaelchli Sep 30, 2022
7b2bf63
Lite api
awaelchli Sep 30, 2022
d7814de
format
awaelchli Sep 30, 2022
bbd33db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2022
f56979c
add args in docs
awaelchli Sep 30, 2022
8550487
update
awaelchli Sep 30, 2022
363a109
Merge remote-tracking branch 'origin/lite/gradient-sync' into lite/gr…
awaelchli Sep 30, 2022
9c5a3ef
Merge branch 'master' into lite/gradient-sync
awaelchli Oct 13, 2022
82696ee
alternative
awaelchli Oct 13, 2022
fc8dbf9
alternative
awaelchli Oct 13, 2022
060d843
rename
awaelchli Oct 13, 2022
a106b1f
rename
awaelchli Oct 16, 2022
b309451
Merge branch 'master' into lite/gradient-sync-alt
awaelchli Oct 16, 2022
2e9ac6d
update
awaelchli Oct 16, 2022
ba3dc14
test
awaelchli Oct 16, 2022
a5b327e
test
awaelchli Oct 16, 2022
3bcc77d
test
awaelchli Oct 16, 2022
a72780f
test
awaelchli Oct 16, 2022
e7827aa
deepspeed
awaelchli Oct 16, 2022
dbe32c2
changelog
awaelchli Oct 16, 2022
bfc83d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2022
a87b500
docs
awaelchli Oct 16, 2022
2b14f78
update
awaelchli Oct 16, 2022
5709685
update
awaelchli Oct 16, 2022
2c2c314
Update src/lightning_lite/lite.py
awaelchli Oct 19, 2022
1696cbe
Update src/lightning_lite/strategies/ddp.py
awaelchli Oct 19, 2022
680ed93
skip -> no
awaelchli Oct 19, 2022
59783e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2022
53cd439
warning instead of error
awaelchli Oct 19, 2022
cbcef78
update
awaelchli Oct 19, 2022
d73dff4
Merge branch 'master' into lite/gradient-sync-alt
awaelchli Oct 19, 2022
67fbda8
document which strategies don't support it
awaelchli Oct 19, 2022
d6a095a
Merge branch 'master' into lite/gradient-sync-alt
awaelchli Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/source-pytorch/starter/lightning_lite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,35 @@ the data is written to disk.
self.barrier()

# All processes are allowed to read the data now


no_backward_sync
================

Use this context manager when performing gradient accumulation and using a distributed strategy (e.g., DDP).
It will speed up your training loop by cutting redundant communication between processes during the accumulation phase.

.. code-block:: python

# Accumulate gradient 8 batches at a time
is_accumulating = batch_idx % 8 != 0

with self.no_backward_sync(model, enabled=is_accumulating):
output = model(input)
loss = ...
self.backward(loss)
...

# Step the optimizer every 8 batches
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()

Both the model's `.forward()` and the `self.backward()` call need to run under this context as shown in the example above.
For single-device strategies, it is a no-op. There are strategies that don't support this:

- deepspeed
- dp
- xla

For these, the context manager falls back to a no-op and emits a warning.
50 changes: 48 additions & 2 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from pathlib import Path
from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union
Expand All @@ -29,7 +29,7 @@
from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy
from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.apply_func import convert_to_tensors
Expand Down Expand Up @@ -345,6 +345,52 @@ def all_gather(
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return self._strategy.broadcast(obj, src=src)

@contextmanager
def no_backward_sync(self, module: _LiteModule, enabled: bool = True) -> Generator:
"""Skip gradient synchronization during backward to avoid redundant communication overhead.

Use this context manager when performing gradient accumulation to speed up training with multiple devices.

Example::

# Accumulate gradient 8 batches at a time
with self.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
output = model(input)
loss = ...
self.backward(loss)
...

For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
Both the model's `.forward()` and the `self.backward()` call need to run under this context.

Args:
module: The module for which to control the gradient synchronization.
enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not
skip.
"""

if not isinstance(module, _LiteModule):
raise TypeError(
"You need to set up the model first before you can call `self.no_backward_sync()`:"
" `model = self.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, SingleDeviceStrategy):
context = nullcontext()
elif self._strategy._backward_sync_control is None:
rank_zero_warn(
f"The `{self._strategy.__class__.__name__}` does not support skipping the gradient synchronization."
f" Remove `.no_backward_sync()` from your code or choose a different strategy.",
category=PossibleUserWarning,
)
context = nullcontext()
else:
context = self._strategy._backward_sync_control.no_backward_sync( # type: ignore[assignment]
module._forward_module
)

with context:
yield

def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None:
"""Save checkpoint contents to a file.

Expand Down
21 changes: 19 additions & 2 deletions src/lightning_lite/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Union

import torch
import torch.distributed
Expand All @@ -27,7 +28,7 @@
from lightning_lite.plugins.precision import Precision
from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning_lite.strategies.parallel import ParallelStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.strategies.strategy import _BackwardSyncControl, TBroadcast
from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device
from lightning_lite.utilities.distributed import group as _group
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._backward_sync_control = _DDPBackwardSyncControl()
self._ddp_kwargs = kwargs

@property
Expand Down Expand Up @@ -178,3 +180,18 @@ def _determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
return None
return [self.root_device.index]


class _DDPBackwardSyncControl(_BackwardSyncControl):
@contextmanager
def no_backward_sync(self, module: Module) -> Generator:
"""Blocks gradient synchronization inside the
:class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper."""
if not isinstance(module, DistributedDataParallel):
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `DistributedDataParallel`."
f" Got: {module.__class__.__name__}."
)
with module.no_sync(): # type: ignore[operator]
yield
2 changes: 2 additions & 0 deletions src/lightning_lite/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_io import CheckpointIO
from lightning_lite.plugins.precision import Precision
from lightning_lite.strategies.ddp import _DDPBackwardSyncControl
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_lite.strategies.parallel import ParallelStrategy
from lightning_lite.strategies.strategy import TBroadcast
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._backward_sync_control = _DDPBackwardSyncControl()
self._start_method = start_method
self._ddp_kwargs = kwargs
self._local_rank = 0
Expand Down
1 change: 1 addition & 0 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
precision=precision,
process_group_backend=process_group_backend,
)
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally

self.config = self._load_config(config)
if self.config is None:
Expand Down
46 changes: 19 additions & 27 deletions src/lightning_lite/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.strategies import DDPSpawnStrategy
from lightning_lite.strategies.ddp import DDPStrategy
from lightning_lite.strategies.strategy import _BackwardSyncControl
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.imports import _IS_WINDOWS

Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
timeout=timeout,
**kwargs,
)
self._backward_sync_control = _FairscaleBackwardSyncControl()
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
Expand All @@ -87,19 +89,6 @@ def setup_module_and_optimizers(
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

@contextmanager
def block_backward_sync(self, module: Module) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(module, ShardedDataParallel):
with module.no_sync():
yield None
else:
yield None

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
Expand Down Expand Up @@ -141,13 +130,14 @@ def __init__(
timeout=timeout,
**kwargs,
)
self._backward_sync_control = _FairscaleBackwardSyncControl()
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
) -> Tuple["ShardedDataParallel", List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

Return:
Expand All @@ -163,19 +153,6 @@ def setup_module_and_optimizers(
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

@contextmanager
def block_backward_sync(self, module: Module) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(module, ShardedDataParallel):
with module.no_sync():
yield None
else:
yield None

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
Expand Down Expand Up @@ -204,3 +181,18 @@ def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precisio
optimizers[x] = zero_optimizer
del optimizer
return optimizers


class _FairscaleBackwardSyncControl(_BackwardSyncControl):
@contextmanager
def no_backward_sync(self, module: Module) -> Generator:
"""Blocks gradient synchronization inside the :class:`~fairscale.nn.data_parallel.ShardedDataParallel`
wrapper."""
if not isinstance(module, ShardedDataParallel):
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `ShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
with module.no_sync():
yield None
18 changes: 1 addition & 17 deletions src/lightning_lite/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, List, Optional

import torch
from torch import Tensor
from torch.nn import Module

import lightning_lite as lite
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_io import CheckpointIO
Expand Down Expand Up @@ -94,19 +91,6 @@ def reduce_boolean_decision(self, decision: bool) -> bool:
decision = bool(decision == self.world_size)
return decision

@contextmanager
def block_backward_sync(self, module: Module) -> Generator:
"""Blocks ddp sync gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(module, lite.utilities.types.DistributedDataParallel):
with module.no_sync():
yield None
else:
yield None

def teardown(self) -> None:
assert self.cluster_environment is not None
self.cluster_environment.teardown()
Expand Down
22 changes: 20 additions & 2 deletions src/lightning_lite/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union

import torch
Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io
self._precision: Optional[Precision] = precision
self._launcher: Optional[_Launcher] = None
self._backward_sync_control: Optional[_BackwardSyncControl] = None

@property
@abstractmethod
Expand Down Expand Up @@ -148,7 +149,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->
device = device or self.root_device
return move_data_to_device(batch, device)

@contextlib.contextmanager
@contextmanager
def module_sharded_context(self) -> Generator:
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly, which is useful for extremely large models which can save memory and
Expand Down Expand Up @@ -296,3 +297,20 @@ def teardown(self) -> None:
@classmethod
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
pass


class _BackwardSyncControl(ABC):
"""Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient
synchronization during/after back-propagation.

The most common use-case is gradient accumulation. If a :class:`Strategy` implements this interface, the user can
implement their gradient accumulation loop very efficiently by disabling redundant gradient synchronization.
"""

@contextmanager
@abstractmethod
def no_backward_sync(self, module: Module) -> Generator:
"""Blocks the synchronization of gradients during the backward pass.

This is a context manager. It is only effective if it wraps a call to `.backward()`.
"""
1 change: 1 addition & 0 deletions src/lightning_lite/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
start_method="fork",
)
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
self._launched = False

@property
Expand Down
7 changes: 6 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - 2022-MM-DD

### Added

- Added `LightningLite.no_backward_sync` for control over efficient gradient accumulation with distributed strategies ([#14966](https://github.com/Lightning-AI/lightning/pull/14966))



### Changed

- Moved the warning about saving nn.Module in `save_hyperparameters()` to before the deepcopy ([#15132](https://github.com/Lightning-AI/lightning/pull/15132))
Expand Down
Loading