Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/transformers/models/convbert/modeling_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def __init__(self, config):
if config.hidden_size % self.num_attention_heads != 0:
raise ValueError("hidden_size should be divisible by num_attention_heads")

self.attention_head_size = config.hidden_size // config.num_attention_heads
self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(config.hidden_size, self.all_head_size)
Expand Down Expand Up @@ -413,7 +413,8 @@ def forward(
conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
context_layer = torch.cat([context_layer, conv_out], 2)

new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,)
# conv and context
new_context_layer_shape = context_layer.size()[:-2] + (self.num_attention_heads * self.attention_head_size * 2,)
context_layer = context_layer.view(*new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
Expand Down
5 changes: 5 additions & 0 deletions tests/models/convbert/test_modeling_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@ def test_model_for_input_embeds(self):
result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))

def test_reducing_attention_heads(self):
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
config.head_ratio = 4
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)


@require_torch
class ConvBertModelIntegrationTest(unittest.TestCase):
Expand Down