Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ jobs:
run: |
python -m pip install --no-cache-dir --upgrade pip
python -m pip install --no-cache-dir ${{ matrix.requirements }}
python -m spacy download en_core_web_lg
python -m spacy download en_core_web_sm
if: steps.restore-cache.outputs.cache-hit != 'true'

- name: Install the checked-out setfit
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/api/main.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
# SetFitHead

[[autodoc]] SetFitHead

# AbsaModel

[[autodoc]] AbsaModel
6 changes: 5 additions & 1 deletion docs/source/en/api/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@

# DistillationTrainer

[[autodoc]] DistillationTrainer
[[autodoc]] DistillationTrainer

# AbsaTrainer

[[autodoc]] AbsaTrainer
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@
MAINTAINER_EMAIL = "lewis@huggingface.co"

INTEGRATIONS_REQUIRE = ["optuna"]
REQUIRED_PKGS = ["datasets>=2.3.0", "sentence-transformers>=2.2.1", "evaluate>=0.3.0"]
REQUIRED_PKGS = [
"datasets>=2.3.0",
"sentence-transformers>=2.2.1",
"evaluate>=0.3.0",
"huggingface_hub>=0.11.0",
"scikit-learn",
]
ABSA_REQUIRE = ["spacy"]
QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]
ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"]
OPENVINO_REQUIRE = ["hummingbird-ml<0.4.9", "openvino==2022.3.0"]
TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE
TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE + ABSA_REQUIRE
DOCS_REQUIRE = ["hf-doc-builder>=0.3.0"]
EXTRAS_REQUIRE = {
"optuna": INTEGRATIONS_REQUIRE,
Expand All @@ -23,6 +30,7 @@
"onnx": ONNX_REQUIRE,
"openvino": ONNX_REQUIRE + OPENVINO_REQUIRE,
"docs": DOCS_REQUIRE,
"absa": ABSA_REQUIRE,
}


Expand Down
1 change: 1 addition & 0 deletions src/setfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .data import get_templated_dataset, sample_dataset
from .modeling import SetFitHead, SetFitModel
from .span import AbsaModel, AbsaTrainer, AspectExtractor, AspectModel, PolarityModel
from .trainer import SetFitTrainer, Trainer
from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer
from .training_args import TrainingArguments
Expand Down
41 changes: 27 additions & 14 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import requests
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from sentence_transformers import SentenceTransformer, models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
Expand Down Expand Up @@ -74,14 +75,14 @@

```bibtex
@article{{https://doi.org/10.48550/arxiv.2209.11055,
doi = {{10.48550/ARXIV.2209.11055}},
url = {{https://arxiv.org/abs/2209.11055}},
author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}},
keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}},
title = {{Efficient Few-Shot Learning Without Prompts}},
publisher = {{arXiv}},
year = {{2022}},
copyright = {{Creative Commons Attribution 4.0 International}}
doi = {{10.48550/ARXIV.2209.11055}},
url = {{https://arxiv.org/abs/2209.11055}},
author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}},
keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}},
title = {{Efficient Few-Shot Learning Without Prompts}},
publisher = {{arXiv}},
year = {{2022}},
copyright = {{Creative Commons Attribution 4.0 International}}
}}
```
"""
Expand Down Expand Up @@ -246,7 +247,6 @@ class SetFitModel(PyTorchModelHubMixin):
model_body: Optional[SentenceTransformer] = (None,)
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None
multi_target_strategy: Optional[str] = None
l2_weight: float = 1e-2
normalize_embeddings: bool = False

@property
Expand Down Expand Up @@ -372,7 +372,7 @@ def _prepare_optimizer(
l2_weight: float,
) -> torch.optim.Optimizer:
body_learning_rate = body_learning_rate or head_learning_rate
l2_weight = l2_weight or self.l2_weight
l2_weight = l2_weight or 1e-2
optimizer = torch.optim.AdamW(
[
{
Expand Down Expand Up @@ -519,6 +519,15 @@ def predict_proba(
outputs = self.model_head.predict_proba(embeddings)
return self._output_type_conversion(outputs, as_numpy=as_numpy)

@property
def device(self) -> torch.device:
"""Get the Torch device that this model is on.

Returns:
torch.device: The device that the model is on.
"""
return self.model_body.device

def to(self, device: Union[str, torch.device]) -> "SetFitModel":
"""Move this SetFitModel to `device`, and then return `self`. This method does not copy.

Expand Down Expand Up @@ -589,6 +598,7 @@ def _save_pretrained(self, save_directory: Union[Path, str]) -> None:
joblib.dump(self.model_head, str(Path(save_directory) / MODEL_HEAD_NAME))

