Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@
- local: model_doc/dinat
title: DiNAT
- local: model_doc/dinov2
title: DINO V2
title: DINOV2
- local: model_doc/dit
title: DiT
- local: model_doc/dpt
Expand Down
63 changes: 53 additions & 10 deletions src/transformers/models/dinov2/convert_dinov2_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@


import argparse
import json
from pathlib import Path

import requests
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms

from transformers import BitImageProcessor, Dinov2Config, Dinov2Model
from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
from transformers.utils import logging

Expand All @@ -35,7 +38,7 @@
logger = logging.get_logger(__name__)


def get_dinov2_config(model_name):
def get_dinov2_config(model_name, image_classifier=False):
config = Dinov2Config(image_size=518, patch_size=14)

# size of the architecture
Expand All @@ -56,6 +59,13 @@ def get_dinov2_config(model_name):
else:
raise ValueError("Model not supported")

if image_classifier:
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
config.num_labels = 1000
config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
config.id2label = {int(k): v for k, v in config.id2label.items()}

return config


Expand Down Expand Up @@ -140,10 +150,11 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
"""

# define default Dinov2 configuration
config = get_dinov2_config(model_name)
image_classifier = "1layer" in model_name
config = get_dinov2_config(model_name, image_classifier=image_classifier)

# load original model from torch hub
original_model = torch.hub.load("facebookresearch/dinov2", model_name)
original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
original_model.eval()

# load state_dict of original model, remove and rename some keys
Expand All @@ -162,8 +173,22 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
state_dict[key] = val

# load HuggingFace model
model = Dinov2Model(config, add_pooling_layer=False).eval()
model.load_state_dict(state_dict)
if image_classifier:
model = Dinov2ForImageClassification(config).eval()
model.dinov2.load_state_dict(state_dict)
model_name_to_classifier_dict_url = {
"dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth",
"dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth",
"dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth",
"dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth",
}
url = model_name_to_classifier_dict_url[model_name]
classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
else:
model = Dinov2Model(config).eval()
model.load_state_dict(state_dict)

# load image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand Down Expand Up @@ -195,12 +220,17 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
assert torch.allclose(original_pixel_values, pixel_values)

with torch.no_grad():
outputs = model(pixel_values)
outputs = model(pixel_values, output_hidden_states=True)
original_outputs = original_model(pixel_values)

# assert values
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
if image_classifier:
print("Predicted class:")
class_idx = outputs.logits.argmax(-1).item()
print(model.config.id2label[class_idx])
else:
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
print("Looks ok!")

if pytorch_dump_folder_path is not None:
Expand All @@ -216,6 +246,10 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
"dinov2_vitb14": "dinov2-base",
"dinov2_vitl14": "dinov2-large",
"dinov2_vitg14": "dinov2-giant",
"dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer",
"dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer",
"dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer",
"dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer",
}

name = model_name_to_hf_name[model_name]
Expand All @@ -230,7 +264,16 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
"--model_name",
default="dinov2_vitb14",
type=str,
choices=["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"],
choices=[
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vitg14",
"dinov2_vits14_1layer",
"dinov2_vitb14_1layer",
"dinov2_vitl14_1layer",
"dinov2_vitg14_1layer",
],
help="Name of the model you'd like to convert.",
)
parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"


DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -693,6 +694,7 @@ def __init__(self, config: Dinov2Config) -> None:
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
Expand Down