diff --git a/src/gluonnlp/model/bert.py b/src/gluonnlp/model/bert.py index 022d1103b2..876749f8d2 100644 --- a/src/gluonnlp/model/bert.py +++ b/src/gluonnlp/model/bert.py @@ -318,11 +318,12 @@ def __init__(self, *, num_layers=2, units=512, hidden_size=2048, self._output_attention = output_attention self._output_all_encodings = output_all_encodings self._dropout = dropout + self._layer_norm_eps = layer_norm_eps with self.name_scope(): if dropout: self.dropout_layer = nn.Dropout(rate=dropout) - self.layer_norm = nn.LayerNorm(in_channels=units, epsilon=1e-12) + self.layer_norm = nn.LayerNorm(in_channels=units, epsilon=self._layer_norm_eps) self.position_weight = self.params.get('position_weight', shape=(max_length, units), init=weight_initializer) self.transformer_cells = nn.HybridSequential() @@ -550,7 +551,7 @@ def _get_decoder(self, units, vocab_size, embed, prefix): decoder = nn.HybridSequential(prefix=prefix) decoder.add(nn.Dense(units, flatten=False)) decoder.add(GELU()) - decoder.add(nn.LayerNorm(in_channels=units, epsilon=1e-12)) + decoder.add(nn.LayerNorm(in_channels=units, epsilon=self.encoder._layer_norm_eps)) decoder.add(nn.Dense(vocab_size, flatten=False, params=embed.collect_params())) assert decoder[3].weight == list(embed.collect_params().values())[0], \ 'The weights of word embedding are not tied with those of decoder'