diff --git a/scripts/conversion_toolkits/README.md b/scripts/conversion_toolkits/README.md index 2c29e87db7..ea2430d367 100644 --- a/scripts/conversion_toolkits/README.md +++ b/scripts/conversion_toolkits/README.md @@ -12,6 +12,8 @@ The testing step mentioned above are controlled by the flag `--test`, in which t tolerance of 1e-3 between gluon model with converted weights and original tensorflow model. In addition, we can use GPU in all converting scripts by adding `--gpu 0`. +For RoBERTa XLM-R and BART model, please instal the [fairseq](https://github.com/pytorch/fairseq#requirements-and-installation) package locally as `pip install git+https://github.com/pytorch/fairseq.git@master`. + ## BERT Convert model from [BERT LIST](https://tfhub.dev/google/collections/bert/1). @@ -37,25 +39,42 @@ do done ``` -## RoBERTa +## ELECTRA +The TF Hub is not available for ELECTRA model currently. +Thus, you will need to clone the [electra repository](https://github.com/ZheyuYe/electra) +and download the checkpoint. The parameters are converted from local checkpoints. +By running the following command, you can convert + verify the ELECTRA model with both the discriminator and the generator. + +Notice: pleas set up the `--electra_path` with the cloned path ~~or get this electra repository packaged by `pip install -e .`.~~ + +```bash +# Need to use TF 1.13.2 to use contrib layer +pip uninstall tensorflow +pip install tensorflow==1.13.2 + +# Actual conversion +bash convert_electra.sh +``` +## Mobile Bert ```bash -pip install fairseq==0.9.0 +bash convert_mobilebert.sh +``` +## RoBERTa +```bash for model in base large do mkdir roberta_${model} wget "https://dl.fbaipublicfiles.com/fairseq/models/roberta.${model}.tar.gz" tar zxf roberta.${model}.tar.gz --directory roberta_${model} - python convert_fairseq_roberta.py --fairseq_model_path roberta_${model}/roberta.${model} --model_size ${model} --test + python convert_fairseq_roberta.py --fairseq_model_path roberta_${model}/roberta.${model} --test done ``` ## XLM-R ```bash -pip install fairseq==0.9.0 - for model in base large do mkdir xlmr_${model} @@ -65,23 +84,13 @@ do done ``` -## ELECTRA -The TF Hub is not available for ELECTRA model currently. -Thus, you will need to clone the [electra repository](https://github.com/ZheyuYe/electra) -and download the checkpoint. The parameters are converted from local checkpoints. -By running the following command, you can convert + verify the ELECTRA model with both the discriminator and the generator. - -Notice: pleas set up the `--electra_path` with the cloned path or get this electra repository packaged by `pip install -e .`. - +## BART ```bash -# Need to use TF 1.13.2 to use contrib layer -pip install tensorflow==1.13.2 --upgrade --force-reinstall - -# Actual conversion -bash convert_electra.sh -``` - -## Mobile Bert -```bash -bash convert_mobilebert.sh +for model in base large +do + mkdir bart_${model} + wget "https://dl.fbaipublicfiles.com/fairseq/models/bart.${model}.tar.gz" + tar zxf bart.${model}.tar.gz --directory bart_${model} + python convert_fairseq_bart.py --fairseq_model_path bart_${model}/bart.${model} --test +done ``` diff --git a/scripts/conversion_toolkits/convert_bart.sh b/scripts/conversion_toolkits/convert_bart.sh new file mode 100644 index 0000000000..e6c3db3d07 --- /dev/null +++ b/scripts/conversion_toolkits/convert_bart.sh @@ -0,0 +1,7 @@ +for model in base large +do + mkdir bart_${model} + wget "https://dl.fbaipublicfiles.com/fairseq/models/bart.${model}.tar.gz" + tar zxf bart.${model}.tar.gz --directory bart_${model} + python convert_fairseq_bart.py --fairseq_model_path bart_${model}/bart.${model} --test +done diff --git a/scripts/conversion_toolkits/convert_electra.py b/scripts/conversion_toolkits/convert_electra.py index 57d79c8e0b..4d76b4ab7b 100644 --- a/scripts/conversion_toolkits/convert_electra.py +++ b/scripts/conversion_toolkits/convert_electra.py @@ -53,7 +53,9 @@ def read_tf_checkpoint(path): return tensors -def get_dict_config(model_size, electra_dir): +def get_dict_config(model_size, electra_path): + sys.path.append(electra_path) + electra_dir = os.path.abspath(os.path.join(os.path.dirname(electra_path), os.path.pardir)) sys.path.append(electra_dir) from electra.util.training_utils import get_bert_config from electra.configure_pretraining import PretrainingConfig @@ -100,7 +102,7 @@ def convert_tf_config(config_dict, vocab_size): return cfg -def convert_tf_assets(tf_assets_dir, model_size, electra_dir): +def convert_tf_assets(tf_assets_dir, model_size, electra_path): """Convert the assets file including config, vocab and tokenizer model""" file_names = os.listdir(tf_assets_dir) vocab_path = None @@ -113,7 +115,7 @@ def convert_tf_assets(tf_assets_dir, model_size, electra_dir): if vocab_path: vocab_path = os.path.join(tf_assets_dir, vocab_path) vocab_size = len(open(vocab_path, 'rU').readlines()) - config_dict = get_dict_config(model_size, electra_dir) + config_dict = get_dict_config(model_size, electra_path) cfg = convert_tf_config(config_dict, vocab_size) return cfg, vocab_path @@ -190,12 +192,12 @@ def get_name_map(tf_names, convert_type='backbone'): return name_map -def convert_tf_model(model_dir, save_dir, test_conversion, model_size, gpu, electra_dir): +def convert_tf_model(model_dir, save_dir, test_conversion, model_size, gpu, electra_path): ctx = mx.gpu(gpu) if gpu is not None else mx.cpu() if not os.path.exists(save_dir): os.makedirs(save_dir) - cfg, vocab_path = convert_tf_assets(model_dir, model_size, electra_dir) + cfg, vocab_path = convert_tf_assets(model_dir, model_size, electra_path) with open(os.path.join(save_dir, 'model.yml'), 'w') as of: of.write(cfg.dump()) new_vocab = HuggingFaceWordPieceTokenizer( @@ -234,6 +236,8 @@ def convert_tf_model(model_dir, save_dir, test_conversion, model_size, gpu, elec tf_names = list(tf_names) # reload the electra module for this local scope + sys.path.append(electra_path) + electra_dir = os.path.abspath(os.path.join(os.path.dirname(electra_path), os.path.pardir)) sys.path.append(electra_dir) from electra.util.training_utils import get_bert_config from electra.configure_pretraining import PretrainingConfig @@ -426,11 +430,10 @@ def convert_qkv_weights(tf_prefix, mx_prefix): logging_config() save_dir = args.save_dir if args.save_dir is not None else os.path.basename( args.tf_model_path) + '_gluon' - electra_dir = os.path.abspath(os.path.join(os.path.dirname(args.electra_path), os.path.pardir)) convert_tf_model( args.tf_model_path, save_dir, args.test, args.model_size, args.gpu, - electra_dir) + args.electra_path) diff --git a/scripts/conversion_toolkits/convert_fairseq_bart.py b/scripts/conversion_toolkits/convert_fairseq_bart.py new file mode 100644 index 0000000000..4c78fff23c --- /dev/null +++ b/scripts/conversion_toolkits/convert_fairseq_bart.py @@ -0,0 +1,321 @@ +import os +import shutil +import logging +import argparse + +import mxnet as mx +import numpy as np +from numpy.testing import assert_allclose + +import torch +from fairseq.models.bart import BARTModel as fairseq_BARTModel +from gluonnlp.utils.misc import sha1sum, logging_config, naming_convention +from gluonnlp.models.bart import BartModel +from convert_fairseq_roberta import convert_vocab + +mx.npx.set_np() + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert the fairseq BART Model to Gluon.') + parser.add_argument('--fairseq_model_path', type=str, required=True, + help='Directory of the fairseq BART model.') + parser.add_argument('--save_dir', type=str, default=None, + help='Directory path to save the converted BART model.') + parser.add_argument('--gpu', type=int, default=None, + help='The single gpu to run mxnet, (e.g. --gpu 0) the default device is cpu.') + parser.add_argument('--test', action='store_true', + help='Whether to test the conversion.') + return parser.parse_args() + + +def convert_config(fairseq_cfg, vocab_size, cfg): + print('converting config') + cfg.defrost() + # Config for the bart base model + cfg.MODEL.vocab_size = vocab_size + cfg.MODEL.max_src_length = fairseq_cfg.max_source_positions + cfg.MODEL.max_tgt_length = fairseq_cfg.max_target_positions + cfg.MODEL.pos_embed_type = 'learned' + cfg.MODEL.shared_embed = fairseq_cfg.share_all_embeddings + cfg.MODEL.scale_embed = not fairseq_cfg.no_scale_embedding + cfg.MODEL.tie_weights = fairseq_cfg.share_decoder_input_output_embed + cfg.MODEL.data_norm = fairseq_cfg.layernorm_embedding + cfg.MODEL.pooler_activation = fairseq_cfg.pooler_activation_fn + cfg.MODEL.layer_norm_eps = 1E-5 + cfg.MODEL.dropout = fairseq_cfg.dropout + cfg.MODEL.activation_dropout = fairseq_cfg.activation_dropout + cfg.MODEL.attention_dropout = fairseq_cfg.attention_dropout + cfg.MODEL.dtype = 'float32' + + # Parameters for the encoder + cfg.MODEL.ENCODER.pre_norm = fairseq_cfg.encoder_normalize_before + cfg.MODEL.ENCODER.num_layers = fairseq_cfg.encoder_layers + cfg.MODEL.ENCODER.units = fairseq_cfg.encoder_embed_dim + cfg.MODEL.ENCODER.num_heads = fairseq_cfg.encoder_attention_heads + cfg.MODEL.ENCODER.hidden_size = fairseq_cfg.encoder_ffn_embed_dim + cfg.MODEL.ENCODER.activation = fairseq_cfg.activation_fn + + # Parameters for the decoder + cfg.MODEL.DECODER.pre_norm = fairseq_cfg.decoder_normalize_before + cfg.MODEL.DECODER.num_layers = fairseq_cfg.decoder_layers + cfg.MODEL.DECODER.units = fairseq_cfg.decoder_embed_dim + cfg.MODEL.DECODER.num_heads = fairseq_cfg.decoder_attention_heads + cfg.MODEL.DECODER.hidden_size = fairseq_cfg.decoder_ffn_embed_dim + cfg.MODEL.DECODER.activation = fairseq_cfg.activation_fn + + cfg.INITIALIZER.embed = ['xavier', 'gaussian', 'in', 1.0] + cfg.INITIALIZER.weight = ['xavier', 'uniform', 'avg', 1.0] + cfg.INITIALIZER.bias = ['zeros'] + cfg.VERSION = 1 + cfg.freeze() + return cfg + + +def convert_params(fairseq_model, + gluon_cfg, + ctx): + fairseq_params = fairseq_model.state_dict() + # apply a linear mapping to vocab dictionary + gluon_model = BartModel.from_cfg(gluon_cfg, use_pooler=False) + gluon_model.initialize(ctx=ctx) + gluon_model.hybridize() + gluon_params = gluon_model.collect_params() + all_keys = set(gluon_params.keys()) + + def convert_attention(num_layers, + fairseq_prefix, + gluon_prefix, + fairseq_attn_prefix='self_attn', + gluon_attn_prefix='attn_qkv'): + for layer_id in range(num_layers): + fs_atten_prefix = \ + '{}.layers.{}.{}.' \ + .format(fairseq_prefix, layer_id, fairseq_attn_prefix) + fs_q_weight = fairseq_params[fs_atten_prefix + 'q_proj.weight'].cpu().numpy() + fs_k_weight = fairseq_params[fs_atten_prefix + 'k_proj.weight'].cpu().numpy() + fs_v_weight = fairseq_params[fs_atten_prefix + 'v_proj.weight'].cpu().numpy() + fs_q_bias = fairseq_params[fs_atten_prefix + 'q_proj.bias'].cpu().numpy() + fs_k_bias = fairseq_params[fs_atten_prefix + 'k_proj.bias'].cpu().numpy() + fs_v_bias = fairseq_params[fs_atten_prefix + 'v_proj.bias'].cpu().numpy() + gl_qkv_prefix = \ + '{}.layers.{}.{}.' \ + .format(gluon_prefix, layer_id, gluon_attn_prefix) + gl_qkv_weight = gluon_params[gl_qkv_prefix + 'weight'] + gl_qkv_bias = gluon_params[gl_qkv_prefix + 'bias'] + all_keys.remove(gl_qkv_prefix + 'weight') + all_keys.remove(gl_qkv_prefix + 'bias') + gl_qkv_weight.set_data( + np.concatenate([fs_q_weight, fs_k_weight, fs_v_weight], axis=0)) + gl_qkv_bias.set_data( + np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0)) + + def convert_ffn(num_layers, fairseq_prefix, gluon_prefix): + # convert feed forward layer in encoder + for layer_id in range(num_layers): + for k, v in [ + ('fc1.weight', 'ffn.ffn_1.weight'), + ('fc1.bias', 'ffn.ffn_1.bias'), + ('fc2.weight', 'ffn.ffn_2.weight'), + ('fc2.bias', 'ffn.ffn_2.bias'), + ('final_layer_norm.weight', 'ffn.layer_norm.gamma'), + ('final_layer_norm.bias', 'ffn.layer_norm.beta') + ]: + fs_name = '{}.layers.{}.{}' \ + .format(fairseq_prefix, layer_id, k) + gl_name = '{}.layers.{}.{}' \ + .format(gluon_prefix, layer_id, v) + all_keys.remove(gl_name) + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()) + + print('converting embedding params') + padding_idx = fairseq_model.task.dictionary.pad_index + for fs_name, gl_name in [ + ('model.encoder.embed_tokens.weight', 'src_embed_layer.weight'), + ('model.encoder.embed_positions.weight', 'src_pos_embed_layer._embed.weight'), + ('model.encoder.layernorm_embedding.weight', 'encoder.ln_data.gamma'), + ('model.encoder.layernorm_embedding.bias', 'encoder.ln_data.beta'), + ('model.decoder.embed_tokens.weight', 'tgt_embed_layer.weight'), + ('model.decoder.embed_positions.weight', 'tgt_pos_embed_layer._embed.weight'), + ('model.decoder.layernorm_embedding.weight', 'decoder.ln_data.gamma'), + ('model.decoder.layernorm_embedding.bias', 'decoder.ln_data.beta'), + # final projection in decoder + ('model.decoder.output_projection.weight', 'tgt_final_layer.weight'), + ]: + all_keys.remove(gl_name) + if 'embed_positions' in fs_name: + # position embed weight + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()[padding_idx + 1:, :]) + else: + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()) + + print('converting encoder params') + encoder_num_layers = gluon_cfg.MODEL.ENCODER.num_layers + convert_attention(encoder_num_layers, 'model.encoder', 'encoder') + convert_ffn(encoder_num_layers, 'model.encoder', 'encoder') + for layer_id in range(encoder_num_layers): + for k, v in [ + ('self_attn.out_proj.weight', 'attention_proj.weight'), + ('self_attn.out_proj.bias', 'attention_proj.bias'), + ('self_attn_layer_norm.weight', 'layer_norm.gamma'), + ('self_attn_layer_norm.bias', 'layer_norm.beta'), + ]: + fs_name = 'model.encoder.layers.{}.{}' \ + .format(layer_id, k) + gl_name = 'encoder.layers.{}.{}' \ + .format(layer_id, v) + all_keys.remove(gl_name) + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()) + + print('converting decoder params') + decoder_num_layers = gluon_cfg.MODEL.DECODER.num_layers + convert_attention(decoder_num_layers, 'model.decoder', 'decoder', + gluon_attn_prefix='attn_in_qkv') + convert_ffn(decoder_num_layers, 'model.decoder', 'decoder') + + for layer_id in range(decoder_num_layers): + for k, v in [ + ('self_attn.out_proj.weight', 'proj_in.weight'), + ('self_attn.out_proj.bias', 'proj_in.bias'), + ('self_attn_layer_norm.weight', 'ln_in.gamma'), + ('self_attn_layer_norm.bias', 'ln_in.beta'), + ('encoder_attn.out_proj.weight', 'proj_inter.weight'), + ('encoder_attn.out_proj.bias', 'proj_inter.bias'), + ('encoder_attn_layer_norm.weight', 'ln_inter.gamma'), + ('encoder_attn_layer_norm.bias', 'ln_inter.beta'), + ('encoder_attn.q_proj.weight', 'attn_inter_q.weight'), + ('encoder_attn.q_proj.bias', 'attn_inter_q.bias'), + ('encoder_attn.k_proj.weight', 'attn_inter_k.weight'), + ('encoder_attn.k_proj.bias', 'attn_inter_k.bias'), + ('encoder_attn.v_proj.weight', 'attn_inter_v.weight'), + ('encoder_attn.v_proj.bias', 'attn_inter_v.bias'), + + ]: + fs_name = 'model.decoder.layers.{}.{}' \ + .format(layer_id, k) + gl_name = 'decoder.layers.{}.{}' \ + .format(layer_id, v) + all_keys.remove(gl_name) + gluon_params[gl_name].set_data( + fairseq_params[fs_name].cpu().numpy()) + + assert len(all_keys) == 0, 'parameters missing from tensorflow checkpoint' + + # check parameters sharing if share_decoder_input_output_embed is true + assert np.array_equal( + fairseq_params['model.decoder.embed_tokens.weight'].cpu().numpy(), + fairseq_params['model.decoder.output_projection.weight'].cpu().numpy() + ) + return gluon_model + + +def test_model(fairseq_model, gluon_model, gpu): + print('testing model') + ctx = mx.gpu(gpu) if gpu is not None else mx.cpu() + batch_size = 3 + seq_length = 32 + vocab_size = len(fairseq_model.task.dictionary) + padding_id = fairseq_model.model.decoder.padding_idx + input_ids = np.random.randint( # skip padding_id + padding_id + 1, + vocab_size, + (batch_size, seq_length) + ) + valid_length = np.random.randint( + seq_length // 2, + seq_length, + (batch_size,) + ) + + for i in range(batch_size): # add padding, for fairseq padding mask + input_ids[i, valid_length[i]:] = padding_id + + gl_input_ids = mx.np.array(input_ids, dtype=np.int32, ctx=ctx) + gl_valid_length = mx.np.array(valid_length, dtype=np.int32, ctx=ctx) + gl_dec_out = \ + gluon_model(gl_input_ids, gl_valid_length, gl_input_ids, gl_valid_length) + + fs_input_ids = torch.from_numpy(input_ids).cuda(gpu) + fairseq_model.model.eval() + fs_dec_out, fs_extra = \ + fairseq_model.model.cuda(gpu)( + fs_input_ids, + valid_length, + fs_input_ids, + return_all_hiddens=True) + + # checking decoder output + gl_dec_out = gl_dec_out.asnumpy() + fs_dec_out = fs_dec_out.detach().cpu().numpy() + for j in range(batch_size): + assert_allclose( + gl_dec_out[j, :valid_length[j], :], + fs_dec_out[j, :valid_length[j], :], + 1E-3, + 1E-3 + ) + + +def rename(save_dir): + """Rename converted files with hash""" + old_names = os.listdir(save_dir) + for old_name in old_names: + old_path = os.path.join(save_dir, old_name) + long_hash = sha1sum(old_path) + file_prefix, file_sufix = old_name.split('.') + new_name = '{file_prefix}-{short_hash}.{file_sufix}'.format( + file_prefix=file_prefix, + short_hash=long_hash[:8], + file_sufix=file_sufix) + new_path = os.path.join(save_dir, new_name) + shutil.move(old_path, new_path) + file_size = os.path.getsize(new_path) + logging.info('\t{} {} {}'.format(new_path, long_hash, file_size)) + + +def convert_fairseq_model(args): + if not args.save_dir: + args.save_dir = os.path.basename(args.fairseq_model_path) + '_gluon' + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + fairseq_bart = fairseq_BARTModel.from_pretrained(args.fairseq_model_path, + checkpoint_file='model.pt') + vocab_size = convert_vocab(args, fairseq_bart) + gluon_cfg = convert_config(fairseq_bart.args, vocab_size, + BartModel.get_cfg().clone()) + with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of: + of.write(gluon_cfg.dump()) + + ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu() + gluon_bart = convert_params(fairseq_bart, + gluon_cfg, + ctx) + if args.test: + test_model(fairseq_bart, gluon_bart, args.gpu) + + gluon_bart.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True) + logging.info('Convert the BART MLM model in {} to {}'. + format(os.path.join(args.fairseq_model_path, 'model.pt'), + os.path.join(args.save_dir, 'model.params'))) + + logging.info('Conversion finished!') + logging.info('Statistics:') + old_names = os.listdir(args.save_dir) + for old_name in old_names: + new_name, long_hash = naming_convention(args.save_dir, old_name) + old_path = os.path.join(args.save_dir, old_name) + new_path = os.path.join(args.save_dir, new_name) + shutil.move(old_path, new_path) + file_size = os.path.getsize(new_path) + logging.info('\t{}/{} {} {}'.format(args.save_dir, new_name, long_hash, file_size)) + + +if __name__ == '__main__': + args = parse_args() + logging_config() + convert_fairseq_model(args) diff --git a/scripts/conversion_toolkits/convert_fairseq_roberta.py b/scripts/conversion_toolkits/convert_fairseq_roberta.py index f20664af4e..bcdac44436 100644 --- a/scripts/conversion_toolkits/convert_fairseq_roberta.py +++ b/scripts/conversion_toolkits/convert_fairseq_roberta.py @@ -1,4 +1,5 @@ import os +import re import sys import json import shutil @@ -11,7 +12,7 @@ import torch from gluonnlp.data.vocab import Vocab as gluon_Vocab -from gluonnlp.utils.misc import sha1sum, logging_config +from gluonnlp.utils.misc import sha1sum, logging_config, naming_convention from fairseq.models.roberta import RobertaModel as fairseq_RobertaModel from gluonnlp.models.roberta import RobertaModel, RobertaForMLM from gluonnlp.data.tokenizers import HuggingFaceByteBPETokenizer @@ -23,8 +24,6 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert the fairseq RoBERTa Model to Gluon.') parser.add_argument('--fairseq_model_path', type=str, required=True, help='Directory of the fairseq RoBERTa model.') - parser.add_argument('--model_size', type=str, choices=['base', 'large'], default='base', - help='Size of RoBERTa model.') parser.add_argument('--save_dir', type=str, default=None, help='Directory path to save the converted RoBERTa model.') parser.add_argument('--gpu', type=int, default=None, @@ -69,16 +68,12 @@ def convert_vocab(args, fairseq_model): inter_vocab = sorted(inter_vocab, key=lambda x: x[1]) tokens = [e[0] for e in inter_vocab] - tail = [fairseq_vocab[-4], - fairseq_vocab[-3], - fairseq_vocab[-2], - fairseq_vocab[-1]] - assert tail == ['madeupword0000', - 'madeupword0001', - 'madeupword0002', - ''] + tail = [ + vocab for vocab in fairseq_vocab.indices.keys() if re.match( + r'^madeupword[\d]{4}$', + vocab) is not None] all_tokens = ['', '', '', ''] + \ - tokens + tail + tokens + tail + [''] gluon_vocab = gluon_Vocab(all_tokens, unk_token=fairseq_vocab.unk_word, @@ -172,7 +167,7 @@ def convert_params(fairseq_model, gluon_cfg, ctx): fairseq_params = fairseq_model.state_dict() - fairseq_prefix = 'model.decoder.' + fairseq_prefix = 'model.encoder.' gluon_prefix = 'backbone_model.' print('converting {} params'.format(gluon_prefix)) @@ -265,7 +260,7 @@ def test_model(fairseq_model, gluon_model, gpu): batch_size = 3 seq_length = 32 vocab_size = len(fairseq_model.task.dictionary) - padding_id = fairseq_model.model.decoder.sentence_encoder.padding_idx + padding_id = fairseq_model.model.encoder.sentence_encoder.padding_idx input_ids = np.random.randint( # skip padding_id padding_id + 1, vocab_size, @@ -315,7 +310,6 @@ def test_model(fairseq_model, gluon_model, gpu): ) # checking masked_language_scores gl_mlm_scores = gl_mlm_scores.asnumpy() - fs_mlm_scores = fs_mlm_scores.transpose(0, 1) fs_mlm_scores = fs_mlm_scores.detach().cpu().numpy() for j in range(batch_size): assert_allclose( @@ -377,7 +371,14 @@ def convert_fairseq_model(args): logging.info('Conversion finished!') logging.info('Statistics:') - rename(args.save_dir) + old_names = os.listdir(args.save_dir) + for old_name in old_names: + new_name, long_hash = naming_convention(args.save_dir, old_name) + old_path = os.path.join(args.save_dir, old_name) + new_path = os.path.join(args.save_dir, new_name) + shutil.move(old_path, new_path) + file_size = os.path.getsize(new_path) + logging.info('\t{}/{} {} {}'.format(args.save_dir, new_name, long_hash, file_size)) if __name__ == '__main__': diff --git a/scripts/conversion_toolkits/convert_mobilebert.py b/scripts/conversion_toolkits/convert_mobilebert.py index 3617b082be..8be50f672e 100644 --- a/scripts/conversion_toolkits/convert_mobilebert.py +++ b/scripts/conversion_toolkits/convert_mobilebert.py @@ -306,11 +306,11 @@ def convert_tf_model(model_dir, save_dir, test_conversion, gpu, mobilebert_dir): tf_pooled_output = tf_token_outputs_np['pooled_output'] contextual_embedding, pooled_output = model.backbone_model( mx_input_ids, mx_token_types, mx_valid_length) - assert_allclose(pooled_output.asnumpy(), tf_pooled_output, 1E-3, 1E-3) + assert_allclose(pooled_output.asnumpy(), tf_pooled_output, 1E-2, 1E-2) for i in range(batch_size): ele_valid_length = valid_length[i] assert_allclose(contextual_embedding[i, :ele_valid_length, :].asnumpy(), - tf_contextual_embedding[i, :ele_valid_length, :], 1E-3, 1E-3) + tf_contextual_embedding[i, :ele_valid_length, :], 1E-2, 1E-2) model.backbone_model.save_parameters(os.path.join(save_dir, 'model.params'), deduplicate=True) logging.info('Convert the backbone model in {} to {}/{}'.format(model_dir, save_dir, 'model.params')) model.save_parameters(os.path.join(save_dir, 'model_mlm.params'), deduplicate=True) diff --git a/scripts/conversion_toolkits/convert_roberta.sh b/scripts/conversion_toolkits/convert_roberta.sh new file mode 100644 index 0000000000..83e6636fef --- /dev/null +++ b/scripts/conversion_toolkits/convert_roberta.sh @@ -0,0 +1,7 @@ +for model in base large +do + mkdir roberta_${model} + wget "https://dl.fbaipublicfiles.com/fairseq/models/roberta.${model}.tar.gz" + tar zxf roberta.${model}.tar.gz --directory roberta_${model} + python convert_fairseq_roberta.py --fairseq_model_path roberta_${model}/roberta.${model} --test +done diff --git a/scripts/conversion_toolkits/convert_xlmr.sh b/scripts/conversion_toolkits/convert_xlmr.sh new file mode 100644 index 0000000000..f7f4832996 --- /dev/null +++ b/scripts/conversion_toolkits/convert_xlmr.sh @@ -0,0 +1,7 @@ +for model in base large +do + mkdir xlmr_${model} + wget "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.${model}.tar.gz" + tar zxf xlmr.${model}.tar.gz --directory xlmr_${model} + python convert_fairseq_xlmr.py --fairseq_model_path xlmr_${model}/xlmr.${model} --model_size ${model} --test +done diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 8468eb22ad..8b5d0695f1 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -11,7 +11,7 @@ bash wmt2014_ende.sh yttm ``` Then, you can run the experiment, we use the -"transformer_nmt_base" configuration. +"transformer_base" configuration. ```bash SUBWORD_MODEL=yttm @@ -25,7 +25,7 @@ python train_transformer.py \ --tgt_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ --tgt_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ --save_dir transformer_wmt2014_ende_${SUBWORD_MODEL} \ - --cfg transformer_nmt_base \ + --cfg transformer_base \ --lr 0.002 \ --warmup_steps 4000 \ --warmup_init_lr 0.0 \ @@ -55,7 +55,7 @@ python evaluate_transformer.py \ Test BLEU score with 3 seeds (evaluated via sacre BLEU): -- transformer_nmt_base +- transformer_base | Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | |---------------|------------|-------------|-------------|--------------|-------------| diff --git a/scripts/machine_translation/evaluate_transformer.py b/scripts/machine_translation/evaluate_transformer.py index 5cbd065050..9010b83384 100644 --- a/scripts/machine_translation/evaluate_transformer.py +++ b/scripts/machine_translation/evaluate_transformer.py @@ -7,7 +7,7 @@ import logging import time from gluonnlp.utils.misc import logging_config -from gluonnlp.models.transformer import TransformerNMTModel,\ +from gluonnlp.models.transformer import TransformerModel,\ TransformerNMTInference from gluonnlp.data.batchify import Tuple, Pad, Stack from gluonnlp.data.filtering import MosesNormalizer @@ -144,16 +144,16 @@ def evaluate(args): src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab if args.cfg.endswith('.yml'): - cfg = TransformerNMTModel.get_cfg().clone_merge(args.cfg) + cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: - cfg = TransformerNMTModel.get_cfg(args.cfg) + cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) if args.fp16: cfg.MODEL.dtype = 'float16' cfg.freeze() - model = TransformerNMTModel.from_cfg(cfg) + model = TransformerModel.from_cfg(cfg) model.hybridize() model.load_parameters(args.param_path, ctx=ctx_l) inference_model = TransformerNMTInference(model=model) diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index dfa9a74cfb..089d0267c0 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -42,7 +42,7 @@ import numpy as np import mxnet as mx from mxnet import gluon -from gluonnlp.models.transformer import TransformerNMTModel +from gluonnlp.models.transformer import TransformerModel from gluonnlp.utils.misc import logging_config, AverageSGDTracker, count_parameters,\ md5sum, grouper from gluonnlp.data.sampler import ( @@ -112,7 +112,7 @@ def parse_args(): 'each update step contains gpu_num * num_accumulated batches.') parser.add_argument('--save_interval_update', type=int, default=500, help='Update interval of saving checkpoints while using max_update.') - parser.add_argument('--cfg', type=str, default='transformer_nmt_base', + parser.add_argument('--cfg', type=str, default='transformer_base', help='Configuration of the transformer model. ' 'You may select a yml file or use the prebuild configurations.') parser.add_argument('--label_smooth_alpha', type=float, default=0.1, @@ -176,7 +176,7 @@ def validation(model, data_loader, ctx_l): Parameters ---------- - model : TransformerNMTModel + model : TransformerModel The transformer model data_loader : DataLoader DataLoader @@ -308,9 +308,9 @@ def train(args): else [mx.gpu(int(x)) for x in args.gpus.split(',')] # Construct the model + loss function if args.cfg.endswith('.yml'): - cfg = TransformerNMTModel.get_cfg().clone_merge(args.cfg) + cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: - cfg = TransformerNMTModel.get_cfg(args.cfg) + cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) @@ -318,7 +318,7 @@ def train(args): raise NotImplementedError # cfg.MODEL.dtype = 'float16' cfg.freeze() - model = TransformerNMTModel.from_cfg(cfg) + model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() diff --git a/scripts/machine_translation/wmt2014_back_translation.sh b/scripts/machine_translation/wmt2014_back_translation.sh index 29ad1dfd33..9e12f8be3c 100644 --- a/scripts/machine_translation/wmt2014_back_translation.sh +++ b/scripts/machine_translation/wmt2014_back_translation.sh @@ -37,7 +37,7 @@ python train_transformer.py \ --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ --save_dir transformer_wmt2014_de_en_${SUBWORD_ALGO} \ - --cfg transformer_nmt_base \ + --cfg transformer_base \ --lr 0.002 \ --warmup_steps 4000 \ --warmup_init_lr 0.0 \ @@ -63,7 +63,7 @@ for NUM in ` seq -f %03g 0 193 `; do --param_path transformer_wmt2014_de_en_${SUBWORD_ALGO}/average.params \ --src_lang ${TGT} \ --tgt_lang ${SRC} \ - --cfg transformer_nmt_base \ + --cfg transformer_base \ --src_tokenizer ${SUBWORD_ALGO} \ --tgt_tokenizer ${SUBWORD_ALGO} \ --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ @@ -125,7 +125,7 @@ python train_transformer.py \ --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ --save_dir backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO} \ - --cfg transformer_nmt_base \ + --cfg transformer_base \ --lr 0.002 \ --batch_size 2700 \ --max_update 60000 \ @@ -145,7 +145,7 @@ python evaluate_transformer.py \ --param_path backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO}/average.params \ --src_lang ${SRC} \ --tgt_lang ${TGT} \ - --cfg transformer_nmt_base \ + --cfg transformer_base \ --src_tokenizer ${SUBWORD_ALGO} \ --tgt_tokenizer ${SUBWORD_ALGO} \ --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ diff --git a/src/gluonnlp/layers.py b/src/gluonnlp/layers.py index a6ea6b181e..d68e99d8ca 100644 --- a/src/gluonnlp/layers.py +++ b/src/gluonnlp/layers.py @@ -356,7 +356,8 @@ def __init__(self, mode='erf'): def hybrid_forward(self, F, x): if self._mode == 'erf': - return F.npx.leaky_relu(x, act_type='gelu') + # TODO Investigate the precision of F.npx.leaky_relu(x, act_type='gelu') + return x * 0.5 * (1.0 + F.npx.erf(x / math.sqrt(2.0))) elif self._mode == 'tanh': return 0.5 * x\ * (1.0 + F.np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3)))) diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py new file mode 100644 index 0000000000..463b5b1037 --- /dev/null +++ b/src/gluonnlp/models/bart.py @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +BART Model + +@article{lewis2019bart, + title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural +Language Generation, Translation, and Comprehension}, + author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and + Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov + and Luke Zettlemoyer }, + journal={arXiv preprint arXiv:1910.13461}, + year = {2019}, +} + +""" + +__all__ = ['BartModel', 'list_pretrained_bart', 'get_pretrained_bart'] + +import os +from typing import Tuple + +import mxnet as mx +from mxnet import use_np +from mxnet.gluon import nn + +from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, \ + get_model_zoo_checksum_dir +from ..registry import BACKBONE_REGISTRY +from ..utils.misc import download, load_checksum_stats +from .transformer import TransformerModel +from ..utils.config import CfgNode as CN +from ..utils.registry import Registry +from ..data.tokenizers import HuggingFaceByteBPETokenizer + +bart_cfg_reg = Registry('bart_cfg') + + +@bart_cfg_reg.register() +def bart_base(): + cfg = CN() + # Config for the bart base model + cfg.MODEL = CN() + cfg.MODEL.vocab_size = 51201 + cfg.MODEL.max_src_length = 1024 + cfg.MODEL.max_tgt_length = 1024 + cfg.MODEL.scale_embed = False + cfg.MODEL.pos_embed_type = 'learned' + cfg.MODEL.shared_embed = True + cfg.MODEL.tie_weights = True + cfg.MODEL.attention_dropout = 0.1 + cfg.MODEL.activation_dropout = 0.0 + cfg.MODEL.dropout = 0.1 + cfg.MODEL.layer_norm_eps = 1E-5 + cfg.MODEL.pooler_activation = 'tanh' + cfg.MODEL.data_norm = True + cfg.MODEL.layout = 'NT' + cfg.MODEL.dtype = 'float32' + + # Parameters for the encoder + cfg.MODEL.ENCODER = CN() + cfg.MODEL.ENCODER.num_layers = 6 + cfg.MODEL.ENCODER.units = 768 + cfg.MODEL.ENCODER.num_heads = 12 + cfg.MODEL.ENCODER.hidden_size = 3072 + cfg.MODEL.ENCODER.recurrent = False + cfg.MODEL.ENCODER.pre_norm = False + cfg.MODEL.ENCODER.activation = 'gelu' + cfg.MODEL.ENCODER.use_qkv_bias = True + + # Parameters for the decoder + cfg.MODEL.DECODER = CN() + cfg.MODEL.DECODER.num_layers = 6 + cfg.MODEL.DECODER.units = 768 + cfg.MODEL.DECODER.num_heads = 12 + cfg.MODEL.DECODER.hidden_size = 3072 + cfg.MODEL.DECODER.recurrent = False + cfg.MODEL.DECODER.pre_norm = False + cfg.MODEL.DECODER.activation = 'gelu' + cfg.MODEL.DECODER.use_qkv_bias = True + + # Parameters for the initializer + cfg.INITIALIZER = CN() + cfg.INITIALIZER.embed = ['xavier', 'gaussian', 'in', 1.0] + cfg.INITIALIZER.weight = ['xavier', 'uniform', 'avg', 1.0] + cfg.INITIALIZER.bias = ['zeros'] + cfg.VERSION = 1 + cfg.freeze() + return cfg + + +@bart_cfg_reg.register() +def bart_large(): + cfg = bart_base() + cfg.defrost() + cfg.MODEL.vocab_size = 50265 + cfg.MODEL.ENCODER.units = 1024 + cfg.MODEL.ENCODER.hidden_size = 4096 + cfg.MODEL.ENCODER.num_heads = 16 + cfg.MODEL.ENCODER.num_layers = 12 + cfg.MODEL.DECODER.units = 1024 + cfg.MODEL.DECODER.hidden_size = 4096 + cfg.MODEL.DECODER.num_heads = 16 + cfg.MODEL.DECODER.num_layers = 12 + cfg.freeze() + return cfg + + +PRETRAINED_URL = { + 'fairseq_bart_base': { + 'cfg': bart_base(), + 'merges': 'fairseq_bart_base/gpt2-396d4d8e.merges', + 'vocab': 'fairseq_bart_base/gpt2-f4dedacb.vocab', + 'params': 'fairseq_bart_base/model-8f4929b5.params', + 'lowercase': False, + }, + 'fairseq_bart_large': { + 'cfg': bart_large(), + 'merges': 'fairseq_bart_large/gpt2-396d4d8e.merges', + 'vocab': 'fairseq_bart_large/gpt2-f1335494.vocab', + 'params': 'fairseq_bart_large/model-862277b1.params', + 'lowercase': False, + } +} + + +FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'bart.txt')) + + +@use_np +class BartModel(TransformerModel): + def __init__(self, + use_pooler: bool = False, + classifier_activation: bool = False, + pooler_activation='tanh', + **kwargs): + """ + + Parameters + ---------- + use_pooler + classifier_activation + pooler_activation + **kwargs + """ + super().__init__(**kwargs) + assert self._src_vocab_size == self._tgt_vocab_size, \ + 'Vocab size mismatch between encoder and decoder' + self._vocab_size = self._src_vocab_size + self.use_pooler = use_pooler + self.classifier_activation = classifier_activation + if not use_pooler: + if self.tie_weights: + self.tgt_final_layer = \ + nn.Dense(self._tgt_vocab_size, flatten=False, + use_bias=False, + dtype=self._dtype) + self.tgt_final_layer.weight = self.tgt_embed_layer.weight + else: + self.tgt_final_layer = \ + nn.Dense(self._tgt_vocab_size, + flatten=False, + weight_initializer=self.weight_initializer, + use_bias=False, + dtype=self._dtype) + elif classifier_activation: + # Construct pooler + self.pooler = nn.Dense(units=self.units, + in_units=self.units, + flatten=False, + activation=pooler_activation, + weight_initializer=self.weight_initializer, + bias_initializer=self.bias_initializer, + dtype=self._dtype) + + def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length): + """ + + Parameters + ---------- + F + src_data + - layout = 'NT' + Shape (batch_size, src_length) + - layout = 'TN' + Shape (src_length, batch_size) + src_valid_length + Shape (batch_size,) + tgt_data + - layout = 'NT' + Shape (batch_size, tgt_length) + - layout = 'TN' + Shape (tgt_length, batch_size) + tgt_valid_length + Shape (batch_size,) + + Returns + ------- + (contextual_embedding) + - layout = 'NT' + Shape (batch_size, tgt_length, units) + - layout = 'TN' + Shape (tgt_length, batch_size, units) + (pooled_output) + This is optional. Shape (batch_size, units) + (dec_out) + - layout = 'NT' + Shape (batch_size, tgt_length, tgt_vocab_size) + - layout = 'TN' + Shape (tgt_length, batch_size, tgt_vocab_size) + """ + enc_out = self.encode(F, src_data, src_valid_length) + contextual_embedding = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length) + if self.use_pooler: + pooled_output = self.apply_pooling(contextual_embedding) + return contextual_embedding, pooled_output + else: + dec_out = self.tgt_final_layer(contextual_embedding) + return dec_out + + def apply_pooling(self, sequence): + """Generate the representation given the inputs. + + This is used for pre-training or fine-tuning a mobile bert model. + Get the first token of the whole sequence which is [CLS] + + sequence: + Shape (batch_size, sequence_length, units) + return: + Shape (batch_size, units) + """ + if self._layout == 'NT': + outputs = sequence[:, 0, :] + elif self._layout == 'TN': + outputs = sequence[0, :, :] + else: + raise NotImplementedError + if self.classifier_activation: + return self.pooler(outputs) + else: + return outputs + + @property + def layout(self) -> str: + return self._layout + + @property + def vocab_size(self): + return self._vocab_size + + @classmethod + def get_cfg(cls, key=None): + if key is None: + return bart_base() + else: + return bart_cfg_reg.create(key) + + @classmethod + def from_cfg(cls, cfg, dtype=None, + use_pooler=False, + classifier_activation=False): + cfg = cls.get_cfg().clone_merge(cfg) + embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) + weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) + bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype + return cls(src_vocab_size=cfg.MODEL.vocab_size, + tgt_vocab_size=cfg.MODEL.vocab_size, + max_src_length=cfg.MODEL.max_src_length, + max_tgt_length=cfg.MODEL.max_tgt_length, + scale_embed=cfg.MODEL.scale_embed, + pos_embed_type=cfg.MODEL.pos_embed_type, + shared_embed=cfg.MODEL.shared_embed, + tie_weights=cfg.MODEL.tie_weights, + data_norm=cfg.MODEL.data_norm, + use_pooler=use_pooler, + classifier_activation=classifier_activation, + attention_dropout=cfg.MODEL.attention_dropout, + activation_dropout=cfg.MODEL.activation_dropout, + dropout=cfg.MODEL.dropout, + pooler_activation=cfg.MODEL.pooler_activation, + layer_norm_eps=cfg.MODEL.layer_norm_eps, + enc_num_layers=cfg.MODEL.ENCODER.num_layers, + enc_units=cfg.MODEL.ENCODER.units, + enc_num_heads=cfg.MODEL.ENCODER.num_heads, + enc_hidden_size=cfg.MODEL.ENCODER.hidden_size, + enc_recurrent=cfg.MODEL.ENCODER.recurrent, + enc_activation=cfg.MODEL.ENCODER.activation, + enc_pre_norm=cfg.MODEL.ENCODER.pre_norm, + dec_num_layers=cfg.MODEL.DECODER.num_layers, + dec_units=cfg.MODEL.DECODER.units, + dec_num_heads=cfg.MODEL.DECODER.num_heads, + dec_hidden_size=cfg.MODEL.DECODER.hidden_size, + dec_recurrent=cfg.MODEL.DECODER.recurrent, + dec_activation=cfg.MODEL.DECODER.activation, + dec_pre_norm=cfg.MODEL.DECODER.pre_norm, + layout=cfg.MODEL.layout, + embed_initializer=embed_initializer, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + dtype=dtype) + + +def list_pretrained_bart(): + return sorted(list(PRETRAINED_URL.keys())) + + +def get_pretrained_bart(model_name: str = 'fairseq_bart_base', + root: str = get_model_zoo_home_dir(), + load_backbone: bool = True) \ + -> Tuple[CN, HuggingFaceByteBPETokenizer, str]: + """Get the pretrained RoBERTa weights + + Parameters + ---------- + model_name + The name of the RoBERTa model. + root + The downloading root + load_backbone + Whether to load the weights of the backbone network + Returns + ------- + cfg + Network configuration + tokenizer + The HuggingFaceByteBPETokenizer + params_path + Path to the parameters + """ + assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( + model_name, list_pretrained_bart()) + cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None + merges_path = PRETRAINED_URL[model_name]['merges'] + vocab_path = PRETRAINED_URL[model_name]['vocab'] + params_path = PRETRAINED_URL[model_name]['params'] + + local_paths = dict() + download_jobs = [('vocab', vocab_path), ('merges', merges_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: + local_paths[k] = download(url=get_repo_model_zoo_url() + path, + path=os.path.join(root, path), + sha1_hash=FILE_STATS[path]) + if load_backbone: + local_params_path = download(url=get_repo_model_zoo_url() + params_path, + path=os.path.join(root, params_path), + sha1_hash=FILE_STATS[params_path]) + else: + local_params_path = None + + local_mlm_params_path = None + do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\ + and PRETRAINED_URL[model_name]['lowercase'] else False + tokenizer = HuggingFaceByteBPETokenizer( + merges_file=local_paths['merges'], + vocab_file=local_paths['vocab'], + lowercase=do_lower) + if cfg is None: + cfg = BartModel.get_cfg().clone_merge(local_paths['cfg']) + return cfg, tokenizer, local_params_path, local_mlm_params_path + + +BACKBONE_REGISTRY.register('bart', [BartModel, + get_pretrained_bart, + list_pretrained_bart]) diff --git a/src/gluonnlp/models/model_zoo_checksums/bart.txt b/src/gluonnlp/models/model_zoo_checksums/bart.txt new file mode 100644 index 0000000000..75e61f9ef8 --- /dev/null +++ b/src/gluonnlp/models/model_zoo_checksums/bart.txt @@ -0,0 +1,8 @@ +fairseq_bart_base/model-8f4929b5.params 8f4929b54f2f77619885cea9f3bd7dba51a27f38 560560748 +fairseq_bart_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +fairseq_bart_base/model-251bf089.yml 251bf08944d18cc29b59a4a854bdbccf601dabb5 754 +fairseq_bart_base/gpt2-f4dedacb.vocab f4dedacb076b1df441c9c7398ed9acd3c19865f3 575079 +fairseq_bart_large/model-862277b1.params 862277b1489ed95140cb63279fbd0098ef2dea90 1625180962 +fairseq_bart_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +fairseq_bart_large/model-a2932dea.yml a2932deaf9737d95891755841fae3e388f3d698a 746 +fairseq_bart_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231 diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index da18447f07..8e30f1048a 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -12,13 +12,13 @@ from ..sequence_sampler import BaseStepDecoder __all__ = ['TransformerEncoderLayer', 'TransformerDecoderLayer', 'TransformerEncoder', 'TransformerDecoder', - 'TransformerNMTModel', 'TransformerNMTInference'] + 'TransformerModel', 'TransformerNMTInference'] -transformer_nmt_cfg_reg = Registry('transformer_nmt_cfg') +transformer_cfg_reg = Registry('transformer_cfg') -@transformer_nmt_cfg_reg.register() -def transformer_nmt_base(): +@transformer_cfg_reg.register() +def transformer_base(): """Configuration of Transformer WMT EN-DE Base""" cfg = CN() cfg.MODEL = CN() @@ -45,6 +45,7 @@ def transformer_nmt_base(): cfg.MODEL.ENCODER.recurrent = False cfg.MODEL.ENCODER.activation = 'relu' cfg.MODEL.ENCODER.pre_norm = False + cfg.MODEL.ENCODER.use_qkv_bias = True # Parameters for the decoder cfg.MODEL.DECODER = CN() @@ -55,6 +56,7 @@ def transformer_nmt_base(): cfg.MODEL.DECODER.recurrent = False cfg.MODEL.DECODER.activation = 'relu' cfg.MODEL.DECODER.pre_norm = False + cfg.MODEL.DECODER.use_qkv_bias = False # Parameters for the initializer cfg.INITIALIZER = CN() @@ -66,9 +68,9 @@ def transformer_nmt_base(): return cfg -@transformer_nmt_cfg_reg.register() -def transformer_nmt_base_prenorm(): - cfg = transformer_nmt_base() +@transformer_cfg_reg.register() +def transformer_base_prenorm(): + cfg = transformer_base() cfg.defrost() cfg.MODEL.ENCODER.pre_norm = True cfg.MODEL.DECODER.pre_norm = True @@ -76,9 +78,9 @@ def transformer_nmt_base_prenorm(): return cfg -@transformer_nmt_cfg_reg.register() +@transformer_cfg_reg.register() def transformer_iwslt_de_en(): - cfg = TransformerNMTModel.get_cfg() + cfg = TransformerModel.get_cfg() cfg.defrost() cfg.MODEL.ENCODER.units = 512 cfg.MODEL.ENCODER.hidden_size = 1024 @@ -92,10 +94,10 @@ def transformer_iwslt_de_en(): return cfg -@transformer_nmt_cfg_reg.register() +@transformer_cfg_reg.register() def transformer_wmt_en_de_big(): """Same wmt_en_de_big architecture as in FairSeq""" - cfg = TransformerNMTModel.get_cfg() + cfg = TransformerModel.get_cfg() cfg.defrost() cfg.MODEL.attention_dropout = 0.1 cfg.MODEL.dropout = 0.3 @@ -111,7 +113,7 @@ def transformer_wmt_en_de_big(): return cfg -@transformer_nmt_cfg_reg.register() +@transformer_cfg_reg.register() def transformer_wmt_en_de_big_t2t(): """Parameter used in the T2T implementation""" cfg = transformer_wmt_en_de_big() @@ -161,6 +163,7 @@ def __init__(self, data -> attn -> norm(res(+data)) -> ffn use_qkv_bias + Wether to use bias for self attention weight_initializer bias_initializer activation @@ -196,7 +199,7 @@ def __init__(self, bias_initializer=bias_initializer, dtype=self._dtype) attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' - self.attention_cell =\ + self.attention_cell = \ MultiHeadAttentionCell( query_units=self._units, num_heads=self._num_heads, @@ -261,12 +264,11 @@ def hybrid_forward(self, F, data, attn_mask): out = self.ffn(out) return out, attn_weight - @use_np class TransformerEncoder(HybridBlock): def __init__(self, num_layers=6, recurrent=False, units=512, hidden_size=2048, num_heads=8, - activation_dropout=0.0, dropout=0.1, + activation_dropout=0.0, dropout=0.1, use_qkv_bias=True, attention_dropout=0.1, layer_norm_eps=1E-5, data_norm=False, pre_norm=False, weight_initializer=None, bias_initializer='zeros', activation='relu', dtype='float32', layout='NT'): @@ -320,6 +322,7 @@ def __init__(self, num_layers=6, recurrent=False, hidden_dropout_prob=dropout, attention_dropout_prob=attention_dropout, activation_dropout_prob=activation_dropout, + use_qkv_bias=use_qkv_bias, layer_norm_eps=layer_norm_eps, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -385,6 +388,7 @@ def __init__(self, units: int = 512, layer_norm_eps: float = 1E-5, activation: str = 'relu', pre_norm: bool = False, + use_qkv_bias: bool = True, weight_initializer=None, bias_initializer='zeros', dtype='float32', @@ -406,6 +410,8 @@ def __init__(self, units: int = 512, activation pre_norm Whether to apply normalization before the attention layer + use_qkv_bias + Wether to use bias for both self attention and contextual attention weight_initializer bias_initializer dtype @@ -432,7 +438,7 @@ def __init__(self, units: int = 512, raise ValueError('In Transformer, units should be divided exactly by the number of ' 'heads. Received units={}, num_heads={}'.format(units, num_heads)) self.attn_in_qkv = nn.Dense(3 * units, in_units=units, - use_bias=False, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -442,25 +448,25 @@ def __init__(self, units: int = 512, attention_dropout=self._attention_dropout, dtype=dtype, layout=attention_layout) - self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=False, + self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype) self.attn_inter_q = nn.Dense(units, in_units=units, - use_bias=False, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype) self.attn_inter_k = nn.Dense(units, in_units=mem_units, - use_bias=False, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype) self.attn_inter_v = nn.Dense(units, in_units=mem_units, - use_bias=False, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -471,7 +477,7 @@ def __init__(self, units: int = 512, dtype=dtype, layout=attention_layout) self.proj_inter = nn.Dense(units=units, in_units=units, - flatten=False, use_bias=False, + flatten=False, use_bias=True, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype) @@ -484,6 +490,9 @@ def __init__(self, units: int = 512, hidden_size=hidden_size, dropout=dropout, activation_dropout=activation_dropout, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + layer_norm_eps=layer_norm_eps, activation=activation, pre_norm=pre_norm, dtype=dtype) @@ -543,10 +552,11 @@ def hybrid_forward(self, F, data, mem, self_causal_mask, mem_attn_mask): if self._pre_norm: data = self.ln_in(data) self_query, self_key, self_value = F.np.split(self.attn_in_qkv(data), 3, axis=-1) - out, _ = self.self_attention(F.npx.reshape(self_query, (-2, -2, self._num_heads, -1)), - F.npx.reshape(self_key, (-2, -2, self._num_heads, -1)), - F.npx.reshape(self_value, (-2, -2, self._num_heads, -1)), - self_causal_mask) + out, [_, self_attn_weight] = self.self_attention( + F.npx.reshape(self_query, (-2, -2, self._num_heads, -1)), + F.npx.reshape(self_key, (-2, -2, self._num_heads, -1)), + F.npx.reshape(self_value, (-2, -2, self._num_heads, -1)), + self_causal_mask) out = self.proj_in(out) out = self.dropout_layer(out) out = out + data @@ -556,13 +566,11 @@ def hybrid_forward(self, F, data, mem, self_causal_mask, mem_attn_mask): data = out if self._pre_norm: data = self.ln_inter(data) - out, _ = self.inter_attention(F.npx.reshape(self.attn_inter_q(data), - (-2, -2, self._num_heads, -1)), - F.npx.reshape(self.attn_inter_k(mem), - (-2, -2, self._num_heads, -1)), - F.npx.reshape(self.attn_inter_v(mem), - (-2, -2, self._num_heads, -1)), - mem_attn_mask) + out, [_, context_attn_weight] = self.inter_attention( + F.npx.reshape(self.attn_inter_q(data), (-2, -2, self._num_heads, -1)), + F.npx.reshape(self.attn_inter_k(mem), (-2, -2, self._num_heads, -1)), + F.npx.reshape(self.attn_inter_v(mem), (-2, -2, self._num_heads, -1)), + mem_attn_mask) out = self.proj_inter(out) out = self.dropout_layer(out) out = out + data @@ -674,7 +682,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma step_value = F.npx.reshape(step_value, (-2, -2, self._num_heads, -1)) new_key = F.np.concatenate([prev_key, step_key], axis=time_axis) new_value = F.np.concatenate([prev_value, step_value], axis=time_axis) - out, _ = self.self_attention(step_query, new_key, new_value, None) + out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None) out = self.proj_in(out) out = self.dropout_layer(out) out = out + data @@ -705,8 +713,8 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma @use_np class TransformerDecoder(HybridBlock): def __init__(self, num_layers=6, recurrent=False, - units=512, mem_units=None, hidden_size=2048, - num_heads=8, max_shift=None, rel_pos_embed=False, activation_dropout=0.0, + units=512, mem_units=None, hidden_size=2048, use_qkv_bias=True, + num_heads=8, max_shift=None, activation_dropout=0.0, dropout=0.1, attention_dropout=0.1, layer_norm_eps=1E-5, data_norm=False, pre_norm=False, weight_initializer=None, bias_initializer=None, activation='relu', dtype='float32', @@ -718,7 +726,6 @@ def __init__(self, num_layers=6, recurrent=False, self.num_layers = num_layers self.recurrent = recurrent self.max_shift = max_shift - self.rel_pos_embed = rel_pos_embed self._data_norm = data_norm self._pre_norm = pre_norm self._layout = layout @@ -740,6 +747,7 @@ def __init__(self, num_layers=6, recurrent=False, hidden_size=hidden_size, num_heads=num_heads, activation_dropout=activation_dropout, + use_qkv_bias=use_qkv_bias, dropout=dropout, attention_dropout=attention_dropout, layer_norm_eps=layer_norm_eps, @@ -909,7 +917,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length): @use_np -class TransformerNMTModel(HybridBlock): +class TransformerModel(HybridBlock): def __init__(self, src_vocab_size: int, tgt_vocab_size: int, max_src_length: Optional[int] = None, @@ -930,6 +938,7 @@ def __init__(self, src_vocab_size: int, enc_recurrent: bool = False, enc_activation='relu', enc_pre_norm: bool = False, + enc_use_qkv_bias: bool = True, dec_units: int = 512, dec_hidden_size: int = 2048, dec_num_heads: int = 8, @@ -937,6 +946,7 @@ def __init__(self, src_vocab_size: int, dec_recurrent: bool = False, dec_activation='relu', dec_pre_norm: bool = False, + dec_use_qkv_bias: bool = True, embed_initializer=mx.init.Xavier('gaussian', 'in', 1), weight_initializer=mx.init.Xavier('uniform', 'avg', 3), bias_initializer='zeros', @@ -988,6 +998,8 @@ def __init__(self, src_vocab_size: int, Activation of the encoder layer enc_pre_norm Whether to add layer_norm before self-attention in the encoder + enc_use_qkv_bias + Wether to use bias for attention layer in the encoder dec_units Units of the decoder dec_hidden_size @@ -1002,6 +1014,8 @@ def __init__(self, src_vocab_size: int, Activation of the decoder layer dec_pre_norm Whether to add layer_norm before self-attention in the decoder + dec_use_qkv_bias + Wether to use bias for attention layer in the decoder embed_initializer Initializer of the embedding layer weight_initializer @@ -1017,17 +1031,20 @@ def __init__(self, src_vocab_size: int, assert src_vocab_size > 0 and tgt_vocab_size > 0,\ 'Cannot set "src_vocab_size" and "tgt_vocab_size" to negative numbers. ' \ 'Are you creating ' \ - 'the model with the config from TransformerNMTModel.get_cfg()? If that is ' \ + 'the model with the config from TransformerModel.get_cfg()? If that is ' \ 'the case, you will need to set the cfg.MODEL.src_vocab_size and ' \ 'cfg.MODEL.tgt_vocab_size manually before passing to ' \ - 'TransformerNMTModel.from_cfg().' + 'TransformerModel.from_cfg().' self._dtype = dtype self._src_vocab_size = src_vocab_size self._tgt_vocab_size = tgt_vocab_size + self.tie_weights = tie_weights self.pos_embed_type = pos_embed_type self.scaled_embed = scale_embed self.enc_units = enc_units self.dec_units = dec_units + self.weight_initializer = weight_initializer + self.bias_initializer = bias_initializer self._layout = layout assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ 'Only "TN" and "NT" are accepted!'.format(layout) @@ -1064,6 +1081,7 @@ def __init__(self, src_vocab_size: int, hidden_size=enc_hidden_size, num_heads=enc_num_heads, activation_dropout=activation_dropout, + use_qkv_bias=enc_use_qkv_bias, dropout=dropout, attention_dropout=attention_dropout, layer_norm_eps=layer_norm_eps, @@ -1081,6 +1099,7 @@ def __init__(self, src_vocab_size: int, hidden_size=dec_hidden_size, num_heads=dec_num_heads, activation_dropout=activation_dropout, + use_qkv_bias=dec_use_qkv_bias, dropout=dropout, attention_dropout=attention_dropout, layer_norm_eps=layer_norm_eps, @@ -1092,7 +1111,7 @@ def __init__(self, src_vocab_size: int, dtype=self._dtype, layout=layout) if tie_weights: - self.tgt_final_layer =\ + self.tgt_final_layer = \ nn.Dense(tgt_vocab_size, flatten=False, bias_initializer=bias_initializer, use_bias=False, @@ -1154,6 +1173,7 @@ def encode(self, F, src_data, src_valid_length): else: src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer( F.npx.arange_like(src_data, axis=0)), axis=1) + enc_out = self.encoder(src_data, src_valid_length) return enc_out @@ -1196,8 +1216,8 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length): else: tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer( F.npx.arange_like(tgt_data, axis=0)), axis=1) + dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length) - dec_out = self.tgt_final_layer(dec_out) return dec_out def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_length): @@ -1231,15 +1251,16 @@ def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_leng """ enc_out = self.encode(F, src_data, src_valid_length) dec_out = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length) + dec_out = self.tgt_final_layer(dec_out) return dec_out @classmethod def get_cfg(cls, key=None): if key is None: # Use Transformer WMT EN-DE Base - return transformer_nmt_base() + return transformer_base() else: - return transformer_nmt_cfg_reg.create(key) + return transformer_cfg_reg.create(key) @classmethod def from_cfg(cls, cfg, dtype=None): @@ -1267,6 +1288,7 @@ def from_cfg(cls, cfg, dtype=None): enc_recurrent=cfg.MODEL.ENCODER.recurrent, enc_activation=cfg.MODEL.ENCODER.activation, enc_pre_norm=cfg.MODEL.ENCODER.pre_norm, + enc_use_qkv_bias=cfg.MODEL.ENCODER.use_qkv_bias, dec_num_layers=cfg.MODEL.DECODER.num_layers, dec_units=cfg.MODEL.DECODER.units, dec_num_heads=cfg.MODEL.DECODER.num_heads, @@ -1274,6 +1296,7 @@ def from_cfg(cls, cfg, dtype=None): dec_recurrent=cfg.MODEL.DECODER.recurrent, dec_activation=cfg.MODEL.DECODER.activation, dec_pre_norm=cfg.MODEL.DECODER.pre_norm, + dec_use_qkv_bias=cfg.MODEL.DECODER.use_qkv_bias, layout=cfg.MODEL.layout, embed_initializer=embed_initializer, weight_initializer=weight_initializer, @@ -1296,7 +1319,7 @@ def __init__(self, model): def initialize(self, **kwargs): # Manually disable the initialize raise NotImplementedError('You can not initialize a TransformerNMTFastInference Model! ' - 'The correct approach is to create a TransformerNMTModel and ' + 'The correct approach is to create a TransformerModel and ' 'then build the TransformerNMTInference with the given model.') @property diff --git a/tests/test_models_bart.py b/tests/test_models_bart.py new file mode 100644 index 0000000000..d6130b63fb --- /dev/null +++ b/tests/test_models_bart.py @@ -0,0 +1,52 @@ +import pytest +import numpy as np +import mxnet as mx +import tempfile +from gluonnlp.models.bart import BartModel, \ + list_pretrained_bart, get_pretrained_bart, bart_cfg_reg +from gluonnlp.utils.testing import verify_nmt_model + +mx.npx.set_np() + + +def test_list_pretrained_bart(): + assert len(list_pretrained_bart()) > 0 + + +@pytest.mark.remote_required +@pytest.mark.parametrize('model_name', list_pretrained_bart()) +def test_bart(model_name): + # test from pretrained + assert len(list_pretrained_bart()) > 0 + with tempfile.TemporaryDirectory() as root: + cfg, tokenizer, params_path, _ =\ + get_pretrained_bart(model_name, load_backbone=True, root=root) + assert cfg.MODEL.vocab_size == len(tokenizer.vocab) + # test standard bart encoder and decoder + bart_model = BartModel.from_cfg(cfg) + bart_model.load_parameters(params_path) + # test bart encoder and decoder with pooler + bart_model_with_pooler = BartModel.from_cfg( + cfg, use_pooler=True, classifier_activation=False) + bart_model_with_pooler.load_parameters(params_path) + + +def test_bart_cfg_registry(): + assert len(bart_cfg_reg.list_keys()) > 0 + +@pytest.mark.parametrize('cfg_key', bart_cfg_reg.list_keys()) +def test_bart_cfg(cfg_key): + cfg = BartModel.get_cfg(cfg_key) + cfg.defrost() + cfg.MODEL.vocab_size = 32 + cfg.freeze() + model = BartModel.from_cfg(cfg) + model.initialize() + model.hybridize() + cfg.defrost() + cfg.MODEL.layout = 'TN' + cfg.freeze() + model_tn = BartModel.from_cfg(cfg) + model_tn.share_parameters(model.collect_params()) + model_tn.hybridize() + mx.npx.waitall() diff --git a/tests/test_models_transformer.py b/tests/test_models_transformer.py index e9b1cd6184..96cb60ee1d 100644 --- a/tests/test_models_transformer.py +++ b/tests/test_models_transformer.py @@ -3,8 +3,8 @@ from numpy.testing import assert_allclose from gluonnlp.models.transformer import\ TransformerEncoder, TransformerDecoder, \ - TransformerNMTModel, TransformerNMTInference,\ - transformer_nmt_cfg_reg + TransformerModel, TransformerNMTInference,\ + transformer_cfg_reg from gluonnlp.attention_cell import gen_mem_attn_mask, gen_self_attn_mask from gluonnlp.utils.testing import verify_nmt_model, verify_nmt_inference mx.npx.set_np() @@ -117,26 +117,26 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize, shared_embed = False else: shared_embed = True - model = TransformerNMTModel(src_vocab_size=src_vocab_size, - tgt_vocab_size=tgt_vocab_size, - max_src_length=src_seq_length, - max_tgt_length=tgt_seq_length, - enc_units=enc_units, - enc_hidden_size=64, - enc_num_heads=4, - enc_num_layers=enc_num_layers, - enc_pre_norm=enc_pre_norm, - enc_recurrent=enc_recurrent, - dec_units=dec_units, - dec_hidden_size=64, - dec_num_heads=4, - dec_num_layers=dec_num_layers, - dec_pre_norm=dec_pre_norm, - dec_recurrent=dec_recurrent, - shared_embed=shared_embed, - tie_weights=tie_weights, - dropout=0.0, - layout=layout) + model = TransformerModel(src_vocab_size=src_vocab_size, + tgt_vocab_size=tgt_vocab_size, + max_src_length=src_seq_length, + max_tgt_length=tgt_seq_length, + enc_units=enc_units, + enc_hidden_size=64, + enc_num_heads=4, + enc_num_layers=enc_num_layers, + enc_pre_norm=enc_pre_norm, + enc_recurrent=enc_recurrent, + dec_units=dec_units, + dec_hidden_size=64, + dec_num_heads=4, + dec_num_layers=dec_num_layers, + dec_pre_norm=dec_pre_norm, + dec_recurrent=dec_recurrent, + shared_embed=shared_embed, + tie_weights=tie_weights, + dropout=0.0, + layout=layout) inference_model = TransformerNMTInference(model=model) model.initialize() if train_hybridize: @@ -148,23 +148,23 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize, def test_transformer_cfg_registry(): - assert len(transformer_nmt_cfg_reg.list_keys()) > 0 + assert len(transformer_cfg_reg.list_keys()) > 0 -@pytest.mark.parametrize('cfg_key', transformer_nmt_cfg_reg.list_keys()) +@pytest.mark.parametrize('cfg_key', transformer_cfg_reg.list_keys()) def test_transformer_cfg(cfg_key): - cfg = TransformerNMTModel.get_cfg(cfg_key) + cfg = TransformerModel.get_cfg(cfg_key) cfg.defrost() cfg.MODEL.src_vocab_size = 32 cfg.MODEL.tgt_vocab_size = 32 cfg.freeze() - model = TransformerNMTModel.from_cfg(cfg) + model = TransformerModel.from_cfg(cfg) model.initialize() model.hybridize() cfg.defrost() cfg.MODEL.layout = 'TN' cfg.freeze() - model_tn = TransformerNMTModel.from_cfg(cfg) + model_tn = TransformerModel.from_cfg(cfg) model_tn.share_parameters(model.collect_params()) model_tn.hybridize() mx.npx.waitall()