diff --git a/optimum/habana/trainer.py b/optimum/habana/trainer.py index 4e3c99fa07..813216dbee 100644 --- a/optimum/habana/trainer.py +++ b/optimum/habana/trainer.py @@ -1464,3 +1464,42 @@ def _push_from_checkpoint(self, checkpoint_folder): if self.args.hub_strategy == HubStrategy.CHECKPOINT: # Move back the checkpoint to its place shutil.move(tmp_checkpoint, checkpoint_folder) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + model = self.model + if os.path.exists(best_model_path): + # TODO: uncomment the code below when Habana DeepSpeed >= 0.6.5 + # if self.deepspeed: + + # if self.model_wrapped is not None: + # # this removes the pre-hooks from the previous engine + # self.model_wrapped.destroy() + # self.model_wrapped = None + + # # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping + # deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + # self, + # num_training_steps=self.args.max_steps, + # resume_from_checkpoint=self.state.best_model_checkpoint, + # ) + # self.model = deepspeed_engine.module + # self.model_wrapped = deepspeed_engine + # self.deepspeed = deepspeed_engine + # self.optimizer = optimizer + # self.lr_scheduler = lr_scheduler + # else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + load_result = model.load_state_dict(state_dict, strict=False) + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=False) + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + )