diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 9d1c33900c10..e49c5ba31b04 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -27,151 +27,150 @@
title: Generation with LLMs
title: Tutorials
- sections:
- - sections:
- - local: tasks/sequence_classification
- title: Text classification
- - local: tasks/token_classification
- title: Token classification
- - local: tasks/question_answering
- title: Question answering
- - local: tasks/language_modeling
- title: Causal language modeling
- - local: tasks/masked_language_modeling
- title: Masked language modeling
- - local: tasks/translation
- title: Translation
- - local: tasks/summarization
- title: Summarization
- - local: tasks/multiple_choice
- title: Multiple choice
+ - isExpanded: false
+ sections:
+ - local: tasks/sequence_classification
+ title: Text classification
+ - local: tasks/token_classification
+ title: Token classification
+ - local: tasks/question_answering
+ title: Question answering
+ - local: tasks/language_modeling
+ title: Causal language modeling
+ - local: tasks/masked_language_modeling
+ title: Masked language modeling
+ - local: tasks/translation
+ title: Translation
+ - local: tasks/summarization
+ title: Summarization
+ - local: tasks/multiple_choice
+ title: Multiple choice
title: Natural Language Processing
- isExpanded: false
- - sections:
- - local: tasks/audio_classification
- title: Audio classification
- - local: tasks/asr
- title: Automatic speech recognition
+ - isExpanded: false
+ sections:
+ - local: tasks/audio_classification
+ title: Audio classification
+ - local: tasks/asr
+ title: Automatic speech recognition
title: Audio
- isExpanded: false
- - sections:
- - local: tasks/image_classification
- title: Image classification
- - local: tasks/semantic_segmentation
- title: Semantic segmentation
- - local: tasks/video_classification
- title: Video classification
- - local: tasks/object_detection
- title: Object detection
- - local: tasks/zero_shot_object_detection
- title: Zero-shot object detection
- - local: tasks/zero_shot_image_classification
- title: Zero-shot image classification
- - local: tasks/monocular_depth_estimation
- title: Depth estimation
+ - isExpanded: false
+ sections:
+ - local: tasks/image_classification
+ title: Image classification
+ - local: tasks/semantic_segmentation
+ title: Semantic segmentation
+ - local: tasks/video_classification
+ title: Video classification
+ - local: tasks/object_detection
+ title: Object detection
+ - local: tasks/zero_shot_object_detection
+ title: Zero-shot object detection
+ - local: tasks/zero_shot_image_classification
+ title: Zero-shot image classification
+ - local: tasks/monocular_depth_estimation
+ title: Depth estimation
title: Computer Vision
- isExpanded: false
- - sections:
- - local: tasks/image_captioning
- title: Image captioning
- - local: tasks/document_question_answering
- title: Document Question Answering
- - local: tasks/visual_question_answering
- title: Visual Question Answering
- - local: tasks/text-to-speech
- title: Text to speech
+ - isExpanded: false
+ sections:
+ - local: tasks/image_captioning
+ title: Image captioning
+ - local: tasks/document_question_answering
+ title: Document Question Answering
+ - local: tasks/visual_question_answering
+ title: Visual Question Answering
+ - local: tasks/text-to-speech
+ title: Text to speech
title: Multimodal
- isExpanded: false
- - sections:
- - local: generation_strategies
- title: Customize the generation strategy
+ - isExpanded: false
+ sections:
+ - local: generation_strategies
+ title: Customize the generation strategy
title: Generation
- isExpanded: false
title: Task Guides
- sections:
- - local: fast_tokenizers
- title: Use fast tokenizers from 🤗 Tokenizers
- - local: multilingual
- title: Run inference with multilingual models
- - local: create_a_model
- title: Use model-specific APIs
- - local: custom_models
- title: Share a custom model
- - local: sagemaker
- title: Run training on Amazon SageMaker
- - local: serialization
- title: Export to ONNX
- - local: tflite
- title: Export to TFLite
- - local: torchscript
- title: Export to TorchScript
- - local: benchmarks
- title: Benchmarks
- - local: notebooks
- title: Notebooks with examples
- - local: community
- title: Community resources
- - local: custom_tools
- title: Custom Tools and Prompts
- - local: troubleshooting
- title: Troubleshoot
+ - local: fast_tokenizers
+ title: Use fast tokenizers from 🤗 Tokenizers
+ - local: multilingual
+ title: Run inference with multilingual models
+ - local: create_a_model
+ title: Use model-specific APIs
+ - local: custom_models
+ title: Share a custom model
+ - local: sagemaker
+ title: Run training on Amazon SageMaker
+ - local: serialization
+ title: Export to ONNX
+ - local: tflite
+ title: Export to TFLite
+ - local: torchscript
+ title: Export to TorchScript
+ - local: benchmarks
+ title: Benchmarks
+ - local: notebooks
+ title: Notebooks with examples
+ - local: community
+ title: Community resources
+ - local: custom_tools
+ title: Custom Tools and Prompts
+ - local: troubleshooting
+ title: Troubleshoot
title: Developer guides
- sections:
- - local: performance
- title: Overview
- - sections:
- - local: perf_train_gpu_one
- title: Methods and tools for efficient training on a single GPU
- - local: perf_train_gpu_many
- title: Multiple GPUs and parallelism
- - local: perf_train_cpu
- title: Efficient training on CPU
- - local: perf_train_cpu_many
- title: Distributed CPU training
- - local: perf_train_tpu
- title: Training on TPUs
- - local: perf_train_tpu_tf
- title: Training on TPU with TensorFlow
- - local: perf_train_special
- title: Training on Specialized Hardware
- - local: perf_hardware
- title: Custom hardware for training
- - local: hpo_train
- title: Hyperparameter Search using Trainer API
- title: Efficient training techniques
- - sections:
- - local: perf_infer_cpu
- title: Inference on CPU
- - local: perf_infer_gpu_one
- title: Inference on one GPU
- - local: perf_infer_gpu_many
- title: Inference on many GPUs
- - local: perf_infer_special
- title: Inference on Specialized Hardware
- title: Optimizing inference
- - local: big_models
- title: Instantiating a big model
- - local: debugging
- title: Troubleshooting
- - local: tf_xla
- title: XLA Integration for TensorFlow Models
- - local: perf_torch_compile
- title: Optimize inference using `torch.compile()`
+ - local: performance
+ title: Overview
+ - sections:
+ - local: perf_train_gpu_one
+ title: Methods and tools for efficient training on a single GPU
+ - local: perf_train_gpu_many
+ title: Multiple GPUs and parallelism
+ - local: perf_train_cpu
+ title: Efficient training on CPU
+ - local: perf_train_cpu_many
+ title: Distributed CPU training
+ - local: perf_train_tpu
+ title: Training on TPUs
+ - local: perf_train_tpu_tf
+ title: Training on TPU with TensorFlow
+ - local: perf_train_special
+ title: Training on Specialized Hardware
+ - local: perf_hardware
+ title: Custom hardware for training
+ - local: hpo_train
+ title: Hyperparameter Search using Trainer API
+ title: Efficient training techniques
+ - sections:
+ - local: perf_infer_cpu
+ title: Inference on CPU
+ - local: perf_infer_gpu_one
+ title: Inference on one GPU
+ - local: perf_infer_gpu_many
+ title: Inference on many GPUs
+ - local: perf_infer_special
+ title: Inference on Specialized Hardware
+ title: Optimizing inference
+ - local: big_models
+ title: Instantiating a big model
+ - local: debugging
+ title: Troubleshooting
+ - local: tf_xla
+ title: XLA Integration for TensorFlow Models
+ - local: perf_torch_compile
+ title: Optimize inference using `torch.compile()`
title: Performance and scalability
- sections:
- - local: contributing
- title: How to contribute to transformers?
- - local: add_new_model
- title: How to add a model to 🤗 Transformers?
- - local: add_tensorflow_model
- title: How to convert a 🤗 Transformers model to TensorFlow?
- - local: add_new_pipeline
- title: How to add a pipeline to 🤗 Transformers?
- - local: testing
- title: Testing
- - local: pr_checks
- title: Checks on a Pull Request
+ - local: contributing
+ title: How to contribute to transformers?
+ - local: add_new_model
+ title: How to add a model to 🤗 Transformers?
+ - local: add_tensorflow_model
+ title: How to convert a 🤗 Transformers model to TensorFlow?
+ - local: add_new_pipeline
+ title: How to add a pipeline to 🤗 Transformers?
+ - local: testing
+ title: Testing
+ - local: pr_checks
+ title: Checks on a Pull Request
title: Contribute
-
- sections:
- local: philosophy
title: Philosophy
@@ -533,6 +532,8 @@
title: ResNet
- local: model_doc/segformer
title: SegFormer
+ - local: model_doc/superglue
+ title: SuperGlue
- local: model_doc/swiftformer
title: SwiftFormer
- local: model_doc/swin
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 695e51fbe293..dd68c08aa3b9 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -123,6 +123,7 @@
],
"models": [],
# Models
+ "models.superglue": ["SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP", "SuperGlueConfig", "SuperGlueTokenizer"],
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
"models.align": [
"ALIGN_PRETRAINED_CONFIG_ARCHIVE_MAP",
@@ -478,6 +479,7 @@
"models.regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"],
"models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
"models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
+ "models.superglue": ["SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP", "SuperGlueConfig"],
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
"models.roberta_prelayernorm": ["ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaPreLayerNormConfig"],
"models.roc_bert": ["ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoCBertConfig", "RoCBertTokenizer"],
@@ -789,6 +791,7 @@
]
else:
# Fast tokenizers structure
+ _import_structure["models.superglue"].append("SuperGlueTokenizerFast")
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
_import_structure["models.barthez"].append("BarthezTokenizerFast")
@@ -1027,6 +1030,22 @@
_import_structure["modeling_utils"] = ["PreTrainedModel"]
# PyTorch models structure
+
+ _import_structure["models.superglue"].extend(
+ [
+ "SUPERGLUE_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "SuperGlueForMaskedLM",
+ "SuperGlueForCausalLM",
+ "SuperGlueForMultipleChoice",
+ "SuperGlueForQuestionAnswering",
+ "SuperGlueForSequenceClassification",
+ "SuperGlueForTokenClassification",
+ "SuperGlueLayer",
+ "SuperGlueModel",
+ "SuperGluePreTrainedModel",
+ "load_tf_weights_in_superglue",
+ ]
+ )
_import_structure["models.albert"].extend(
[
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2505,6 +2524,15 @@
"ResNetPreTrainedModel",
]
)
+ _import_structure["models.superglue"].extend(
+ [
+ "SUPERGLUE_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "SuperGlueBackbone",
+ "SuperGlueForImageClassification",
+ "SuperGlueModel",
+ "SuperGluePreTrainedModel",
+ ]
+ )
_import_structure["models.roberta"].extend(
[
"ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -4129,6 +4157,7 @@
load_tf2_weights_in_pytorch_model,
)
from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
+ from .models.superglue import SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP, SuperGlueConfig, SuperGlueTokenizer
from .models.align import (
ALIGN_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlignConfig,
@@ -4462,6 +4491,7 @@
from .models.regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
+ from .models.superglue import SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP, SuperGlueConfig
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.roberta_prelayernorm import (
ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -4749,6 +4779,7 @@
from .utils.dummy_tokenizers_objects import *
else:
# Fast tokenizers imports
+ from .models.superglue import SuperGlueTokenizerFast
from .models.albert import AlbertTokenizerFast
from .models.bart import BartTokenizerFast
from .models.barthez import BarthezTokenizerFast
@@ -4948,6 +4979,20 @@
from .modeling_utils import PreTrainedModel
# PyTorch model imports
+
+ from .models.superglue import (
+ SUPERGLUE_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SuperGlueForMaskedLM,
+ SuperGlueForCausalLM,
+ SuperGlueForMultipleChoice,
+ SuperGlueForQuestionAnswering,
+ SuperGlueForSequenceClassification,
+ SuperGlueForTokenClassification,
+ SuperGlueLayer,
+ SuperGlueModel,
+ SuperGluePreTrainedModel,
+ load_tf_weights_in_superglue,
+ )
from .models.albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AlbertForMaskedLM,
@@ -6162,6 +6207,13 @@
ResNetModel,
ResNetPreTrainedModel,
)
+ from .models.superglue import (
+ SUPERGLUE_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SuperGlueBackbone,
+ SuperGlueForImageClassification,
+ SuperGlueModel,
+ SuperGluePreTrainedModel,
+ )
from .models.roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
RobertaForCausalLM,
diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py
index aceec7abd406..86856268bca0 100755
--- a/src/transformers/modeling_outputs.py
+++ b/src/transformers/modeling_outputs.py
@@ -1660,3 +1660,14 @@ def logits(self):
FutureWarning,
)
return self.reconstruction
+
+@dataclass
+class ImageMatchingOutput(ModelOutput):
+ """
+ TODO documentation
+ """
+
+ image0_matches: torch.FloatTensor = None
+ image1_matches: torch.FloatTensor = None
+ image0_matching_scores: torch.FloatTensor = None
+ image1_matching_scores: torch.FloatTensor = None
\ No newline at end of file
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 7af9ff766aed..accf79a8cfdf 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -13,6 +13,7 @@
# limitations under the License.
from . import (
+ superglue,
albert,
align,
altclip,
@@ -165,6 +166,7 @@
regnet,
rembert,
resnet,
+ superglue,
roberta,
roberta_prelayernorm,
roc_bert,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 7230c3f1fa19..04173788cece 100755
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -30,6 +30,7 @@
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
+ ("superglue", "SuperGlueConfig"),
("albert", "AlbertConfig"),
("align", "AlignConfig"),
("altclip", "AltCLIPConfig"),
@@ -170,6 +171,7 @@
("regnet", "RegNetConfig"),
("rembert", "RemBertConfig"),
("resnet", "ResNetConfig"),
+ ("superglue", "SuperGlueConfig"),
("retribert", "RetriBertConfig"),
("roberta", "RobertaConfig"),
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
@@ -236,6 +238,7 @@
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here)
+ ("superglue", "SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("align", "ALIGN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("altclip", "ALTCLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -366,6 +369,7 @@
("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("superglue", "SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta-prelayernorm", "ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -422,6 +426,7 @@
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
+ ("superglue", "SuperGlue"),
("albert", "ALBERT"),
("align", "ALIGN"),
("altclip", "AltCLIP"),
@@ -584,6 +589,7 @@
("regnet", "RegNet"),
("rembert", "RemBERT"),
("resnet", "ResNet"),
+ ("superglue", "SuperGlue"),
("retribert", "RetriBERT"),
("roberta", "RoBERTa"),
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index 9b3ab2b1705b..a59057029ffa 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -75,6 +75,7 @@
("poolformer", "PoolFormerFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
+ ("superglue", "ConvNextFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("sew", "Wav2Vec2FeatureExtractor"),
("sew-d", "Wav2Vec2FeatureExtractor"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 075fe0c96db0..dfa63fb9e121 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -89,6 +89,7 @@
("pvt", "PvtImageProcessor"),
("regnet", "ConvNextImageProcessor"),
("resnet", "ConvNextImageProcessor"),
+ ("superglue", "ConvNextImageProcessor"),
("sam", "SamImageProcessor"),
("segformer", "SegformerImageProcessor"),
("swiftformer", "ViTImageProcessor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index aec9eacc2a7a..8eb481b71a12 100755
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -28,6 +28,7 @@
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
+ ("superglue", "SuperGlueModel"),
("albert", "AlbertModel"),
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
@@ -162,6 +163,7 @@
("regnet", "RegNetModel"),
("rembert", "RemBertModel"),
("resnet", "ResNetModel"),
+ ("superglue", "SuperGlueModel"),
("retribert", "RetriBertModel"),
("roberta", "RobertaModel"),
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
@@ -290,6 +292,7 @@
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
+("superglue", "SuperGlueForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("bert", "BertForMaskedLM"),
@@ -372,6 +375,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
+ ("superglue", "SuperGlueForCausalLM"),
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
("bert-generation", "BertGenerationDecoder"),
@@ -490,6 +494,7 @@
("pvt", "PvtForImageClassification"),
("regnet", "RegNetForImageClassification"),
("resnet", "ResNetForImageClassification"),
+ ("superglue", "SuperGlueForImageClassification"),
("segformer", "SegformerForImageClassification"),
("swiftformer", "SwiftFormerForImageClassification"),
("swin", "SwinForImageClassification"),
@@ -563,6 +568,7 @@
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
+("superglue", "SuperGlueForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("bert", "BertForMaskedLM"),
@@ -678,6 +684,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
+ ("superglue", "SuperGlueForSequenceClassification"),
("albert", "AlbertForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("bert", "BertForSequenceClassification"),
@@ -757,6 +764,7 @@
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
+ ("superglue", "SuperGlueForQuestionAnswering"),
("albert", "AlbertForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
@@ -846,6 +854,7 @@
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
+("superglue", "SuperGlueForTokenClassification"),
("albert", "AlbertForTokenClassification"),
("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
@@ -906,6 +915,7 @@
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
+("superglue", "SuperGlueForMultipleChoice"),
("albert", "AlbertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
@@ -1037,6 +1047,7 @@
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
+ ("superglue", "SuperGlueBackbone"),
("swin", "SwinBackbone"),
("timm_backbone", "TimmBackbone"),
]
diff --git a/src/transformers/models/superglue/__init__.py b/src/transformers/models/superglue/__init__.py
new file mode 100644
index 000000000000..d7db1e2c9944
--- /dev/null
+++ b/src/transformers/models/superglue/__init__.py
@@ -0,0 +1,61 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_superglue": ["SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP", "SuperGlueConfig", "SuperGlueOnnxConfig"]
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_superglue"] = [
+ "SUPERGLUE_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "SuperGlueForImageClassification",
+ "SuperGlueModel",
+ "SuperGluePreTrainedModel",
+ "SuperGlueBackbone",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_superglue import SUPERGLUE_PRETRAINED_CONFIG_ARCHIVE_MAP, SuperGlueConfig, SuperGlueOnnxConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_superglue import (
+ SUPERGLUE_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SuperGlueBackbone,
+ SuperGlueForImageClassification,
+ SuperGlueModel,
+ SuperGluePreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
diff --git a/src/transformers/models/superglue/configuration_superglue.py b/src/transformers/models/superglue/configuration_superglue.py
new file mode 100644
index 000000000000..99cc2d4ccfdf
--- /dev/null
+++ b/src/transformers/models/superglue/configuration_superglue.py
@@ -0,0 +1,34 @@
+from typing import List
+
+from transformers import PretrainedConfig
+
+
+class SuperGlueConfig(PretrainedConfig):
+
+ def __init__(
+ self,
+ descriptor_dim: int = 256,
+ keypoint_encoder_sizes: List[int] = [32, 64, 128, 256],
+ gnn_layers_types: List[str] = ['self', 'cross'] * 9,
+ num_heads: int = 4,
+ sinkhorn_iterations: int = 100,
+ matching_threshold: float = 0.2,
+ model_version: str = "indoor",
+ **kwargs,
+ ):
+ # Check whether all gnn_layers_types are either 'self' or 'cross'
+ if not all([layer_type in ['self', 'cross'] for layer_type in gnn_layers_types]):
+ raise ValueError("All gnn_layers_types must be either 'self' or 'cross'")
+
+ if model_version != "indoor" and model_version != "outdoor":
+ raise ValueError("model_version must be either 'indoor' or 'outdoor'")
+
+ self.descriptor_dim = descriptor_dim
+ self.keypoint_encoder_sizes = keypoint_encoder_sizes
+ self.gnn_layers_types = gnn_layers_types
+ self.num_heads = num_heads
+ self.sinkhorn_iterations = sinkhorn_iterations
+ self.matching_threshold = matching_threshold
+ self.model_version = model_version
+
+ super().__init__(**kwargs)
diff --git a/src/transformers/models/superglue/convert_superglue_to_pytorch.py b/src/transformers/models/superglue/convert_superglue_to_pytorch.py
new file mode 100644
index 000000000000..dd318435aae7
--- /dev/null
+++ b/src/transformers/models/superglue/convert_superglue_to_pytorch.py
@@ -0,0 +1,121 @@
+import argparse
+
+import torch
+
+from transformers import SuperGlueConfig, SuperGlueModel
+
+
+def get_superglue_config(checkpoint_url):
+ config = SuperGlueConfig(
+ descriptor_dim=256,
+ keypoint_encoder_sizes=[32, 64, 128, 256],
+ gnn_layers_types=['self', 'cross'] * 9,
+ sinkhorn_iterations=100,
+ matching_threshold=0.2,
+ )
+
+ if "superglue_indoor" in checkpoint_url:
+ config.model_version = "indoor"
+ elif "superglue_outdoor" in checkpoint_url:
+ config.model_version = "outdoor"
+
+ return config
+
+
+def create_rename_keys(config, state_dict):
+ rename_keys = []
+
+ # keypoint encoder
+ n = len([3] + config.keypoint_encoder_sizes + [config.descriptor_dim])
+ for i in range(n * 2 + 1):
+ if ((i + 1) % 3) != 0:
+ rename_keys.append((f"kenc.encoder.{i}.weight", f"keypoint_encoder.encoder.layers.{i}.weight"))
+ rename_keys.append((f"kenc.encoder.{i}.bias", f"keypoint_encoder.encoder.layers.{i}.bias"))
+ if ((i % 3) - 1) == 0:
+ rename_keys.append((f"kenc.encoder.{i}.running_mean",
+ f"keypoint_encoder.encoder.layers.{i}.running_mean"))
+ rename_keys.append((f"kenc.encoder.{i}.running_var",
+ f"keypoint_encoder.encoder.layers.{i}.running_var"))
+ rename_keys.append((f"kenc.encoder.{i}.num_batches_tracked",
+ f"keypoint_encoder.encoder.layers.{i}.num_batches_tracked"))
+
+ # gnn
+ for i in range(len(config.gnn_layers_types)):
+ rename_keys.append((f"gnn.layers.{i}.attn.merge.weight", f"gnn.layers.{i}.attention.merge.weight"))
+ rename_keys.append((f"gnn.layers.{i}.attn.merge.bias", f"gnn.layers.{i}.attention.merge.bias"))
+ for j in range(3):
+ rename_keys.append((f"gnn.layers.{i}.attn.proj.{j}.weight", f"gnn.layers.{i}.attention.proj.{j}.weight"))
+ rename_keys.append((f"gnn.layers.{i}.attn.proj.{j}.bias", f"gnn.layers.{i}.attention.proj.{j}.bias"))
+ for j in range(len([config.descriptor_dim * 2, config.descriptor_dim * 2, config.descriptor_dim]) + 1):
+ if j != 2 :
+ rename_keys.append((f"gnn.layers.{i}.mlp.{j}.weight", f"gnn.layers.{i}.mlp.layers.{j}.weight"))
+ rename_keys.append((f"gnn.layers.{i}.mlp.{j}.bias", f"gnn.layers.{i}.mlp.layers.{j}.bias"))
+ if j == 1:
+ rename_keys.append((f"gnn.layers.{i}.mlp.{j}.running_mean",
+ f"gnn.layers.{i}.mlp.layers.{j}.running_mean"))
+ rename_keys.append((f"gnn.layers.{i}.mlp.{j}.running_var",
+ f"gnn.layers.{i}.mlp.layers.{j}.running_var"))
+ rename_keys.append((f"gnn.layers.{i}.mlp.{j}.num_batches_tracked",
+ f"gnn.layers.{i}.mlp.layers.{j}.num_batches_tracked"))
+ return rename_keys
+
+
+# Copied from transformers.models.dinov2.convert_dinov2_to_hf
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+@torch.no_grad()
+def convert_superglue_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
+ """
+ TODO docs
+ """
+
+ print("Downloading original model from checkpoint...")
+ config = get_superglue_config(checkpoint_url)
+
+ original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
+ print(original_state_dict)
+
+ print("Converting model parameters...")
+ rename_keys = create_rename_keys(config, original_state_dict)
+ new_state_dict = original_state_dict.copy()
+ for src, dest in rename_keys:
+ rename_key(new_state_dict, src, dest)
+
+ for key in new_state_dict.copy().keys():
+ val = new_state_dict.pop(key)
+ if not key.startswith("superglue"):
+ key = "superglue." + key
+ new_state_dict[key] = val
+
+ model = SuperGlueModel(config)
+ model.load_state_dict(new_state_dict)
+ model.eval()
+ print("Successfully loaded weights in the model")
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superglue_indoor.pth",
+ type=str,
+ help="URL of the original SuperGlue checkpoint you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default="model",
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument("--save_model", action="store_true", help="Save model to local")
+ parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
+
+ args = parser.parse_args()
+ convert_superglue_checkpoint(
+ args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
+ )
diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py
new file mode 100644
index 000000000000..d362759236d8
--- /dev/null
+++ b/src/transformers/models/superglue/modeling_superglue.py
@@ -0,0 +1,364 @@
+from copy import deepcopy
+from typing import List, Tuple, Optional, Union
+
+import torch
+from torch import nn, Tensor
+
+from transformers import PreTrainedModel
+from transformers.models.superglue.configuration_superglue import SuperGlueConfig
+from transformers.modeling_outputs import ImageMatchingOutput
+
+
+class SuperGlueMultiLayerPerceptron(nn.Module):
+ def __init__(
+ self,
+ channels: List[int],
+ do_batch_norm: bool = True,
+ ):
+ super().__init__()
+ num_layers = len(channels)
+ layers = []
+ for i in range(1, num_layers):
+ layers.append(
+ nn.Conv1d(
+ channels[i - 1],
+ channels[i],
+ kernel_size=1,
+ bias=True
+ )
+ )
+ if i < (num_layers - 1):
+ if do_batch_norm:
+ layers.append(nn.BatchNorm1d(channels[i]))
+ layers.append(nn.ReLU())
+ self.layers = nn.Sequential(*layers)
+ nn.init.constant_(self.layers[-1].bias, 0.0)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.layers(input)
+
+
+class SuperGlueKeypointEncoder(nn.Module):
+
+ def __init__(
+ self,
+ feature_dim: int = 256,
+ layers_sizes: List[int] = [32, 64, 128, 256],
+ ):
+ super().__init__()
+ self.encoder = SuperGlueMultiLayerPerceptron(
+ channels=[3] + layers_sizes + [feature_dim]
+ )
+
+ def forward(self, keypoints: Tensor, scores: Tensor) -> Tensor:
+ keypoints = keypoints.transpose(1, 2)
+ scores = scores.unsqueeze(1)
+ inputs = torch.cat([keypoints, scores], dim=1)
+ return self.encoder(inputs)
+
+
+class SuperGlueMultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ feature_dim: int,
+ num_heads: int
+ ):
+ super().__init__()
+ assert feature_dim % num_heads == 0
+ self.feature_dim = feature_dim
+ self.num_heads = num_heads
+ self.dim = feature_dim // num_heads
+ self.merge = nn.Conv1d(feature_dim, feature_dim, kernel_size=1)
+ self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor
+ ) -> torch.Tensor:
+ batch_dim = query.size(0)
+ query, key, value = [
+ layer(x).view(batch_dim, self.dim, self.num_heads, -1)
+ for layer, x in zip(self.proj, (query, key, value))
+ ]
+ x, _ = self.attention(query, key, value)
+ output = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1))
+ return output
+
+ def attention(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ dim = query.shape[1]
+ scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5
+ prob = torch.nn.functional.softmax(scores, dim=-1)
+ output = torch.einsum('bhnm,bdhm->bdhn', prob, value)
+ return output, prob
+
+
+class SuperGlueAttentionalPropagation(nn.Module):
+ def __init__(
+ self,
+ feature_dim: int,
+ num_heads: int
+ ):
+ super().__init__()
+ self.feature_dim = feature_dim
+ self.num_heads = num_heads
+ self.attention = SuperGlueMultiHeadAttention(
+ feature_dim=feature_dim,
+ num_heads=num_heads
+ )
+ self.mlp = SuperGlueMultiLayerPerceptron(
+ [feature_dim * 2, feature_dim * 2, feature_dim]
+ )
+ nn.init.constant_(self.mlp.layers[-1].bias, 0.0)
+
+ def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
+ message = self.attention(x, source, source)
+ message = torch.cat([x, message], dim=1)
+ message = self.mlp(message)
+ return message
+
+
+class SuperGlueAttentionalGNN(nn.Module):
+ def __init__(
+ self,
+ feature_dim: int,
+ num_heads: int,
+ layers_types: List[str],
+ ):
+ super().__init__()
+ self.feature_dim = feature_dim
+ self.num_heads = num_heads
+ self.layers_types = layers_types
+ self.num_layers = len(self.layers_types)
+ self.layers = nn.ModuleList(
+ [
+ SuperGlueAttentionalPropagation(
+ self.feature_dim,
+ self.num_heads
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+
+ def forward(self, descriptors_0: torch.Tensor, descriptors_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ for gnn_layer, type in zip(self.layers, self.layers_types):
+ if type == 'cross':
+ source_0, source_1 = descriptors_1, descriptors_0
+ else: # if type == 'self':
+ source_0, source_1 = descriptors_0, descriptors_1
+ delta0, delta1 = gnn_layer(descriptors_0, source_0), gnn_layer(descriptors_1, source_1)
+ descriptor_0, descriptors_1 = (descriptors_0 + delta0), (descriptors_1 + delta1)
+ return descriptors_0, descriptors_1
+
+
+class SuperGlue(nn.Module):
+ """SuperGlue feature matching middle-end
+
+ Given two sets of keypoints and locations, we determine the
+ correspondences by:
+ 1. Keypoint Encoding (normalization + visual feature and location fusion)
+ 2. Graph Neural Network with multiple self and cross-attention layers
+ 3. Final projection layer
+ 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
+ 5. Thresholding matrix based on mutual exclusivity and a match_threshold
+
+ The correspondence ids use -1 to indicate non-matching points.
+
+ Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
+ Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
+ Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
+
+ """
+
+ def __init__(
+ self,
+ descriptor_dim: int = 256,
+ keypoint_encoder_sizes: List[int] = [32, 64, 128, 256],
+ gnn_layers_types: List[str] = ['self', 'cross'] * 9,
+ num_heads: int = 4,
+ sinkhorn_iterations: int = 100,
+ matching_threshold: float = 0.2,
+ ):
+ super().__init__()
+
+ self.descriptor_dim = descriptor_dim
+ self.keypoint_encoder_sizes = keypoint_encoder_sizes
+ self.gnn_layers_types = gnn_layers_types
+ self.num_heads = num_heads
+ self.sinkhorn_iterations = sinkhorn_iterations
+ self.match_threshold = matching_threshold
+
+ self.keypoint_encoder = SuperGlueKeypointEncoder(
+ feature_dim=self.descriptor_dim,
+ layers_sizes=self.keypoint_encoder_sizes
+ )
+
+ self.gnn = SuperGlueAttentionalGNN(
+ feature_dim=self.descriptor_dim,
+ num_heads=self.num_heads,
+ layers_types=self.gnn_layers_types
+ )
+
+ self.final_proj = nn.Conv1d(
+ self.descriptor_dim, self.descriptor_dim,
+ kernel_size=1, bias=True)
+
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('bin_score', bin_score)
+
+ def forward(
+ self,
+ keypoints_0: Tensor,
+ scores_0: Tensor,
+ descriptors_0: Tensor,
+ keypoints_1: Tensor,
+ scores_1: Tensor,
+ descriptors_1: Tensor
+ ):
+ """Run SuperGlue on a pair of keypoints and descriptors"""
+ if keypoints_0.shape[1] == 0 or keypoints_1.shape[1] == 0: # no keypoints
+ shape0, shape1 = keypoints_0.shape[:-1], keypoints_1.shape[:-1]
+ return tuple([
+ keypoints_0.new_full(shape0, -1, dtype=torch.int),
+ keypoints_1.new_full(shape1, -1, dtype=torch.int),
+ keypoints_0.new_zeros(shape0),
+ keypoints_1.new_zeros(shape1)
+ ])
+
+ # Keypoint MLP encoder.
+ descriptors_0 = descriptors_0 + self.keypoint_encoder(keypoints_0, scores_0)
+ descriptors_1 = descriptors_1 + self.keypoint_encoder(keypoints_1, scores_1)
+
+ # Multi-layer Transformer network.
+ descriptors_0, descriptors_1 = self.gnn(descriptors_0, descriptors_1)
+
+ # Final MLP projection.
+ projected_descriptors_0, projected_descriptors_1 = self.final_proj(descriptors_0), self.final_proj(descriptors_1)
+
+ # Compute matching descriptor distance.
+ scores = torch.einsum('bdn,bdm->bnm', projected_descriptors_0, projected_descriptors_1)
+ scores = scores / self.descriptor_dim ** .5
+
+ # Run the optimal transport.
+ scores = self.log_optimal_transport(
+ scores,
+ self.bin_score,
+ iters=self.sinkhorn_iterations
+ )
+
+ # Get the matches with score above "match_threshold".
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+ indices0, indices1 = max0.indices, max1.indices
+ mutual0 = self.arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
+ mutual1 = self.arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
+ zero = scores.new_tensor(0)
+ matching_scores_0 = torch.where(mutual0, max0.values.exp(), zero)
+ matching_scores_1 = torch.where(mutual1, matching_scores_0.gather(1, indices1), zero)
+ valid0 = mutual0 & (matching_scores_0 > self.match_threshold)
+ valid1 = mutual1 & valid0.gather(1, indices1)
+ matches_0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
+ matches_1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
+
+ return matches_0, matches_1, matching_scores_0, matching_scores_1
+
+ @staticmethod
+ def normalize_keypoints(
+ kpts: Tensor,
+ height: int,
+ width: int
+ ):
+ """ Normalize keypoints locations based on image image_shape"""
+ one = kpts.new_tensor(1)
+ size = torch.stack([one * width, one * height])[None]
+ center = size / 2
+ scaling = size.max(1, keepdim=True).values * 0.7
+ return (kpts - center[:, None, :]) / scaling[:, None, :]
+
+ @staticmethod
+ def log_sinkhorn_iterations(
+ Z: torch.Tensor,
+ log_mu: torch.Tensor,
+ log_nu: torch.Tensor,
+ iters: int
+ ) -> torch.Tensor:
+ """ Perform Sinkhorn Normalization in Log-space for stability"""
+ u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
+ for _ in range(iters):
+ u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
+ v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
+ return Z + u.unsqueeze(2) + v.unsqueeze(1)
+
+ @staticmethod
+ def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
+ """ Perform Differentiable Optimal Transport in Log-space for stability"""
+ b, m, n = scores.shape
+ one = scores.new_tensor(1)
+ ms, ns = (m * one).to(scores), (n * one).to(scores)
+
+ bins0 = alpha.expand(b, m, 1)
+ bins1 = alpha.expand(b, 1, n)
+ alpha = alpha.expand(b, 1, 1)
+
+ couplings = torch.cat([torch.cat([scores, bins0], -1),
+ torch.cat([bins1, alpha], -1)], 1)
+
+ norm = - (ms + ns).log()
+ log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
+ log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
+ log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
+
+ Z = SuperGlueModel.log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
+ Z = Z - norm # multiply probabilities by M+N
+ return Z
+
+ @staticmethod
+ def arange_like(x, dim: int):
+ return x.new_ones(x.shape[dim]).cumsum(0) - 1
+
+
+class SuperGlueModel(PreTrainedModel):
+ config_class = SuperGlueConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.superglue = SuperGlue(
+ descriptor_dim=config.descriptor_dim,
+ keypoint_encoder_sizes=config.keypoint_encoder_sizes,
+ gnn_layers_types=config.gnn_layers_types,
+ sinkhorn_iterations=config.sinkhorn_iterations,
+ matching_threshold=config.matching_threshold,
+ )
+
+ def forward(
+ self,
+ image0_keypoints: Tensor = None,
+ image0_scores: Tensor = None,
+ image0_descriptors: Tensor = None,
+ image1_keypoints: Tensor = None,
+ image1_scores: Tensor = None,
+ image1_descriptors: Tensor = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, ImageMatchingOutput]:
+ image0_matches, image1_matches, image0_matching_scores, image1_matching_scores = self.model(
+ image0_keypoints=image0_keypoints,
+ image0_scores=image0_scores,
+ image0_descriptors=image0_descriptors,
+ image1_keypoints=image1_keypoints,
+ image1_scores=image1_scores,
+ image1_descriptors=image1_descriptors,
+ )
+ if not return_dict:
+ return image0_matches, image1_matches, image0_matching_scores, image1_matching_scores
+
+ return ImageMatchingOutput(
+ image0_matches=image0_matches,
+ image1_matches=image1_matches,
+ image0_matching_scores=image0_matching_scores,
+ image1_matching_scores=image1_matching_scores,
+ )
diff --git a/src/transformers/models/superglue/tokenization_superglue.py b/src/transformers/models/superglue/tokenization_superglue.py
new file mode 100644
index 000000000000..c079afeb5a91
--- /dev/null
+++ b/src/transformers/models/superglue/tokenization_superglue.py
@@ -0,0 +1,251 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for SuperGlue."""
+from typing import List, Optional
+
+from tokenizers import ByteLevelBPETokenizer
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "brand-new-bert-base-cased": "https://huggingface.co/brand-new-bert-base-cased/resolve/main/vocab.txt",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "brand-new-bert-base-cased": 1024,
+}
+
+class SuperGlueTokenizer(PreTrainedTokenizer):
+ """
+ Construct a SuperGlue tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ **kwargs
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
+
+ """ Initialisation """
+
+ @property
+ def vocab_size(self):
+ """ Returns vocab size """
+
+ def get_vocab(self):
+ """ Returns vocab as a dict """
+
+ def _tokenize(self, text):
+ """ Returns a tokenized string. """
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+
+ def save_vocabulary(self, save_directory):
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ A SuperGlue sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ SuperGlue does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
+ text = " " + text
+ return (text, kwargs)
+
+class SuperGlueTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" SuperGlue tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ trim_offsets=True,
+ **kwargs
+ ):
+ super().__init__(
+ ByteLevelBPETokenizer(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ add_prefix_space=add_prefix_space,
+ trim_offsets=trim_offsets,
+ ),
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ **kwargs,
+ )
+ self.add_prefix_space = add_prefix_space
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return output
+
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
+
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ SuperGlue does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+
diff --git a/src/transformers/models/superglue/tokenization_superglue_fast.py b/src/transformers/models/superglue/tokenization_superglue_fast.py
new file mode 100644
index 000000000000..319a3d977839
--- /dev/null
+++ b/src/transformers/models/superglue/tokenization_superglue_fast.py
@@ -0,0 +1,113 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for SuperGlue."""
+from typing import List, Optional
+
+from tokenizers import ByteLevelBPETokenizer
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_superglue import SuperGlueTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "brand-new-bert-base-cased": "https://huggingface.co/brand-new-bert-base-cased/resolve/main/vocab.txt",
+ },
+ "tokenizer_file": {
+ "brand-new-bert-base-cased": "https://huggingface.co/brand-new-bert-base-cased/resolve/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "brand-new-bert-base-cased": 1024,
+}
+
+class SuperGlueTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" SuperGlue tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ slow_tokenizer_class = SuperGlueTokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ trim_offsets=True,
+ **kwargs
+ ):
+ super().__init__(
+ ByteLevelBPETokenizer(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ add_prefix_space=add_prefix_space,
+ trim_offsets=trim_offsets,
+ ),
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ **kwargs,
+ )
+ self.add_prefix_space = add_prefix_space
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return output
+
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
+
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ SuperGlue does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+
+