Skip to content
Open
32 changes: 32 additions & 0 deletions docs/source/en/model_doc/dinov3.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ pooled_output = outputs.pooler_output
print("Pooled output shape:", pooled_output.shape)
```

</hfoption>
<hfoption id="AutoModelForImageClassification">

```py
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from transformers.image_utils import load_image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = load_image(url)

checkpoint = "dimidagd/dinov3-vit7b16-pretrain-lvd1689m-imagenet1k-lc"
processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(
checkpoint,
dtype=torch.bfloat16,
device_map="auto",
)

inputs = processor(images=image, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model(**inputs)

predicted_class_idx = outputs.logits.argmax(-1).item()
print(model.config.id2label[predicted_class_idx])
```

</hfoption>
</hfoptions>

Expand Down Expand Up @@ -173,6 +200,11 @@ print("Pooled output shape:", pooled_output.shape)

[[autodoc]] DINOv3ViTBackbone

## DINOv3ViTForImageClassification

[[autodoc]] DINOv3ViTForImageClassification
- forward

## DINOv3ConvNextModel

[[autodoc]] DINOv3ConvNextModel
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("dinat", "DinatForImageClassification"),
("dinov2", "Dinov2ForImageClassification"),
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
("dinov3_vit", "DINOv3ViTForImageClassification"),
("donut-swin", "DonutSwinForImageClassification"),
("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"),
Expand Down
62 changes: 60 additions & 2 deletions src/transformers/models/dinov3_vit/modeling_dinov3_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...backbone_utils import BackboneMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPooling
from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
Expand Down Expand Up @@ -611,4 +611,62 @@ def forward(
return output


__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel", "DINOv3ViTBackbone"]
@auto_docstring(
custom_intro="""
DINOv3ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet.
"""
)
class DINOv3ViTForImageClassification(DINOv3ViTPreTrainedModel):
def __init__(self, config: DINOv3ViTConfig) -> None:
super().__init__(config)

self.num_labels = config.num_labels
self.dinov3_vit = DINOv3ViTModel(config)

# Classifier head
self.classifier = (
nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.dinov3_vit.embeddings.patch_embeddings

@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> ImageClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs: BaseModelOutputWithPooling = self.dinov3_vit(pixel_values, **kwargs)

sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
cls_token = sequence_output[:, 0]
patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
logits = self.classifier(linear_input)

loss = None
if labels is not None:
loss = self.loss_function(labels, logits, self.config, **kwargs)

return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel", "DINOv3ViTBackbone", "DINOv3ViTForImageClassification"]
62 changes: 60 additions & 2 deletions src/transformers/models/dinov3_vit/modular_dinov3_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ... import initialization as init
from ...backbone_utils import BackboneMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPooling
from ...modeling_outputs import BackboneOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
Expand Down Expand Up @@ -507,4 +507,62 @@ def forward(
return output


__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel", "DINOv3ViTBackbone"]
@auto_docstring(
custom_intro="""
DINOv3ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet.
"""
)
class DINOv3ViTForImageClassification(DINOv3ViTPreTrainedModel):
def __init__(self, config: DINOv3ViTConfig) -> None:
super().__init__(config)

self.num_labels = config.num_labels
self.dinov3_vit = DINOv3ViTModel(config)

# Classifier head
self.classifier = (
nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.dinov3_vit.embeddings.patch_embeddings

@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> ImageClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs: BaseModelOutputWithPooling = self.dinov3_vit(pixel_values, **kwargs)

sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
cls_token = sequence_output[:, 0]
patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
logits = self.classifier(linear_input)

loss = None
if labels is not None:
loss = self.loss_function(labels, logits, self.config, **kwargs)

return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel", "DINOv3ViTBackbone", "DINOv3ViTForImageClassification"]
67 changes: 64 additions & 3 deletions tests/models/dinov3_vit/test_modeling_dinov3_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from functools import cached_property

from transformers import DINOv3ViTConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_large_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_vision_available

from ...test_configuration_common import ConfigTester
Expand All @@ -29,7 +35,7 @@
import torch
from torch import nn

from transformers import DINOv3ViTBackbone, DINOv3ViTModel
from transformers import DINOv3ViTBackbone, DINOv3ViTForImageClassification, DINOv3ViTModel


if is_vision_available():
Expand Down Expand Up @@ -169,6 +175,25 @@ def create_and_check_model(self, config, pixel_values, labels):
(self.batch_size, self.seq_length, self.hidden_size),
)

def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
torch_device_override = "cpu" # Required, or else VRAM is not enough.
config.device_map = torch_device_override
model = DINOv3ViTForImageClassification(config)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))

# test greyscale images
config.num_channels = 1

model = DINOv3ViTForImageClassification(config)
model.eval()

pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]).to(torch_device_override)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Expand All @@ -187,7 +212,9 @@ class Dinov3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""

all_model_classes = (DINOv3ViTModel, DINOv3ViTBackbone) if is_torch_available() else ()
all_model_classes = (
(DINOv3ViTModel, DINOv3ViTBackbone, DINOv3ViTForImageClassification) if is_torch_available() else ()
)
pipeline_model_mapping = (
{
"image-feature-extraction": DINOv3ViTModel,
Expand Down Expand Up @@ -224,6 +251,10 @@ def test_model_get_set_embeddings(self):
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))

def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
Expand Down Expand Up @@ -256,6 +287,36 @@ def default_image_processor(self):
else None
)

@require_torch_large_accelerator
@slow
def test_inference_lc_head_imagenet(self):
torch_device_override = "cpu"
model = DINOv3ViTForImageClassification.from_pretrained(
"dimidagd/dinov3-vit7b16-pretrain-lvd1689m-imagenet1k-lc", device_map=torch_device_override
)

ground_truth_class_imagenet1 = "tabby, tabby cat"
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device_override)

# forward pass
with torch.no_grad():
outputs = model(**inputs)

# Verify logits
expected_logits = torch.tensor([-1.0708860159, -0.7589257956, -1.1738269329, -0.9263097048, -1.0259437561]).to(
torch_device_override
)

torch.testing.assert_close(outputs.logits[0, : len(expected_logits)], expected_logits, rtol=1e-4, atol=1e-4)

# Test correct class prediction
predicted_class_idx = outputs.logits.argmax(-1).item()
predicted_class_str = model.config.id2label[predicted_class_idx]

self.assertEqual(predicted_class_str, ground_truth_class_imagenet1)

@slow
def test_inference_no_head(self):
model = DINOv3ViTModel.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m").to(torch_device)
Expand Down
Loading