Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 510d991
Author: ZheyuYe <[email protected]>
Date:   Thu Jul 30 02:33:22 2020 +0800

    test

commit 1b5fa7b
Author: ZheyuYe <[email protected]>
Date:   Thu Jul 30 01:48:01 2020 +0800

    fix comment1

commit 6533601
Author: ZheyuYe <[email protected]>
Date:   Thu Jul 30 01:27:44 2020 +0800

    fix comment

commit a8853f9
Author: ZheyuYe <[email protected]>
Date:   Thu Jul 30 01:10:06 2020 +0800

    Squashed commit of the following:

    commit 232e0b6
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 30 01:05:17 2020 +0800

        update

    commit 995e5d7
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 30 01:01:56 2020 +0800

        fix

    commit 9623240
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 30 00:52:17 2020 +0800

        fix

    commit d9c4140
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 29 23:07:10 2020 +0800

        fix transformer

    commit e49fbe1
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 29 22:18:12 2020 +0800

        update

    commit 1f75b26
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 29 22:04:08 2020 +0800

        test bart

    commit 5bab516
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 29 21:34:47 2020 +0800

        fix cfg

    commit 6c62a29
    Merge: 3366cf3 033214e
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 29 21:33:10 2020 +0800

        Merge remote-tracking branch 'upstream/numpy' into bart

    commit 033214e
    Author: Xingjian Shi <[email protected]>
    Date:   Wed Jul 29 00:36:57 2020 -0700

        [Numpy] Fix SQuAD + Fix GLUE downloading (dmlc#1280)

        * Update run_squad.py

        * Update run_squad.py

        * Update prepare_glue.py

    commit 3c87457
    Author: Xingjian Shi <[email protected]>
    Date:   Tue Jul 28 18:03:21 2020 -0700

        Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (dmlc#1258)

        * Add layout support

        * fix test

        * Update transformer.py

        * Update transformer.py

        * Update README.md

        * try to add set_layout

        * update test case

        * fix

        * update

        * update

        * update

        * Update bert.py

        * fix bug

        * update

        * Update test_models_bert.py

        * Update tokenizers.py

        * add compute layout

        * Update xlmr.py

        * Update test_models_bert.py

        * revise test cases

        * Update layers.py

        * move jieba to try import

        * fix

        * Update transformer.py

        * fix

        * Update bert.py

        * Update setup.py

        * Update test_models_bert.py

        * Update test_models_bert.py

        * fix

        * update

        * Revise

        * Update electra.py

        * Update electra.py

        * Update test_models_electra.py

        * fix

        * fix bug

        * Update test_models_albert.py

        * add more testcases

        * fix

        * Update albert.py

        * Update albert.py

        * fix bug

        * fix testcase

        * Update test_models_electra.py

        * Update bert.py

        * update

        * Update test_models_electra.py

        * Update mobilebert.py

        * Update mobilebert.py

        * update mobilebert

        * Update test_models_mobilebert.py

        * Update mobilebert.py

        * fix bug

        * Update roberta.py

        * fix roberta

        * update

        * update

        * fix import

        * fix bug

        * update

        * reduce test workloads

        * address comment

        * address comment

    commit 4d43f82
    Author: Sheng Zha <[email protected]>
    Date:   Mon Jul 27 20:21:00 2020 -0700

        add subversion/wget to docker, add readme (dmlc#1279)

    commit d76897b
    Author: phile <[email protected]>
    Date:   Tue Jul 28 10:10:13 2020 +0800

        Add embedding related methods in numpy version (dmlc#1263)

        * A draft for embedding

        * fix embed_loader

        * add hyperbolic space and some updates

        * revise evaluation

        * fix

        * simple fixes

        * move l2norm to op.py

        * new features

        * fix

        * update

        * add tests, update

        * newline
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent d9c4140 commit a53b9f4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 59 deletions.
55 changes: 25 additions & 30 deletions scripts/conversion_toolkits/convert_fairseq_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def convert_config(fairseq_cfg, vocab_size, cfg):
cfg.MODEL.shared_embed = fairseq_cfg.share_all_embeddings
cfg.MODEL.scale_embed = not fairseq_cfg.no_scale_embedding
cfg.MODEL.tie_weights = fairseq_cfg.share_decoder_input_output_embed
cfg.MODEL.layernorm_embedding = fairseq_cfg.layernorm_embedding
cfg.MODEL.data_norm = fairseq_cfg.layernorm_embedding
cfg.MODEL.pooler_activation = fairseq_cfg.pooler_activation_fn
cfg.MODEL.layer_norm_eps = 1E-5
cfg.MODEL.dropout = fairseq_cfg.dropout
Expand Down Expand Up @@ -111,26 +111,6 @@ def convert_attention(num_layers,
gl_qkv_bias.set_data(
np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0))

def convert_embeddings(fairseq_prefix, gluon_prefix):
for k, v in [
('.embed_tokens.weight', '_embed_layer.weight'),
('.layernorm_embedding.weight', '_embed_ln.gamma'),
('.layernorm_embedding.bias', '_embed_ln.beta'),
]:
fs_name = fairseq_prefix + k
gl_name = gluon_prefix + v
all_keys.remove(gl_name)
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

# position embed weight
padding_idx = fairseq_model.task.dictionary.pad_index
fs_pos_embed_name = fairseq_prefix + '.embed_positions.weight'
gl_pos_embed_name = gluon_prefix + '_pos_embed_layer._embed.weight'
all_keys.remove(gl_pos_embed_name)
gluon_params[gl_pos_embed_name].set_data(
fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:, :])

def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
# convert feed forward layer in encoder
for layer_id in range(num_layers):
Expand All @@ -150,11 +130,33 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

print('converting embedding params')
padding_idx = fairseq_model.task.dictionary.pad_index
for fs_name, gl_name in [
('model.encoder.embed_tokens.weight', 'src_embed_layer.weight'),
('model.encoder.embed_positions.weight', 'src_pos_embed_layer._embed.weight'),
('model.encoder.layernorm_embedding.weight', 'encoder.ln_data.gamma'),
('model.encoder.layernorm_embedding.bias', 'encoder.ln_data.beta'),
('model.decoder.embed_tokens.weight', 'tgt_embed_layer.weight'),
('model.decoder.embed_positions.weight', 'tgt_pos_embed_layer._embed.weight'),
('model.decoder.layernorm_embedding.weight', 'decoder.ln_data.gamma'),
('model.decoder.layernorm_embedding.bias', 'decoder.ln_data.beta'),
# final projection in decoder
('model.decoder.output_projection.weight', 'tgt_final_layer.weight'),
]:
all_keys.remove(gl_name)
if 'embed_positions' in fs_name:
# position embed weight
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy()[padding_idx + 1:, :])
else:
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

print('converting encoder params')
encoder_num_layers = gluon_cfg.MODEL.ENCODER.num_layers
convert_attention(encoder_num_layers, 'model.encoder', 'encoder')
convert_ffn(encoder_num_layers, 'model.encoder', 'encoder')
convert_embeddings('model.encoder', 'src')
for layer_id in range(encoder_num_layers):
for k, v in [
('self_attn.out_proj.weight', 'attention_proj.weight'),
Expand All @@ -170,6 +172,7 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

print('converting decoder params')
decoder_num_layers = gluon_cfg.MODEL.DECODER.num_layers
convert_attention(decoder_num_layers, 'model.decoder', 'decoder',
gluon_attn_prefix='attn_in_qkv')
Expand Down Expand Up @@ -201,14 +204,6 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

convert_embeddings('model.decoder', 'tgt')
# final projection in decoder
for fs_name, gl_name in [
('model.decoder.output_projection.weight', 'tgt_final_layer.weight'),
]:
all_keys.remove(gl_name)
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())
assert len(all_keys) == 0, 'parameters missing from tensorflow checkpoint'

# check parameters sharing if share_decoder_input_output_embed is true
Expand Down
39 changes: 24 additions & 15 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


@bart_cfg_reg.register()
def fair_bart_base():
def bart_base():
cfg = CN()
# Config for the bart base model
cfg.MODEL = CN()
Expand All @@ -64,10 +64,10 @@ def fair_bart_base():
cfg.MODEL.tie_weights = True
cfg.MODEL.attention_dropout = 0.1
cfg.MODEL.activation_dropout = 0.0
cfg.MODEL.dropout = 0.0
cfg.MODEL.dropout = 0.1
cfg.MODEL.layer_norm_eps = 1E-5
cfg.MODEL.pooler_activation = 'tanh'
cfg.MODEL.layernorm_embedding = True
cfg.MODEL.data_norm = True
cfg.MODEL.layout = 'NT'
cfg.MODEL.dtype = 'float32'

Expand Down Expand Up @@ -102,8 +102,8 @@ def fair_bart_base():


@bart_cfg_reg.register()
def fair_bart_large():
cfg = fair_bart_base()
def bart_large():
cfg = bart_base()
cfg.defrost()
cfg.MODEL.vocab_size = 50265
cfg.MODEL.ENCODER.units = 1024
Expand All @@ -122,14 +122,14 @@ def fair_bart_large():

PRETRAINED_URL = {
'fairseq_bart_base': {
'cfg': fair_bart_base(),
'cfg': bart_base(),
'merges': 'fairseq_bart_base/gpt2-396d4d8e.merges',
'vocab': 'fairseq_bart_base/gpt2-f4dedacb.vocab',
'params': 'fairseq_bart_base/model-6dea1e11.params',
'lowercase': False,
},
'fairseq_bart_large': {
'cfg': fair_bart_large(),
'cfg': bart_large(),
'merges': 'fairseq_bart_large/gpt2-396d4d8e.merges',
'vocab': 'fairseq_bart_large/gpt2-f1335494.vocab',
'params': 'fairseq_bart_large/model-38f35552.params',
Expand Down Expand Up @@ -183,8 +183,9 @@ def __init__(self,
in_units=self.units,
flatten=False,
activation=pooler_activation,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
weight_initializer=self.weight_initializer,
bias_initializer=self.bias_initializer,
dtype=self._dtype)

def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length):
"""
Expand Down Expand Up @@ -242,7 +243,12 @@ def apply_pooling(self, sequence):
return:
Shape (batch_size, units)
"""
outputs = sequence[:, 0, :]
if self._layout == 'NT':
outputs = sequence[:, 0, :]
elif self._layout == 'TN':
outputs = sequence[0, :, :]
else:
raise NotImplementedError
if self.classifier_activation:
return self.pooler(outputs)
else:
Expand All @@ -259,18 +265,20 @@ def vocab_size(self):
@classmethod
def get_cfg(cls, key=None):
if key is None:
return fair_bart_base()
return bart_base()
else:
return bart_cfg_reg.create(key)

@classmethod
def from_cfg(cls, cfg,
def from_cfg(cls, cfg, dtype=None,
use_pooler=False,
classifier_activation=False):
cfg = cls.get_cfg().clone_merge(cfg)
embed_initializer = mx.init.create(*cfg.INITIALIZER.embed)
weight_initializer = mx.init.create(*cfg.INITIALIZER.weight)
bias_initializer = mx.init.create(*cfg.INITIALIZER.bias)
if dtype is None:
dtype = cfg.MODEL.dtype
return cls(src_vocab_size=cfg.MODEL.vocab_size,
tgt_vocab_size=cfg.MODEL.vocab_size,
max_src_length=cfg.MODEL.max_src_length,
Expand All @@ -279,13 +287,13 @@ def from_cfg(cls, cfg,
pos_embed_type=cfg.MODEL.pos_embed_type,
shared_embed=cfg.MODEL.shared_embed,
tie_weights=cfg.MODEL.tie_weights,
data_norm=cfg.MODEL.data_norm,
use_pooler=use_pooler,
attention_dropout=cfg.MODEL.attention_dropout,
activation_dropout=cfg.MODEL.activation_dropout,
dropout=cfg.MODEL.dropout,
pooler_activation=cfg.MODEL.pooler_activation,
layer_norm_eps=cfg.MODEL.layer_norm_eps,
layernorm_embedding=cfg.MODEL.layernorm_embedding,
enc_num_layers=cfg.MODEL.ENCODER.num_layers,
enc_units=cfg.MODEL.ENCODER.units,
enc_num_heads=cfg.MODEL.ENCODER.num_heads,
Expand All @@ -300,17 +308,18 @@ def from_cfg(cls, cfg,
dec_recurrent=cfg.MODEL.DECODER.recurrent,
dec_activation=cfg.MODEL.DECODER.activation,
dec_pre_norm=cfg.MODEL.DECODER.pre_norm,
layout=cfg.MODEL.layout,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=cfg.MODEL.dtype)
dtype=dtype)


def list_pretrained_bart():
return sorted(list(PRETRAINED_URL.keys()))


def get_pretrained_bart(model_name: str = 'fairseq_roberta_base',
def get_pretrained_bart(model_name: str = 'fairseq_bart_base',
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
Expand Down
24 changes: 10 additions & 14 deletions src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(self,
bias_initializer=bias_initializer,
dtype=self._dtype)
attention_layout = 'NTK' if self._layout == 'NT' else 'TNK'
self.self_attention =\
self.attention_cell = \
MultiHeadAttentionCell(
query_units=self._units,
num_heads=self._num_heads,
Expand Down Expand Up @@ -252,7 +252,7 @@ def hybrid_forward(self, F, data, attn_mask):
query = F.npx.reshape(query, (-2, -2, self._num_heads, -1))
key = F.npx.reshape(key, (-2, -2, self._num_heads, -1))
value = F.npx.reshape(value, (-2, -2, self._num_heads, -1))
out, [_, attn_weight] = self.self_attention(query, key, value, attn_mask)
out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask)
out = self.attention_proj(out)
out = self.dropout_layer(out)
out = out + data
Expand All @@ -261,7 +261,6 @@ def hybrid_forward(self, F, data, attn_mask):
out = self.ffn(out)
return out, attn_weight


@use_np
class TransformerEncoder(HybridBlock):
def __init__(self, num_layers=6, recurrent=False,
Expand Down Expand Up @@ -441,7 +440,7 @@ def __init__(self, units: int = 512,
num_heads=num_heads,
attention_dropout=self._attention_dropout,
dtype=dtype,
layout='NTK')
layout=attention_layout)
self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
Expand Down Expand Up @@ -484,6 +483,9 @@ def __init__(self, units: int = 512,
hidden_size=hidden_size,
dropout=dropout,
activation_dropout=activation_dropout,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
layer_norm_eps=layer_norm_eps,
activation=activation,
pre_norm=pre_norm,
dtype=dtype)
Expand Down Expand Up @@ -673,7 +675,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma
step_value = F.npx.reshape(step_value, (-2, -2, self._num_heads, -1))
new_key = F.np.concatenate([prev_key, step_key], axis=time_axis)
new_value = F.np.concatenate([prev_value, step_value], axis=time_axis)
out, _ = self.self_attention(step_query, new_key, new_value, None)
out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None)
out = self.proj_in(out)
out = self.dropout_layer(out)
out = out + data
Expand Down Expand Up @@ -914,7 +916,6 @@ def __init__(self, src_vocab_size: int,
max_tgt_length: Optional[int] = None,
scale_embed: bool = True,
pos_embed_type="sinusoidal",
layernorm_embedding: bool = False,
shared_embed: bool = True,
tie_weights: bool = True,
activation_dropout: float = 0.0,
Expand Down Expand Up @@ -959,8 +960,6 @@ def __init__(self, src_vocab_size: int,
Whether to multiply the src and dst embeddings by sqrt(units)
pos_embed_type
Type of the positional embedding
layernorm_embedding
Wether to layer normalize the embedding
shared_embed
Whether to share the embedding of the src and tgt language
tie_weights
Expand Down Expand Up @@ -1027,11 +1026,11 @@ def __init__(self, src_vocab_size: int,
self._tgt_vocab_size = tgt_vocab_size
self.tie_weights = tie_weights
self.pos_embed_type = pos_embed_type
self.layernorm_embedding = layernorm_embedding
self.scaled_embed = scale_embed
self.enc_units = enc_units
self.dec_units = dec_units
self.weight_initializer = weight_initializer
self.bias_initializer = bias_initializer
self._layout = layout
assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \
'Only "TN" and "NT" are accepted!'.format(layout)
Expand Down Expand Up @@ -1062,11 +1061,6 @@ def __init__(self, src_vocab_size: int,
max_length=max_tgt_length,
dtype=self._dtype,
method=pos_embed_type)
if layernorm_embedding:
self.src_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps,
in_channels=enc_units)
self.tgt_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps,
in_channels=dec_units)
self.encoder = TransformerEncoder(num_layers=enc_num_layers,
recurrent=enc_recurrent,
units=enc_units,
Expand Down Expand Up @@ -1163,6 +1157,7 @@ def encode(self, F, src_data, src_valid_length):
else:
src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer(
F.npx.arange_like(src_data, axis=0)), axis=1)

enc_out = self.encoder(src_data, src_valid_length)
return enc_out

Expand Down Expand Up @@ -1205,6 +1200,7 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length):
else:
tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer(
F.npx.arange_like(tgt_data, axis=0)), axis=1)

dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length)
return dec_out

Expand Down

0 comments on commit a53b9f4

Please sign in to comment.