Skip to content

Commit

Permalink
Skips DDP parameter sync (#4301)
Browse files Browse the repository at this point in the history
* ddp no-sync

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: ananthsub <[email protected]>

* Update training_loop.py

* factor __enter__ and __exit__ out to separate context manager

* delete _updated_model_last_step

Co-authored-by: justusschock <[email protected]>
Co-authored-by: Teddy Koker <[email protected]>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: chaton <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
6 people authored Oct 29, 2020
1 parent b459fd2 commit bbd81df
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from contextlib import contextmanager
from copy import copy, deepcopy

import numpy as np
Expand Down Expand Up @@ -655,6 +655,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
should_accumulate = not (accumulation_done or is_final_batch)

# lightning module hook
splits = self.tbptt_split_batch(batch)
Expand All @@ -675,13 +676,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
model = self.trainer.get_model()
model.toggle_optimizer(optimizer, opt_idx)

if not (accumulation_done or is_final_batch):
if should_accumulate:
# For gradient accumulation

# -------------------
# calculate loss (train step + train step end)
# -------------------
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

# perform dpp sync only when performing optimizer_step
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

batch_outputs = self._process_closure_result(
batch_callback_metrics=batch_callback_metrics,
batch_log_metrics=batch_log_metrics,
Expand All @@ -695,7 +700,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# gradient update with accumulated gradients

else:

if self.automatic_optimization:

def train_step_and_backward_closure():
Expand Down Expand Up @@ -760,6 +764,13 @@ def train_step_and_backward_closure():
)
return result

@contextmanager
def block_ddp_sync_behaviour(self):
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
yield from self.trainer.model.no_sync()
else:
yield

def _process_closure_result(
self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int
) -> list:
Expand Down

0 comments on commit bbd81df

Please sign in to comment.