diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 26067124a..55f8213a2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -29,6 +29,7 @@ import transformers # First Party +from build.utils import serialize_args from scripts.run_inference import TunedCausalLM from tests.data import ( EMPTY_DATA, @@ -159,6 +160,8 @@ def test_parse_arguments_peft_method(job_config): parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) + assert not tune_config.target_modules + assert "target_modules" not in job_config_lora ############################# Prompt Tuning Tests ############################# @@ -403,6 +406,42 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): assert "Simply put, the theory of relativity states that" in output_inference +def test_successful_lora_target_modules_default_from_main(): + """Check that if target_modules is not set, or set to None via JSON, the + default value by model type will be using in LoRA tuning. + The correct default target modules will be used for model type llama + and will exist in the resulting adapter_config.json. + https://github.com/huggingface/peft/blob/7b1c08d2b5e13d3c99b7d6ee83eab90e1216d4ba/ + src/peft/tuners/lora/model.py#L432 + """ + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **MODEL_ARGS.__dict__, + **TRAIN_ARGS.__dict__, + **DATA_ARGS.__dict__, + **PEFT_LORA_ARGS.__dict__, + **{"peft_method": "lora", "output_dir": tempdir}, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + sft_trainer.main() + + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "LORA") + + assert ( + "target_modules" in adapter_config + ), "target_modules not found in adapter_config.json." + + assert set(adapter_config.get("target_modules")) == { + "q_proj", + "v_proj", + }, "target_modules are not set to the default values." + + ############################# Finetuning Tests ############################# def test_run_causallm_ft_and_inference(): """Check if we can bootstrap and finetune tune causallm models"""