Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
test bart
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent 5bab516 commit 1f75b26
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 74 deletions.
45 changes: 33 additions & 12 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def fair_bart_base():
cfg.MODEL.layer_norm_eps = 1E-5
cfg.MODEL.pooler_activation = 'tanh'
cfg.MODEL.layernorm_embedding = True
cfg.MODEL.layout = 'NT'
cfg.MODEL.dtype = 'float32'

# Parameters for the encoder
Expand Down Expand Up @@ -191,27 +192,43 @@ def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_leng
Parameters
----------
F
src_data :
Shape (batch_size, src_length)
src_valid_length :
src_data
- layout = 'NT'
Shape (batch_size, src_length)
- layout = 'TN'
Shape (src_length, batch_size)
src_valid_length
Shape (batch_size,)
tgt_data :
Shape (batch_size, tgt_length)
tgt_valid_length :
tgt_data
- layout = 'NT'
Shape (batch_size, tgt_length)
- layout = 'TN'
Shape (tgt_length, batch_size)
tgt_valid_length
Shape (batch_size,)
Returns
-------
out :
Shape (batch_size, tgt_length, tgt_vocab_size)
(contextual_embedding)
- layout = 'NT'
Shape (batch_size, tgt_length, units)
- layout = 'TN'
Shape (tgt_length, batch_size, units)
(pooled_output)
This is optional. Shape (batch_size, units)
(dec_out)
- layout = 'NT'
Shape (batch_size, tgt_length, tgt_vocab_size)
- layout = 'TN'
Shape (tgt_length, batch_size, tgt_vocab_size)
"""
enc_out = self.encode(F, src_data, src_valid_length)
dec_out = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length)
contextual_embedding = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length)
if self.use_pooler:
pooled_out = self.apply_pooling(dec_out)
return dec_out, pooled_out
pooled_output = self.apply_pooling(contextual_embedding)
return contextual_embedding, pooled_output
else:
dec_out = self.tgt_final_layer(dec_out)
dec_out = self.tgt_final_layer(contextual_embedding)
return dec_out

def apply_pooling(self, sequence):
Expand All @@ -231,6 +248,10 @@ def apply_pooling(self, sequence):
else:
return outputs

@property
def layout(self) -> str:
return self._layout

@property
def vocab_size(self):
return self._vocab_size
Expand Down
1 change: 0 additions & 1 deletion src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,6 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length):
tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer(
F.npx.arange_like(tgt_data, axis=0)), axis=1)
dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length)
dec_out = self.tgt_final_layer(dec_out)
return dec_out

def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length):
Expand Down
57 changes: 23 additions & 34 deletions tests/test_models_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import mxnet as mx
import tempfile
from gluonnlp.models.bart import BartModel, \
list_pretrained_bart, get_pretrained_bart
from gluonnlp.loss import LabelSmoothCrossEntropyLoss
list_pretrained_bart, get_pretrained_bart, bart_cfg_reg
from gluonnlp.utils.testing import verify_nmt_model

mx.npx.set_np()

Expand All @@ -22,42 +22,31 @@ def test_bart(model_name):
cfg, tokenizer, params_path, _ =\
get_pretrained_bart(model_name, load_backbone=True, root=root)
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
# test backbone
# test standard bart encoder and decoder
bart_model = BartModel.from_cfg(cfg)
bart_model.load_parameters(params_path)
# test mlm model
# test bart encoder and decoder with pooler
bart_model_with_pooler = BartModel.from_cfg(
cfg, use_pooler=True, classifier_activation=False)
bart_model_with_pooler.load_parameters(params_path)

# test forward
batch_size = 3
seq_length = 32
vocab_size = len(tokenizer.vocab)
input_ids = mx.np.array(
np.random.randint(
2,
vocab_size,
(batch_size, seq_length)
),
dtype=np.int32
)
valid_length = mx.np.array(
np.random.randint(
seq_length // 2,
seq_length,
(batch_size,)
),
dtype=np.int32
)
contextual_embeddings, pooled_out = bart_model_with_pooler(
input_ids, valid_length, input_ids, valid_length)
mx.npx.waitall()
# test backward
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
with mx.autograd.record():
contextual_embeddings, pooled_out = bart_model_with_pooler(
input_ids, valid_length, input_ids, valid_length)
loss = label_smooth_loss(contextual_embeddings, input_ids)
loss.backward()

def test_bart_cfg_registry():
assert len(bart_cfg_reg.list_keys()) > 0

@pytest.mark.parametrize('cfg_key', bart_cfg_reg.list_keys())
def test_bart_cfg(cfg_key):
cfg = BartModel.get_cfg(cfg_key)
cfg.defrost()
cfg.MODEL.vocab_size = 32
cfg.freeze()
model = BartModel.from_cfg(cfg)
model.initialize()
model.hybridize()
cfg.defrost()
cfg.MODEL.layout = 'TN'
cfg.freeze()
model_tn = BartModel.from_cfg(cfg)
model_tn.share_parameters(model.collect_params())
model_tn.hybridize()
mx.npx.waitall()
54 changes: 27 additions & 27 deletions tests/test_models_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from numpy.testing import assert_allclose
from gluonnlp.models.transformer import\
TransformerEncoder, TransformerDecoder, \
TransformerNMTModel, TransformerNMTInference,\
transformer_nmt_cfg_reg
TransformerModel, TransformerNMTInference,\
transformer_cfg_reg
from gluonnlp.attention_cell import gen_mem_attn_mask, gen_self_attn_mask
from gluonnlp.utils.testing import verify_nmt_model, verify_nmt_inference
mx.npx.set_np()
Expand Down Expand Up @@ -117,26 +117,26 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize,
shared_embed = False
else:
shared_embed = True
model = TransformerNMTModel(src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
max_src_length=src_seq_length,
max_tgt_length=tgt_seq_length,
enc_units=enc_units,
enc_hidden_size=64,
enc_num_heads=4,
enc_num_layers=enc_num_layers,
enc_pre_norm=enc_pre_norm,
enc_recurrent=enc_recurrent,
dec_units=dec_units,
dec_hidden_size=64,
dec_num_heads=4,
dec_num_layers=dec_num_layers,
dec_pre_norm=dec_pre_norm,
dec_recurrent=dec_recurrent,
shared_embed=shared_embed,
tie_weights=tie_weights,
dropout=0.0,
layout=layout)
model = TransformerModel(src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
max_src_length=src_seq_length,
max_tgt_length=tgt_seq_length,
enc_units=enc_units,
enc_hidden_size=64,
enc_num_heads=4,
enc_num_layers=enc_num_layers,
enc_pre_norm=enc_pre_norm,
enc_recurrent=enc_recurrent,
dec_units=dec_units,
dec_hidden_size=64,
dec_num_heads=4,
dec_num_layers=dec_num_layers,
dec_pre_norm=dec_pre_norm,
dec_recurrent=dec_recurrent,
shared_embed=shared_embed,
tie_weights=tie_weights,
dropout=0.0,
layout=layout)
inference_model = TransformerNMTInference(model=model)
model.initialize()
if train_hybridize:
Expand All @@ -148,23 +148,23 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize,


def test_transformer_cfg_registry():
assert len(transformer_nmt_cfg_reg.list_keys()) > 0
assert len(transformer_cfg_reg.list_keys()) > 0


@pytest.mark.parametrize('cfg_key', transformer_nmt_cfg_reg.list_keys())
@pytest.mark.parametrize('cfg_key', transformer_cfg_reg.list_keys())
def test_transformer_cfg(cfg_key):
cfg = TransformerNMTModel.get_cfg(cfg_key)
cfg = TransformerModel.get_cfg(cfg_key)
cfg.defrost()
cfg.MODEL.src_vocab_size = 32
cfg.MODEL.tgt_vocab_size = 32
cfg.freeze()
model = TransformerNMTModel.from_cfg(cfg)
model = TransformerModel.from_cfg(cfg)
model.initialize()
model.hybridize()
cfg.defrost()
cfg.MODEL.layout = 'TN'
cfg.freeze()
model_tn = TransformerNMTModel.from_cfg(cfg)
model_tn = TransformerModel.from_cfg(cfg)
model_tn.share_parameters(model.collect_params())
model_tn.hybridize()
mx.npx.waitall()

0 comments on commit 1f75b26

Please sign in to comment.