From 53ecba243f2ee70f512d1c664c3470d76e54e790 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Tue, 13 Feb 2024 20:09:47 +0000 Subject: [PATCH 1/3] Adding a falg whether to save checkpoint or not --- optimum/habana/transformers/trainer.py | 2 +- optimum/habana/transformers/training_args.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 8d621db5e9..f336abd7b3 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1163,7 +1163,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for metric_to_check = f"eval_{metric_to_check}" self.lr_scheduler.step(metrics[metric_to_check]) - if self.control.should_save: + if self.control.should_save and self.args.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 5979e00243..28568b40c4 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -283,6 +283,12 @@ class GaudiTrainingArguments(TrainingArguments): }, ) + # Overriding should_save from trainer callback + should_save: Optional[bool] = field( + default=True, + metadata={"help": "Whether to save models and checkpoints."}, + ) + def __post_init__(self): if self.use_hpu_graphs: warnings.warn( From a46442cdbfce1e3b0b99ec7aa9557e37da041866 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Wed, 14 Feb 2024 11:34:58 -0800 Subject: [PATCH 2/3] Adding a save last checkpoint flag to the training script --- examples/language-modeling/run_clm.py | 6 +++++- optimum/habana/transformers/trainer.py | 2 +- optimum/habana/transformers/training_args.py | 6 ------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index b68aba9677..7d63a69869 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -243,6 +243,9 @@ class DataTrainingArguments: keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) + save_last_ckpt: bool = field( + default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} + ) def __post_init__(self): if self.streaming: @@ -643,7 +646,8 @@ def compute_metrics(eval_preds): elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_model() # Saves the tokenizer too for easy upload + if data_args.save_last_ckpt: #YSY + trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index f336abd7b3..8d621db5e9 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1163,7 +1163,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for metric_to_check = f"eval_{metric_to_check}" self.lr_scheduler.step(metrics[metric_to_check]) - if self.control.should_save and self.args.should_save: + if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 28568b40c4..5979e00243 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -283,12 +283,6 @@ class GaudiTrainingArguments(TrainingArguments): }, ) - # Overriding should_save from trainer callback - should_save: Optional[bool] = field( - default=True, - metadata={"help": "Whether to save models and checkpoints."}, - ) - def __post_init__(self): if self.use_hpu_graphs: warnings.warn( From 5513b850cbdc97556fa7193eda2ab8dd485676b3 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Wed, 14 Feb 2024 11:37:52 -0800 Subject: [PATCH 3/3] Clean code --- examples/language-modeling/run_clm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 7d63a69869..4e7b439659 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -646,7 +646,7 @@ def compute_metrics(eval_preds): elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) - if data_args.save_last_ckpt: #YSY + if data_args.save_last_ckpt: trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics