Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent a8853f9 commit 6533601
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
27 changes: 17 additions & 10 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


@bart_cfg_reg.register()
def fair_bart_base():
def bart_base():
cfg = CN()
# Config for the bart base model
cfg.MODEL = CN()
Expand Down Expand Up @@ -102,8 +102,8 @@ def fair_bart_base():


@bart_cfg_reg.register()
def fair_bart_large():
cfg = fair_bart_base()
def bart_large():
cfg = bart_base()
cfg.defrost()
cfg.MODEL.vocab_size = 50265
cfg.MODEL.ENCODER.units = 1024
Expand All @@ -122,14 +122,14 @@ def fair_bart_large():

PRETRAINED_URL = {
'fairseq_bart_base': {
'cfg': fair_bart_base(),
'cfg': bart_base(),
'merges': 'fairseq_bart_base/gpt2-396d4d8e.merges',
'vocab': 'fairseq_bart_base/gpt2-f4dedacb.vocab',
'params': 'fairseq_bart_base/model-6dea1e11.params',
'lowercase': False,
},
'fairseq_bart_large': {
'cfg': fair_bart_large(),
'cfg': bart_large(),
'merges': 'fairseq_bart_large/gpt2-396d4d8e.merges',
'vocab': 'fairseq_bart_large/gpt2-f1335494.vocab',
'params': 'fairseq_bart_large/model-38f35552.params',
Expand Down Expand Up @@ -183,8 +183,9 @@ def __init__(self,
in_units=self.units,
flatten=False,
activation=pooler_activation,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
weight_initializer=self.weight_initializer,
bias_initializer=self.bias_initializer,
dtype=self._dtype)

def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length):
"""
Expand Down Expand Up @@ -242,7 +243,12 @@ def apply_pooling(self, sequence):
return:
Shape (batch_size, units)
"""
outputs = sequence[:, 0, :]
if self._layout == 'NT':
outputs = sequence[:, 0, :]
elif self._layout == 'TN':
outputs = sequence[0, :, :]
else:
raise NotImplementedError
if self.classifier_activation:
return self.pooler(outputs)
else:
Expand All @@ -259,7 +265,7 @@ def vocab_size(self):
@classmethod
def get_cfg(cls, key=None):
if key is None:
return fair_bart_base()
return bart_base()
else:
return bart_cfg_reg.create(key)

Expand Down Expand Up @@ -300,6 +306,7 @@ def from_cfg(cls, cfg,
dec_recurrent=cfg.MODEL.DECODER.recurrent,
dec_activation=cfg.MODEL.DECODER.activation,
dec_pre_norm=cfg.MODEL.DECODER.pre_norm,
layout=cfg.MODEL.layout,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
Expand All @@ -310,7 +317,7 @@ def list_pretrained_bart():
return sorted(list(PRETRAINED_URL.keys()))


def get_pretrained_bart(model_name: str = 'fairseq_roberta_base',
def get_pretrained_bart(model_name: str = 'fairseq_bart_base',
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
Expand Down
7 changes: 4 additions & 3 deletions src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(self,
bias_initializer=bias_initializer,
dtype=self._dtype)
attention_layout = 'NTK' if self._layout == 'NT' else 'TNK'
self.self_attention = \
self.attention_cell = \
MultiHeadAttentionCell(
query_units=self._units,
num_heads=self._num_heads,
Expand Down Expand Up @@ -252,7 +252,7 @@ def hybrid_forward(self, F, data, attn_mask):
query = F.npx.reshape(query, (-2, -2, self._num_heads, -1))
key = F.npx.reshape(key, (-2, -2, self._num_heads, -1))
value = F.npx.reshape(value, (-2, -2, self._num_heads, -1))
out, [_, attn_weight] = self.self_attention(query, key, value, attn_mask)
out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask)
out = self.attention_proj(out)
out = self.dropout_layer(out)
out = out + data
Expand All @@ -261,7 +261,6 @@ def hybrid_forward(self, F, data, attn_mask):
out = self.ffn(out)
return out, attn_weight


@use_np
class TransformerEncoder(HybridBlock):
def __init__(self, num_layers=6, recurrent=False,
Expand Down Expand Up @@ -441,6 +440,7 @@ def __init__(self, units: int = 512,
num_heads=num_heads,
attention_dropout=self._attention_dropout,
dtype=dtype,
layout=attention_layout,
layout='NTK')
self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True,
weight_initializer=weight_initializer,
Expand Down Expand Up @@ -1032,6 +1032,7 @@ def __init__(self, src_vocab_size: int,
self.enc_units = enc_units
self.dec_units = dec_units
self.weight_initializer = weight_initializer
self.bias_initializer = bias_initializer
self._layout = layout
assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \
'Only "TN" and "NT" are accepted!'.format(layout)
Expand Down

0 comments on commit 6533601

Please sign in to comment.