diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index b53f5c12b6..595c566f5d 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1358,6 +1358,65 @@ def _maybe_log_save_evaluate( timer.step() self.log_evaluate_save_time += timer.last_duration + def _save_checkpoint(self, model, trial): + # Copied from https://github.com/huggingface/transformers/blob/v4.51-release/src/transformers/trainer.py#L3187 + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + + # NOTE(pbielak): In a multi-card scenario, the model saving is done by the main process (rank zero), + # whereas all other ranks continue processing. When checking for the `best_checkpoint_dir` below, + # a race condition occurs. This barrier forces other processes to wait till the model is saved. + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.distributed.barrier() + + if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step: + best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}" + best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder) + + if os.path.exists(best_checkpoint_dir): + self.state.best_model_checkpoint = best_checkpoint_dir + + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + self._save_scaler(output_dir) + # Save RNG state + self._save_rng_state(output_dir) + + # Save the Trainer state + if self.args.should_save: + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + # Solely rely on numerical checkpoint id for rotation. + # mtime is not reliable especially on some fuse fs in cloud environments. + self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) + def _load_rng_state(self, checkpoint): # Load RNG states from `checkpoint` if checkpoint is None: