diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index b6671b2faf..ea66e27b23 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -196,7 +196,7 @@ def __init__(self, bias_initializer=bias_initializer, dtype=self._dtype) attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' - self.attention_cell =\ + self.self_attention =\ MultiHeadAttentionCell( query_units=self._units, num_heads=self._num_heads,