diff --git a/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py index d31cf67c1e3f..c5919b94d42f 100644 --- a/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py @@ -18,7 +18,7 @@ import torch -from transformers import BartConfig, BartForConditionalGeneration +from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration from transformers.utils import logging @@ -81,8 +81,8 @@ def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_ """ model = torch.load(checkpoint_path, map_location="cpu") sd = model["model"] - cfg = BartConfig.from_json_file(config_json_path) - m = BartForConditionalGeneration(cfg) + cfg = BlenderbotConfig.from_json_file(config_json_path) + m = BlenderbotForConditionalGeneration(cfg) valid_keys = m.model.state_dict().keys() failures = [] mapping = {}