diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 9f1b9e57d7..ea8940d8d3 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -1402,6 +1402,8 @@ def hybrid_forward(self, F, step_data, states): step_data = step_data * np.sqrt(self.model.dec_units) if self.model.pos_embed_type is not None: step_data = step_data + self.model.tgt_pos_embed_layer(position) + if self.model.layernorm_embedding: + step_data = self.tgt_embed_ln(step_data) out, new_states =\ self.model.decoder.incremental_decode(F, step_data, dec_states, mem_data, mem_valid_length)