Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent d9c4140 commit 9623240
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(self,
bias_initializer=bias_initializer,
dtype=self._dtype)
attention_layout = 'NTK' if self._layout == 'NT' else 'TNK'
self.self_attention =\
self.self_attention = \
MultiHeadAttentionCell(
query_units=self._units,
num_heads=self._num_heads,
Expand Down Expand Up @@ -1163,6 +1163,9 @@ def encode(self, F, src_data, src_valid_length):
else:
src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer(
F.npx.arange_like(src_data, axis=0)), axis=1)
if self.layernorm_embedding:
src_data = self.src_embed_ln(src_data)

enc_out = self.encoder(src_data, src_valid_length)
return enc_out

Expand Down Expand Up @@ -1205,6 +1208,8 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length):
else:
tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer(
F.npx.arange_like(tgt_data, axis=0)), axis=1)
if self.layernorm_embedding:
tgt_data = self.src_embed_ln(tgt_data)
dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length)
return dec_out

Expand Down

0 comments on commit 9623240

Please sign in to comment.