diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 07ec629d7901..9a78ae61d186 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -282,7 +282,7 @@ def __init__( # Create output directory if needed if self.is_world_process_zero(): os.makedirs(self.args.output_dir, exist_ok=True) - if is_torch_tpu_available(): + if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel): # Set an xla_device flag on the model's config. # We'll find a more elegant and not need to do this in the future. self.model.config.xla_device = True @@ -490,11 +490,9 @@ def setup_wandb(self): logger.info( 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' ) - try: - combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} - except AttributeError: - # in case the model has no config - combined_dict = {**self.args.to_sanitized_dict()} + combined_dict = {**self.args.to_sanitized_dict()} + if isinstance(self.model, PreTrainedModel): + combined_dict = {**self.model.config.to_dict(), **combined_dict} wandb.init( project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name ) @@ -533,7 +531,8 @@ def setup_comet(self): if experiment is not None: experiment._set_model_graph(self.model, framework="transformers") experiment._log_parameters(self.args, prefix="args/", framework="transformers") - experiment._log_parameters(self.model.config, prefix="config/", framework="transformers") + if isinstance(self.model, PreTrainedModel): + experiment._log_parameters(self.model.config, prefix="config/", framework="transformers") def num_examples(self, dataloader: DataLoader) -> int: """ @@ -679,7 +678,11 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, - find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False), + find_unused_parameters=( + not getattr(model.config, "gradient_checkpointing", False) + if isinstance(model, PreTrainedModel) + else True + ), ) # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 @@ -707,15 +710,14 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.global_step = 0 self.epoch = 0 - self.total_flos = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 + # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) - self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0) epochs_trained = self.global_step // num_update_steps_per_epoch steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch) @@ -723,14 +725,13 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) - logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 - self.total_flos = 0 logger.info(" Starting fine-tuning.") tr_loss = torch.tensor(0.0).to(self.args.device) + self.total_flos = self.state.total_flos logging_loss_scalar = 0.0 model.zero_grad() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() @@ -1029,7 +1030,7 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: else: total_flos = self.total_flos if total_flos > 0: - logs["total_flos"] = self.total_flos + logs["total_flos"] = total_flos if self.global_step is None: # when logging evaluation metrics without training self.global_step = 0 @@ -1245,11 +1246,9 @@ def store_flos(self): # Storing the number of floating-point operations that went into the model if self.total_flos is not None: if self.args.local_rank != -1: - total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() + self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() else: - total_flos = self.total_flos - if total_flos > 0: - self.model.config.total_flos = total_flos + self.state.total_flos = self.total_flos def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: ordering_and_checkpoint_path = [] @@ -1363,13 +1362,6 @@ def prediction_loop( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) - assert not getattr( - self.model.config, "output_attentions", False - ), "The prediction loop does not work with `output_attentions=True`." - assert not getattr( - self.model.config, "output_hidden_states", False - ), "The prediction loop does not work with `output_hidden_states=True`." - model = self.model # multi-gpu eval if self.args.n_gpu > 1: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index d93adda1862b..63a1ddfc33c6 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -224,6 +224,7 @@ class TrainerState: A class containing the `Trainer` fields that will be saved along the model and optimizer. """ + total_flos: int = 0 best_metric: Optional[float] = None best_model_checkpoint: Optional[str] = None