diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 559a03e3e6c0..6b5b2f093328 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -366,6 +366,47 @@ def test_generation_pre_attn_layer_norm(self): self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS) + def test_batch_generation(self): + model_id = "facebook/opt-350m" + + tokenizer = GPT2Tokenizer.from_pretrained(model_id) + model = OPTForCausalLM.from_pretrained(model_id) + model.to(torch_device) + + tokenizer.padding_side = "left" + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little bit of a dork.\nI'm a little bit", + "Today, I was in the middle of a conversation with a friend about the", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence]) + def test_generation_post_attn_layer_norm(self): model_id = "facebook/opt-350m"