From 66db4b93f1a2959da22c550ea3a5ebf37e9beffd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 21 Jan 2025 19:10:33 +0000 Subject: [PATCH] fix gpt2 generation tests --- tests/models/gpt2/test_modeling_gpt2.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 88ccdc8ee45a..ba8eb90c5ea1 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -651,16 +651,18 @@ def test_batch_generation(self): outputs = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), + max_length=20, ) outputs_tt = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), token_type_ids=token_type_ids, + max_length=20, ) inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) + output_non_padded = model.generate(input_ids=inputs_non_padded, max_length=20) num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) @@ -711,16 +713,18 @@ def test_batch_generation_2heads(self): outputs = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), + max_length=20, ) outputs_tt = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), token_type_ids=token_type_ids, + max_length=20, ) inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) + output_non_padded = model.generate(input_ids=inputs_non_padded, max_length=20) num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) @@ -776,7 +780,7 @@ def _test_lm_generate_gpt2_helper( # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,] # fmt: skip - output_ids = model.generate(input_ids, do_sample=False) + output_ids = model.generate(input_ids, do_sample=False, max_length=20) if verify_outputs: self.assertListEqual(output_ids[0].tolist(), expected_output_ids) @@ -805,13 +809,13 @@ def test_gpt2_sample(self): torch.manual_seed(0) tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) input_ids = tokenized.input_ids.to(torch_device) - output_ids = model.generate(input_ids, do_sample=True) + output_ids = model.generate(input_ids, do_sample=True, max_length=20) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) token_type_ids = tokenized.token_type_ids.to(torch_device) - output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) + output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5, max_length=20) output_seq_tt = model.generate( - input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 + input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5, max_length=20 ) output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)