diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6e339e8652cd..1a8cac0722e8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1373,6 +1373,7 @@ def train( # AT THE VERY END! _ = list(train_dataloader.sampler) + start_train_stable_time = 0 for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) @@ -1402,6 +1403,9 @@ def train( step = -1 for step, inputs in enumerate(epoch_iterator): + if (self.state.global_step == 10): + start_train_stable_time = time.time() + # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 @@ -1549,6 +1553,11 @@ def train( train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + + total_samples = args.max_steps*total_train_batch_size if args.max_steps > 0 else num_examples*num_train_epochs + perf_samples = total_samples - 10*total_train_batch_size + stable_train_metrics = speed_metrics("stable_train", start_train_stable_time, perf_samples) + self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss @@ -1559,6 +1568,8 @@ def train( self.log(metrics) + self.log(stable_train_metrics) + self.control = self.callback_handler.on_train_end(args, self.state, self.control) return TrainOutput(self.state.global_step, train_loss, metrics)