Skip to content

Commit

Permalink
fix roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 3, 2020
1 parent 4fc564c commit bd270f2
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ def __init__(self,
)
self.encoder.hybridize()

if self.use_pooler:
# Construct pooler
self.pooler = nn.Dense(units=self.units,
in_units=self.units,
flatten=False,
activation=self.pooler_activation,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='pooler_')

if self.use_mlm:
embed_weight = None if untie_weight else \
self.tokens_embed.collect_params('.*weight')
Expand Down Expand Up @@ -293,7 +303,7 @@ def apply_pooling(self, sequence):
Shape (batch_size, units)
"""
outputs = sequence[:, 0, :]
return outputs
return self.pooler(outputs)

@staticmethod
def get_cfg(key=None):
Expand Down

0 comments on commit bd270f2

Please sign in to comment.