@classmethod
@validate_hf_hub_args
def _from_pretrained(
cls,
model_id: str,
Expand All @@ -598,13 +608,13 @@ def _from_pretrained(
proxies: Optional[Dict] = None,
resume_download: Optional[bool] = None,
local_files_only: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
multi_target_strategy: Optional[str] = None,
use_differentiable_head: bool = False,
normalize_embeddings: bool = False,
**model_kwargs,
) -> "SetFitModel":
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=use_auth_token)
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token)
target_device = model_body._target_device
model_body.to(target_device) # put `model_body` on the target device

Expand All @@ -628,7 +638,7 @@ def _from_pretrained(
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
use_auth_token=use_auth_token,
token=token,
local_files_only=local_files_only,
)
except requests.exceptions.RequestException:
Expand All @@ -641,7 +651,7 @@ def _from_pretrained(
if model_head_file is not None:
model_head = joblib.load(model_head_file)
else:
head_params = model_kwargs.get("head_params", {})
head_params = model_kwargs.pop("head_params", {})
if use_differentiable_head:
if multi_target_strategy is None:
use_multitarget = False
Expand Down Expand Up @@ -677,9 +687,12 @@ def _from_pretrained(
else:
model_head = clf

# Remove the `transformers` config
model_kwargs.pop("config", None)
return cls(
model_body=model_body,
model_head=model_head,
multi_target_strategy=multi_target_strategy,
normalize_embeddings=normalize_embeddings,
**model_kwargs,
)
3 changes: 3 additions & 0 deletions src/setfit/span/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .aspect_extractor import AspectExtractor
from .modeling import AbsaModel, AspectModel, PolarityModel
from .trainer import AbsaTrainer
34 changes: 34 additions & 0 deletions src/setfit/span/aspect_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import TYPE_CHECKING, List, Tuple


if TYPE_CHECKING:
from spacy.tokens import Doc


class AspectExtractor:
def __init__(self, spacy_model: str) -> None:
super().__init__()
import spacy

self.nlp = spacy.load(spacy_model)

def find_groups(self, aspect_mask: List[bool]):
start = None
for idx, flag in enumerate(aspect_mask):
if flag:
if start is None:
start = idx
else:
if start is not None:
yield slice(start, idx)
start = None
if start is not None:
yield slice(start, idx)

def __call__(self, texts: List[str]) -> Tuple[List["Doc"], List[slice]]:
aspects_list = []
docs = list(self.nlp.pipe(texts))
for doc in docs:
aspect_mask = [token.pos_ in ("NOUN", "PROPN") for token in doc]
aspects_list.append(list(self.find_groups(aspect_mask)))
return docs, aspects_list
64 changes: 64 additions & 0 deletions src/setfit/span/model_card_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
license: apache-2.0
tags:
- setfit
- sentence-transformers
- absa
- token-classification
pipeline_tag: token-classification
---

# {{ model_name | default("SetFit ABSA Model", true) }}

This is a [SetFit ABSA model](https://github.com/huggingface/setfit) that can be used for Aspect Based Sentiment Analysis (ABSA). \
In particular, this model is in charge of {{ "filtering aspect span candidates" if is_aspect else "classifying aspect polarities"}}.
It has been trained using SetFit, an efficient few-shot learning technique that involves:

1. Fine-tuning a [Sentence Transformer](https://www.sbert.net) with contrastive learning.
2. Training a classification head with features from the fine-tuned Sentence Transformer.

This model was trained within the context of a larger system for ABSA, which looks like so:

1. Use a spaCy model to select possible aspect span candidates.
2. {{ "**" if is_aspect else "" }}Use {{ "this" if is_aspect else "a" }} SetFit model to filter these possible aspect span candidates.{{ "**" if is_aspect else "" }}
3. {{ "**" if not is_aspect else "" }}Use {{ "this" if not is_aspect else "a" }} SetFit model to classify the filtered aspect span candidates.{{ "**" if not is_aspect else "" }}

## Usage

To use this model for inference, first install the SetFit library:

```bash
pip install setfit
```

You can then run inference as follows:

```python
from setfit import AbsaModel

# Download from Hub and run inference
model = AbsaModel.from_pretrained(
"{{ aspect_model }}",
"{{ polarity_model }}",
)
# Run inference
preds = model([
"The best pizza outside of Italy and really tasty.",
"The food here is great but the service is terrible",
])
```

## BibTeX entry and citation info

```bibtex
@article{https://doi.org/10.48550/arxiv.2209.11055,
doi = {10.48550/ARXIV.2209.11055},
url = {https://arxiv.org/abs/2209.11055},
author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Efficient Few-Shot Learning Without Prompts},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
```
Loading