From 9a55c618039dd47fd6f30ea131db321e67b84173 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Sun, 22 Jan 2023 20:51:02 +0100 Subject: [PATCH 1/2] Extend conversion script --- .../models/git/convert_git_to_pytorch.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/git/convert_git_to_pytorch.py b/src/transformers/models/git/convert_git_to_pytorch.py index f072db7cd9a9..2f826e62572a 100644 --- a/src/transformers/models/git/convert_git_to_pytorch.py +++ b/src/transformers/models/git/convert_git_to_pytorch.py @@ -246,6 +246,11 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal "git-large-msrvtt-qa": ( "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_MSRVTT_QA/snapshot/model.pt" ), + "git-large-r": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R/snapshot/model.pt", + "git-large-r-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_COCO/snapshot/model.pt", + "git-large-r-textcaps": ( + "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_TEXTCAPS/snapshot/model.pt" + ), } model_name_to_path = { @@ -258,7 +263,7 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal # define GIT configuration based on model name config, image_size, is_video = get_git_config(model_name) - if "large" in model_name and not is_video: + if "large" in model_name and not is_video and "large-r" not in model_name: # large checkpoints take way too long to download checkpoint_path = model_name_to_path[model_name] state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] @@ -320,6 +325,7 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal outputs = model(input_ids, pixel_values=pixel_values) logits = outputs.logits print("Logits:", logits[0, -1, :3]) + print("Model name:", model_name) if model_name == "git-base": expected_slice_logits = torch.tensor([-1.2832, -1.2835, -1.2840]) @@ -349,6 +355,12 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal expected_slice_logits = torch.tensor([-1.0113, -1.0114, -1.0113]) elif model_name == "git-large-msrvtt-qa": expected_slice_logits = torch.tensor([0.0130, 0.0134, 0.0131]) + elif model_name == "git-large-r": + expected_slice_logits = torch.tensor([-1.1283, -1.1285, -1.1286]) + elif model_name == "git-large-r-coco": + expected_slice_logits = torch.tensor([-0.9641, -0.9641, -0.9641]) + elif model_name == "git-large-r-textcaps": + expected_slice_logits = torch.tensor([-1.1121, -1.1120, -1.1124]) assert torch.allclose(logits[0, -1, :3], expected_slice_logits, atol=1e-4) print("Looks ok!") From 4f34e050339154b364d7a03e404a30c78c6d73df Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Sun, 22 Jan 2023 21:16:03 +0100 Subject: [PATCH 2/2] Remove print statement --- src/transformers/models/git/convert_git_to_pytorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/git/convert_git_to_pytorch.py b/src/transformers/models/git/convert_git_to_pytorch.py index 2f826e62572a..e4ebf645ae7a 100644 --- a/src/transformers/models/git/convert_git_to_pytorch.py +++ b/src/transformers/models/git/convert_git_to_pytorch.py @@ -325,7 +325,6 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal outputs = model(input_ids, pixel_values=pixel_values) logits = outputs.logits print("Logits:", logits[0, -1, :3]) - print("Model name:", model_name) if model_name == "git-base": expected_slice_logits = torch.tensor([-1.2832, -1.2835, -1.2840])