Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Fix attn bias #1296

Merged
merged 1 commit into from
Aug 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/gluonnlp/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def hybrid_forward(self, F, qkv, valid_len, query_bias, key_bias, value_bias,
value_weight = value_weight.reshape(shape=(self._num_heads, -1, 0), reverse=True)
in_weight = F.concat(query_weight, key_weight, value_weight, dim=-2)
in_weight = in_weight.reshape(shape=(-1, 0), reverse=True)
in_bias = F.concat(query_bias, key_bias, value_bias, dim=0)
# concat bias
query_bias = query_bias.reshape(shape=(self._num_heads, -1), reverse=True)
key_bias = key_bias.reshape(shape=(self._num_heads, -1), reverse=True)
value_bias = value_bias.reshape(shape=(self._num_heads, -1), reverse=True)
in_bias = F.stack(query_bias, key_bias, value_bias, axis=1).reshape(-1)

# qkv_proj shape = (seq_length, batch_size, num_heads * head_dim * 3)
qkv_proj = F.FullyConnected(data=qkv, weight=in_weight, bias=in_bias,
Expand Down