From 7f3e084f4fef91bc289698c21069910e40cc2329 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 9 May 2022 20:11:14 -0700 Subject: [PATCH 1/2] [trainer] sharded _load_best_model probably needs a test? --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index aa54f2af1bb5..a760bc90f210 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1705,9 +1705,9 @@ def _load_best_model(self): # If the model is on the GPU, it still works! load_result = self.model.load_state_dict(state_dict, strict=False) self._issue_warnings_after_load(load_result) - elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): # Best model is a sharded checkpoint - load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False) + load_result = load_sharded_checkpoint(self.state.best_model_checkpoint, strict=False) self._issue_warnings_after_load(load_result) else: logger.warning( From d5b6ced04d7c84cf9e48245a4dc7ee0685fb4f38 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 9 May 2022 20:13:29 -0700 Subject: [PATCH 2/2] undo delete --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a760bc90f210..dda278471811 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1707,7 +1707,7 @@ def _load_best_model(self): self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): # Best model is a sharded checkpoint - load_result = load_sharded_checkpoint(self.state.best_model_checkpoint, strict=False) + load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False) self._issue_warnings_after_load(load_result) else: logger.warning(