diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 90be24348eb..c1bcd97b713 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -107,6 +107,18 @@ def test_collate_with_margin(self): class TestRewardTrainer(TrlTestCase): + def test_raises_error_when_model_num_labels_not_one(self): + """Test that RewardTrainer raises ValueError when model doesn't have num_labels=1.""" + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + dtype="float32", + # num_labels=2, # Defaults to 2 num_labels for causal models + ) + + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + with pytest.raises(ValueError, match=r"reward models require `num_labels=1`"): + RewardTrainer(model=model, args=training_args) + @pytest.mark.parametrize( "model_id", [ @@ -176,6 +188,7 @@ def test_train_model(self): # Instantiate the model model = AutoModelForSequenceClassification.from_pretrained( "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + num_labels=1, # required for reward models dtype="float32", ) @@ -341,7 +354,11 @@ def test_train_moe_with_peft_config(self): def test_train_peft_model(self): # Get the base model model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" - model = AutoModelForSequenceClassification.from_pretrained(model_id, dtype="float32") + model = AutoModelForSequenceClassification.from_pretrained( + model_id, + num_labels=1, # required for reward models + dtype="float32", + ) # Get the base model parameter names base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 2677901726b..9c309060e3c 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -354,6 +354,13 @@ def __init__( "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. " "The `model_init_kwargs` will be ignored." ) + # Validate that the model has num_labels = 1 (required for reward models) + if getattr(model.config, "num_labels", None) != 1: + raise ValueError( + f"The model has `num_labels={model.config.num_labels}`, but reward models require `num_labels=1` " + "to output a single scalar reward per sequence. Please instantiate your model with `num_labels=1` " + "or pass a model name as a string to have it configured automatically." + ) # Processing class if processing_class is None: