diff --git a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py index c6169c0ff5e7..bcf155ad7257 100644 --- a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py +++ b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py @@ -300,6 +300,10 @@ def main(): if args.config_file == "": # Default config of megatron-bert 345m config = MegatronBertConfig() + + # different megatron-bert-*-345m models have different vocab sizes, so override the default + # config (which is for megatron-bert-cased-345m) with the actual vocab dimension + config.vocab_size = input_state_dict["model"]["lm_head"]["bias"].numel() else: config = MegatronBertConfig.from_json_file(args.config_file)