Skip to content

Commit

Permalink
fix get_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 13, 2020
1 parent 8ed8a72 commit cd0509d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
12 changes: 6 additions & 6 deletions scripts/conversion_toolkits/convert_fairseq_xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def convert_fairseq_model(args):

ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu()
for is_mlm in [False, True]:
gluon_xlmr = convert_params(fairseq_roberta,
gluon_cfg,
ctx,
is_mlm=is_mlm,
gluon_prefix='roberta_')
gluon_xlmr = convert_params(fairseq_xlmr,
gluon_cfg,
ctx,
is_mlm=is_mlm,
gluon_prefix='roberta_')

if is_mlm:
if args.test:
test_model(fairseq_roberta, gluon_xlmr, args.gpu)
test_model(fairseq_xlmr, gluon_xlmr, args.gpu)

gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model_mlm.params'), deduplicate=True)
logging.info('Convert the RoBERTa MLM model in {} to {}'.
Expand Down
19 changes: 15 additions & 4 deletions src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'cfg': 'fairseq_roberta_base/model-565d1db7.yml',
'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab',
'params': 'fairseq_roberta_base/model-09a1520a.params'
'params': 'fairseq_roberta_base/model-09a1520a.params',
'mlm_params': 'google_uncased_mobilebert/model_mlm-29889e2b.params',
},
'fairseq_roberta_large': {
Expand Down Expand Up @@ -495,7 +495,8 @@ def list_pretrained_roberta():

def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True) \
load_backbone: bool = True,
load_mlm: bool = False) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
"""Get the pretrained RoBERTa weights
Expand All @@ -507,6 +508,8 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
The downloading root
load_backbone
Whether to load the weights of the backbone network
load_mlm
Whether to load the weights of MLM
Returns
-------
Expand All @@ -516,13 +519,16 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
The HuggingFaceByteBPETokenizer
params_path
Path to the parameters
mlm_params_path
Path to the parameter that includes both the backbone and the MLM
"""
assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format(
model_name, list_pretrained_roberta())
cfg_path = PRETRAINED_URL[model_name]['cfg']
merges_path = PRETRAINED_URL[model_name]['merges']
vocab_path = PRETRAINED_URL[model_name]['vocab']
params_path = PRETRAINED_URL[model_name]['params']
mlm_params_path = PRETRAINED_URL[model_name]['mlm_params']
local_paths = dict()
for k, path in [('cfg', cfg_path), ('vocab', vocab_path),
('merges', merges_path)]:
Expand All @@ -535,10 +541,15 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
sha1_hash=FILE_STATS[params_path])
else:
local_params_path = None

if load_mlm:
local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path,
path=os.path.join(root, mlm_params_path),
sha1_hash=FILE_STATS[mlm_params_path])
else:
local_mlm_params_path = None
tokenizer = HuggingFaceByteBPETokenizer(local_paths['merges'], local_paths['vocab'])
cfg = RobertaModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_params_path
return cfg, tokenizer, local_params_path, local_mlm_params_path


BACKBONE_REGISTRY.register('roberta', [RobertaModel,
Expand Down
22 changes: 17 additions & 5 deletions src/gluonnlp/models/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing import Tuple
import os
from mxnet import use_np
from .roberta import RobertaModel, RobertaForMLM roberta_base, roberta_large
from .roberta import RobertaModel, RobertaForMLM, roberta_base, roberta_large
from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir
from ..utils.config import CfgNode as CN
from ..utils.registry import Registry
Expand Down Expand Up @@ -84,15 +84,16 @@ def get_cfg(key=None):
return xlmr_base()
@use_np
class XLMRForMLM(RobertaForMLM):
super().__init__()
pass

def list_pretrained_xlmr():
return sorted(list(PRETRAINED_URL.keys()))


def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True) \
load_backbone: bool = True,
load_mlm: bool = False) \
-> Tuple[CN, SentencepieceTokenizer, str]:
"""Get the pretrained XLM-R weights
Expand All @@ -104,21 +105,26 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
The downloading root
load_backbone
Whether to load the weights of the backbone network
load_mlm
Whether to load the weights of MLM
Returns
-------
cfg
Network configuration
tokenizer
The SentencepieceTokenizer
The HuggingFaceByteBPETokenizer
params_path
Path to the parameters
mlm_params_path
Path to the parameter that includes both the backbone and the MLM
"""
assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format(
model_name, list_pretrained_xlmr())
cfg_path = PRETRAINED_URL[model_name]['cfg']
sp_model_path = PRETRAINED_URL[model_name]['sentencepiece.model']
params_path = PRETRAINED_URL[model_name]['params']
mlm_params_path = PRETRAINED_URL[model_name]['mlm_params']
local_paths = dict()
for k, path in [('cfg', cfg_path), ('sentencepiece.model', sp_model_path)]:
local_paths[k] = download(url=get_repo_model_zoo_url() + path,
Expand All @@ -130,10 +136,16 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
sha1_hash=FILE_STATS[params_path])
else:
local_params_path = None
if load_mlm:
local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path,
path=os.path.join(root, mlm_params_path),
sha1_hash=FILE_STATS[mlm_params_path])
else:
local_mlm_params_path = None

tokenizer = SentencepieceTokenizer(local_paths['sentencepiece.model'])
cfg = XLMRModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_params_path
return cfg, tokenizer, local_params_path, local_mlm_params_path


BACKBONE_REGISTRY.register('xlmr', [XLMRModel,
Expand Down

0 comments on commit cd0509d

Please sign in to comment.