From eb11c354a54553c0997863f24305cfbd826161ba Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sat, 1 Oct 2022 10:58:33 +0100 Subject: [PATCH] Add CLIP backbones for text / image classification (#1458) --- CHANGELOG.md | 2 + docs/source/quickstart.rst | 2 +- flash/core/registry.py | 14 +- flash/core/utilities/embedder.py | 32 +++- flash/core/utilities/providers.py | 1 + flash/image/classification/adapters.py | 2 +- .../classification/backbones/__init__.py | 4 +- flash/image/classification/backbones/clip.py | 58 ++++++ flash/image/segmentation/model.py | 2 +- flash/text/classification/adapters.py | 165 ++++++++++++++++++ .../text/classification/backbones/__init__.py | 5 + flash/text/classification/backbones/clip.py | 67 +++++++ .../huggingface.py} | 16 +- flash/text/classification/model.py | 56 ++---- .../text_classification/inference_server.py | 2 +- requirements/datatype_image.txt | 2 + requirements/datatype_text.txt | 3 + tests/core/test_model.py | 2 +- tests/image/classification/test_model.py | 3 + tests/text/classification/test_model.py | 21 ++- 20 files changed, 405 insertions(+), 54 deletions(-) create mode 100644 flash/image/classification/backbones/clip.py create mode 100644 flash/text/classification/adapters.py create mode 100644 flash/text/classification/backbones/__init__.py create mode 100644 flash/text/classification/backbones/clip.py rename flash/text/classification/{backbones.py => backbones/huggingface.py} (74%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7371fa0a48..201c120031 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for CLIP backbones to the `TextClassifier` and `ImageClassifier` tasks ([#1458](https://github.com/Lightning-AI/lightning-flash/pull/1458)) + ### Changed diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 02d5c54263..e3d94615cb 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -92,7 +92,7 @@ Here's an example of inference: from flash.text import TextClassifier, TextClassificationData # 1. Init the finetuned task from URL - model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/text_classification_model.pt") + model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.9.0/text_classification_model.pt") # 2. Perform inference from list of sequences trainer = Trainer() diff --git a/flash/core/registry.py b/flash/core/registry.py index 58098fc740..73453b0ee9 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -205,11 +205,13 @@ def __init__( name: str, providers: Optional[Union[Provider, List[Provider]]] = None, verbose: bool = False, + **metadata, ): super().__init__(name, verbose=verbose) self.getter = getter self.providers = providers if providers is None or isinstance(providers, list) else [providers] + self.metadata = metadata def __contains__(self, item): """Contains is always ``True`` for an ``ExternalRegistry`` as we can't know whether the getter will fail @@ -228,7 +230,10 @@ def get( fn = functools.partial(self.getter, key) if self.providers is not None: fn = print_provider_info(key, self.providers, fn) - return fn + + if not with_metadata: + return fn + return {"fn": fn, "metadata": self.metadata} def available_keys(self) -> List[str]: """Since we don't know the available keys, just give a generic message.""" @@ -242,7 +247,12 @@ class ConcatRegistry(FlashRegistry): def __init__(self, *registries: FlashRegistry): super().__init__( - ",".join({registry.name for registry in registries}), + ",".join( + { + registry.name + for registry in sorted(registries, key=lambda r: 1 if isinstance(r, ExternalRegistry) else 0) + } + ), verbose=any(registry._verbose for registry in registries), ) diff --git a/flash/core/utilities/embedder.py b/flash/core/utilities/embedder.py index dca570680e..1744bab393 100644 --- a/flash/core/utilities/embedder.py +++ b/flash/core/utilities/embedder.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional from pytorch_lightning import LightningModule +from torch.utils.data import DataLoader, Sampler +import flash +from flash.core.data.io.input import InputBase +from flash.core.data.io.input_transform import InputTransform from flash.core.model import Task @@ -33,6 +37,32 @@ def __init__(self, model: LightningModule, layer: str): self._handle = None self._out = None + def process_predict_dataset( + self, + dataset: InputBase, + batch_size: int, + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + persistent_workers: bool = False, + input_transform: Optional[InputTransform] = None, + trainer: Optional["flash.Trainer"] = None, + ) -> DataLoader: + return self.model.process_predict_dataset( + dataset, + batch_size, + num_workers, + pin_memory, + shuffle, + drop_last, + sampler, + persistent_workers, + input_transform, + trainer, + ) + def _make_hook(self): def hook(_, __, output): self._out = output diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index 942e86e429..b6eed0e8b4 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -31,6 +31,7 @@ def __str__(self): _TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models") _DINO = Provider("Facebook Research/dino", "https://github.com/facebookresearch/dino") +_CLIP = Provider("OpenAI/CLIP", "https://github.com/openai/CLIP") _ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision") _TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision") _ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index ad846a3279..03148f1e7f 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -530,7 +530,7 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: batch[DataKeys.PREDS] = Task.predict_step( - self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + self._task, batch[DataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx ) return batch diff --git a/flash/image/classification/backbones/__init__.py b/flash/image/classification/backbones/__init__.py index db068b42b5..5cfb0a65b8 100644 --- a/flash/image/classification/backbones/__init__.py +++ b/flash/image/classification/backbones/__init__.py @@ -1,4 +1,5 @@ -from flash.core.registry import FlashRegistry # noqa: F401 +from flash.core.registry import FlashRegistry +from flash.image.classification.backbones.clip import register_clip_backbones # noqa: F401 from flash.image.classification.backbones.resnet import register_resnet_backbones # noqa: F401 from flash.image.classification.backbones.timm import register_timm_backbones # noqa: F401 from flash.image.classification.backbones.torchvision import ( # noqa: F401 @@ -12,6 +13,7 @@ register_resnet_backbones(IMAGE_CLASSIFIER_BACKBONES) register_dino_backbones(IMAGE_CLASSIFIER_BACKBONES) +register_clip_backbones(IMAGE_CLASSIFIER_BACKBONES) register_mobilenet_vgg_backbones(IMAGE_CLASSIFIER_BACKBONES) register_resnext_model(IMAGE_CLASSIFIER_BACKBONES) diff --git a/flash/image/classification/backbones/clip.py b/flash/image/classification/backbones/clip.py new file mode 100644 index 0000000000..2d70088af2 --- /dev/null +++ b/flash/image/classification/backbones/clip.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch +from torch import nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.providers import _CLIP +from flash.core.utilities.url_error import catch_url_error + +# Paper: Learning Transferable Visual Models From Natural Language Supervision +# https://arxiv.org/abs/2103.00020 from Alec Radford et. al. (26 Feb 2021) +# weights from https://github.com/openai/CLIP + + +_CLIP_MODELS = { + "RN50": "resnet50", + "RN101": "resnet101", + "RN50x4": "resrnet50x4", + "RN50x16": "resrnet50x16", + "RN50x64": "resrnet50x64", + "ViT_B_32": "vitb32", + "ViT_B_16": "vitb16", + "ViT_L_14": "vitl14", + "ViT_L_14_336px": "vitl14@336px", +} + + +class _CLIPWrapper(nn.Module): + def __init__(self, clip_model: nn.Module): + super().__init__() + + self.clip_model = clip_model + + def forward(self, x): + return self.clip_model.encode_image(x) + + +def _load_clip(model_name: str, **kwargs): + backbone, _ = torch.hub.load("openai/CLIP:main", model_name) + return _CLIPWrapper(backbone), backbone.visual.output_dim + + +def register_clip_backbones(register: FlashRegistry): + for clip_model_name, flash_model_name in _CLIP_MODELS.items(): + register(catch_url_error(partial(_load_clip, clip_model_name)), f"clip_{flash_model_name}", providers=_CLIP) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index ccd1ce3b16..4ea7ac14e6 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -210,4 +210,4 @@ def serve( @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" - assert history[-1]["val_jaccardindex"] > 0.2 + assert history[-1]["val_jaccardindex"] > 0.1 diff --git a/flash/text/classification/adapters.py b/flash/text/classification/adapters.py new file mode 100644 index 0000000000..a6ebd23bf5 --- /dev/null +++ b/flash/text/classification/adapters.py @@ -0,0 +1,165 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from dataclasses import dataclass +from types import FunctionType +from typing import Any, Callable, Dict + +import torch +from torch import Tensor + +from flash.core.adapter import Adapter, AdapterTask +from flash.core.data.io.input import DataKeys +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE +from flash.image.classification.heads import IMAGE_CLASSIFIER_HEADS +from flash.text.classification.collate import TextClassificationCollate + +if _TRANSFORMERS_AVAILABLE: + from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput + + +class HuggingFaceAdapter(Adapter): + def __init__(self, backbone, num_classes: int, max_length: int = 128): + super().__init__() + + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" + # disable HF thousand warnings + warnings.simplefilter("ignore") + # set os environ variable for multiprocesses + os.environ["PYTHONWARNINGS"] = "ignore" + + self.model, tokenizer = backbone(num_classes) + self.collate_fn = TextClassificationCollate(tokenizer, max_length=max_length) + + @classmethod + def from_task( + cls, + task: AdapterTask, + backbone: str, + num_classes: int, + **kwargs, + ) -> Adapter: + adapter = cls(backbone, num_classes, **kwargs) + adapter.__dict__["_task"] = task + return adapter + + @property + def backbone(self): + return self.model.base_model + + def forward(self, batch: Dict[str, Tensor]): + result = self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) + if isinstance(result, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): + result = result.logits + return result + + def training_step(self, batch: Any, batch_idx: int) -> Any: + target = batch.pop(DataKeys.TARGET) + batch = (batch, target) + return Task.training_step(self._task, batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + target = batch.pop(DataKeys.TARGET) + batch = (batch, target) + return Task.validation_step(self._task, batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + target = batch.pop(DataKeys.TARGET) + batch = (batch, target) + return Task.test_step(self._task, batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) + + +@dataclass +class GenericCollate: + + tokenizer: Callable[[str], Any] + + @staticmethod + def to_tensor(sample: Dict[str, Any]) -> Dict[str, Any]: + tensor_sample = {} + for key in sample: + if key is DataKeys.METADATA: + tensor_sample[key] = sample[key] + else: + tensor_sample[key] = torch.tensor(sample[key]) + return tensor_sample + + def tokenize(self, sample): + sample[DataKeys.INPUT] = self.tokenizer(sample[DataKeys.INPUT]) + return sample + + def __call__(self, samples): + return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0].keys()})) + + +class GenericAdapter(Adapter): + + # TODO: Move IMAGE_CLASSIFIIER_HEADS out for general classification tasks + heads: FlashRegistry = IMAGE_CLASSIFIER_HEADS + + def __init__(self, backbone, num_classes: int, max_length: int = 128, head="linear"): + super().__init__() + + self.backbone, tokenizer, num_features = backbone() + + self.collate_fn = GenericCollate(tokenizer) + + if isinstance(head, str): + head = self.heads.get(head)(num_features=num_features, num_classes=num_classes) + else: + head = head(num_features, num_classes) if isinstance(head, FunctionType) else head + + self.head = head + + @classmethod + def from_task( + cls, + task: AdapterTask, + backbone: str, + num_classes: int, + **kwargs, + ) -> Adapter: + adapter = cls(backbone, num_classes, **kwargs) + adapter.__dict__["_task"] = task + return adapter + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) + return Task.training_step(self._task, batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) + return Task.validation_step(self._task, batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) + return Task.test_step(self._task, batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch[DataKeys.PREDS] = Task.predict_step( + self._task, batch[DataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx + ) + return batch + + def forward(self, x) -> Tensor: + x = self.backbone(x) + if x.dim() == 4: + x = x.mean(-1).mean(-1) + return self.head(x) diff --git a/flash/text/classification/backbones/__init__.py b/flash/text/classification/backbones/__init__.py new file mode 100644 index 0000000000..90bb4ef015 --- /dev/null +++ b/flash/text/classification/backbones/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry +from flash.text.classification.backbones.clip import CLIP_BACKBONES +from flash.text.classification.backbones.huggingface import HUGGINGFACE_BACKBONES + +TEXT_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + CLIP_BACKBONES + HUGGINGFACE_BACKBONES diff --git a/flash/text/classification/backbones/clip.py b/flash/text/classification/backbones/clip.py new file mode 100644 index 0000000000..5f9d98ea27 --- /dev/null +++ b/flash/text/classification/backbones/clip.py @@ -0,0 +1,67 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch +from torch import nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.providers import _CLIP +from flash.core.utilities.url_error import catch_url_error +from flash.text.classification.adapters import GenericAdapter + +# Paper: Learning Transferable Visual Models From Natural Language Supervision +# https://arxiv.org/abs/2103.00020 from Alec Radford et. al. (26 Feb 2021) +# weights from https://github.com/openai/CLIP + + +_CLIP_MODELS = { + "RN50": "resnet50", + "RN101": "resnet101", + "RN50x4": "resrnet50x4", + "RN50x16": "resrnet50x16", + "RN50x64": "resrnet50x64", + "ViT_B_32": "vitb32", + "ViT_B_16": "vitb16", + "ViT_L_14": "vitl14", + "ViT_L_14_336px": "vitl14@336px", +} + + +class _CLIPWrapper(nn.Module): + def __init__(self, clip_model: nn.Module): + super().__init__() + + self.clip_model = clip_model + + def forward(self, x): + return self.clip_model.encode_text(x) + + +def _load_clip(model_name: str, **kwargs): + backbone, _ = torch.hub.load("openai/CLIP:main", model_name) + tokenizer = torch.hub.load("openai/CLIP:main", "tokenize") + tokenizer = partial(tokenizer, truncate=True) + return _CLIPWrapper(backbone), tokenizer, backbone.visual.output_dim + + +CLIP_BACKBONES = FlashRegistry("backbones") + +for clip_model_name, flash_model_name in _CLIP_MODELS.items(): + CLIP_BACKBONES( + catch_url_error(partial(_load_clip, clip_model_name)), + f"clip_{flash_model_name}", + providers=_CLIP, + adapter=GenericAdapter, + ) diff --git a/flash/text/classification/backbones.py b/flash/text/classification/backbones/huggingface.py similarity index 74% rename from flash/text/classification/backbones.py rename to flash/text/classification/backbones/huggingface.py index 0a150feaf7..b9381405fb 100644 --- a/flash/text/classification/backbones.py +++ b/flash/text/classification/backbones/huggingface.py @@ -19,16 +19,24 @@ from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE +from flash.text.classification.adapters import HuggingFaceAdapter if _TRANSFORMERS_AVAILABLE: from transformers import AutoModelForSequenceClassification -TEXT_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + +def load_hugingface(backbone: str, num_classes: int): + model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) + return model, backbone + + +HUGGINGFACE_BACKBONES = FlashRegistry("backbones") if _TRANSFORMERS_AVAILABLE: - HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES = ExternalRegistry( - getter=AutoModelForSequenceClassification.from_pretrained, + + HUGGINGFACE_BACKBONES = ExternalRegistry( + getter=load_hugingface, name="backbones", providers=_HUGGINGFACE, + adapter=HuggingFaceAdapter, ) - TEXT_CLASSIFIER_BACKBONES += HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 971347c8ae..89a1273836 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -11,20 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings from typing import Any, Dict, List, Optional, Type, Union from pytorch_lightning import Callback -from torch import Tensor -from flash.core.classification import ClassificationTask -from flash.core.data.io.input import DataKeys, ServeInput +from flash.core.classification import ClassificationAdapterTask +from flash.core.data.io.input import ServeInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.registry import FlashRegistry from flash.core.serve import Composition -from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE, requires +from flash.core.utilities.imports import requires from flash.core.utilities.types import ( INPUT_TRANSFORM_TYPE, LOSS_FN_TYPE, @@ -33,15 +30,11 @@ OPTIMIZER_TYPE, ) from flash.text.classification.backbones import TEXT_CLASSIFIER_BACKBONES -from flash.text.classification.collate import TextClassificationCollate from flash.text.input import TextDeserializer from flash.text.ort_callback import ORTCallback -if _TRANSFORMERS_AVAILABLE: - from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput - -class TextClassifier(ClassificationTask): +class TextClassifier(ClassificationAdapterTask): """The ``TextClassifier`` is a :class:`~flash.Task` for classifying text. For more details, see :ref:`text_classification`. The ``TextClassifier`` also supports multi-label classification with ``multi_label=True``. For more details, see :ref:`text_classification_multi_label`. @@ -78,51 +71,36 @@ def __init__( learning_rate: Optional[float] = None, multi_label: bool = False, enable_ort: bool = False, + **kwargs, ): self.save_hyperparameters() if labels is not None and num_classes is None: num_classes = len(labels) - os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" - # disable HF thousand warnings - warnings.simplefilter("ignore") - # set os environ variable for multiprocesses - os.environ["PYTHONWARNINGS"] = "ignore" + metadata = self.backbones.get(backbone, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + backbone=metadata["fn"], + num_classes=num_classes, + max_length=max_length, + **kwargs, + ) super().__init__( - num_classes=num_classes, - model=None, + adapter, loss_fn=loss_fn, + learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, - learning_rate=learning_rate, multi_label=multi_label, + num_classes=num_classes, labels=labels, ) + self.enable_ort = enable_ort self.max_length = max_length - self.collate_fn = TextClassificationCollate(backbone=backbone, max_length=max_length) - self.model = self.backbones.get(backbone)(num_labels=num_classes) - - @property - def backbone(self): - return self.model.base_model - - def forward(self, batch: Dict[str, Tensor]): - result = self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) - if isinstance(result, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): - result = result.logits - return result - - def step(self, batch, batch_idx, metrics) -> dict: - target = batch.pop(DataKeys.TARGET) - batch = (batch, target) - return super().step(batch, batch_idx, metrics) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return self(batch) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" diff --git a/flash_examples/serve/text_classification/inference_server.py b/flash_examples/serve/text_classification/inference_server.py index 40623dcb58..29baf51ad3 100644 --- a/flash_examples/serve/text_classification/inference_server.py +++ b/flash_examples/serve/text_classification/inference_server.py @@ -13,5 +13,5 @@ # limitations under the License. from flash.text import TextClassifier -model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/text_classification_model.pt") +model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.9.0/text_classification_model.pt") model.serve() diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 8e083abf28..a3eac90c52 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -5,3 +5,5 @@ Pillow>=7.2 albumentations>=1.0 pystiche==1.* segmentation-models-pytorch>=0.2.0 +ftfy +regex diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index 1c1c18535f..c61d6fb591 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,6 +1,9 @@ +torchvision sentencepiece>=0.1.95 filelock transformers>=4.5 torchmetrics[text]>=0.5.1 datasets>=1.8 sentence-transformers +ftfy +regex diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 08c80d9a41..5b701fdf4f 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -244,7 +244,7 @@ def test_classification_task_trainer_predict(tmpdir): ), pytest.param( TextClassifier, - "0.7.0/text_classification_model.pt", + "0.9.0/text_classification_model.pt", marks=pytest.mark.skipif( not _TEXT_TESTING, reason="text packages aren't installed", diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 8d6c2be563..ade0f45bc6 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -61,6 +61,7 @@ class TestImageClassifier(TaskTester): {"backbone": "vit_small_patch16_224"}, {"backbone": "resnet18", "head": "linear"}, {"backbone": "resnet18", "head": torch.nn.Linear(512, 2)}, + {"backbone": "clip_resnet50"}, ], ) ], @@ -72,6 +73,7 @@ class TestImageClassifier(TaskTester): {"backbone": "vit_small_patch16_224"}, {"backbone": "resnet18", "head": "linear"}, {"backbone": "resnet18", "head": torch.nn.Linear(512, 2)}, + {"backbone": "clip_resnet50"}, ], ) ], @@ -83,6 +85,7 @@ class TestImageClassifier(TaskTester): {"backbone": "vit_small_patch16_224"}, {"backbone": "resnet18", "head": "linear"}, {"backbone": "resnet18", "head": torch.nn.Linear(512, 2)}, + {"backbone": "clip_resnet50"}, ], ) ], diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 8cbc632b77..0187483b41 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -51,11 +51,28 @@ class TestTextClassifier(TaskTester): {"enable_ort": True}, marks=pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed."), ), + {"backbone": "clip_resnet50"}, + ], + ) + ], + "test_val": [ + pytest.mark.parametrize( + "task_kwargs", + [ + {}, + {"backbone": "clip_resnet50"}, + ], + ) + ], + "test_test": [ + pytest.mark.parametrize( + "task_kwargs", + [ + {}, + {"backbone": "clip_resnet50"}, ], ) ], - "test_val": [pytest.mark.parametrize("task_kwargs", [{}])], - "test_test": [pytest.mark.parametrize("task_kwargs", [{}])], "test_cli": [pytest.mark.parametrize("extra_args", ([], ["from_toxic"]))], }