From 962324004a60058464d8e55def76b1f5cea6e0bc Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Thu, 30 Jul 2020 00:52:17 +0800 Subject: [PATCH] fix --- src/gluonnlp/models/transformer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index ea66e27b23..90927ccd24 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -196,7 +196,7 @@ def __init__(self, bias_initializer=bias_initializer, dtype=self._dtype) attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' - self.self_attention =\ + self.self_attention = \ MultiHeadAttentionCell( query_units=self._units, num_heads=self._num_heads, @@ -1163,6 +1163,9 @@ def encode(self, F, src_data, src_valid_length): else: src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer( F.npx.arange_like(src_data, axis=0)), axis=1) + if self.layernorm_embedding: + src_data = self.src_embed_ln(src_data) + enc_out = self.encoder(src_data, src_valid_length) return enc_out @@ -1205,6 +1208,8 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length): else: tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer( F.npx.arange_like(tgt_data, axis=0)), axis=1) + if self.layernorm_embedding: + tgt_data = self.src_embed_ln(tgt_data) dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length) return dec_out