From 5bab5163b64bc3ed35069243dbcd1b5a634e0024 Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Wed, 29 Jul 2020 21:34:47 +0800 Subject: [PATCH] fix cfg --- src/gluonnlp/models/bart.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index 3f2ae159c9..b6f7a7c32d 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -315,13 +315,19 @@ def get_pretrained_bart(model_name: str = 'fairseq_roberta_base', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_bart()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None merges_path = PRETRAINED_URL[model_name]['merges'] vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('vocab', vocab_path), - ('merges', merges_path)]: + download_jobs = [('vocab', vocab_path), ('merges', merges_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: local_paths[k] = download(url=get_repo_model_zoo_url() + path, path=os.path.join(root, path), sha1_hash=FILE_STATS[path]) @@ -339,7 +345,8 @@ def get_pretrained_bart(model_name: str = 'fairseq_roberta_base', merges_file=local_paths['merges'], vocab_file=local_paths['vocab'], lowercase=do_lower) - cfg = BartModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = BartModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path