Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/auto.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ The following auto classes are available for the following computer vision tasks

[[autodoc]] AutoModelForInstanceSegmentation

### AutoModelForUniversalSegmentation

[[autodoc]] AutoModelForUniversalSegmentation

### AutoModelForZeroShotObjectDetection

[[autodoc]] AutoModelForZeroShotObjectDetection
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,7 @@
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
Expand Down Expand Up @@ -974,6 +975,7 @@
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForUniversalSegmentation",
"AutoModelForVideoClassification",
"AutoModelForVision2Seq",
"AutoModelForVisualQuestionAnswering",
Expand Down Expand Up @@ -4112,6 +4114,7 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
Expand Down Expand Up @@ -4143,6 +4146,7 @@
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForUniversalSegmentation,
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
Expand Down Expand Up @@ -97,6 +98,7 @@
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForUniversalSegmentation",
"AutoModelForVideoClassification",
"AutoModelForVision2Seq",
"AutoModelForVisualQuestionAnswering",
Expand Down Expand Up @@ -222,6 +224,7 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
Expand Down Expand Up @@ -253,6 +256,7 @@
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForUniversalSegmentation,
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
Expand Down
20 changes: 20 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,18 @@
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Instance Segmentation mapping
# MaskFormerForInstanceSegmentation can be removed from this mapping in v5
("maskformer", "MaskFormerForInstanceSegmentation"),
]
)

MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Universal Segmentation mapping
("maskformer", "MaskFormerForUniversalSegmentation"),
]
)

MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("timesformer", "TimesformerForVideoClassification"),
Expand Down Expand Up @@ -891,6 +899,9 @@
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
)
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
)
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
Expand Down Expand Up @@ -1082,6 +1093,15 @@ class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
)


class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING


AutoModelForUniversalSegmentation = auto_class_update(
AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
)


class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/maskformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
else:
_import_structure["modeling_maskformer"] = [
"MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"MaskFormerForInstanceSegmentation",
"MaskFormerForUniversalSegmentation",
"MaskFormerModel",
"MaskFormerPreTrainedModel",
]
Expand Down Expand Up @@ -73,7 +73,7 @@
else:
from .modeling_maskformer import (
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
MaskFormerForInstanceSegmentation,
MaskFormerForUniversalSegmentation,
MaskFormerModel,
MaskFormerPreTrainedModel,
)
Expand Down
31 changes: 21 additions & 10 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import math
import random
import warnings
from dataclasses import dataclass
from numbers import Number
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -188,9 +189,9 @@ class MaskFormerModelOutput(ModelOutput):


@dataclass
class MaskFormerForInstanceSegmentationOutput(ModelOutput):
class MaskFormerForUniversalSegmentationOutput(ModelOutput):
"""
Class for outputs of [`MaskFormerForInstanceSegmentation`].
Class for outputs of [`MaskFormerForUniversalSegmentation`].

This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or or
[`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
Expand Down Expand Up @@ -1633,7 +1634,7 @@ def forward(
return output


class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
class MaskFormerForUniversalSegmentation(MaskFormerPreTrainedModel):
def __init__(self, config: MaskFormerConfig):
super().__init__(config)
self.model = MaskFormerModel(config)
Expand Down Expand Up @@ -1715,7 +1716,7 @@ def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Di
return class_queries_logits, masks_queries_logits, auxiliary_logits

@add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskFormerForInstanceSegmentationOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=MaskFormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
Expand All @@ -1726,7 +1727,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> MaskFormerForInstanceSegmentationOutput:
) -> MaskFormerForUniversalSegmentationOutput:
r"""
mask_labels (`List[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
Expand All @@ -1741,13 +1742,13 @@ def forward(
Semantic segmentation example:

```python
>>> from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
>>> from transformers import MaskFormerImageProcessor, MaskFormerForUniversalSegmentation
>>> from PIL import Image
>>> import requests

>>> # load MaskFormer fine-tuned on ADE20k semantic segmentation
>>> image_processor = MaskFormerImageProcessor.from_pretrained("facebook/maskformer-swin-base-ade")
>>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade")
>>> model = MaskFormerForUniversalSegmentation.from_pretrained("facebook/maskformer-swin-base-ade")

>>> url = (
... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
Expand All @@ -1774,13 +1775,13 @@ def forward(
Panoptic segmentation example:

```python
>>> from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
>>> from transformers import MaskFormerImageProcessor, MaskFormerForUniversalSegmentation
>>> from PIL import Image
>>> import requests

>>> # load MaskFormer fine-tuned on COCO panoptic segmentation
>>> image_processor = MaskFormerImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
>>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-coco")
>>> model = MaskFormerForUniversalSegmentation.from_pretrained("facebook/maskformer-swin-base-coco")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
Expand Down Expand Up @@ -1832,7 +1833,7 @@ def forward(
if not output_auxiliary_logits:
auxiliary_logits = None

output = MaskFormerForInstanceSegmentationOutput(
output = MaskFormerForUniversalSegmentationOutput(
loss=loss,
**outputs,
class_queries_logits=class_queries_logits,
Expand All @@ -1845,3 +1846,13 @@ def forward(
if loss is not None:
output = ((loss)) + output
return output


class MaskFormerForInstanceSegmentation(MaskFormerForUniversalSegmentation):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class MaskFormerForInstanceSegmentation is deprecated and will be removed in version 5 of"
" Transformers. Please use MaskFormerForUniversalSegmentation instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)
2 changes: 2 additions & 0 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
)


Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(self, *args, **kwargs):
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING.items()
)
)

Expand Down
10 changes: 10 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def __init__(self, *args, **kwargs):
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None


MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = None


MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = None


Expand Down Expand Up @@ -639,6 +642,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AutoModelForUniversalSegmentation(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AutoModelForVideoClassification(metaclass=DummyObject):
_backends = ["torch"]

Expand Down