diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index b0e398d7b9..4c5d4946f8 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -27,6 +27,7 @@ except Exception: pass # isort: on +import json import logging import math import os @@ -54,6 +55,7 @@ DataCollatorForSeq2Seq, get_scheduler, ) +from transformers.training_args import _convert_str_dict from open_instruct.dataset_transformation import ( INPUT_IDS_KEY, @@ -84,6 +86,13 @@ class FlatArguments: """ Full arguments class for all fine-tuning jobs. """ + # Sometimes users will pass in a `str` repr of a dict in the CLI + # We need to track what fields those can be. Each time a new arg + # has a dict type, it must be added to this list. + # Important: These should be typed with Optional[Union[dict,str,...]] + _VALID_DICT_FIELDS = [ + "additional_model_arguments", + ] exp_name: str = os.path.basename(__file__)[: -len(".py")] """The name of this experiment""" @@ -379,6 +388,10 @@ class FlatArguments: }, ) add_seed_and_date_to_exp_name: bool = True + additional_model_arguments: Optional[Union[dict, str]] = field( + default_factory=dict, + metadata={"help": "A dictionary of additional model args used to construct the model."}, + ) def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: @@ -399,6 +412,17 @@ def __post_init__(self): if not (1.0 >= self.final_lr_ratio >= 0.0): raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}") + # Parse in args that could be `dict` sent in from the CLI as a string + for dict_feld in self._VALID_DICT_FIELDS: + passed_value = getattr(self, dict_feld) + # We only want to do this if the str starts with a bracket to indicate a `dict` + # else its likely a filename if supported + if isinstance(passed_value, str) and passed_value.startswith("{"): + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, dict_feld, loaded_dict) + def main(args: FlatArguments, tc: TokenizerConfig): # ------------------------------------------------------------ @@ -549,12 +573,14 @@ def main(args: FlatArguments, tc: TokenizerConfig): args.config_name, revision=args.model_revision, trust_remote_code=tc.trust_remote_code, + **args.additional_model_arguments, ) elif args.model_name_or_path: config = AutoConfig.from_pretrained( args.model_name_or_path, revision=args.model_revision, trust_remote_code=tc.trust_remote_code, + **args.additional_model_arguments, ) else: raise ValueError( diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 33da7cf75b..6cab76648c 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -21,8 +21,8 @@ import requests from litellm import acompletion -from IFEvalG import instructions_registry from open_instruct.if_functions import IF_FUNCTIONS_MAP +from open_instruct.IFEvalG import instructions_registry from open_instruct.judge_utils import ( EXTRACTOR_MAP, JUDGE_PROMPT_MAP, diff --git a/tests/test_flat_args.py b/tests/test_flat_args.py new file mode 100644 index 0000000000..7261228f6a --- /dev/null +++ b/tests/test_flat_args.py @@ -0,0 +1,23 @@ +from open_instruct.finetune import FlatArguments +from open_instruct.utils import ArgumentParserPlus + + +class TestFlatArguments: + def test_additional_model_args(self) -> None: + parser = ArgumentParserPlus(FlatArguments) + # NOTE: the boolean must be lower case, true not True + (args,) = parser.parse_args_into_dataclasses( + ["--additional_model_arguments", '{"int": 1, "bool": true, "float": 0.0}'] + ) + assert isinstance(args.additional_model_arguments, dict) + assert isinstance(args.additional_model_arguments["int"], int) + assert isinstance(args.additional_model_arguments["bool"], bool) + assert isinstance(args.additional_model_arguments["float"], float) + + def test_no_additional_model_args(self) -> None: + parser = ArgumentParserPlus(FlatArguments) + # NOTE: the boolean must be lower case, true not True + (args,) = parser.parse_args_into_dataclasses(["--exp_name", "test"]) + # Should get a empty dict + assert isinstance(args.additional_model_arguments, dict) + assert not args.additional_model_arguments