diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 72ea1f4b46..c7ba0ebeae 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -248,6 +248,9 @@ class DataArguments: default=False, metadata={"help": "Whether to have a SQL style prompt"}, ) + save_last_ckpt: bool = field( + default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} + ) @dataclass @@ -700,7 +703,8 @@ def compute_metrics(eval_preds): if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - trainer.save_model() + if data_args.save_last_ckpt: + trainer.save_model() metrics = train_result.metrics trainer.log_metrics("train", metrics)