Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def merge_dataclasses(self, dataclasses):
field_name = data_class_field.name
field_value = getattr(dataclass, field_name)

if not isinstance(dataclass, TrainingArguments):
if not isinstance(dataclass, TrainingArguments) or not hasattr(
self._dummy_training_args, field_name
):
default_value = data_class_field.default
else:
default_value = (
Expand All @@ -95,12 +97,13 @@ def merge_dataclasses(self, dataclasses):
setattr(dataclasses_copy[i], field_name, value_to_replace)
# Otherwise do nothing

# Re-init `TrainingArguments` to handle all post-processing correctly
# Re-init `TrainingArguments` or derived class to handle all post-processing correctly
if is_hf_training_args:
init_signature = list(inspect.signature(TrainingArguments.__init__).parameters)
ArgCls = type(dataclass)
init_signature = list(inspect.signature(ArgCls.__init__).parameters)
dict_dataclass = asdict(dataclasses_copy[i])
new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature}
dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass)
dataclasses_copy[i] = ArgCls(**new_dict_dataclass)

return dataclasses_copy

Expand Down Expand Up @@ -141,12 +144,16 @@ def warning_handler(message, category, filename, lineno, file=None, line=None):

@dataclass
class SFTScriptArguments:
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
dataset_name: str = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "the dataset name"},
)
dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"})
dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
default=False,
metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"},
)


Expand All @@ -166,7 +173,8 @@ class DPOScriptArguments:
)
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
default=False,
metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"},
)


Expand Down Expand Up @@ -229,10 +237,12 @@ class ChatArguments:
},
)
load_in_8bit: bool = field(
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
default=False,
metadata={"help": "use 8 bit precision for the base model - works only with LoRA"},
)
load_in_4bit: bool = field(
default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}
default=False,
metadata={"help": "use 4 bit precision for the base model - works only with LoRA"},
)

bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
Expand Down Expand Up @@ -264,7 +274,10 @@ def post_process_dataclasses(self, dataclasses):
if dataclass_obj.__class__.__name__ == "TrainingArguments":
training_args = dataclass_obj
training_args_index = i
elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"):
elif dataclass_obj.__class__.__name__ in (
"SFTScriptArguments",
"DPOScriptArguments",
):
trl_args = dataclass_obj
else:
...
Expand Down