diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index 3f2ae159c9..b6f7a7c32d 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -315,13 +315,19 @@ def get_pretrained_bart(model_name: str = 'fairseq_roberta_base', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_bart()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None merges_path = PRETRAINED_URL[model_name]['merges'] vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('vocab', vocab_path), - ('merges', merges_path)]: + download_jobs = [('vocab', vocab_path), ('merges', merges_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: local_paths[k] = download(url=get_repo_model_zoo_url() + path, path=os.path.join(root, path), sha1_hash=FILE_STATS[path]) @@ -339,7 +345,8 @@ def get_pretrained_bart(model_name: str = 'fairseq_roberta_base', merges_file=local_paths['merges'], vocab_file=local_paths['vocab'], lowercase=do_lower) - cfg = BartModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = BartModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path