diff --git a/src/gluonnlp/model/bert.py b/src/gluonnlp/model/bert.py index ed29c8aaae..5bd8b1a8ef 100644 --- a/src/gluonnlp/model/bert.py +++ b/src/gluonnlp/model/bert.py @@ -134,7 +134,11 @@ def hybrid_forward(self, F, qkv, valid_len, query_bias, key_bias, value_bias, value_weight = value_weight.reshape(shape=(self._num_heads, -1, 0), reverse=True) in_weight = F.concat(query_weight, key_weight, value_weight, dim=-2) in_weight = in_weight.reshape(shape=(-1, 0), reverse=True) - in_bias = F.concat(query_bias, key_bias, value_bias, dim=0) + # concat bias + query_bias = query_bias.reshape(shape=(self._num_heads, -1), reverse=True) + key_bias = key_bias.reshape(shape=(self._num_heads, -1), reverse=True) + value_bias = value_bias.reshape(shape=(self._num_heads, -1), reverse=True) + in_bias = F.stack(query_bias, key_bias, value_bias, axis=1).reshape(-1) # qkv_proj shape = (seq_length, batch_size, num_heads * head_dim * 3) qkv_proj = F.FullyConnected(data=qkv, weight=in_weight, bias=in_bias,