This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: thomas chaton <[email protected]>
- Loading branch information
1 parent
1ddd556
commit eb09e26
Showing
12 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.text.embedding.model import TextEmbedder # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from flash.core.registry import ExternalRegistry, FlashRegistry | ||
from flash.core.utilities.imports import _TEXT_AVAILABLE | ||
from flash.core.utilities.providers import _HUGGINGFACE | ||
|
||
if _TEXT_AVAILABLE: | ||
from transformers import AutoModel | ||
|
||
HUGGINGFACE_BACKBONES = ExternalRegistry( | ||
AutoModel.from_pretrained, | ||
"backbones", | ||
_HUGGINGFACE, | ||
) | ||
else: | ||
HUGGINGFACE_BACKBONES = FlashRegistry("backbones") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
import os | ||
import warnings | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
from pytorch_lightning import Callback | ||
|
||
from flash.core.integrations.transformers.states import TransformersBackboneState | ||
from flash.core.model import Task | ||
from flash.core.registry import FlashRegistry, print_provider_info | ||
from flash.core.utilities.imports import _TEXT_AVAILABLE | ||
from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS | ||
from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES | ||
from flash.text.ort_callback import ORTCallback | ||
|
||
if _TEXT_AVAILABLE: | ||
from sentence_transformers.models import Pooling | ||
|
||
Pooling = print_provider_info("Pooling", _SENTENCE_TRANSFORMERS, Pooling) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TextEmbedder(Task): | ||
"""The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation. | ||
For more details, see `embeddings`. | ||
You can change the backbone to any question answering model from `UKPLab/sentence-transformers | ||
<https://github.com/UKPLab/sentence-transformers>`_ using the ``backbone`` | ||
argument. | ||
Args: | ||
backbone: backbone model to use for the task. | ||
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training | ||
""" | ||
|
||
required_extras: str = "text" | ||
|
||
backbones: FlashRegistry = HUGGINGFACE_BACKBONES | ||
|
||
def __init__( | ||
self, | ||
backbone: str = "sentence-transformers/all-MiniLM-L6-v2", | ||
tokenizer_backbone: Optional[str] = None, | ||
tokenizer_kwargs: Optional[Dict[str, Any]] = None, | ||
enable_ort: bool = False, | ||
): | ||
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" | ||
# disable HF thousand warnings | ||
warnings.simplefilter("ignore") | ||
# set os environ variable for multiprocesses | ||
os.environ["PYTHONWARNINGS"] = "ignore" | ||
super().__init__() | ||
|
||
if tokenizer_backbone is None: | ||
tokenizer_backbone = backbone | ||
self.set_state(TransformersBackboneState(tokenizer_backbone, tokenizer_kwargs=tokenizer_kwargs)) | ||
self.model = self.backbones.get(backbone)() | ||
self.pooling = Pooling(self.model.config.hidden_size) | ||
self.enable_ort = enable_ort | ||
|
||
def training_step(self, batch: Any, batch_idx: int) -> Any: | ||
raise NotImplementedError("Training a `TextEmbedder` is not supported. Use a different text task instead.") | ||
|
||
def validation_step(self, batch: Any, batch_idx: int) -> Any: | ||
raise NotImplementedError("Validating a `TextEmbedder` is not supported. Use a different text task instead.") | ||
|
||
def test_step(self, batch: Any, batch_idx: int) -> Any: | ||
raise NotImplementedError("Testing a `TextEmbedder` is not supported. Use a different text task instead.") | ||
|
||
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | ||
"""Adapted from sentence-transformers: | ||
https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Transformer.py#L45 | ||
""" | ||
|
||
trans_features = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} | ||
if "token_type_ids" in batch: | ||
trans_features["token_type_ids"] = batch["token_type_ids"] | ||
|
||
output_states = self.model(**trans_features, return_dict=False) | ||
output_tokens = output_states[0] | ||
|
||
batch.update({"token_embeddings": output_tokens, "attention_mask": batch["attention_mask"]}) | ||
|
||
return self.pooling(batch)["sentence_embedding"] | ||
|
||
def configure_callbacks(self) -> List[Callback]: | ||
callbacks = super().configure_callbacks() or [] | ||
if self.enable_ort: | ||
callbacks.append(ORTCallback()) | ||
return callbacks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
|
||
import flash | ||
from flash.text import TextClassificationData, TextEmbedder | ||
|
||
# 1. Create the DataModule | ||
datamodule = TextClassificationData.from_lists( | ||
predict_data=[ | ||
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", | ||
"The worst movie in the history of cinema.", | ||
"I come from Bulgaria where it 's almost impossible to have a tornado.", | ||
] | ||
) | ||
|
||
# 2. Load a previously trained TextEmbedder | ||
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2") | ||
|
||
# 3. Generate embeddings for the first 3 graphs | ||
trainer = flash.Trainer(gpus=torch.cuda.device_count()) | ||
predictions = trainer.predict(model, datamodule=datamodule) | ||
print(predictions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ filelock | |
transformers>=4.5 | ||
torchmetrics[text]>=0.5.1 | ||
datasets>=1.8,<1.13 | ||
sentence-transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
import flash | ||
from flash.text import TextClassificationData, TextEmbedder | ||
from tests.helpers.utils import _TEXT_TESTING | ||
|
||
# ======== Mock data ======== | ||
|
||
predict_data = [ | ||
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", | ||
"The worst movie in the history of cinema.", | ||
"I come from Bulgaria where it 's almost impossible to have a tornado.", | ||
] | ||
# ============================== | ||
|
||
TEST_BACKBONE = "sentence-transformers/all-MiniLM-L6-v2" # super small model for testing | ||
|
||
|
||
@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") | ||
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") | ||
def test_predict(tmpdir): | ||
datamodule = TextClassificationData.from_lists(predict_data=predict_data) | ||
model = TextEmbedder(backbone=TEST_BACKBONE) | ||
|
||
trainer = flash.Trainer(gpus=torch.cuda.device_count()) | ||
predictions = trainer.predict(model, datamodule=datamodule) | ||
assert [t.size() for t in predictions[0]] == [torch.Size([384]), torch.Size([384]), torch.Size([384])] |