diff --git a/trl/experimental/gkd/gkd_config.py b/trl/experimental/gkd/gkd_config.py index ad0e854ba6c..622bb327a7c 100644 --- a/trl/experimental/gkd/gkd_config.py +++ b/trl/experimental/gkd/gkd_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.sft_config import SFTConfig @@ -52,7 +50,7 @@ class GKDConfig(SFTConfig): teacher-generated output). """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + _VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] temperature: float = field( default=0.9, diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 827b639dec8..5850dac084a 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.sft_config import SFTConfig @@ -94,7 +92,7 @@ class GOLDConfig(SFTConfig): low, but waking the engine adds host–device transfer latency. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + _VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field(