Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent 6c62a29 commit 5bab516
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,19 @@ def get_pretrained_bart(model_name: str = 'fairseq_roberta_base',
assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format(
model_name, list_pretrained_bart())
cfg_path = PRETRAINED_URL[model_name]['cfg']
if isinstance(cfg_path, CN):
cfg = cfg_path
else:
cfg = None
merges_path = PRETRAINED_URL[model_name]['merges']
vocab_path = PRETRAINED_URL[model_name]['vocab']
params_path = PRETRAINED_URL[model_name]['params']

local_paths = dict()
for k, path in [('cfg', cfg_path), ('vocab', vocab_path),
('merges', merges_path)]:
download_jobs = [('vocab', vocab_path), ('merges', merges_path)]
if cfg is None:
download_jobs.append(('cfg', cfg_path))
for k, path in download_jobs:
local_paths[k] = download(url=get_repo_model_zoo_url() + path,
path=os.path.join(root, path),
sha1_hash=FILE_STATS[path])
Expand All @@ -339,7 +345,8 @@ def get_pretrained_bart(model_name: str = 'fairseq_roberta_base',
merges_file=local_paths['merges'],
vocab_file=local_paths['vocab'],
lowercase=do_lower)
cfg = BartModel.get_cfg().clone_merge(local_paths['cfg'])
if cfg is None:
cfg = BartModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_params_path, local_mlm_params_path


Expand Down

0 comments on commit 5bab516

Please sign in to comment.