From 68c101fb698188a4f55e9c776078f1ed48b0b79e Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 10 Mar 2026 10:09:37 +0100 Subject: [PATCH] Add missing model_init_kwargs to _VALID_DICT_FIELDS --- trl/experimental/gkd/gkd_config.py | 4 +--- trl/experimental/gold/gold_config.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/trl/experimental/gkd/gkd_config.py b/trl/experimental/gkd/gkd_config.py index ad0e854ba6c..622bb327a7c 100644 --- a/trl/experimental/gkd/gkd_config.py +++ b/trl/experimental/gkd/gkd_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.sft_config import SFTConfig @@ -52,7 +50,7 @@ class GKDConfig(SFTConfig): teacher-generated output). """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + _VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] temperature: float = field( default=0.9, diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 827b639dec8..5850dac084a 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -15,8 +15,6 @@ from dataclasses import dataclass, field from typing import Any -from transformers import TrainingArguments - from ...trainer.sft_config import SFTConfig @@ -94,7 +92,7 @@ class GOLDConfig(SFTConfig): low, but waking the engine adds host–device transfer latency. """ - _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + _VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field(