diff --git a/README.md b/README.md index 088e9721e..4bc762e4d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ - [Training](#training) - [Single GPU](#single-gpu) - [Multiple GPUs with FSDP](#multiple-gpus-with-fsdp) + - [Tips on Parameters to Set](#tips-on-parameters-to-set) - [Tuning Techniques](#tuning-techniques) - [LoRA Tuning Example](#lora-tuning-example) - [Prompt Tuning](#prompt-tuning) @@ -225,6 +226,50 @@ tuning/sft_trainer.py \ To summarize you can pick either python for single-GPU jobs or use accelerate launch for multi-GPU jobs. The following tuning techniques can be applied: +### Tips on Parameters to Set + +#### Saving checkpoints while training + +By default, [`save_strategy`](tuning/config/configs.py) is set to `"epoch"` in the TrainingArguments. This means that checkpoints will be saved on each epoch. This can also be set to `"steps"` to save on every `"save_steps"` or `"no"` to not save any checkpoints. + +Checkpoints are saved to the given `output_dir`, which is a required field. If `save_strategy="no"`, the `output_dir` will only contain the training logs with loss details. + +A useful flag to set to limit the number of checkpoints saved is [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit). Older checkpoints are deleted from the `output_dir` to limit the number of checkpoints, for example, if `save_total_limit=1`, this will only save the last checkpoint. However, while tuning, two checkpoints will exist in `output_dir` for a short time as the new checkpoint is created and then the older one will be deleted. If the user sets a validation dataset and [`load_best_model_at_end`](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments.load_best_model_at_end), then the best checkpoint will be saved. + +#### Saving model after training + +`save_model_dir` can optionally be set to save the tuned model using `SFTTrainer.save_model()`. This can be used in tandem with `save_strategy="no"` to only save the designated checkpoint and not any intermediate checkpoints, which can help to save space. + +`save_model_dir` can be set to a different directory than `output_dir`. If set to the same directory, the designated checkpoint, training logs, and any intermediate checkpoints will all be saved to the same directory as seen below. + +
+Ways you can use `save_model_dir` and more tips: + +For example, if `save_model_dir` is set to a sub-directory of `output_dir`and `save_total_limit=1` with LoRA tuning, the directory would look like: + +```sh +$ ls /tmp/output_dir/ +checkpoint-35 save_model_dir training_logs.jsonl + +$ ls /tmp/output_dir/save_model_dir/ +README.md adapter_model.safetensors special_tokens_map.json tokenizer.model training_args.bin +adapter_config.json added_tokens.json tokenizer.json tokenizer_config.json +``` + +Here is an fine tuning example of how the directory would look if `output_dir` is set to the same value as `save_model_dir` and `save_total_limit=2`. Note the checkpoint directories as well as the `training_logs.jsonl`: + +```sh +$ ls /tmp/same_dir + +added_tokens.json model-00001-of-00006.safetensors model-00006-of-00006.safetensors tokenizer_config.json +checkpoint-16 model-00002-of-00006.safetensors model.safetensors.index.json training_args.bin +checkpoint-20 model-00003-of-00006.safetensors special_tokens_map.json training_logs.jsonl +config.json model-00004-of-00006.safetensors tokenizer.json +generation_config.json model-00005-of-00006.safetensors tokenizer.model +``` + +
+ ## Tuning Techniques: ### LoRA Tuning Example diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 2cfc9069f..d7753728c 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -23,8 +23,6 @@ import subprocess import sys import traceback -import tempfile -import shutil from pathlib import Path import json @@ -37,12 +35,9 @@ # Local from build.utils import ( process_accelerate_launch_args, - serialize_args, get_highest_checkpoint, - copy_checkpoint, ) from tuning.utils.config_utils import get_json_config -from tuning.config.tracker_configs import FileLoggingTrackerConfig from tuning.utils.error_logging import ( write_termination_log, USER_ERROR_EXIT_CODE, @@ -111,142 +106,111 @@ def main(): # Launch training # ########## - original_output_dir = job_config.get("output_dir") - with tempfile.TemporaryDirectory() as tempdir: - try: - # checkpoints outputted to tempdir, only final checkpoint copied to output dir - job_config["output_dir"] = tempdir - updated_args = serialize_args(job_config) - os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args - launch_command(args) - except subprocess.CalledProcessError as e: - # If the subprocess throws an exception, the base exception is hidden in the - # subprocess call and is difficult to access at this level. However, that is not - # an issue because sft_trainer.py would have already written the exception - # message to termination log. - logging.error(traceback.format_exc()) - # The exit code that sft_trainer.py threw is captured in e.returncode - - return_code = e.returncode - if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]: - return_code = INTERNAL_ERROR_EXIT_CODE - write_termination_log(f"Unhandled exception during training. {e}") - sys.exit(return_code) - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) + output_dir = job_config.get("output_dir") + try: + # checkpoints outputted to tempdir, only final checkpoint copied to output dir + launch_command(args) + except subprocess.CalledProcessError as e: + # If the subprocess throws an exception, the base exception is hidden in the + # subprocess call and is difficult to access at this level. However, that is not + # an issue because sft_trainer.py would have already written the exception + # message to termination log. + logging.error(traceback.format_exc()) + # The exit code that sft_trainer.py threw is captured in e.returncode + + return_code = e.returncode + if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]: + return_code = INTERNAL_ERROR_EXIT_CODE write_termination_log(f"Unhandled exception during training. {e}") - sys.exit(INTERNAL_ERROR_EXIT_CODE) + sys.exit(return_code) + except Exception as e: # pylint: disable=broad-except + logging.error(traceback.format_exc()) + write_termination_log(f"Unhandled exception during training. {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) - try: - last_checkpoint_dir = get_highest_checkpoint(tempdir) - last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir) + # remove lm_head from granite with llama arch models + try: + checkpoint_dir = job_config.get("save_model_dir") + if not checkpoint_dir: + checkpoint_dir = os.path.join( + output_dir, get_highest_checkpoint(output_dir) + ) + + use_flash_attn = job_config.get("use_flash_attn", True) + adapter_config_path = os.path.join(checkpoint_dir, "adapter_config.json") + tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) - use_flash_attn = job_config.get("use_flash_attn", True) - adapter_config_path = os.path.join( - last_checkpoint_path, "adapter_config.json" + if os.path.exists(adapter_config_path): + base_model_path = get_base_model_from_adapter_config(adapter_config_path) + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, ) - tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path) - if os.path.exists(adapter_config_path): - base_model_path = get_base_model_from_adapter_config( - adapter_config_path - ) - base_model = AutoModelForCausalLM.from_pretrained( - base_model_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, - ) + # since the peft library (PEFTModelForCausalLM) does not handle cases + # where the model's layers are modified, in our case the embedding layer + # is modified, so we resize the backbone model's embedding layer with our own + # utility before passing it along to load the PEFT model. + tokenizer_data_utils.tokenizer_and_embedding_resize( + {}, tokenizer=tokenizer, model=base_model + ) + model = PeftModel.from_pretrained( + base_model, + checkpoint_dir, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + checkpoint_dir, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) - # since the peft library (PEFTModelForCausalLM) does not handle cases - # where the model's layers are modified, in our case the embedding layer - # is modified, so we resize the backbone model's embedding layer with our own - # utility before passing it along to load the PEFT model. - tokenizer_data_utils.tokenizer_and_embedding_resize( - {}, tokenizer=tokenizer, model=base_model + model_arch = model.config.model_type + # check that it is a granite model with llama architecture with tied weights + # ie. lm_head is duplicate of embeddings + + # a fine tuned model will have params_dict.get("model.embed_tokens.weight") + # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") + # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") + if model_arch == "llama" and hasattr(model, "lm_head"): + if ( + # lora tuned model has an addt model layer + ( + hasattr(model.model, "model") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() ) - model = PeftModel.from_pretrained( - base_model, - last_checkpoint_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - last_checkpoint_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, + # prompt tuned model or fine tuned model + or ( + hasattr(model.model, "embed_tokens") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.embed_tokens.weight.untyped_storage().data_ptr() ) + ): - model_arch = model.config.model_type - # check that it is a granite model with llama architecture with tied weights - # ie. lm_head is duplicate of embeddings - - # a fine tuned model will have params_dict.get("model.embed_tokens.weight") - # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") - # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") - copy_checkpoint_bool = True - if model_arch == "llama" and hasattr(model, "lm_head"): - if ( - # lora tuned model has an addt model layer - ( - hasattr(model.model, "model") - and model.lm_head.weight.untyped_storage().data_ptr() - == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() - ) - # prompt tuned model or fine tuned model - or ( - hasattr(model.model, "embed_tokens") - and model.lm_head.weight.untyped_storage().data_ptr() - == model.model.embed_tokens.weight.untyped_storage().data_ptr() - ) - ): - - copy_checkpoint_bool = False - logging.info("Removing lm_head from checkpoint") - del model.lm_head.weight - - if hasattr(model, "lm_head.weight"): - logging.warning("Failed to delete lm_head.weight from model") - - logging.info("Saving checkpoint to %s", original_output_dir) - model.save_pretrained(original_output_dir) - # save tokenizer with model - tokenizer.save_pretrained(original_output_dir) - - # copy last checkpoint into mounted output dir - if copy_checkpoint_bool: - logging.info( - "Copying last checkpoint %s into output dir %s", - last_checkpoint_dir, - original_output_dir, - ) - copy_checkpoint(last_checkpoint_path, original_output_dir) - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) - write_termination_log( - f"Exception encountered writing output model to storage: {e}" - ) - sys.exit(INTERNAL_ERROR_EXIT_CODE) + logging.info("Removing lm_head from checkpoint") + del model.lm_head.weight - # copy over any loss logs - try: - train_logs_filepath = os.path.join( - tempdir, - FileLoggingTrackerConfig.training_logs_filename, - ) - if os.path.exists(train_logs_filepath): - shutil.copy(train_logs_filepath, original_output_dir) - - # The .complete file will signal to users that we are finished copying - # files over - if os.path.exists(original_output_dir): - Path(os.path.join(original_output_dir, ".complete")).touch() - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) - write_termination_log( - f"Exception encountered in capturing training logs: {e}" - ) - sys.exit(INTERNAL_ERROR_EXIT_CODE) + if hasattr(model, "lm_head.weight"): + logging.warning("Failed to delete lm_head.weight from model") + + logging.info("Saving checkpoint to %s", output_dir) + model.save_pretrained(checkpoint_dir) + # save tokenizer with model + tokenizer.save_pretrained(checkpoint_dir) + + except Exception as e: # pylint: disable=broad-except + logging.error(traceback.format_exc()) + write_termination_log(f"Exception encountered removing lm_head from model: {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) + + # The .complete file will signal to users that we are finished copying + # files over + if os.path.exists(output_dir): + Path(os.path.join(output_dir, ".complete")).touch() return 0 diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 421849b1f..927af3165 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -25,12 +25,13 @@ # First Party from build.accelerate_launch import main -from build.utils import serialize_args +from build.utils import serialize_args, get_highest_checkpoint from tests.data import TWITTER_COMPLAINTS_DATA from tuning.utils.error_logging import ( USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) +from tuning.config.tracker_configs import FileLoggingTrackerConfig SCRIPT = "tuning/sft_trainer.py" MODEL_NAME = "Maykeye/TinyLLama-v0" @@ -97,11 +98,9 @@ def test_successful_ft(): os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - # check termination log and .complete files - assert os.path.exists(tempdir + "/termination-log") is False - assert os.path.exists(os.path.join(tempdir, ".complete")) is True - assert os.path.exists(tempdir + "/adapter_config.json") is False - assert len(glob.glob(f"{tempdir}/model*.safetensors")) > 0 + _validate_termination_files_when_tuning_succeeds(tempdir) + checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint, "ft") def test_successful_pt(): @@ -113,11 +112,9 @@ def test_successful_pt(): os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - # check termination log and .complete files - assert os.path.exists(tempdir + "/termination-log") is False - assert os.path.exists(os.path.join(tempdir, ".complete")) is True - assert os.path.exists(tempdir + "/adapter_model.safetensors") is True - assert os.path.exists(tempdir + "/adapter_config.json") is True + _validate_termination_files_when_tuning_succeeds(tempdir) + checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint, "pt") def test_successful_lora(): @@ -129,11 +126,92 @@ def test_successful_lora(): os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - # check termination log and .complete files - assert os.path.exists(tempdir + "/termination-log") is False - assert os.path.exists(os.path.join(tempdir, ".complete")) is True - assert os.path.exists(tempdir + "/adapter_model.safetensors") is True - assert os.path.exists(tempdir + "/adapter_config.json") is True + _validate_termination_files_when_tuning_succeeds(tempdir) + checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint, "lora") + + +def test_lora_save_model_dir_separate_dirs(): + """Run LoRA tuning with separate save_model_dir and output_dir. + Verify model saved to save_model_dir and checkpoints saved to + output_dir. + """ + with tempfile.TemporaryDirectory() as tempdir: + output_dir = os.path.join(tempdir, "output_dir") + save_model_dir = os.path.join(tempdir, "save_model_dir") + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": output_dir, + "save_model_dir": save_model_dir, + "save_total_limit": 1, + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + _validate_termination_files_when_tuning_succeeds(output_dir) + _validate_training_output(save_model_dir, "lora") + + assert len(os.listdir(output_dir)) == 3 + checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*")) + assert len(checkpoints) == 1 + + +def test_lora_save_model_dir_same_dir_as_output_dir(): + """Run LoRA tuning with same save_model_dir and output_dir. + Verify checkpoints, logs, and model saved to path. + """ + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": tempdir, + "save_model_dir": tempdir, + "gradient_accumulation_steps": 1, + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + # check logs, checkpoint dir, and model exists in path + _validate_termination_files_when_tuning_succeeds(tempdir) + # check that model exists in output_dir and checkpoint dir + _validate_training_output(tempdir, "lora") + checkpoint_path = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint_path, "lora") + + # number of checkpoints should equal number of epochs + checkpoints = glob.glob(os.path.join(tempdir, "checkpoint-*")) + assert len(checkpoints) == TRAIN_KWARGS["num_train_epochs"] + + +def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no(): + """Run LoRA tuning with same save_model_dir and output_dir and + save_strategy=no. Verify no checkpoints created, only + logs and final model. + """ + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{"output_dir": tempdir, "save_model_dir": tempdir, "save_strategy": "no"}, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + # check that model and logs exists in output_dir + _validate_termination_files_when_tuning_succeeds(tempdir) + _validate_training_output(tempdir, "lora") + + # no checkpoints should be created + checkpoints = glob.glob(os.path.join(tempdir, "checkpoint-*")) + assert len(checkpoints) == 0 def test_bad_script_path(): @@ -212,6 +290,27 @@ def test_config_parsing_error(): assert os.stat(tempdir + "/termination-log").st_size > 0 +def _validate_termination_files_when_tuning_succeeds(base_dir): + # check termination log and .complete files + assert os.path.exists(os.path.join(base_dir, "/termination-log")) is False + assert os.path.exists(os.path.join(base_dir, ".complete")) is True + assert ( + os.path.exists( + os.path.join(base_dir, FileLoggingTrackerConfig.training_logs_filename) + ) + is True + ) + + +def _validate_training_output(base_dir, tuning_technique): + if tuning_technique == "ft": + assert len(glob.glob(f"{base_dir}/model*.safetensors")) > 0 + assert os.path.exists(base_dir + "/adapter_config.json") is False + else: + assert os.path.exists(base_dir + "/adapter_config.json") is True + assert os.path.exists(base_dir + "/adapter_model.safetensors") is True + + def test_cleanup(): # This runs to unset env variables that could disrupt other tests cleanup_env() diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 55f8213a2..0264d3b3b 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -44,6 +44,7 @@ # Local from tuning import sft_trainer from tuning.config import configs, peft_config +from tuning.config.tracker_configs import FileLoggingTrackerConfig MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" @@ -444,10 +445,32 @@ def test_successful_lora_target_modules_default_from_main(): ############################# Finetuning Tests ############################# def test_run_causallm_ft_and_inference(): - """Check if we can bootstrap and finetune tune causallm models""" + """Check if we can bootstrap and finetune causallm models""" with tempfile.TemporaryDirectory() as tempdir: _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) - _test_run_inference(tempdir=tempdir) + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + + +def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): + """Check if we can bootstrap and finetune causallm model with save_model_dir + and save_strategy=no. Verify no checkpoints created and can save model. + """ + with tempfile.TemporaryDirectory() as tempdir: + save_model_args = copy.deepcopy(TRAIN_ARGS) + save_model_args.save_strategy = "no" + save_model_args.output_dir = tempdir + + trainer = sft_trainer.train(MODEL_ARGS, DATA_ARGS, save_model_args, None) + logs_path = os.path.join( + tempdir, FileLoggingTrackerConfig.training_logs_filename + ) + _validate_logfile(logs_path) + # validate that no checkpoints created + assert not any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) + + sft_trainer.save(tempdir, trainer) + assert any(x.endswith(".safetensors") for x in os.listdir(tempdir)) + _test_run_inference(checkpoint_path=tempdir) def test_run_causallm_ft_pretokenized(): @@ -493,9 +516,7 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): _validate_training(tempdir) -def _test_run_inference(tempdir): - checkpoint_path = _get_checkpoint_path(tempdir) - +def _test_run_inference(checkpoint_path): # Load the model loaded_model = TunedCausalLM.load(checkpoint_path) @@ -512,12 +533,16 @@ def _validate_training( ): assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) train_logs_file_path = "{}/{}".format(tempdir, train_logs_file) + _validate_logfile(train_logs_file_path, check_eval) + + +def _validate_logfile(log_file_path, check_eval=False): train_log_contents = "" - with open(train_logs_file_path, encoding="utf-8") as f: + with open(log_file_path, encoding="utf-8") as f: train_log_contents = f.read() - assert os.path.exists(train_logs_file_path) is True - assert os.path.getsize(train_logs_file_path) > 0 + assert os.path.exists(log_file_path) is True + assert os.path.getsize(log_file_path) > 0 assert "training_loss" in train_log_contents if check_eval: diff --git a/tests/trackers/test_aim_tracker.py b/tests/trackers/test_aim_tracker.py index f0002e9ff..d2aa301b7 100644 --- a/tests/trackers/test_aim_tracker.py +++ b/tests/trackers/test_aim_tracker.py @@ -30,6 +30,7 @@ DATA_ARGS, MODEL_ARGS, TRAIN_ARGS, + _get_checkpoint_path, _test_run_inference, _validate_training, ) @@ -98,7 +99,7 @@ def test_e2e_run_with_aim_tracker(aimrepo): _validate_training(tempdir) # validate inference - _test_run_inference(tempdir) + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) @pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed") diff --git a/tests/trackers/test_file_logging_tracker.py b/tests/trackers/test_file_logging_tracker.py index 2129e4927..e5e62ab8b 100644 --- a/tests/trackers/test_file_logging_tracker.py +++ b/tests/trackers/test_file_logging_tracker.py @@ -25,6 +25,7 @@ DATA_ARGS, MODEL_ARGS, TRAIN_ARGS, + _get_checkpoint_path, _test_run_causallm_ft, _test_run_inference, _validate_training, @@ -44,7 +45,7 @@ def test_run_with_file_logging_tracker(): train_args.trackers = ["file_logger"] _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) - _test_run_inference(tempdir=tempdir) + _test_run_inference(_get_checkpoint_path(tempdir)) def test_sample_run_with_file_logger_updated_filename(): diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 2990ef801..0db5b518e 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -97,6 +97,7 @@ class DataArguments: @dataclass class TrainingArguments(transformers.TrainingArguments): + # pylint: disable=too-many-instance-attributes cache_dir: Optional[str] = field(default=None) # optim: str = field(default=DEFAULT_OPTIMIZER) max_seq_length: int = field( @@ -119,6 +120,13 @@ class TrainingArguments(transformers.TrainingArguments): 'steps' (save is done every `save_steps`)" }, ) + save_model_dir: str = field( + default=None, + metadata={ + "help": "Directory where tuned model will be saved to \ + using SFTTrainer.save_model()." + }, + ) logging_strategy: str = field( default="epoch", metadata={ diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2fce4ec37..45fea7ca6 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -16,6 +16,8 @@ from typing import Dict, List, Optional, Union import dataclasses import json +import logging +import os import sys import time import traceback @@ -360,6 +362,31 @@ def train( trainer.train() + return trainer + + +def save(path: str, trainer: SFTTrainer, log_level="WARNING"): + """Saves model and tokenizer to given path. + + Args: + path: str + Path to save the model to. + trainer: SFTTrainer + Instance of SFTTrainer used for training to save the model. + """ + logger = logging.getLogger("sft_trainer_save") + # default value from TrainingArguments + if log_level == "passive": + log_level = "WARNING" + + logger.setLevel(log_level) + + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + logger.info("Saving tuned model to path: %s", path) + trainer.save_model(path) + def get_parser(): """Get the command-line argument parser.""" @@ -545,7 +572,7 @@ def main(**kwargs): # pylint: disable=unused-argument combined_tracker_configs.aim_config = aim_config try: - train( + trainer = train( model_args=model_args, data_args=data_args, train_args=training_args, @@ -582,6 +609,21 @@ def main(**kwargs): # pylint: disable=unused-argument write_termination_log(f"Unhandled exception during training: {e}") sys.exit(INTERNAL_ERROR_EXIT_CODE) + # save model + if training_args.save_model_dir: + try: + save( + path=training_args.save_model_dir, + trainer=trainer, + log_level=training_args.log_level, + ) + except Exception as e: # pylint: disable=broad-except + logger.error(traceback.format_exc()) + write_termination_log( + f"Failed to save model to {training_args.save_model_dir}: {e}" + ) + sys.exit(INTERNAL_ERROR_EXIT_CODE) + if __name__ == "__main__": fire.Fire(main)