From 6533601ddb1ef540d4a9622131f4fd8dee386131 Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Thu, 30 Jul 2020 01:27:44 +0800 Subject: [PATCH] fix comment --- src/gluonnlp/models/bart.py | 27 +++++++++++++++++---------- src/gluonnlp/models/transformer.py | 7 ++++--- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index f4e70d1bc7..d55139f4ab 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -51,7 +51,7 @@ @bart_cfg_reg.register() -def fair_bart_base(): +def bart_base(): cfg = CN() # Config for the bart base model cfg.MODEL = CN() @@ -102,8 +102,8 @@ def fair_bart_base(): @bart_cfg_reg.register() -def fair_bart_large(): - cfg = fair_bart_base() +def bart_large(): + cfg = bart_base() cfg.defrost() cfg.MODEL.vocab_size = 50265 cfg.MODEL.ENCODER.units = 1024 @@ -122,14 +122,14 @@ def fair_bart_large(): PRETRAINED_URL = { 'fairseq_bart_base': { - 'cfg': fair_bart_base(), + 'cfg': bart_base(), 'merges': 'fairseq_bart_base/gpt2-396d4d8e.merges', 'vocab': 'fairseq_bart_base/gpt2-f4dedacb.vocab', 'params': 'fairseq_bart_base/model-6dea1e11.params', 'lowercase': False, }, 'fairseq_bart_large': { - 'cfg': fair_bart_large(), + 'cfg': bart_large(), 'merges': 'fairseq_bart_large/gpt2-396d4d8e.merges', 'vocab': 'fairseq_bart_large/gpt2-f1335494.vocab', 'params': 'fairseq_bart_large/model-38f35552.params', @@ -183,8 +183,9 @@ def __init__(self, in_units=self.units, flatten=False, activation=pooler_activation, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + weight_initializer=self.weight_initializer, + bias_initializer=self.bias_initializer, + dtype=self._dtype) def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length): """ @@ -242,7 +243,12 @@ def apply_pooling(self, sequence): return: Shape (batch_size, units) """ - outputs = sequence[:, 0, :] + if self._layout == 'NT': + outputs = sequence[:, 0, :] + elif self._layout == 'TN': + outputs = sequence[0, :, :] + else: + raise NotImplementedError if self.classifier_activation: return self.pooler(outputs) else: @@ -259,7 +265,7 @@ def vocab_size(self): @classmethod def get_cfg(cls, key=None): if key is None: - return fair_bart_base() + return bart_base() else: return bart_cfg_reg.create(key) @@ -300,6 +306,7 @@ def from_cfg(cls, cfg, dec_recurrent=cfg.MODEL.DECODER.recurrent, dec_activation=cfg.MODEL.DECODER.activation, dec_pre_norm=cfg.MODEL.DECODER.pre_norm, + layout=cfg.MODEL.layout, embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -310,7 +317,7 @@ def list_pretrained_bart(): return sorted(list(PRETRAINED_URL.keys())) -def get_pretrained_bart(model_name: str = 'fairseq_roberta_base', +def get_pretrained_bart(model_name: str = 'fairseq_bart_base', root: str = get_model_zoo_home_dir(), load_backbone: bool = True) \ -> Tuple[CN, HuggingFaceByteBPETokenizer, str]: diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index ea8940d8d3..194844ee53 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -196,7 +196,7 @@ def __init__(self, bias_initializer=bias_initializer, dtype=self._dtype) attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' - self.self_attention = \ + self.attention_cell = \ MultiHeadAttentionCell( query_units=self._units, num_heads=self._num_heads, @@ -252,7 +252,7 @@ def hybrid_forward(self, F, data, attn_mask): query = F.npx.reshape(query, (-2, -2, self._num_heads, -1)) key = F.npx.reshape(key, (-2, -2, self._num_heads, -1)) value = F.npx.reshape(value, (-2, -2, self._num_heads, -1)) - out, [_, attn_weight] = self.self_attention(query, key, value, attn_mask) + out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask) out = self.attention_proj(out) out = self.dropout_layer(out) out = out + data @@ -261,7 +261,6 @@ def hybrid_forward(self, F, data, attn_mask): out = self.ffn(out) return out, attn_weight - @use_np class TransformerEncoder(HybridBlock): def __init__(self, num_layers=6, recurrent=False, @@ -441,6 +440,7 @@ def __init__(self, units: int = 512, num_heads=num_heads, attention_dropout=self._attention_dropout, dtype=dtype, + layout=attention_layout, layout='NTK') self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True, weight_initializer=weight_initializer, @@ -1032,6 +1032,7 @@ def __init__(self, src_vocab_size: int, self.enc_units = enc_units self.dec_units = dec_units self.weight_initializer = weight_initializer + self.bias_initializer = bias_initializer self._layout = layout assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ 'Only "TN" and "NT" are accepted!'.format(layout)