Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
add test_models_bart
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent a5a91e0 commit 3366cf3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


@bart_cfg_reg.register()
def bart_base():
def fair_bart_base():
cfg = CN()
# Config for the bart base model
cfg.MODEL = CN()
Expand Down Expand Up @@ -101,8 +101,8 @@ def bart_base():


@bart_cfg_reg.register()
def bart_large():
cfg = bart_base()
def fair_bart_large():
cfg = fair_bart_base()
cfg.defrost()
cfg.MODEL.vocab_size = 50265
cfg.MODEL.ENCODER.units = 1024
Expand All @@ -121,14 +121,14 @@ def bart_large():

PRETRAINED_URL = {
'fairseq_bart_base': {
'cfg': bart_base(),
'cfg': fair_bart_base(),
'merges': 'fairseq_bart_base/gpt2-396d4d8e.merges',
'vocab': 'fairseq_bart_base/gpt2-f4dedacb.vocab',
'params': 'fairseq_bart_base/model-6dea1e11.params',
'lowercase': False,
},
'fairseq_bart_large': {
'cfg': bart_large(),
'cfg': fair_bart_large(),
'merges': 'fairseq_bart_large/gpt2-396d4d8e.merges',
'vocab': 'fairseq_bart_large/gpt2-f1335494.vocab',
'params': 'fairseq_bart_large/model-38f35552.params',
Expand Down Expand Up @@ -209,7 +209,7 @@ def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_leng
dec_out = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length)
if self.use_pooler:
pooled_out = self.apply_pooling(dec_out)
return pooled_out
return dec_out, pooled_out
else:
dec_out = self.tgt_final_layer(dec_out)
return dec_out
Expand Down Expand Up @@ -238,7 +238,7 @@ def vocab_size(self):
@classmethod
def get_cfg(cls, key=None):
if key is None:
return bart_base()
return fair_bart_base()
else:
return bart_cfg_reg.create(key)

Expand Down
63 changes: 63 additions & 0 deletions tests/test_models_bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
import numpy as np
import mxnet as mx
import tempfile
from gluonnlp.models.bart import BartModel, \
list_pretrained_bart, get_pretrained_bart
from gluonnlp.loss import LabelSmoothCrossEntropyLoss

mx.npx.set_np()


def test_list_pretrained_bart():
assert len(list_pretrained_bart()) > 0


@pytest.mark.remote_required
@pytest.mark.parametrize('model_name', list_pretrained_bart())
def test_bart(model_name):
# test from pretrained
assert len(list_pretrained_bart()) > 0
with tempfile.TemporaryDirectory() as root:
cfg, tokenizer, params_path, _ =\
get_pretrained_bart(model_name, load_backbone=True, root=root)
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
# test backbone
bart_model = BartModel.from_cfg(cfg)
bart_model.load_parameters(params_path)
# test mlm model
bart_model_with_pooler = BartModel.from_cfg(
cfg, use_pooler=True, classifier_activation=False)
bart_model_with_pooler.load_parameters(params_path)

# test forward
batch_size = 3
seq_length = 32
vocab_size = len(tokenizer.vocab)
input_ids = mx.np.array(
np.random.randint(
2,
vocab_size,
(batch_size, seq_length)
),
dtype=np.int32
)
valid_length = mx.np.array(
np.random.randint(
seq_length // 2,
seq_length,
(batch_size,)
),
dtype=np.int32
)
contextual_embeddings, pooled_out = bart_model_with_pooler(
input_ids, valid_length, input_ids, valid_length)
mx.npx.waitall()
# test backward
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
with mx.autograd.record():
contextual_embeddings, pooled_out = bart_model_with_pooler(
input_ids, valid_length, input_ids, valid_length)
loss = label_smooth_loss(contextual_embeddings, input_ids)
loss.backward()
mx.npx.waitall()

0 comments on commit 3366cf3

Please sign in to comment.