diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 655ea55eeb56..59f81d31459c 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -316,7 +316,7 @@ def __init__(self, config): if config.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size should be divisible by num_attention_heads") - self.attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2 self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) @@ -413,7 +413,10 @@ def forward( conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size]) context_layer = torch.cat([context_layer, conv_out], 2) - new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,) + # conv and context + new_context_layer_shape = context_layer.size()[:-2] + ( + self.num_attention_heads * self.attention_head_size * 2, + ) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/tests/models/convbert/test_modeling_convbert.py b/tests/models/convbert/test_modeling_convbert.py index 28762ebcfb46..773b37633729 100644 --- a/tests/models/convbert/test_modeling_convbert.py +++ b/tests/models/convbert/test_modeling_convbert.py @@ -446,6 +446,11 @@ def test_model_for_input_embeds(self): result = model(inputs_embeds=inputs_embeds) self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size)) + def test_reducing_attention_heads(self): + config, *inputs_dict = self.model_tester.prepare_config_and_inputs() + config.head_ratio = 4 + self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict) + @require_torch class ConvBertModelIntegrationTest(unittest.TestCase):