diff --git a/trl/experimental/bco/bco_config.py b/trl/experimental/bco/bco_config.py index 1ed8735253b..b6a7d81ff06 100644 --- a/trl/experimental/bco/bco_config.py +++ b/trl/experimental/bco/bco_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.base_config import _BaseConfig @@ -78,7 +76,7 @@ class BCOConfig(_BaseConfig): > - `learning_rate`: Defaults to `5e-7` instead of `5e-5`. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( diff --git a/trl/experimental/cpo/cpo_config.py b/trl/experimental/cpo/cpo_config.py index c50861f247b..5ca2f11a6b7 100644 --- a/trl/experimental/cpo/cpo_config.py +++ b/trl/experimental/cpo/cpo_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.base_config import _BaseConfig @@ -91,7 +89,7 @@ class CPOConfig(_BaseConfig): > - `learning_rate`: Defaults to `1e-6` instead of `5e-5`. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( diff --git a/trl/experimental/kto/kto_config.py b/trl/experimental/kto/kto_config.py index 9974cce2a9c..9c796f110b9 100644 --- a/trl/experimental/kto/kto_config.py +++ b/trl/experimental/kto/kto_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.base_config import _BaseConfig @@ -74,7 +72,7 @@ class KTOConfig(_BaseConfig): > - `learning_rate`: Defaults to `1e-6` instead of `5e-5`. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( diff --git a/trl/experimental/online_dpo/online_dpo_config.py b/trl/experimental/online_dpo/online_dpo_config.py index aaffac55aff..5794f347e49 100644 --- a/trl/experimental/online_dpo/online_dpo_config.py +++ b/trl/experimental/online_dpo/online_dpo_config.py @@ -16,8 +16,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.base_config import _BaseConfig @@ -161,7 +159,7 @@ class may differ from those in [`~transformers.TrainingArguments`]. > - `learning_rate`: Defaults to `5e-7` instead of `5e-5`. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( diff --git a/trl/experimental/orpo/orpo_config.py b/trl/experimental/orpo/orpo_config.py index 473ef64d2a6..ca392079bf1 100644 --- a/trl/experimental/orpo/orpo_config.py +++ b/trl/experimental/orpo/orpo_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.base_config import _BaseConfig @@ -71,7 +69,7 @@ class ORPOConfig(_BaseConfig): > - `learning_rate`: Defaults to `1e-6` instead of `5e-5`. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field(