diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index ede1b916175..a1d4503540f 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -18,7 +18,8 @@ import pytest import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM +from parameterized import parameterized +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model @@ -248,16 +249,17 @@ def test_dropout_kwargs(self): # Check if v head of the model has the same dropout as the config assert model.v_head.dropout.p == 0.5 - def test_generate(self): + @parameterized.expand(ALL_CAUSAL_LM_MODELS) + def test_generate(self, model_name): r""" Test if `generate` works for every model """ - for model_name in self.all_model_names: - model = self.trl_model_class.from_pretrained(model_name) - input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + generation_config = GenerationConfig(max_new_tokens=9) + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) - # Just check if the generation works - _ = model.generate(input_ids) + # Just check if the generation works + _ = model.generate(input_ids, generation_config=generation_config) def test_raise_error_not_causallm(self): # Test with a model without a LM head @@ -370,17 +372,18 @@ def test_dropout_kwargs(self): # Check if v head of the model has the same dropout as the config assert model.v_head.dropout.p == 0.5 - def test_generate(self): + @parameterized.expand(ALL_SEQ2SEQ_MODELS) + def test_generate(self, model_name): r""" Test if `generate` works for every model """ - for model_name in self.all_model_names: - model = self.trl_model_class.from_pretrained(model_name) - input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) - decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + generation_config = GenerationConfig(max_new_tokens=9) + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) - # Just check if the generation works - _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids) + # Just check if the generation works + _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config) def test_raise_error_not_causallm(self): # Test with a model without a LM head