Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add text embedder (#996)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2021
1 parent 1ddd556 commit eb09e26
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased] - YYYY-DD-MM

### Added
- Added `TextEmbedder` task ([#996](https://github.com/PyTorchLightning/lightning-flash/pull/996))

- Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990))

Expand Down
2 changes: 2 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _compare_version(package: str, op, version) -> bool:
_ALBUMENTATIONS_AVAILABLE = _module_available("albumentations")
_BAAL_AVAILABLE = _module_available("baal")
_TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer")
_SENTENCE_TRANSFORMERS_AVAILABLE = _module_available("sentence_transformers")


if _PIL_AVAILABLE:
Expand All @@ -130,6 +131,7 @@ class Image:
_SENTENCEPIECE_AVAILABLE,
_DATASETS_AVAILABLE,
_TM_TEXT_AVAILABLE,
_SENTENCE_TRANSFORMERS_AVAILABLE,
]
)
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __str__(self):
_LEARN2LEARN = Provider("learnables/learn2learn", "https://github.com/learnables/learn2learn")
_PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche")
_HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers")
_SENTENCE_TRANSFORMERS = Provider("UKPLab/sentence-transformers", "https://github.com/UKPLab/sentence-transformers")
_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq")
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
Expand Down
1 change: 1 addition & 0 deletions flash/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401
from flash.text.embedding import TextEmbedder # noqa: F401
from flash.text.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401
from flash.text.seq2seq import ( # noqa: F401
Seq2SeqData,
Expand Down
1 change: 1 addition & 0 deletions flash/text/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.text.embedding.model import TextEmbedder # noqa: F401
14 changes: 14 additions & 0 deletions flash/text/embedding/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE

if _TEXT_AVAILABLE:
from transformers import AutoModel

HUGGINGFACE_BACKBONES = ExternalRegistry(
AutoModel.from_pretrained,
"backbones",
_HUGGINGFACE,
)
else:
HUGGINGFACE_BACKBONES = FlashRegistry("backbones")
106 changes: 106 additions & 0 deletions flash/text/embedding/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
from typing import Any, Dict, List, Optional

import torch
from pytorch_lightning import Callback

from flash.core.integrations.transformers.states import TransformersBackboneState
from flash.core.model import Task
from flash.core.registry import FlashRegistry, print_provider_info
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS
from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES
from flash.text.ort_callback import ORTCallback

if _TEXT_AVAILABLE:
from sentence_transformers.models import Pooling

Pooling = print_provider_info("Pooling", _SENTENCE_TRANSFORMERS, Pooling)

logger = logging.getLogger(__name__)


class TextEmbedder(Task):
"""The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation.
For more details, see `embeddings`.
You can change the backbone to any question answering model from `UKPLab/sentence-transformers
<https://github.com/UKPLab/sentence-transformers>`_ using the ``backbone``
argument.
Args:
backbone: backbone model to use for the task.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

required_extras: str = "text"

backbones: FlashRegistry = HUGGINGFACE_BACKBONES

def __init__(
self,
backbone: str = "sentence-transformers/all-MiniLM-L6-v2",
tokenizer_backbone: Optional[str] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
enable_ort: bool = False,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
warnings.simplefilter("ignore")
# set os environ variable for multiprocesses
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__()

if tokenizer_backbone is None:
tokenizer_backbone = backbone
self.set_state(TransformersBackboneState(tokenizer_backbone, tokenizer_kwargs=tokenizer_kwargs))
self.model = self.backbones.get(backbone)()
self.pooling = Pooling(self.model.config.hidden_size)
self.enable_ort = enable_ort

def training_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Training a `TextEmbedder` is not supported. Use a different text task instead.")

def validation_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Validating a `TextEmbedder` is not supported. Use a different text task instead.")

def test_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Testing a `TextEmbedder` is not supported. Use a different text task instead.")

def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Adapted from sentence-transformers:
https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Transformer.py#L45
"""

trans_features = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
if "token_type_ids" in batch:
trans_features["token_type_ids"] = batch["token_type_ids"]

output_states = self.model(**trans_features, return_dict=False)
output_tokens = output_states[0]

batch.update({"token_embeddings": output_tokens, "attention_mask": batch["attention_mask"]})

return self.pooling(batch)["sentence_embedding"]

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
if self.enable_ort:
callbacks.append(ORTCallback())
return callbacks
34 changes: 34 additions & 0 deletions flash_examples/text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder

# 1. Create the DataModule
datamodule = TextClassificationData.from_lists(
predict_data=[
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado.",
]
)

# 2. Load a previously trained TextEmbedder
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2")

# 3. Generate embeddings for the first 3 graphs
trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
1 change: 1 addition & 0 deletions requirements/datatype_text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ filelock
transformers>=4.5
torchmetrics[text]>=0.5.1
datasets>=1.8,<1.13
sentence-transformers
4 changes: 4 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@
"text_classification.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
),
pytest.param(
"text_embedder.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
),
# pytest.param(
# "text_classification_multi_label.py",
# marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
Expand Down
Empty file.
43 changes: 43 additions & 0 deletions tests/text/embedding/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder
from tests.helpers.utils import _TEXT_TESTING

# ======== Mock data ========

predict_data = [
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado.",
]
# ==============================

TEST_BACKBONE = "sentence-transformers/all-MiniLM-L6-v2" # super small model for testing


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_predict(tmpdir):
datamodule = TextClassificationData.from_lists(predict_data=predict_data)
model = TextEmbedder(backbone=TEST_BACKBONE)

trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
assert [t.size() for t in predictions[0]] == [torch.Size([384]), torch.Size([384]), torch.Size([384])]

0 comments on commit eb09e26

Please sign in to comment.