Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1166511
first commit:
younesbelkada Jun 18, 2022
563217b
make quality
younesbelkada Jun 18, 2022
77220f7
implement pt_flax equivlence test to bypass `AttributeError: 'NoneTy…
ArthurZucker Jun 18, 2022
5f1341e
add few changes:
younesbelkada Jun 19, 2022
86320ad
add modification sequential
younesbelkada Jun 20, 2022
6b2c28c
fix copies
younesbelkada Jun 20, 2022
926fb77
few fixes
younesbelkada Jun 20, 2022
72bf6a4
changes
younesbelkada Jun 25, 2022
7762a1e
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Jun 25, 2022
63a6144
style
ArthurZucker Jun 25, 2022
3d74f94
add FlaxViTPatchEmbeddings for consistency
ArthurZucker Jun 25, 2022
ed7a4bc
update tests
ArthurZucker Jun 25, 2022
e2a61c9
consistency
ArthurZucker Jun 25, 2022
e11b043
style
ArthurZucker Jun 25, 2022
05dcc85
all tests should pas
younesbelkada Jun 25, 2022
7c26f69
fixing few comments
younesbelkada Aug 10, 2022
3505582
Apply suggestions from code review
younesbelkada Aug 10, 2022
425508e
Update src/transformers/models/dpt/gradient_convolution.py
younesbelkada Aug 10, 2022
24aeb4d
add few comments
younesbelkada Aug 10, 2022
479d0e8
Apply suggestions from code review
younesbelkada Aug 10, 2022
791ea24
refactor a bit:
younesbelkada Aug 10, 2022
010a915
Merge remote-tracking branch 'upstream/main' into dpt-flax-younes
younesbelkada Aug 10, 2022
a40f21d
add comments on key naming strategy
younesbelkada Aug 10, 2022
f945eef
few modifications
younesbelkada Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Flax), PyTorch, and/or TensorFlow.
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
| DPT | ❌ | ❌ | ✅ | ❌ | |
| DPT | ❌ | ❌ | ✅ | ❌ | |
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
Expand Down
17 changes: 16 additions & 1 deletion docs/source/en/model_doc/dpt.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,19 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
## DPTForSemanticSegmentation

[[autodoc]] DPTForSemanticSegmentation
- forward
- forward

## FlaxDPTForSemanticSegmentation

[[autodoc]] FlaxDPTForSemanticSegmentation
- __call__

## FlaxDPTForDepthEstimation

[[autodoc]] FlaxDPTForDepthEstimation
- __call__

## FlaxDPTModel

[[autodoc]] FlaxDPTModel
- __call__
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,6 +2745,14 @@
"FlaxDistilBertPreTrainedModel",
]
)
_import_structure["models.dpt"].extend(
[
"FlaxDPTModel",
"FlaxDPTPreTrainedModel",
"FlaxDPTForSemanticSegmentation",
"FlaxDPTForDepthEstimation",
]
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForCausalLM",
Expand Down Expand Up @@ -5101,6 +5109,12 @@
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)
from .models.dpt import (
FlaxDPTForDepthEstimation,
FlaxDPTForSemanticSegmentation,
FlaxDPTModel,
FlaxDPTPreTrainedModel,
)
from .models.electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
Expand Down
69 changes: 69 additions & 0 deletions src/transformers/modeling_flax_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,72 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxDepthEstimatorOutput(ModelOutput):
"""
Base class for outputs of depth estimation models.

Args:
loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided):
Depth Estimation loss.
predicted_depth (`jnp.ndarray` of shape `(batch_size, height, width)`):
Predicted depth for each pixel.

hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: jnp.ndarray = None
predicted_depth: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxSemanticSegmenterOutput(ModelOutput):
"""
Base class for outputs of depth estimation models.

Args:
loss (`jnp.ndarray` of shape `(1,)`, *optional*, returned when `labels` is provided):
Semantic Segmentation loss (CrossEntropy Loss).
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Semantic Segmentation raw logits for each pixel.

<Tip warning={true}>

The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.

</Tip>

hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: jnp.ndarray = None
logits: jnp.ndarray = None
predicted_depth: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
36 changes: 36 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("dpt", "FlaxDPTModel"),
("electra", "FlaxElectraModel"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
Expand Down Expand Up @@ -211,6 +212,17 @@
]
)

FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = OrderedDict(
[
("dpt", "FlaxDPTForDepthEstimation"),
]
)

FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = OrderedDict(
[
("dpt", "FlaxDPTForSemanticSegmentation"),
]
)

FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down Expand Up @@ -241,6 +253,12 @@
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING
)
FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
)


class FlaxAutoModel(_BaseAutoModelClass):
Expand Down Expand Up @@ -344,3 +362,21 @@ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)


class FlaxAutoModelForDepthEstimation(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_DEPTH_ESTIMATION_MAPPING


FlaxAutoModelForDepthEstimation = auto_class_update(
FlaxAutoModelForDepthEstimation, head_doc="depth estimation modeling"
)


class FlaxAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


FlaxAutoModelForSemanticSegmentation = auto_class_update(
FlaxAutoModelForSemanticSegmentation, head_doc="semantic segmentation modeling"
)
33 changes: 32 additions & 1 deletion src/transformers/models/dpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
from ...file_utils import (
_LazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
from ...utils import OptionalDependencyNotAvailable


Expand Down Expand Up @@ -45,6 +51,19 @@
"DPTPreTrainedModel",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_dpt"] = [
"FlaxDPTForSemanticSegmentation",
"FlaxDPTForDepthEstimation",
"FlaxDPTModel",
"FlaxDPTPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig
Expand All @@ -71,6 +90,18 @@
DPTPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_dpt import (
FlaxDPTForDepthEstimation,
FlaxDPTForSemanticSegmentation,
FlaxDPTModel,
FlaxDPTPreTrainedModel,
)

else:
import sys
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/dpt/configuration_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
auxiliary_loss_weight=0.4,
semantic_loss_ignore_index=255,
semantic_classifier_dropout=0.1,
align_corners=True,
**kwargs
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -168,3 +169,4 @@ def __init__(
self.auxiliary_loss_weight = auxiliary_loss_weight
self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.semantic_classifier_dropout = semantic_classifier_dropout
self.align_corners = align_corners
Loading