From a55e26e0e7e20c2329b936980bb42026306e19b9 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 6 Jun 2024 22:53:47 +0000 Subject: [PATCH] fix yaml parser for derived config classes fixes #1712 reformatted cli_utils with ruff --- trl/commands/cli_utils.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 7f421efa9c3..ce08d7b6309 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -75,7 +75,9 @@ def merge_dataclasses(self, dataclasses): field_name = data_class_field.name field_value = getattr(dataclass, field_name) - if not isinstance(dataclass, TrainingArguments): + if not isinstance(dataclass, TrainingArguments) or not hasattr( + self._dummy_training_args, field_name + ): default_value = data_class_field.default else: default_value = ( @@ -95,12 +97,13 @@ def merge_dataclasses(self, dataclasses): setattr(dataclasses_copy[i], field_name, value_to_replace) # Otherwise do nothing - # Re-init `TrainingArguments` to handle all post-processing correctly + # Re-init `TrainingArguments` or derived class to handle all post-processing correctly if is_hf_training_args: - init_signature = list(inspect.signature(TrainingArguments.__init__).parameters) + ArgCls = type(dataclass) + init_signature = list(inspect.signature(ArgCls.__init__).parameters) dict_dataclass = asdict(dataclasses_copy[i]) new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature} - dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass) + dataclasses_copy[i] = ArgCls(**new_dict_dataclass) return dataclasses_copy @@ -141,12 +144,16 @@ def warning_handler(message, category, filename, lineno, file=None, line=None): @dataclass class SFTScriptArguments: - dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"}) + dataset_name: str = field( + default="timdettmers/openassistant-guanaco", + metadata={"help": "the dataset name"}, + ) dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"}) dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"}) config: str = field(default=None, metadata={"help": "Path to the optional config file"}) gradient_checkpointing_use_reentrant: bool = field( - default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"} + default=False, + metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, ) @@ -166,7 +173,8 @@ class DPOScriptArguments: ) config: str = field(default=None, metadata={"help": "Path to the optional config file"}) gradient_checkpointing_use_reentrant: bool = field( - default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"} + default=False, + metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, ) @@ -229,10 +237,12 @@ class ChatArguments: }, ) load_in_8bit: bool = field( - default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"} + default=False, + metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}, ) load_in_4bit: bool = field( - default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"} + default=False, + metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}, ) bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) @@ -264,7 +274,10 @@ def post_process_dataclasses(self, dataclasses): if dataclass_obj.__class__.__name__ == "TrainingArguments": training_args = dataclass_obj training_args_index = i - elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"): + elif dataclass_obj.__class__.__name__ in ( + "SFTScriptArguments", + "DPOScriptArguments", + ): trl_args = dataclass_obj else: ...