diff --git a/scripts/language_model/model/qa.py b/scripts/language_model/model/qa.py index bcaa5be436..73619efff6 100644 --- a/scripts/language_model/model/qa.py +++ b/scripts/language_model/model/qa.py @@ -43,7 +43,7 @@ def __init__(self, units=768, is_eval=False, prefix=None, params=None): with self.name_scope(): self.dense_0 = nn.Dense(units, activation='tanh', flatten=False) self.dense_1 = nn.Dense(1, flatten=False) - self.layernorm = nn.LayerNorm(epsilon=1e-12, in_channels=768) + self.layernorm = nn.LayerNorm(epsilon=1e-12, in_channels=units) def __call__(self, hidden_states,