Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions tests/test_modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import pytest
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from parameterized import parameterized
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model

Expand Down Expand Up @@ -248,16 +249,17 @@ def test_dropout_kwargs(self):
# Check if v head of the model has the same dropout as the config
assert model.v_head.dropout.p == 0.5

def test_generate(self):
@parameterized.expand(ALL_CAUSAL_LM_MODELS)
def test_generate(self, model_name):
r"""
Test if `generate` works for every model
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
generation_config = GenerationConfig(max_new_tokens=9)
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

# Just check if the generation works
_ = model.generate(input_ids)
# Just check if the generation works
_ = model.generate(input_ids, generation_config=generation_config)

def test_raise_error_not_causallm(self):
# Test with a model without a LM head
Expand Down Expand Up @@ -370,17 +372,18 @@ def test_dropout_kwargs(self):
# Check if v head of the model has the same dropout as the config
assert model.v_head.dropout.p == 0.5

def test_generate(self):
@parameterized.expand(ALL_SEQ2SEQ_MODELS)
def test_generate(self, model_name):
r"""
Test if `generate` works for every model
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
generation_config = GenerationConfig(max_new_tokens=9)
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

# Just check if the generation works
_ = model.generate(input_ids, decoder_input_ids=decoder_input_ids)
# Just check if the generation works
_ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config)

def test_raise_error_not_causallm(self):
# Test with a model without a LM head
Expand Down