Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/de/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
14 changes: 13 additions & 1 deletion docs/source/en/model_doc/regnet.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f
## TFRegNetForImageClassification

[[autodoc]] TFRegNetForImageClassification
- call
- call


## FlaxRegNetModel

[[autodoc]] FlaxRegNetModel
- __call__


## FlaxRegNetForImageClassification

[[autodoc]] FlaxRegNetForImageClassification
- __call__
2 changes: 1 addition & 1 deletion docs/source/es/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Flax), PyTorch y/o TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ❌ | |
| RegNet | ❌ | ❌ | ✅ | ❌ | |
Comment thread
Shubhamai marked this conversation as resolved.
Outdated
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/fr/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/it/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ❌ | |
| RegNet | ❌ | ❌ | ✅ | ❌ | |
Comment thread
Shubhamai marked this conversation as resolved.
Outdated
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ja/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ specific language governing permissions and limitations under the License.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ specific language governing permissions and limitations under the License.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pt/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ❌ | |
| RegNet | ❌ | ❌ | ✅ | ❌ | |
Comment thread
Shubhamai marked this conversation as resolved.
Outdated
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/zh/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Flax), PyTorch, 和/或者 TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3561,6 +3561,9 @@
"FlaxPegasusPreTrainedModel",
]
)
_import_structure["models.regnet"].extend(
["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"]
)
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForCausalLM",
Expand Down Expand Up @@ -6548,6 +6551,7 @@
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.regnet import FlaxRegNetForImageClassification, FlaxRegNetModel, FlaxRegNetPreTrainedModel
from .models.roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
Expand Down
57 changes: 57 additions & 0 deletions src/transformers/modeling_flax_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,63 @@ class FlaxBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
Comment thread
Shubhamai marked this conversation as resolved.
class FlaxBaseModelOutputWithNoAttention(ModelOutput):
"""
Args:
Base class for model's outputs, with potential hidden states.
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""

last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Args:
Base class for model's outputs that also contains a pooling of the last hidden states.
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""

last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
"""
Args:
Base class for outputs of image classification models.
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
"""

logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
("regnet", "FlaxRegNetModel"),
("roberta", "FlaxRobertaModel"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
("roformer", "FlaxRoFormerModel"),
Expand Down Expand Up @@ -119,6 +120,7 @@
[
# Model for Image-classsification
("beit", "FlaxBeitForImageClassification"),
("regnet", "FlaxRegNetForImageClassification"),
("vit", "FlaxViTForImageClassification"),
]
)
Expand Down
32 changes: 31 additions & 1 deletion src/transformers/models/regnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)


_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
Expand Down Expand Up @@ -44,6 +50,18 @@
"TFRegNetPreTrainedModel",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_regnet"] = [
"FlaxRegNetForImageClassification",
"FlaxRegNetModel",
"FlaxRegNetPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
Expand Down Expand Up @@ -74,6 +92,18 @@
TFRegNetPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_regnet import (
FlaxRegNetForImageClassification,
FlaxRegNetModel,
FlaxRegNetPreTrainedModel,
)


else:
import sys
Expand Down
Loading