diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b45a2810cc03..0d21f30f490c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -144,7 +144,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) self.scale_attn_weights = config.scale_attn_weights - self.scaling = self.head_dim**0.5 if config.scale_attn_weights else 1.0 + self.scaling = self.head_dim**-0.5 if config.scale_attn_weights else 1.0 self.is_cross_attention = is_cross_attention self.layer_idx = layer_idx diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 21d1764c76e0..a57dc883f3ca 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -29,6 +29,7 @@ import torch from transformers import ( + AutoTokenizer, GPT2TokenizerFast, GPTBigCodeForCausalLM, GPTBigCodeForSequenceClassification, @@ -510,7 +511,7 @@ def test_generate_simple(self): output_sequence = model.generate(input_ids) output_sentence = tokenizer.decode(output_sequence[0], skip_special_tokens=True) - expected_output = """def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_""" + expected_output = 'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_world_with_args(name' # fmt: skip self.assertEqual(output_sentence, expected_output) def test_generate_batched(self): @@ -527,11 +528,27 @@ def test_generate_batched(self): outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) expected_output = [ - 'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_', - 'def say_hello():\n print("Hello, World!")\n\n\nsay_hello()', + 'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_world_with_args(name', + 'def say_hello():\n print("Hello, World!")\n\n\nsay_hello()\n', ] self.assertListEqual(outputs, expected_output) + def test_newline_regression(self): + """Added to prevent regressions regarding attention (scaling) indicated by excessive newlines""" + tokenizer = AutoTokenizer.from_pretrained("bigcode/tiny_starcoder_py") + model = GPTBigCodeForCausalLM.from_pretrained("bigcode/tiny_starcoder_py").to(torch_device) + + input_ids = tokenizer( + "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n", + return_tensors="pt", + ).input_ids.to(torch_device) + + output_sequence = model.generate(input_ids, max_new_tokens=20, do_sample=False) + output_sentence = tokenizer.decode(output_sequence[0], skip_special_tokens=True) + + expected_output = 'Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n\nThe impact of the COVID-19 pandemic on global economic structures and future business' # fmt: skip + self.assertEqual(output_sentence, expected_output) + @require_torch class GPTBigCodeMQATest(unittest.TestCase):