diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 71173fa2a3f..7bd15451a20 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -223,6 +223,7 @@ def test_sft_trainer_backward_compatibility(self): eval_steps=2, save_steps=2, per_device_train_batch_size=2, + hub_token="not_a_real_token", ) trainer = SFTTrainer( @@ -232,6 +233,8 @@ def test_sft_trainer_backward_compatibility(self): eval_dataset=self.eval_dataset, ) + assert trainer.args.hub_token == training_args.hub_token + trainer.train() assert trainer.state.log_history[(-1)]["train_loss"] is not None diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index e739b2d92a1..0e304ab9a85 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -146,7 +146,10 @@ def __init__( warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.") args = SFTConfig(output_dir=output_dir) elif args is not None and args.__class__.__name__ == "TrainingArguments": - args = SFTConfig(**args.to_dict()) + args_as_dict = args.to_dict() + # Manually copy token values as TrainingArguments.to_dict() redacts them + args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")}) + args = SFTConfig(**args_as_dict) if model_init_kwargs is not None: warnings.warn(