diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a342f17059d5..2d04f82b893e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2426,15 +2426,28 @@ def _inner_training_loop( self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) self._load_callback_state() - epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) - if not args.ignore_data_skip: - steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - steps_trained_in_current_epoch *= args.gradient_accumulation_steps + if num_update_steps_per_epoch is not None: + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 else: - steps_trained_in_current_epoch = 0 + # If the dataloader does not have a length, we cannot restore the number of trained epochs. + # In the following loop, we repeatedly iterate over the dataloader to skip the first + # `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly. + epochs_trained = 0 + steps_trained_in_current_epoch = self.state.global_step * args.gradient_accumulation_steps + if args.ignore_data_skip: + raise ValueError( + "The dataloader does not have a length, so it is impossible to restore the number of trained" + " epochs. Please disable the `ignore_data_skip` option." + ) logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(f" Continuing training from epoch {epochs_trained}") + if num_update_steps_per_epoch is not None: + logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( @@ -2467,6 +2480,32 @@ def _inner_training_loop( if hasattr(epoch_dataloader, "set_epoch"): epoch_dataloader.set_epoch(epoch) + steps_skipped = 0 + rng_to_sync = False + epoch_iterator = None + if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None: + # Since the dataloader does not have a length, we just loop until the required number of steps. + # Every time we reach the end of the dataloader, we increment epoch and reset the iterator. + epoch_iterator = iter(epoch_dataloader) + epoch_over = False + while steps_trained_in_current_epoch > 0: + try: + # If the dataloader yields N batches and N is not divisible by `args.gradient_accumulation_steps`, + # the update loop ignores the last `N % args.gradient_accumulation_steps` batches of an epoch. + # To replicate the same behavior when resuming training, we ignore such batches from skipped epochs. + for _ in range(args.gradient_accumulation_steps): + next(epoch_iterator) + steps_trained_in_current_epoch -= args.gradient_accumulation_steps + steps_skipped += args.gradient_accumulation_steps + except StopIteration: + epoch_over = True + break + if epoch_over: + epochs_trained += 1 + continue + assert steps_trained_in_current_epoch == 0 + rng_to_sync = True + # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None @@ -2481,8 +2520,6 @@ def _inner_training_loop( if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - steps_skipped = 0 if steps_trained_in_current_epoch > 0: epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch @@ -2490,7 +2527,8 @@ def _inner_training_loop( rng_to_sync = True step = -1 - epoch_iterator = iter(epoch_dataloader) + if epoch_iterator is None: + epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches remainder = steps_in_epoch % args.gradient_accumulation_steps if remainder == 0: @@ -2648,13 +2686,6 @@ def _inner_training_loop( if is_torch_xla_available(): xm.mark_step() break - if step < 0: - logger.warning( - "There seems not to be a single sample in your epoch_iterator, stopping training at step" - f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" - f" num_steps ({max_steps}) higher than the number of available samples." - ) - self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate( @@ -5385,7 +5416,7 @@ def set_initial_training_values( elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize - num_update_steps_per_epoch = max_steps + num_update_steps_per_epoch = None num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size else: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2594edcdef8c..8c46428b2a6d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3421,6 +3421,44 @@ def test_resume_training_with_frozen_params(self): self.assertEqual(b, b1) self.check_trainer_state_are_the_same(state, state1) + @parameterized.expand([(9, 1), (10, 1), (11, 1), (20, 1), (21, 1), (9, 3), (9, 2)]) + def test_resume_training_with_iterable_dataset(self, dataset_length, gradient_accumulation_steps): + with tempfile.TemporaryDirectory() as tmpdir: + + def get_trainer(): + config = RegressionModelConfig() + train_dataset = SampleIterableDataset(length=dataset_length) + model = RegressionRandomPreTrainedModel(config) + args = RegressionTrainingArguments( + output_dir=tmpdir, + learning_rate=0.1, + max_steps=20, + save_steps=10, + per_device_train_batch_size=1, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + return Trainer(model=model, args=args, train_dataset=train_dataset) + + # Train from scratch. + trainer = get_trainer() + trainer.train() + self.assertEqual(trainer.state.global_step, 20) + (a, b) = trainer.model.a.item(), trainer.model.b.item() + state = dataclasses.asdict(trainer.state) + + # Train from a checkpoint. + checkpoint = os.path.join(tmpdir, "checkpoint-10") + trainer = get_trainer() + trainer.train(resume_from_checkpoint=checkpoint) + self.assertEqual(trainer.state.global_step, 20) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + + # Check that the resumed model is the same as the original one. + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + def test_load_best_model_at_end(self): total = int(self.n_epochs * 64 / self.batch_size) with tempfile.TemporaryDirectory() as tmpdir: