diff --git a/scripts/conversion_toolkits/convert_fairseq_bart.py b/scripts/conversion_toolkits/convert_fairseq_bart.py index 257abf72a5..b90f0950d2 100644 --- a/scripts/conversion_toolkits/convert_fairseq_bart.py +++ b/scripts/conversion_toolkits/convert_fairseq_bart.py @@ -40,7 +40,7 @@ def convert_config(fairseq_cfg, vocab_size, cfg): cfg.MODEL.shared_embed = fairseq_cfg.share_all_embeddings cfg.MODEL.scale_embed = not fairseq_cfg.no_scale_embedding cfg.MODEL.tie_weights = fairseq_cfg.share_decoder_input_output_embed - cfg.MODEL.layernorm_embedding = fairseq_cfg.layernorm_embedding + cfg.MODEL.data_norm = fairseq_cfg.layernorm_embedding cfg.MODEL.pooler_activation = fairseq_cfg.pooler_activation_fn cfg.MODEL.layer_norm_eps = 1E-5 cfg.MODEL.dropout = fairseq_cfg.dropout @@ -111,26 +111,6 @@ def convert_attention(num_layers, gl_qkv_bias.set_data( np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0)) - def convert_embeddings(fairseq_prefix, gluon_prefix): - for k, v in [ - ('.embed_tokens.weight', '_embed_layer.weight'), - ('.layernorm_embedding.weight', '_embed_ln.gamma'), - ('.layernorm_embedding.bias', '_embed_ln.beta'), - ]: - fs_name = fairseq_prefix + k - gl_name = gluon_prefix + v - all_keys.remove(gl_name) - gluon_params[gl_name].set_data( - fairseq_params[fs_name].cpu().numpy()) - - # position embed weight - padding_idx = fairseq_model.task.dictionary.pad_index - fs_pos_embed_name = fairseq_prefix + '.embed_positions.weight' - gl_pos_embed_name = gluon_prefix + '_pos_embed_layer._embed.weight' - all_keys.remove(gl_pos_embed_name) - gluon_params[gl_pos_embed_name].set_data( - fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:, :]) - def convert_ffn(num_layers, fairseq_prefix, gluon_prefix): # convert feed forward layer in encoder for layer_id in range(num_layers): @@ -150,11 +130,33 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix): gluon_params[gl_name].set_data( fairseq_params[fs_name].cpu().numpy()) + print('converting embedding params') + padding_idx = fairseq_model.task.dictionary.pad_index + for fs_name, gl_name in [ + ('model.encoder.embed_tokens.weight', 'src_embed_layer.weight'), + ('model.encoder.embed_positions.weight', 'src_pos_embed_layer._embed.weight'), + ('model.encoder.layernorm_embedding.weight', 'encoder.ln_data.gamma'), + ('model.encoder.layernorm_embedding.bias', 'encoder.ln_data.beta'), + ('model.decoder.embed_tokens.weight', 'tgt_embed_layer.weight'), + ('model.decoder.embed_positions.weight', 'tgt_pos_embed_layer._embed.weight'), + ('model.decoder.layernorm_embedding.weight', 'decoder.ln_data.gamma'), + ('model.decoder.layernorm_embedding.bias', 'decoder.ln_data.beta'), + # final projection in decoder + ('model.decoder.output_projection.weight', 'tgt_final_layer.weight'), + ]: + all_keys.remove(gl_name) + if 'embed_positions' in fs_name: + # position embed weight + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()[padding_idx + 1:, :]) + else: + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()) + print('converting encoder params') encoder_num_layers = gluon_cfg.MODEL.ENCODER.num_layers convert_attention(encoder_num_layers, 'model.encoder', 'encoder') convert_ffn(encoder_num_layers, 'model.encoder', 'encoder') - convert_embeddings('model.encoder', 'src') for layer_id in range(encoder_num_layers): for k, v in [ ('self_attn.out_proj.weight', 'attention_proj.weight'), @@ -170,6 +172,7 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix): gluon_params[gl_name].set_data( fairseq_params[fs_name].cpu().numpy()) + print('converting decoder params') decoder_num_layers = gluon_cfg.MODEL.DECODER.num_layers convert_attention(decoder_num_layers, 'model.decoder', 'decoder', gluon_attn_prefix='attn_in_qkv') @@ -201,14 +204,6 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix): gluon_params[gl_name].set_data( fairseq_params[fs_name].cpu().numpy()) - convert_embeddings('model.decoder', 'tgt') - # final projection in decoder - for fs_name, gl_name in [ - ('model.decoder.output_projection.weight', 'tgt_final_layer.weight'), - ]: - all_keys.remove(gl_name) - gluon_params[gl_name].set_data( - fairseq_params[fs_name].cpu().numpy()) assert len(all_keys) == 0, 'parameters missing from tensorflow checkpoint' # check parameters sharing if share_decoder_input_output_embed is true diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index f4e70d1bc7..fe61fab37e 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -51,7 +51,7 @@ @bart_cfg_reg.register() -def fair_bart_base(): +def bart_base(): cfg = CN() # Config for the bart base model cfg.MODEL = CN() @@ -64,10 +64,10 @@ def fair_bart_base(): cfg.MODEL.tie_weights = True cfg.MODEL.attention_dropout = 0.1 cfg.MODEL.activation_dropout = 0.0 - cfg.MODEL.dropout = 0.0 + cfg.MODEL.dropout = 0.1 cfg.MODEL.layer_norm_eps = 1E-5 cfg.MODEL.pooler_activation = 'tanh' - cfg.MODEL.layernorm_embedding = True + cfg.MODEL.data_norm = True cfg.MODEL.layout = 'NT' cfg.MODEL.dtype = 'float32' @@ -102,8 +102,8 @@ def fair_bart_base(): @bart_cfg_reg.register() -def fair_bart_large(): - cfg = fair_bart_base() +def bart_large(): + cfg = bart_base() cfg.defrost() cfg.MODEL.vocab_size = 50265 cfg.MODEL.ENCODER.units = 1024 @@ -122,14 +122,14 @@ def fair_bart_large(): PRETRAINED_URL = { 'fairseq_bart_base': { - 'cfg': fair_bart_base(), + 'cfg': bart_base(), 'merges': 'fairseq_bart_base/gpt2-396d4d8e.merges', 'vocab': 'fairseq_bart_base/gpt2-f4dedacb.vocab', 'params': 'fairseq_bart_base/model-6dea1e11.params', 'lowercase': False, }, 'fairseq_bart_large': { - 'cfg': fair_bart_large(), + 'cfg': bart_large(), 'merges': 'fairseq_bart_large/gpt2-396d4d8e.merges', 'vocab': 'fairseq_bart_large/gpt2-f1335494.vocab', 'params': 'fairseq_bart_large/model-38f35552.params', @@ -183,8 +183,9 @@ def __init__(self, in_units=self.units, flatten=False, activation=pooler_activation, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + weight_initializer=self.weight_initializer, + bias_initializer=self.bias_initializer, + dtype=self._dtype) def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length): """ @@ -242,7 +243,12 @@ def apply_pooling(self, sequence): return: Shape (batch_size, units) """ - outputs = sequence[:, 0, :] + if self._layout == 'NT': + outputs = sequence[:, 0, :] + elif self._layout == 'TN': + outputs = sequence[0, :, :] + else: + raise NotImplementedError if self.classifier_activation: return self.pooler(outputs) else: @@ -259,18 +265,20 @@ def vocab_size(self): @classmethod def get_cfg(cls, key=None): if key is None: - return fair_bart_base() + return bart_base() else: return bart_cfg_reg.create(key) @classmethod - def from_cfg(cls, cfg, + def from_cfg(cls, cfg, dtype=None, use_pooler=False, classifier_activation=False): cfg = cls.get_cfg().clone_merge(cfg) embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(src_vocab_size=cfg.MODEL.vocab_size, tgt_vocab_size=cfg.MODEL.vocab_size, max_src_length=cfg.MODEL.max_src_length, @@ -279,13 +287,13 @@ def from_cfg(cls, cfg, pos_embed_type=cfg.MODEL.pos_embed_type, shared_embed=cfg.MODEL.shared_embed, tie_weights=cfg.MODEL.tie_weights, + data_norm=cfg.MODEL.data_norm, use_pooler=use_pooler, attention_dropout=cfg.MODEL.attention_dropout, activation_dropout=cfg.MODEL.activation_dropout, dropout=cfg.MODEL.dropout, pooler_activation=cfg.MODEL.pooler_activation, layer_norm_eps=cfg.MODEL.layer_norm_eps, - layernorm_embedding=cfg.MODEL.layernorm_embedding, enc_num_layers=cfg.MODEL.ENCODER.num_layers, enc_units=cfg.MODEL.ENCODER.units, enc_num_heads=cfg.MODEL.ENCODER.num_heads, @@ -300,17 +308,18 @@ def from_cfg(cls, cfg, dec_recurrent=cfg.MODEL.DECODER.recurrent, dec_activation=cfg.MODEL.DECODER.activation, dec_pre_norm=cfg.MODEL.DECODER.pre_norm, + layout=cfg.MODEL.layout, embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - dtype=cfg.MODEL.dtype) + dtype=dtype) def list_pretrained_bart(): return sorted(list(PRETRAINED_URL.keys())) -def get_pretrained_bart(model_name: str = 'fairseq_roberta_base', +def get_pretrained_bart(model_name: str = 'fairseq_bart_base', root: str = get_model_zoo_home_dir(), load_backbone: bool = True) \ -> Tuple[CN, HuggingFaceByteBPETokenizer, str]: diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index ea66e27b23..8e55111e19 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -196,7 +196,7 @@ def __init__(self, bias_initializer=bias_initializer, dtype=self._dtype) attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' - self.self_attention =\ + self.attention_cell = \ MultiHeadAttentionCell( query_units=self._units, num_heads=self._num_heads, @@ -252,7 +252,7 @@ def hybrid_forward(self, F, data, attn_mask): query = F.npx.reshape(query, (-2, -2, self._num_heads, -1)) key = F.npx.reshape(key, (-2, -2, self._num_heads, -1)) value = F.npx.reshape(value, (-2, -2, self._num_heads, -1)) - out, [_, attn_weight] = self.self_attention(query, key, value, attn_mask) + out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask) out = self.attention_proj(out) out = self.dropout_layer(out) out = out + data @@ -261,7 +261,6 @@ def hybrid_forward(self, F, data, attn_mask): out = self.ffn(out) return out, attn_weight - @use_np class TransformerEncoder(HybridBlock): def __init__(self, num_layers=6, recurrent=False, @@ -441,7 +440,7 @@ def __init__(self, units: int = 512, num_heads=num_heads, attention_dropout=self._attention_dropout, dtype=dtype, - layout='NTK') + layout=attention_layout) self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -484,6 +483,9 @@ def __init__(self, units: int = 512, hidden_size=hidden_size, dropout=dropout, activation_dropout=activation_dropout, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + layer_norm_eps=layer_norm_eps, activation=activation, pre_norm=pre_norm, dtype=dtype) @@ -673,7 +675,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma step_value = F.npx.reshape(step_value, (-2, -2, self._num_heads, -1)) new_key = F.np.concatenate([prev_key, step_key], axis=time_axis) new_value = F.np.concatenate([prev_value, step_value], axis=time_axis) - out, _ = self.self_attention(step_query, new_key, new_value, None) + out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None) out = self.proj_in(out) out = self.dropout_layer(out) out = out + data @@ -914,7 +916,6 @@ def __init__(self, src_vocab_size: int, max_tgt_length: Optional[int] = None, scale_embed: bool = True, pos_embed_type="sinusoidal", - layernorm_embedding: bool = False, shared_embed: bool = True, tie_weights: bool = True, activation_dropout: float = 0.0, @@ -959,8 +960,6 @@ def __init__(self, src_vocab_size: int, Whether to multiply the src and dst embeddings by sqrt(units) pos_embed_type Type of the positional embedding - layernorm_embedding - Wether to layer normalize the embedding shared_embed Whether to share the embedding of the src and tgt language tie_weights @@ -1027,11 +1026,11 @@ def __init__(self, src_vocab_size: int, self._tgt_vocab_size = tgt_vocab_size self.tie_weights = tie_weights self.pos_embed_type = pos_embed_type - self.layernorm_embedding = layernorm_embedding self.scaled_embed = scale_embed self.enc_units = enc_units self.dec_units = dec_units self.weight_initializer = weight_initializer + self.bias_initializer = bias_initializer self._layout = layout assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ 'Only "TN" and "NT" are accepted!'.format(layout) @@ -1062,11 +1061,6 @@ def __init__(self, src_vocab_size: int, max_length=max_tgt_length, dtype=self._dtype, method=pos_embed_type) - if layernorm_embedding: - self.src_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps, - in_channels=enc_units) - self.tgt_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps, - in_channels=dec_units) self.encoder = TransformerEncoder(num_layers=enc_num_layers, recurrent=enc_recurrent, units=enc_units, @@ -1163,6 +1157,7 @@ def encode(self, F, src_data, src_valid_length): else: src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer( F.npx.arange_like(src_data, axis=0)), axis=1) + enc_out = self.encoder(src_data, src_valid_length) return enc_out @@ -1205,6 +1200,7 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length): else: tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer( F.npx.arange_like(tgt_data, axis=0)), axis=1) + dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length) return dec_out