Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient accumulation for ShardedDataParallel #9122

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))


- Fixed gradient accumulation for `DDPShardedPlugin` ([#9122](https://github.com/PyTorchLightning/pytorch-lightning/pull/9122))


## [1.4.7] - 2021-09-14

- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from contextlib import contextmanager
from typing import Dict, Generator, Optional

import torch

Expand Down Expand Up @@ -100,6 +101,19 @@ def lightning_module(self) -> "pl.LightningModule":
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass

@contextmanager
def block_backward_sync(self) -> Generator:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Blocks syncing gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, ShardedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

def post_training_step(self):
pass

Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from contextlib import contextmanager
from typing import Dict, Generator, Optional

import torch

Expand Down Expand Up @@ -63,6 +64,19 @@ def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, ShardedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
else:
assert kwargs["reduce_buffer_size"] == expected_buffer_size


@RunIf(skip_windows=True, fairscale=True)
def test_block_backward_sync(tmpdir):
plugin = DDPShardedPlugin()
model = mock.MagicMock(spec=ShardedDataParallel)
with mock.patch.object(plugin, "_model", model):
with plugin.block_backward_sync():
pass
model.no_sync.assert_called_once()