diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 344523842343..93a12f1f6587 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1831,6 +1831,7 @@ def _inner_training_loop( # AT THE VERY END! _ = list(train_dataloader.sampler) + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) @@ -1867,6 +1868,7 @@ def _inner_training_loop( step = -1 for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False @@ -1887,7 +1889,7 @@ def _inner_training_loop( self.control = self.callback_handler.on_step_begin(args, self.state, self.control) if ( - ((step + 1) % args.gradient_accumulation_steps != 0) + (total_batched_samples % args.gradient_accumulation_steps != 0) and args.local_rank != -1 and args._no_sync_in_gradient_accumulation ): @@ -1913,7 +1915,7 @@ def _inner_training_loop( if self.deepspeed: self.deepspeed.step() - if (step + 1) % args.gradient_accumulation_steps == 0 or ( + if total_batched_samples % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch