diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 30729a6a41..71173fa2a3 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -25,6 +25,7 @@ AutoProcessor, AutoTokenizer, LlavaForConditionalGeneration, + TrainingArguments, ) from trl import SFTConfig, SFTTrainer @@ -213,6 +214,31 @@ def test_constant_length_dataset(self): decoded_text = self.tokenizer.decode(example["input_ids"]) assert ("Question" in decoded_text) and ("Answer" in decoded_text) + def test_sft_trainer_backward_compatibility(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + + assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2") + def test_sft_trainer(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = SFTConfig( diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 322a950177..2c2bc669a2 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -145,6 +145,8 @@ def __init__( output_dir = "tmp_trainer" warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.") args = SFTConfig(output_dir=output_dir) + elif args is not None and args.__class__.__name__ == "TrainingArguments": + args = SFTConfig(**args.to_dict()) if model_init_kwargs is not None: warnings.warn(