Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
except Exception:
pass
# isort: on
import json
import logging
import math
import os
Expand Down Expand Up @@ -54,6 +55,7 @@
DataCollatorForSeq2Seq,
get_scheduler,
)
from transformers.training_args import _convert_str_dict

from open_instruct.dataset_transformation import (
INPUT_IDS_KEY,
Expand Down Expand Up @@ -84,6 +86,13 @@ class FlatArguments:
"""
Full arguments class for all fine-tuning jobs.
"""
# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"additional_model_arguments",
]

exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of this experiment"""
Expand Down Expand Up @@ -379,6 +388,10 @@ class FlatArguments:
},
)
add_seed_and_date_to_exp_name: bool = True
additional_model_arguments: Optional[Union[dict, str]] = field(
default_factory=dict,
metadata={"help": "A dictionary of additional model args used to construct the model."},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
Expand All @@ -399,6 +412,17 @@ def __post_init__(self):
if not (1.0 >= self.final_lr_ratio >= 0.0):
raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}")

# Parse in args that could be `dict` sent in from the CLI as a string
for dict_feld in self._VALID_DICT_FIELDS:
passed_value = getattr(self, dict_feld)
# We only want to do this if the str starts with a bracket to indicate a `dict`
# else its likely a filename if supported
if isinstance(passed_value, str) and passed_value.startswith("{"):
loaded_dict = json.loads(passed_value)
# Convert str values to types if applicable
loaded_dict = _convert_str_dict(loaded_dict)
setattr(self, dict_feld, loaded_dict)


def main(args: FlatArguments, tc: TokenizerConfig):
# ------------------------------------------------------------
Expand Down Expand Up @@ -549,12 +573,14 @@ def main(args: FlatArguments, tc: TokenizerConfig):
args.config_name,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import requests
from litellm import acompletion

from IFEvalG import instructions_registry
from open_instruct.if_functions import IF_FUNCTIONS_MAP
from open_instruct.IFEvalG import instructions_registry
from open_instruct.judge_utils import (
EXTRACTOR_MAP,
JUDGE_PROMPT_MAP,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_flat_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from open_instruct.finetune import FlatArguments
from open_instruct.utils import ArgumentParserPlus


class TestFlatArguments:
def test_additional_model_args(self) -> None:
parser = ArgumentParserPlus(FlatArguments)
# NOTE: the boolean must be lower case, true not True
(args,) = parser.parse_args_into_dataclasses(
["--additional_model_arguments", '{"int": 1, "bool": true, "float": 0.0}']
)
assert isinstance(args.additional_model_arguments, dict)
assert isinstance(args.additional_model_arguments["int"], int)
assert isinstance(args.additional_model_arguments["bool"], bool)
assert isinstance(args.additional_model_arguments["float"], float)

def test_no_additional_model_args(self) -> None:
parser = ArgumentParserPlus(FlatArguments)
# NOTE: the boolean must be lower case, true not True
(args,) = parser.parse_args_into_dataclasses(["--exp_name", "test"])
# Should get a empty dict
assert isinstance(args.additional_model_arguments, dict)
assert not args.additional_model_arguments