diff --git a/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py index 77247495520c..8973c48c909f 100755 --- a/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py @@ -113,7 +113,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder fsmt_folder_path, checkpoint_file, data_name_or_path, archive_map=models, **kwargs ) - args = dict(vars(chkpt["args"])) + args = vars(chkpt["args"]["model"]) src_lang = args["source_lang"] tgt_lang = args["target_lang"] @@ -129,7 +129,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder src_vocab = rewrite_dict_keys(src_dict.indices) src_vocab_size = len(src_vocab) src_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-src.json") - print(f"Generating {src_vocab_file}") + print(f"Generating {src_vocab_file} of {src_vocab_size} of {src_lang} records") with open(src_vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent)) @@ -137,13 +137,16 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder tgt_vocab = rewrite_dict_keys(tgt_dict.indices) tgt_vocab_size = len(tgt_vocab) tgt_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-tgt.json") - print(f"Generating {tgt_vocab_file}") + print(f"Generating {tgt_vocab_file} of {tgt_vocab_size} of {tgt_lang} records") with open(tgt_vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(tgt_vocab, ensure_ascii=False, indent=json_indent)) # merges_file (bpecodes) merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"]) - fsmt_merges_file = os.path.join(fsmt_folder_path, "bpecodes") + for fn in ["bpecodes", "code"]: # older fairseq called the merges file "code" + fsmt_merges_file = os.path.join(fsmt_folder_path, fn) + if os.path.exists(fsmt_merges_file): + break with open(fsmt_merges_file, encoding="utf-8") as fin: merges = fin.read() merges = re.sub(r" \d+$", "", merges, 0, re.M) # remove frequency number @@ -248,10 +251,6 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder print("\nLast step is to upload the files to s3") print(f"cd {data_root}") print(f"transformers-cli upload {model_dir}") - print( - "Note: CDN caches files for up to 24h, so either use a local model path " - "or use `from_pretrained(mname, use_cdn=False)` to use the non-cached version." - ) if __name__ == "__main__":