From 5f90befe3cba34b354cac6e341d044be0ec4d323 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Thu, 8 Sep 2022 23:55:25 +0000 Subject: [PATCH 1/2] Temporary WA for laoding bast model at end with DeepSpeed --- optimum/habana/trainer.py | 41 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/optimum/habana/trainer.py b/optimum/habana/trainer.py index 4e3c99fa07..4d15c2fd1e 100644 --- a/optimum/habana/trainer.py +++ b/optimum/habana/trainer.py @@ -1464,3 +1464,44 @@ def _push_from_checkpoint(self, checkpoint_folder): if self.args.hub_strategy == HubStrategy.CHECKPOINT: # Move back the checkpoint to its place shutil.move(tmp_checkpoint, checkpoint_folder) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + model = self.model + if os.path.exists(best_model_path): + # TODO: uncomment the code below when Habana DeepSpeed >= 0.6.5 + # if self.deepspeed: + + # if self.model_wrapped is not None: + # # this removes the pre-hooks from the previous engine + # self.model_wrapped.destroy() + # self.model_wrapped = None + + # # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping + # deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + # self, + # num_training_steps=self.args.max_steps, + # resume_from_checkpoint=self.state.best_model_checkpoint, + # ) + # self.model = deepspeed_engine.module + # self.model_wrapped = deepspeed_engine + # self.deepspeed = deepspeed_engine + # self.optimizer = optimizer + # self.lr_scheduler = lr_scheduler + # else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + load_result = model.load_state_dict(state_dict, strict=False) + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + load_result = load_sharded_checkpoint( + model, self.state.best_model_checkpoint, strict=False + ) + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) From e1c655ea4a37550e4afbdfbf942402b450920fb0 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 9 Sep 2022 00:09:02 +0000 Subject: [PATCH 2/2] Make style --- optimum/habana/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/habana/trainer.py b/optimum/habana/trainer.py index 4d15c2fd1e..813216dbee 100644 --- a/optimum/habana/trainer.py +++ b/optimum/habana/trainer.py @@ -1496,9 +1496,7 @@ def _load_best_model(self): load_result = model.load_state_dict(state_dict, strict=False) self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): - load_result = load_sharded_checkpoint( - model, self.state.best_model_checkpoint, strict=False - ) + load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=False) self._issue_warnings_after_load(load_result) else: logger.warning(