Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import transformers

# First Party
from build.utils import serialize_args
from scripts.run_inference import TunedCausalLM
from tests.data import (
EMPTY_DATA,
Expand Down Expand Up @@ -159,6 +160,8 @@ def test_parse_arguments_peft_method(job_config):
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
assert not tune_config.target_modules
assert "target_modules" not in job_config_lora
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can also assert not hassattr(tune_config, "target_modules") so that it verifies the input from job_config_lora and this above assertion verifies that after parse_arguments run, the value is still None for target_modules

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I add this check I get the following error:

>       assert not hasattr(tune_config, "target_modules")
E       AssertionError: assert not True
E        +  where True = hasattr(LoraConfig(r=8, lora_alpha=32, target_modules=None, lora_dropout=0.05), 'target_modules')

I think this is because the attribute target_modules exists but is None. Would this suffice?
assert tune_config.target_modules is None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that makes sense and yes that check looks good! i think you can refactor it down to assert not tune_config.target_modules



############################# Prompt Tuning Tests #############################
Expand Down Expand Up @@ -403,6 +406,42 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):
assert "Simply put, the theory of relativity states that" in output_inference


def test_successful_lora_target_modules_default_from_main():
"""Check that if target_modules is not set, or set to None via JSON, the
default value by model type will be using in LoRA tuning.
The correct default target modules will be used for model type llama
and will exist in the resulting adapter_config.json.
https://github.com/huggingface/peft/blob/7b1c08d2b5e13d3c99b7d6ee83eab90e1216d4ba/
src/peft/tuners/lora/model.py#L432
"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**MODEL_ARGS.__dict__,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see since these aren't dicts, you can't combine them as easily. Nice way of combining them! Although a small refactor would be to combine your custom ones

**PEFT_LORA_ARGS.__dict__,
**{"peft_method": "lora", "output_dir": tempdir},

**TRAIN_ARGS.__dict__,
**DATA_ARGS.__dict__,
**PEFT_LORA_ARGS.__dict__,
**{"peft_method": "lora", "output_dir": tempdir},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

sft_trainer.main()

_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")

assert (
"target_modules" in adapter_config
), "target_modules not found in adapter_config.json."

assert set(adapter_config.get("target_modules")) == {
"q_proj",
"v_proj",
}, "target_modules are not set to the default values."


############################# Finetuning Tests #############################
def test_run_causallm_ft_and_inference():
"""Check if we can bootstrap and finetune tune causallm models"""
Expand Down