Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/test_reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ def test_collate_with_margin(self):


class TestRewardTrainer(TrlTestCase):
def test_raises_error_when_model_num_labels_not_one(self):
"""Test that RewardTrainer raises ValueError when model doesn't have num_labels=1."""
model = AutoModelForSequenceClassification.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
dtype="float32",
# num_labels=2, # Defaults to 2 num_labels for causal models
)

training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
with pytest.raises(ValueError, match=r"reward models require `num_labels=1`"):
RewardTrainer(model=model, args=training_args)

@pytest.mark.parametrize(
"model_id",
[
Expand Down Expand Up @@ -176,6 +188,7 @@ def test_train_model(self):
# Instantiate the model
model = AutoModelForSequenceClassification.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
num_labels=1, # required for reward models
dtype="float32",
)

Expand Down Expand Up @@ -341,7 +354,11 @@ def test_train_moe_with_peft_config(self):
def test_train_peft_model(self):
# Get the base model
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForSequenceClassification.from_pretrained(model_id, dtype="float32")
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=1, # required for reward models
dtype="float32",
)

# Get the base model parameter names
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ def __init__(
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
# Validate that the model has num_labels = 1 (required for reward models)
if getattr(model.config, "num_labels", None) != 1:
raise ValueError(
f"The model has `num_labels={model.config.num_labels}`, but reward models require `num_labels=1` "
"to output a single scalar reward per sequence. Please instantiate your model with `num_labels=1` "
"or pass a model name as a string to have it configured automatically."
)

# Processing class
if processing_class is None:
Expand Down
Loading