From c9046a3bafa171e630e88bcaa46bede02a04ce8c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 29 Mar 2022 11:31:02 +0200 Subject: [PATCH] use blenderbot config and model --- ...ert_blenderbot_original_pytorch_checkpoint_to_pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py index d31cf67c1e3f..c5919b94d42f 100644 --- a/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py @@ -18,7 +18,7 @@ import torch -from transformers import BartConfig, BartForConditionalGeneration +from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration from transformers.utils import logging @@ -81,8 +81,8 @@ def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_ """ model = torch.load(checkpoint_path, map_location="cpu") sd = model["model"] - cfg = BartConfig.from_json_file(config_json_path) - m = BartForConditionalGeneration(cfg) + cfg = BlenderbotConfig.from_json_file(config_json_path) + m = BlenderbotForConditionalGeneration(cfg) valid_keys = m.model.state_dict().keys() failures = [] mapping = {}