Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ class FlatArguments:
" Use only when tokenizer does not add bos token by default."
},
)
additional_model_arguments: Optional[list[str]] = field(
default=None,
metadata={"help": "A list of key:val to be passed as additional model args."},
)
clip_grad_norm: float = field(
default=-1,
metadata={
Expand Down Expand Up @@ -448,8 +452,8 @@ def __post_init__(self):
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["json", "jsonl"], (
"`train_file` should be a json or a jsonl file."
assert extension in ["json", "jsonl", "parquet"], (
"`train_file` should be a json or a jsonl or parquet file."
)
if (
(
Expand All @@ -472,6 +476,30 @@ def __post_init__(self):
"Cannot launch Beaker evaluation jobs without pushing to the Hub."
)

if self.additional_model_arguments is not None:
import re
maybe_convert_ = lambda x: (
float(x) if x.count('.') == 1 and re.sub('^-?.*\.', '', x, count=1).isnumeric() else
(
int(x) if x.count('.') == 0 and re.sub('^-?', '', x).isnumeric() else
x
)
)
try:
self.additional_model_arguments = [
x.split(":") for x in self.additional_model_arguments
]
self.additional_model_arguments = {
k:maybe_convert_(v) for k, v, in self.additional_model_arguments
}
except IndexError:
raise ValueError(
"Malformed additional model arguments. "
"Should be space-delimited list of key:val."
)
else:
self.additional_model_arguments = {}


def encode_sft_example(example, tokenizer, max_seq_length):
"""
Expand Down Expand Up @@ -650,9 +678,12 @@ def main(args: FlatArguments):
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
data_type = "json"
if args.train_file.endswith('.parquet'):
data_type = "parquet"
with accelerator.main_process_first():
raw_datasets = load_dataset(
"json",
data_type,
data_files=data_files,
**dataset_args,
)
Expand All @@ -663,12 +694,14 @@ def main(args: FlatArguments):
args.config_name,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
**args.additional_model_arguments,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
**args.additional_model_arguments,
)
else:
raise ValueError(
Expand Down