diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index b6f7a7c32d..f4e70d1bc7 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -68,6 +68,7 @@ def fair_bart_base(): cfg.MODEL.layer_norm_eps = 1E-5 cfg.MODEL.pooler_activation = 'tanh' cfg.MODEL.layernorm_embedding = True + cfg.MODEL.layout = 'NT' cfg.MODEL.dtype = 'float32' # Parameters for the encoder @@ -191,27 +192,43 @@ def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_leng Parameters ---------- F - src_data : - Shape (batch_size, src_length) - src_valid_length : + src_data + - layout = 'NT' + Shape (batch_size, src_length) + - layout = 'TN' + Shape (src_length, batch_size) + src_valid_length Shape (batch_size,) - tgt_data : - Shape (batch_size, tgt_length) - tgt_valid_length : + tgt_data + - layout = 'NT' + Shape (batch_size, tgt_length) + - layout = 'TN' + Shape (tgt_length, batch_size) + tgt_valid_length Shape (batch_size,) Returns ------- - out : - Shape (batch_size, tgt_length, tgt_vocab_size) + (contextual_embedding) + - layout = 'NT' + Shape (batch_size, tgt_length, units) + - layout = 'TN' + Shape (tgt_length, batch_size, units) + (pooled_output) + This is optional. Shape (batch_size, units) + (dec_out) + - layout = 'NT' + Shape (batch_size, tgt_length, tgt_vocab_size) + - layout = 'TN' + Shape (tgt_length, batch_size, tgt_vocab_size) """ enc_out = self.encode(F, src_data, src_valid_length) - dec_out = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length) + contextual_embedding = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length) if self.use_pooler: - pooled_out = self.apply_pooling(dec_out) - return dec_out, pooled_out + pooled_output = self.apply_pooling(contextual_embedding) + return contextual_embedding, pooled_output else: - dec_out = self.tgt_final_layer(dec_out) + dec_out = self.tgt_final_layer(contextual_embedding) return dec_out def apply_pooling(self, sequence): @@ -231,6 +248,10 @@ def apply_pooling(self, sequence): else: return outputs + @property + def layout(self) -> str: + return self._layout + @property def vocab_size(self): return self._vocab_size diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 063dd9c37b..971ba2af13 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -1206,7 +1206,6 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length): tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer( F.npx.arange_like(tgt_data, axis=0)), axis=1) dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length) - dec_out = self.tgt_final_layer(dec_out) return dec_out def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length): diff --git a/tests/test_models_bart.py b/tests/test_models_bart.py index 9ff66f7c0b..d6130b63fb 100644 --- a/tests/test_models_bart.py +++ b/tests/test_models_bart.py @@ -3,8 +3,8 @@ import mxnet as mx import tempfile from gluonnlp.models.bart import BartModel, \ - list_pretrained_bart, get_pretrained_bart -from gluonnlp.loss import LabelSmoothCrossEntropyLoss + list_pretrained_bart, get_pretrained_bart, bart_cfg_reg +from gluonnlp.utils.testing import verify_nmt_model mx.npx.set_np() @@ -22,42 +22,31 @@ def test_bart(model_name): cfg, tokenizer, params_path, _ =\ get_pretrained_bart(model_name, load_backbone=True, root=root) assert cfg.MODEL.vocab_size == len(tokenizer.vocab) - # test backbone + # test standard bart encoder and decoder bart_model = BartModel.from_cfg(cfg) bart_model.load_parameters(params_path) - # test mlm model + # test bart encoder and decoder with pooler bart_model_with_pooler = BartModel.from_cfg( cfg, use_pooler=True, classifier_activation=False) bart_model_with_pooler.load_parameters(params_path) - # test forward - batch_size = 3 - seq_length = 32 - vocab_size = len(tokenizer.vocab) - input_ids = mx.np.array( - np.random.randint( - 2, - vocab_size, - (batch_size, seq_length) - ), - dtype=np.int32 - ) - valid_length = mx.np.array( - np.random.randint( - seq_length // 2, - seq_length, - (batch_size,) - ), - dtype=np.int32 - ) - contextual_embeddings, pooled_out = bart_model_with_pooler( - input_ids, valid_length, input_ids, valid_length) - mx.npx.waitall() - # test backward - label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size) - with mx.autograd.record(): - contextual_embeddings, pooled_out = bart_model_with_pooler( - input_ids, valid_length, input_ids, valid_length) - loss = label_smooth_loss(contextual_embeddings, input_ids) - loss.backward() + +def test_bart_cfg_registry(): + assert len(bart_cfg_reg.list_keys()) > 0 + +@pytest.mark.parametrize('cfg_key', bart_cfg_reg.list_keys()) +def test_bart_cfg(cfg_key): + cfg = BartModel.get_cfg(cfg_key) + cfg.defrost() + cfg.MODEL.vocab_size = 32 + cfg.freeze() + model = BartModel.from_cfg(cfg) + model.initialize() + model.hybridize() + cfg.defrost() + cfg.MODEL.layout = 'TN' + cfg.freeze() + model_tn = BartModel.from_cfg(cfg) + model_tn.share_parameters(model.collect_params()) + model_tn.hybridize() mx.npx.waitall() diff --git a/tests/test_models_transformer.py b/tests/test_models_transformer.py index e9b1cd6184..96cb60ee1d 100644 --- a/tests/test_models_transformer.py +++ b/tests/test_models_transformer.py @@ -3,8 +3,8 @@ from numpy.testing import assert_allclose from gluonnlp.models.transformer import\ TransformerEncoder, TransformerDecoder, \ - TransformerNMTModel, TransformerNMTInference,\ - transformer_nmt_cfg_reg + TransformerModel, TransformerNMTInference,\ + transformer_cfg_reg from gluonnlp.attention_cell import gen_mem_attn_mask, gen_self_attn_mask from gluonnlp.utils.testing import verify_nmt_model, verify_nmt_inference mx.npx.set_np() @@ -117,26 +117,26 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize, shared_embed = False else: shared_embed = True - model = TransformerNMTModel(src_vocab_size=src_vocab_size, - tgt_vocab_size=tgt_vocab_size, - max_src_length=src_seq_length, - max_tgt_length=tgt_seq_length, - enc_units=enc_units, - enc_hidden_size=64, - enc_num_heads=4, - enc_num_layers=enc_num_layers, - enc_pre_norm=enc_pre_norm, - enc_recurrent=enc_recurrent, - dec_units=dec_units, - dec_hidden_size=64, - dec_num_heads=4, - dec_num_layers=dec_num_layers, - dec_pre_norm=dec_pre_norm, - dec_recurrent=dec_recurrent, - shared_embed=shared_embed, - tie_weights=tie_weights, - dropout=0.0, - layout=layout) + model = TransformerModel(src_vocab_size=src_vocab_size, + tgt_vocab_size=tgt_vocab_size, + max_src_length=src_seq_length, + max_tgt_length=tgt_seq_length, + enc_units=enc_units, + enc_hidden_size=64, + enc_num_heads=4, + enc_num_layers=enc_num_layers, + enc_pre_norm=enc_pre_norm, + enc_recurrent=enc_recurrent, + dec_units=dec_units, + dec_hidden_size=64, + dec_num_heads=4, + dec_num_layers=dec_num_layers, + dec_pre_norm=dec_pre_norm, + dec_recurrent=dec_recurrent, + shared_embed=shared_embed, + tie_weights=tie_weights, + dropout=0.0, + layout=layout) inference_model = TransformerNMTInference(model=model) model.initialize() if train_hybridize: @@ -148,23 +148,23 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize, def test_transformer_cfg_registry(): - assert len(transformer_nmt_cfg_reg.list_keys()) > 0 + assert len(transformer_cfg_reg.list_keys()) > 0 -@pytest.mark.parametrize('cfg_key', transformer_nmt_cfg_reg.list_keys()) +@pytest.mark.parametrize('cfg_key', transformer_cfg_reg.list_keys()) def test_transformer_cfg(cfg_key): - cfg = TransformerNMTModel.get_cfg(cfg_key) + cfg = TransformerModel.get_cfg(cfg_key) cfg.defrost() cfg.MODEL.src_vocab_size = 32 cfg.MODEL.tgt_vocab_size = 32 cfg.freeze() - model = TransformerNMTModel.from_cfg(cfg) + model = TransformerModel.from_cfg(cfg) model.initialize() model.hybridize() cfg.defrost() cfg.MODEL.layout = 'TN' cfg.freeze() - model_tn = TransformerNMTModel.from_cfg(cfg) + model_tn = TransformerModel.from_cfg(cfg) model_tn.share_parameters(model.collect_params()) model_tn.hybridize() mx.npx.waitall()