diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index afdcf1ec..243c1306 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,8 @@ jobs: run: | python -m pip install --no-cache-dir --upgrade pip python -m pip install --no-cache-dir ${{ matrix.requirements }} + python -m spacy download en_core_web_lg + python -m spacy download en_core_web_sm if: steps.restore-cache.outputs.cache-hit != 'true' - name: Install the checked-out setfit diff --git a/docs/source/en/api/main.mdx b/docs/source/en/api/main.mdx index ac2b77e4..a65b3db4 100644 --- a/docs/source/en/api/main.mdx +++ b/docs/source/en/api/main.mdx @@ -6,3 +6,7 @@ # SetFitHead [[autodoc]] SetFitHead + +# AbsaModel + +[[autodoc]] AbsaModel \ No newline at end of file diff --git a/docs/source/en/api/trainer.mdx b/docs/source/en/api/trainer.mdx index 4b605dc8..3e3d39d1 100644 --- a/docs/source/en/api/trainer.mdx +++ b/docs/source/en/api/trainer.mdx @@ -5,4 +5,8 @@ # DistillationTrainer -[[autodoc]] DistillationTrainer \ No newline at end of file +[[autodoc]] DistillationTrainer + +# AbsaTrainer + +[[autodoc]] AbsaTrainer \ No newline at end of file diff --git a/setup.py b/setup.py index dcd5a8ea..7079d145 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,18 @@ MAINTAINER_EMAIL = "lewis@huggingface.co" INTEGRATIONS_REQUIRE = ["optuna"] -REQUIRED_PKGS = ["datasets>=2.3.0", "sentence-transformers>=2.2.1", "evaluate>=0.3.0"] +REQUIRED_PKGS = [ + "datasets>=2.3.0", + "sentence-transformers>=2.2.1", + "evaluate>=0.3.0", + "huggingface_hub>=0.11.0", + "scikit-learn", +] +ABSA_REQUIRE = ["spacy"] QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"] ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"] OPENVINO_REQUIRE = ["hummingbird-ml<0.4.9", "openvino==2022.3.0"] -TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE +TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE + ABSA_REQUIRE DOCS_REQUIRE = ["hf-doc-builder>=0.3.0"] EXTRAS_REQUIRE = { "optuna": INTEGRATIONS_REQUIRE, @@ -23,6 +30,7 @@ "onnx": ONNX_REQUIRE, "openvino": ONNX_REQUIRE + OPENVINO_REQUIRE, "docs": DOCS_REQUIRE, + "absa": ABSA_REQUIRE, } diff --git a/src/setfit/__init__.py b/src/setfit/__init__.py index c36d630d..f131eee0 100644 --- a/src/setfit/__init__.py +++ b/src/setfit/__init__.py @@ -4,6 +4,7 @@ from .data import get_templated_dataset, sample_dataset from .modeling import SetFitHead, SetFitModel +from .span import AbsaModel, AbsaTrainer, AspectExtractor, AspectModel, PolarityModel from .trainer import SetFitTrainer, Trainer from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer from .training_args import TrainingArguments diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 0662d2d3..793b2c72 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -17,6 +17,7 @@ import requests import torch from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from huggingface_hub.utils import validate_hf_hub_args from sentence_transformers import SentenceTransformer, models from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier @@ -74,14 +75,14 @@ ```bibtex @article{{https://doi.org/10.48550/arxiv.2209.11055, -doi = {{10.48550/ARXIV.2209.11055}}, -url = {{https://arxiv.org/abs/2209.11055}}, -author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}}, -keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}}, -title = {{Efficient Few-Shot Learning Without Prompts}}, -publisher = {{arXiv}}, -year = {{2022}}, -copyright = {{Creative Commons Attribution 4.0 International}} + doi = {{10.48550/ARXIV.2209.11055}}, + url = {{https://arxiv.org/abs/2209.11055}}, + author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}}, + keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}}, + title = {{Efficient Few-Shot Learning Without Prompts}}, + publisher = {{arXiv}}, + year = {{2022}}, + copyright = {{Creative Commons Attribution 4.0 International}} }} ``` """ @@ -246,7 +247,6 @@ class SetFitModel(PyTorchModelHubMixin): model_body: Optional[SentenceTransformer] = (None,) model_head: Optional[Union[SetFitHead, LogisticRegression]] = None multi_target_strategy: Optional[str] = None - l2_weight: float = 1e-2 normalize_embeddings: bool = False @property @@ -372,7 +372,7 @@ def _prepare_optimizer( l2_weight: float, ) -> torch.optim.Optimizer: body_learning_rate = body_learning_rate or head_learning_rate - l2_weight = l2_weight or self.l2_weight + l2_weight = l2_weight or 1e-2 optimizer = torch.optim.AdamW( [ { @@ -519,6 +519,15 @@ def predict_proba( outputs = self.model_head.predict_proba(embeddings) return self._output_type_conversion(outputs, as_numpy=as_numpy) + @property + def device(self) -> torch.device: + """Get the Torch device that this model is on. + + Returns: + torch.device: The device that the model is on. + """ + return self.model_body.device + def to(self, device: Union[str, torch.device]) -> "SetFitModel": """Move this SetFitModel to `device`, and then return `self`. This method does not copy. @@ -589,6 +598,7 @@ def _save_pretrained(self, save_directory: Union[Path, str]) -> None: joblib.dump(self.model_head, str(Path(save_directory) / MODEL_HEAD_NAME)) @classmethod + @validate_hf_hub_args def _from_pretrained( cls, model_id: str, @@ -598,13 +608,13 @@ def _from_pretrained( proxies: Optional[Dict] = None, resume_download: Optional[bool] = None, local_files_only: Optional[bool] = None, - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, multi_target_strategy: Optional[str] = None, use_differentiable_head: bool = False, normalize_embeddings: bool = False, **model_kwargs, ) -> "SetFitModel": - model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=use_auth_token) + model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token) target_device = model_body._target_device model_body.to(target_device) # put `model_body` on the target device @@ -628,7 +638,7 @@ def _from_pretrained( force_download=force_download, proxies=proxies, resume_download=resume_download, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, ) except requests.exceptions.RequestException: @@ -641,7 +651,7 @@ def _from_pretrained( if model_head_file is not None: model_head = joblib.load(model_head_file) else: - head_params = model_kwargs.get("head_params", {}) + head_params = model_kwargs.pop("head_params", {}) if use_differentiable_head: if multi_target_strategy is None: use_multitarget = False @@ -677,9 +687,12 @@ def _from_pretrained( else: model_head = clf + # Remove the `transformers` config + model_kwargs.pop("config", None) return cls( model_body=model_body, model_head=model_head, multi_target_strategy=multi_target_strategy, normalize_embeddings=normalize_embeddings, + **model_kwargs, ) diff --git a/src/setfit/span/__init__.py b/src/setfit/span/__init__.py new file mode 100644 index 00000000..7fc6f9db --- /dev/null +++ b/src/setfit/span/__init__.py @@ -0,0 +1,3 @@ +from .aspect_extractor import AspectExtractor +from .modeling import AbsaModel, AspectModel, PolarityModel +from .trainer import AbsaTrainer diff --git a/src/setfit/span/aspect_extractor.py b/src/setfit/span/aspect_extractor.py new file mode 100644 index 00000000..096b9bb6 --- /dev/null +++ b/src/setfit/span/aspect_extractor.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, List, Tuple + + +if TYPE_CHECKING: + from spacy.tokens import Doc + + +class AspectExtractor: + def __init__(self, spacy_model: str) -> None: + super().__init__() + import spacy + + self.nlp = spacy.load(spacy_model) + + def find_groups(self, aspect_mask: List[bool]): + start = None + for idx, flag in enumerate(aspect_mask): + if flag: + if start is None: + start = idx + else: + if start is not None: + yield slice(start, idx) + start = None + if start is not None: + yield slice(start, idx) + + def __call__(self, texts: List[str]) -> Tuple[List["Doc"], List[slice]]: + aspects_list = [] + docs = list(self.nlp.pipe(texts)) + for doc in docs: + aspect_mask = [token.pos_ in ("NOUN", "PROPN") for token in doc] + aspects_list.append(list(self.find_groups(aspect_mask))) + return docs, aspects_list diff --git a/src/setfit/span/model_card_template.md b/src/setfit/span/model_card_template.md new file mode 100644 index 00000000..31ec618f --- /dev/null +++ b/src/setfit/span/model_card_template.md @@ -0,0 +1,64 @@ +--- +license: apache-2.0 +tags: +- setfit +- sentence-transformers +- absa +- token-classification +pipeline_tag: token-classification +--- + +# {{ model_name | default("SetFit ABSA Model", true) }} + +This is a [SetFit ABSA model](https://github.com/huggingface/setfit) that can be used for Aspect Based Sentiment Analysis (ABSA). \ +In particular, this model is in charge of {{ "filtering aspect span candidates" if is_aspect else "classifying aspect polarities"}}. +It has been trained using SetFit, an efficient few-shot learning technique that involves: + +1. Fine-tuning a [Sentence Transformer](https://www.sbert.net) with contrastive learning. +2. Training a classification head with features from the fine-tuned Sentence Transformer. + +This model was trained within the context of a larger system for ABSA, which looks like so: + +1. Use a spaCy model to select possible aspect span candidates. +2. {{ "**" if is_aspect else "" }}Use {{ "this" if is_aspect else "a" }} SetFit model to filter these possible aspect span candidates.{{ "**" if is_aspect else "" }} +3. {{ "**" if not is_aspect else "" }}Use {{ "this" if not is_aspect else "a" }} SetFit model to classify the filtered aspect span candidates.{{ "**" if not is_aspect else "" }} + +## Usage + +To use this model for inference, first install the SetFit library: + +```bash +pip install setfit +``` + +You can then run inference as follows: + +```python +from setfit import AbsaModel + +# Download from Hub and run inference +model = AbsaModel.from_pretrained( + "{{ aspect_model }}", + "{{ polarity_model }}", +) +# Run inference +preds = model([ + "The best pizza outside of Italy and really tasty.", + "The food here is great but the service is terrible", +]) +``` + +## BibTeX entry and citation info + +```bibtex +@article{https://doi.org/10.48550/arxiv.2209.11055, + doi = {10.48550/ARXIV.2209.11055}, + url = {https://arxiv.org/abs/2209.11055}, + author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}, + keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Efficient Few-Shot Learning Without Prompts}, + publisher = {arXiv}, + year = {2022}, + copyright = {Creative Commons Attribution 4.0 International} +} +``` \ No newline at end of file diff --git a/src/setfit/span/modeling.py b/src/setfit/span/modeling.py new file mode 100644 index 00000000..02b0b1dd --- /dev/null +++ b/src/setfit/span/modeling.py @@ -0,0 +1,292 @@ +import json +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import requests +import torch +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import SoftTemporaryDirectory, validate_hf_hub_args +from jinja2 import Environment, FileSystemLoader + +from .. import logging +from ..modeling import SetFitModel +from .aspect_extractor import AspectExtractor + + +if TYPE_CHECKING: + from spacy.tokens import Doc + +logger = logging.get_logger(__name__) + +CONFIG_NAME = "config_span_setfit.json" + + +@dataclass +class SpanSetFitModel(SetFitModel): + span_context: int = 0 + + def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]: + for doc, aspects in zip(docs, aspects_list): + for aspect_slice in aspects: + aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context] + # TODO: Investigate performance difference of different formats + yield aspect.text + ":" + doc.text + + def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]: + inputs_list = list(self.prepend_aspects(docs, aspects_list)) + preds = self.predict(inputs_list, as_numpy=True) + iter_preds = iter(preds) + return [[next(iter_preds) for _ in aspects] for aspects in aspects_list] + + @classmethod + @validate_hf_hub_args + def _from_pretrained( + cls, + model_id: str, + span_context: Optional[int] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: Optional[bool] = None, + proxies: Optional[Dict] = None, + resume_download: Optional[bool] = None, + local_files_only: Optional[bool] = None, + token: Optional[Union[bool, str]] = None, + **model_kwargs, + ) -> "SpanSetFitModel": + config_file: Optional[str] = None + if os.path.isdir(model_id): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except requests.exceptions.RequestException: + pass + + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + model_kwargs.update(config) + + if span_context is not None: + model_kwargs["span_context"] = span_context + + return super(SpanSetFitModel, cls)._from_pretrained( + model_id, + revision, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + token, + **model_kwargs, + ) + + def _save_pretrained(self, save_directory: Union[Path, str]) -> None: + path = os.path.join(save_directory, CONFIG_NAME) + with open(path, "w") as f: + json.dump({"span_context": self.span_context}, f, indent=2) + + super()._save_pretrained(save_directory) + + def create_model_card(self, path: str, model_name: Optional[str] = None) -> None: + """Creates and saves a model card for a SetFit model. + + Args: + path (str): The path to save the model card to. + model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`. + """ + if not os.path.exists(path): + os.makedirs(path) + + # If the model_path is a folder that exists locally, i.e. when create_model_card is called + # via push_to_hub, and the path is in a temporary folder, then we only take the last two + # directories + model_path = Path(model_name) + if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents: + model_name = "/".join(model_path.parts[-2:]) + + environment = Environment(loader=FileSystemLoader(Path(__file__).parent)) + template = environment.get_template("model_card_template.md") + is_aspect = isinstance(self, AspectModel) + aspect_model = "setfit-absa-aspect" + polarity_model = "setfit-absa-polarity" + if model_name is not None: + if is_aspect: + aspect_model = model_name + if model_name.endswith("-aspect"): + polarity_model = model_name[: -len("-aspect")] + "-polarity" + else: + polarity_model = model_name + if model_name.endswith("-polarity"): + aspect_model = model_name[: -len("-polarity")] + "-aspect" + + model_card_content = template.render( + model_name=model_name, is_aspect=is_aspect, aspect_model=aspect_model, polarity_model=polarity_model + ) + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + +class AspectModel(SpanSetFitModel): + # TODO: Assumes binary SetFitModel with 0 == no aspect, 1 == aspect + def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]: + sentence_preds = super().__call__(docs, aspects_list) + return [ + [aspect for aspect, pred in zip(aspects, preds) if pred == 1] + for aspects, preds in zip(aspects_list, sentence_preds) + ] + + +@dataclass +class PolarityModel(SpanSetFitModel): + span_context: int = 3 + + +@dataclass +class AbsaModel: + aspect_extractor: AspectExtractor + aspect_model: AspectModel + polarity_model: PolarityModel + + def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: + is_str = isinstance(inputs, str) + inputs_list = [inputs] if is_str else inputs + docs, aspects_list = self.aspect_extractor(inputs_list) + if sum(aspects_list, []) == []: + return aspects_list + + aspects_list = self.aspect_model(docs, aspects_list) + if sum(aspects_list, []) == []: + return aspects_list + + polarity_list = self.polarity_model(docs, aspects_list) + outputs = [] + for docs, aspects, polarities in zip(docs, aspects_list, polarity_list): + outputs.append( + [ + {"span": docs[aspect_slice].text, "polarity": polarity} + for aspect_slice, polarity in zip(aspects, polarities) + ] + ) + return outputs if not is_str else outputs[0] + + @property + def device(self) -> torch.device: + return self.aspect_model.device + + def to(self, device: Union[str, torch.device]) -> "AbsaModel": + self.aspect_model.to(device) + self.polarity_model.to(device) + + def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: + return self.predict(inputs) + + def save_pretrained( + self, + save_directory: Union[str, Path], + polarity_save_directory: Optional[Union[str, Path]] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + if polarity_save_directory is None: + base_save_directory = Path(save_directory) + save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect") + polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity") + self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: str, + polarity_model_id: Optional[str] = None, + spacy_model: Optional[str] = "en_core_web_lg", + span_contexts: Tuple[Optional[int], Optional[int]] = (None, None), + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict] = None, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[str] = None, + local_files_only: bool = False, + use_differentiable_head: bool = False, + normalize_embeddings: bool = False, + **model_kwargs, + ) -> "AbsaModel": + revision = None + if len(model_id.split("@")) == 2: + model_id, revision = model_id.split("@") + aspect_model = AspectModel.from_pretrained( + model_id, + span_context=span_contexts[0], + revision=revision, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + use_differentiable_head=use_differentiable_head, + normalize_embeddings=normalize_embeddings, + **model_kwargs, + ) + if polarity_model_id: + model_id = polarity_model_id + revision = None + if len(model_id.split("@")) == 2: + model_id, revision = model_id.split("@") + polarity_model = PolarityModel.from_pretrained( + model_id, + span_context=span_contexts[1], + revision=revision, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + use_differentiable_head=use_differentiable_head, + normalize_embeddings=normalize_embeddings, + **model_kwargs, + ) + + aspect_extractor = AspectExtractor(spacy_model=spacy_model) + + return cls(aspect_extractor, aspect_model, polarity_model) + + def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: + if "/" not in repo_id: + raise ValueError( + '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".' + ) + if polarity_repo_id is not None and "/" not in polarity_repo_id: + raise ValueError( + '`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".' + ) + commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model") + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp_dir: + save_directory = Path(tmp_dir) / repo_id + polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id + self.save_pretrained( + save_directory=save_directory, + polarity_save_directory=polarity_save_directory, + push_to_hub=True, + commit_message=commit_message, + **kwargs, + ) diff --git a/src/setfit/span/trainer.py b/src/setfit/span/trainer.py new file mode 100644 index 00000000..477cddf8 --- /dev/null +++ b/src/setfit/span/trainer.py @@ -0,0 +1,316 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import optuna +from datasets import Dataset +from transformers.trainer_callback import TrainerCallback + +from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel, SpanSetFitModel +from setfit.training_args import TrainingArguments + +from .. import logging +from ..trainer import ColumnMappingMixin, Trainer + + +logger = logging.get_logger(__name__) + + +class AbsaTrainer(ColumnMappingMixin): + """Trainer to train a SetFit ABSA model. + + Args: + model (`AbsaModel`): + The AbsaModel model to train. + args (`TrainingArguments`, *optional*): + The training arguments to use. If `polarity_args` is not defined, then `args` is used for both + the aspect and the polarity model. + polarity_args (`TrainingArguments`, *optional*): + The training arguments to use for the polarity model. If not defined, `args` is used for both + the aspect and the polarity model. + train_dataset (`Dataset`): + The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns. + eval_dataset (`Dataset`, *optional*): + The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns. + metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): + The metric to use for evaluation. If a string is provided, we treat it as the metric + name and load it with default settings. + If a callable is provided, it must take two arguments (`y_pred`, `y_test`). + metric_kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". + For example useful for providing an averaging strategy for computing f1 in a multi-label setting. + callbacks: (`List[~transformers.TrainerCallback]`, *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). + If you want to remove one of the default callbacks used, use the `Trainer.remove_callback()` method. + column_mapping (`Dict[str, str]`, *optional*): + A mapping from the column names in the dataset to the column names expected by the model. + The expected format is a dictionary with the following format: + `{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`. + """ + + _REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"} + + def __init__( + self, + model: AbsaModel, + args: Optional[TrainingArguments] = None, + polarity_args: Optional[TrainingArguments] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", + metric_kwargs: Optional[Dict[str, Any]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.model = model + self.aspect_extractor = model.aspect_extractor + + if train_dataset is not None and column_mapping: + train_dataset = self._apply_column_mapping(train_dataset, column_mapping) + aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset( + model.aspect_model, model.polarity_model, train_dataset + ) + if eval_dataset is not None and column_mapping: + eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping) + aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( + model.aspect_model, model.polarity_model, eval_dataset + ) + + self.aspect_trainer = Trainer( + model.aspect_model, + args=args, + train_dataset=aspect_train_dataset, + eval_dataset=aspect_eval_dataset, + metric=metric, + metric_kwargs=metric_kwargs, + callbacks=callbacks, + ) + self.aspect_trainer._set_logs_mapper( + {"eval_embedding_loss": "eval_aspect_embedding_loss", "embedding_loss": "aspect_embedding_loss"} + ) + self.polarity_trainer = Trainer( + model.polarity_model, + args=polarity_args or args, + train_dataset=polarity_train_dataset, + eval_dataset=polarity_eval_dataset, + metric=metric, + metric_kwargs=metric_kwargs, + callbacks=callbacks, + ) + self.polarity_trainer._set_logs_mapper( + {"eval_embedding_loss": "eval_polarity_embedding_loss", "embedding_loss": "polarity_embedding_loss"} + ) + + def preprocess_dataset( + self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset + ) -> Dataset: + if dataset is None: + return dataset, dataset + + # Group by "text" + grouped_data = defaultdict(list) + for sample in dataset: + text = sample.pop("text") + grouped_data[text].append(sample) + + def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]: + find_from = 0 + for _ in range(ordinal + 1): + start_idx = text.index(target, find_from) + find_from = start_idx + 1 + return start_idx, start_idx + len(target) + + docs, aspects_list = self.aspect_extractor(grouped_data.keys()) + intersected_aspect_list = [] + polarity_labels = [] + aspect_labels = [] + for doc, aspects, text in zip(docs, aspects_list, grouped_data): + gold_aspects = [] + gold_polarity_labels = [] + for annotation in grouped_data[text]: + try: + start, end = index_ordinal(text, annotation["span"], annotation["ordinal"]) + except ValueError: + logger.info( + f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. " + "Skipping this sample." + ) + continue + + gold_aspect_span = doc.char_span(start, end) + if gold_aspect_span is None: + continue + gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end)) + gold_polarity_labels.append(annotation["label"]) + + # The Aspect model uses all predicted aspects, with labels depending on whether + # the predicted aspects are indeed true/gold aspects. + aspect_labels.extend([aspect in gold_aspects for aspect in aspects]) + + # The Polarity model uses the intersection of pred and gold aspects, with labels for the gold label. + intersected_aspects = [] + for gold_aspect, gold_label in zip(gold_aspects, gold_polarity_labels): + if gold_aspect in aspects: + intersected_aspects.append(gold_aspect) + polarity_labels.append(gold_label) + intersected_aspect_list.append(intersected_aspects) + + aspect_texts = list(aspect_model.prepend_aspects(docs, aspects_list)) + polarity_texts = list(polarity_model.prepend_aspects(docs, intersected_aspect_list)) + return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict( + {"text": polarity_texts, "label": polarity_labels} + ) + + def train( + self, + args: Optional[TrainingArguments] = None, + polarity_args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Main training entry point. + + Args: + args (`TrainingArguments`, *optional*): + Temporarily change the aspect training arguments for this training call. + polarity_args (`TrainingArguments`, *optional*): + Temporarily change the polarity training arguments for this training call. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + """ + self.train_aspect(args=args, trial=trial, **kwargs) + self.train_polarity(args=polarity_args, trial=trial, **kwargs) + + def train_aspect( + self, + args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Train the aspect model only. + + Args: + args (`TrainingArguments`, *optional*): + Temporarily change the aspect training arguments for this training call. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + """ + self.aspect_trainer.train(args=args, trial=trial, **kwargs) + + def train_polarity( + self, + args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Train the polarity model only. + + Args: + args (`TrainingArguments`, *optional*): + Temporarily change the aspect training arguments for this training call. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + """ + self.polarity_trainer.train(args=args, trial=trial, **kwargs) + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.aspect_trainer.add_callback(callback) + self.polarity_trainer.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`Tuple[~transformer.TrainerCallback]`]: The callbacks removed from the aspect and polarity trainers, if found. + """ + return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.aspect_trainer.remove_callback(callback) + self.polarity_trainer.remove_callback(callback) + + def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: + """Upload model checkpoint to the Hub using `huggingface_hub`. + + See the full list of parameters for your `huggingface_hub` version in the\ + [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub). + + Args: + repo_id (`str`): + The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`. + repo_id (`str`): + The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*, defaults to `False`): + Whether the repository created should be private. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. + If not set, will use the token set when logging in with + `transformers-cli login` (stored in `~/.huggingface`). + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to + the default branch as specified in your repository, which + defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. + Defaults to `False`. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + """ + return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs) + + def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]: + """ + Computes the metrics for a given classifier. + + Args: + dataset (`Dataset`, *optional*): + The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via + the `eval_dataset` argument at `Trainer` initialization. + + Returns: + `Dict[str, Dict[str, float]]`: The evaluation metrics. + """ + aspect_eval_dataset = polarity_eval_dataset = None + if dataset: + aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( + self.model.aspect_model, self.model.polarity_model, dataset + ) + return { + "aspect": self.aspect_trainer.evaluate(aspect_eval_dataset), + "polarity": self.polarity_trainer.evaluate(polarity_eval_dataset), + } diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 848796b6..baecd053 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -12,6 +12,7 @@ from sentence_transformers.datasets import SentenceLabelDataset from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction from sentence_transformers.util import batch_to_device +from sklearn.preprocessing import LabelEncoder from torch import nn from torch.cuda.amp import autocast from torch.utils.data import DataLoader @@ -68,7 +69,70 @@ DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback -class Trainer: +class ColumnMappingMixin: + _REQUIRED_COLUMNS = {"text", "label"} + + def _validate_column_mapping(self, dataset: "Dataset") -> None: + """ + Validates the provided column mapping against the dataset. + """ + column_names = set(dataset.column_names) + if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names): + # Issue #226: load_dataset will automatically assign points to "train" if no split is specified + if column_names == {"train"} and isinstance(dataset, DatasetDict): + raise ValueError( + "SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. " + "Did you mean to select the training split with dataset['train']?" + ) + elif isinstance(dataset, DatasetDict): + raise ValueError( + f"SetFit expected a Dataset, but it got a DatasetDict with the splits {sorted(column_names)}. " + "Did you mean to select one of these splits from the dataset?" + ) + else: + raise ValueError( + f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, " + f"but only the columns {sorted(column_names)} were found. " + "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer." + ) + if self.column_mapping is not None: + missing_columns = self._REQUIRED_COLUMNS.difference(self.column_mapping.values()) + if missing_columns: + raise ValueError( + f"The following columns are missing from the column mapping: {missing_columns}. Please provide a mapping for all required columns." + ) + if not set(self.column_mapping.keys()).issubset(column_names): + raise ValueError( + f"The column mapping expected the columns {sorted(self.column_mapping.keys())} in the dataset, " + f"but the dataset had the columns {sorted(column_names)}." + ) + + def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, str]) -> "Dataset": + """ + Applies the provided column mapping to the dataset, renaming columns accordingly. + Extra features not in the column mapping are prefixed with `"feat_"`. + """ + dataset = dataset.rename_columns( + { + **column_mapping, + **{ + col: f"feat_{col}" + for col in dataset.column_names + if col not in column_mapping and col not in self._REQUIRED_COLUMNS + }, + } + ) + dset_format = dataset.format + dataset = dataset.with_format( + type=dset_format["type"], + columns=dataset.column_names, + output_all_columns=dset_format["output_all_columns"], + **dset_format["format_kwargs"], + ) + return dataset + + +class Trainer(ColumnMappingMixin): """Trainer to train a SetFit model. Args: @@ -91,14 +155,16 @@ class Trainer: metric_kwargs (`Dict[str, Any]`, *optional*): Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". For example useful for providing an averaging strategy for computing f1 in a multi-label setting. + callbacks: (`List[~transformers.TrainerCallback]`, *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). + If you want to remove one of the default callbacks used, use the `Trainer.remove_callback()` method. column_mapping (`Dict[str, str]`, *optional*): A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: `{"text_column_name": "text", "label_column_name: "label"}`. """ - _REQUIRED_COLUMNS = {"text", "label"} - def __init__( self, model: Optional["SetFitModel"] = None, @@ -111,6 +177,8 @@ def __init__( callbacks: Optional[List[TrainerCallback]] = None, column_mapping: Optional[Dict[str, str]] = None, ) -> None: + if args is not None and not isinstance(args, TrainingArguments): + raise ValueError("`args` must be a `TrainingArguments` instance imported from `setfit`.") self.args = args or TrainingArguments() self.train_dataset = train_dataset self.eval_dataset = eval_dataset @@ -118,6 +186,7 @@ def __init__( self.metric = metric self.metric_kwargs = metric_kwargs self.column_mapping = column_mapping + self.logs_mapper = {} # Seed must be set before instantiating the model when using model_init. set_seed(12) @@ -184,61 +253,6 @@ def remove_callback(self, callback): """ self.callback_handler.remove_callback(callback) - def _validate_column_mapping(self, dataset: "Dataset") -> None: - """ - Validates the provided column mapping against the dataset. - """ - column_names = set(dataset.column_names) - if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names): - # Issue #226: load_dataset will automatically assign points to "train" if no split is specified - if column_names == {"train"} and isinstance(dataset, DatasetDict): - raise ValueError( - "SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. " - "Did you mean to select the training split with dataset['train']?" - ) - elif isinstance(dataset, DatasetDict): - raise ValueError( - f"SetFit expected a Dataset, but it got a DatasetDict with the splits {sorted(column_names)}. " - "Did you mean to select one of these splits from the dataset?" - ) - else: - raise ValueError( - f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, " - f"but only the columns {sorted(column_names)} were found. " - "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer." - ) - if self.column_mapping is not None: - missing_columns = self._REQUIRED_COLUMNS.difference(self.column_mapping.values()) - if missing_columns: - raise ValueError( - f"The following columns are missing from the column mapping: {missing_columns}. Please provide a mapping for all required columns." - ) - if not set(self.column_mapping.keys()).issubset(column_names): - raise ValueError( - f"The column mapping expected the columns {sorted(self.column_mapping.keys())} in the dataset, " - f"but the dataset had the columns {sorted(column_names)}." - ) - - def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, str]) -> "Dataset": - """ - Applies the provided column mapping to the dataset, renaming columns accordingly. - Extra features not in the column mapping are prefixed with `"feat_"`. - """ - dataset = dataset.rename_columns( - { - **column_mapping, - **{col: f"feat_{col}" for col in dataset.column_names if col not in column_mapping}, - } - ) - dset_format = dataset.format - dataset = dataset.with_format( - type=dset_format["type"], - columns=dataset.column_names, - output_all_columns=dset_format["output_all_columns"], - **dset_format["format_kwargs"], - ) - return dataset - def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = False) -> None: """Applies a dictionary of hyperparameters to both the trainer and the model @@ -329,7 +343,7 @@ def train( args: Optional[TrainingArguments] = None, trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, **kwargs, - ): + ) -> None: """ Main training entry point. @@ -478,6 +492,7 @@ def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None: logs (`Dict[str, float]`): The values to log. """ + logs = {self.logs_mapper.get(key, key): value for key, value in logs.items()} if self.state.epoch is not None: logs["epoch"] = round(self.state.epoch, 2) @@ -485,6 +500,14 @@ def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None: self.state.log_history.append(output) return self.callback_handler.on_log(args, self.state, self.control, logs) + def _set_logs_mapper(self, logs_mapper: Dict[str, str]) -> None: + """Set the logging mapper. + + Args: + logs_mapper (str): The logging mapper, e.g. {"eval_embedding_loss": "eval_aspect_embedding_loss"}. + """ + self.logs_mapper = logs_mapper + def _train_sentence_transformer( self, model_body: SentenceTransformer, @@ -732,6 +755,8 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]: """ eval_dataset = dataset or self.eval_dataset + if eval_dataset is None: + raise ValueError("No evaluation dataset provided to `Trainer.evaluate` nor the `Trainer` initialzation.") self._validate_column_mapping(eval_dataset) if self.column_mapping is not None: @@ -746,6 +771,13 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]: if isinstance(y_pred, torch.Tensor): y_pred = y_pred.cpu() + # Normalize string outputs + if y_test and isinstance(y_test[0], str): + encoder = LabelEncoder() + encoder.fit(list(y_test) + list(y_pred)) + y_test = encoder.transform(y_test) + y_pred = encoder.transform(y_pred) + if isinstance(self.metric, str): metric_config = "multilabel" if self.model.multi_target_strategy is not None else None metric_fn = evaluate.load(self.metric, config_name=metric_config) @@ -843,7 +875,7 @@ def push_to_hub(self, repo_id: str, **kwargs) -> str: Args: repo_id (`str`): - The full repository ID to push to, e.g. `"tomaarsen/setfit_sst2"`. + The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. commit_message (`str`, *optional*): @@ -873,7 +905,7 @@ def push_to_hub(self, repo_id: str, **kwargs) -> str: """ if "/" not in repo_id: raise ValueError( - '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit_sst2".' + '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-sst2".' ) commit_message = kwargs.pop("commit_message", "Add SetFit model") return self.model.push_to_hub(repo_id, commit_message=commit_message, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index acf5b825..11051223 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,29 @@ import pytest +from datasets import Dataset -from setfit import SetFitModel +from setfit import AbsaModel, SetFitModel @pytest.fixture() def model() -> SetFitModel: return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + + +@pytest.fixture() +def absa_model() -> AbsaModel: + return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + + +@pytest.fixture() +def absa_dataset() -> Dataset: + texts = [ + "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.", + "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.", + "Food is great and inexpensive.", + "Good bagels and good cream cheese.", + "Good bagels and good cream cheese.", + ] + spans = ["food", "ambiance", "Food", "bagels", "cream cheese"] + labels = ["negative", "negative", "positive", "positive", "positive"] + ordinals = [0, 0, 0, 0, 0] + return Dataset.from_dict({"text": texts, "span": spans, "label": labels, "ordinal": ordinals}) diff --git a/tests/span/test_modeling.py b/tests/span/test_modeling.py new file mode 100644 index 00000000..02fd7c3e --- /dev/null +++ b/tests/span/test_modeling.py @@ -0,0 +1,78 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +import torch + +from setfit import AbsaModel +from setfit.span.aspect_extractor import AspectExtractor +from setfit.span.modeling import AspectModel, PolarityModel + + +def test_loading(): + model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + assert isinstance(model, AbsaModel) + assert isinstance(model.aspect_extractor, AspectExtractor) + assert isinstance(model.aspect_model, AspectModel) + assert isinstance(model.polarity_model, PolarityModel) + + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2@6c91e73a51599e35bd1145dfdcd3289215225009", + "sentence-transformers/paraphrase-albert-small-v2", + ) + assert isinstance(model, AbsaModel) + + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", + "sentence-transformers/paraphrase-albert-small-v2@6c91e73a51599e35bd1145dfdcd3289215225009", + ) + assert isinstance(model, AbsaModel) + + with pytest.raises(OSError): + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", spacy_model="not_a_spacy_model" + ) + + model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", normalize_embeddings=True) + assert model.aspect_model.normalize_embeddings + assert model.polarity_model.normalize_embeddings + + aspect_model = AspectModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_context=12) + assert aspect_model.span_context == 12 + polarity_model = PolarityModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_context=12) + assert polarity_model.span_context == 12 + + model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_contexts=(12, None)) + assert model.aspect_model.span_context == 12 + assert model.polarity_model.span_context == 3 # <- default + + +def test_save_load(absa_model: AbsaModel) -> None: + absa_model.polarity_model.span_context = 5 + + with TemporaryDirectory() as tmp_dir: + tmp_dir = str(Path(tmp_dir) / "model") + absa_model.save_pretrained(tmp_dir) + assert (Path(tmp_dir + "-aspect") / "config_span_setfit.json").exists() + assert (Path(tmp_dir + "-polarity") / "config_span_setfit.json").exists() + + fresh_model = AbsaModel.from_pretrained(tmp_dir + "-aspect", tmp_dir + "-polarity") + assert fresh_model.polarity_model.span_context == 5 + + with TemporaryDirectory() as aspect_tmp_dir: + with TemporaryDirectory() as polarity_tmp_dir: + absa_model.save_pretrained(aspect_tmp_dir, polarity_tmp_dir) + assert (Path(aspect_tmp_dir) / "config_span_setfit.json").exists() + assert (Path(polarity_tmp_dir) / "config_span_setfit.json").exists() + + fresh_model = AbsaModel.from_pretrained(aspect_tmp_dir, polarity_tmp_dir) + assert fresh_model.polarity_model.span_context == 5 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to move a model between devices") +def test_to(absa_model: AbsaModel) -> None: + assert absa_model.device.type == "cuda" + absa_model.to("cpu") + assert absa_model.device.type == "cpu" + assert absa_model.aspect_model.device.type == "cpu" + assert absa_model.polarity_model.device.type == "cpu" diff --git a/tests/span/test_trainer.py b/tests/span/test_trainer.py new file mode 100644 index 00000000..f89044dc --- /dev/null +++ b/tests/span/test_trainer.py @@ -0,0 +1,75 @@ +from datasets import Dataset +from transformers import TrainerCallback + +from setfit import AbsaTrainer +from setfit.span.modeling import AbsaModel + + +def test_trainer(absa_model: AbsaModel, absa_dataset: Dataset) -> None: + trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset, eval_dataset=absa_dataset) + trainer.train() + + metrics = trainer.evaluate() + assert "aspect" in metrics + assert "polarity" in metrics + assert "accuracy" in metrics["aspect"] + assert "accuracy" in metrics["polarity"] + assert metrics["aspect"]["accuracy"] > 0.0 + assert metrics["polarity"]["accuracy"] > 0.0 + new_metrics = trainer.evaluate(absa_dataset) + assert metrics == new_metrics + + predict = absa_model.predict("Best pizza outside of Italy and really tasty.") + assert {"span": "pizza", "polarity": "positive"} in predict + predict = absa_model.predict(["Best pizza outside of Italy and really tasty.", "This is another sentence"]) + assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list) + predict = absa_model(["Best pizza outside of Italy and really tasty.", "This is another sentence"]) + assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list) + + +def test_trainer_callbacks(absa_model: AbsaModel) -> None: + trainer = AbsaTrainer(absa_model) + assert len(trainer.aspect_trainer.callback_handler.callbacks) >= 2 + callback_names = {callback.__class__.__name__ for callback in trainer.aspect_trainer.callback_handler.callbacks} + assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names + + class TestCallback(TrainerCallback): + pass + + callback = TestCallback() + trainer.add_callback(callback) + assert len(trainer.aspect_trainer.callback_handler.callbacks) == len(callback_names) + 1 + assert len(trainer.polarity_trainer.callback_handler.callbacks) == len(callback_names) + 1 + assert trainer.aspect_trainer.callback_handler.callbacks[-1] == callback + assert trainer.polarity_trainer.callback_handler.callbacks[-1] == callback + + assert trainer.pop_callback(callback) == (callback, callback) + trainer.add_callback(callback) + assert trainer.aspect_trainer.callback_handler.callbacks[-1] == callback + assert trainer.polarity_trainer.callback_handler.callbacks[-1] == callback + trainer.remove_callback(callback) + assert callback not in trainer.aspect_trainer.callback_handler.callbacks + assert callback not in trainer.polarity_trainer.callback_handler.callbacks + + +def test_train_ordinal_too_high(absa_model: AbsaModel) -> None: + absa_dataset = Dataset.from_dict( + { + "text": [ + "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine." + ], + "span": ["food"], + "label": ["negative"], + "ordinal": [1], + } + ) + AbsaTrainer(absa_model, train_dataset=absa_dataset) + # TODO: Capture warning and test against it. + + +def test_train_column_mapping(absa_model: AbsaModel, absa_dataset: Dataset) -> None: + absa_dataset = absa_dataset.rename_columns({"text": "sentence", "span": "aspect"}) + trainer = AbsaTrainer( + absa_model, train_dataset=absa_dataset, column_mapping={"sentence": "text", "aspect": "span"} + ) + trainer.train() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2c699ea2..8eee4d57 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,15 +1,17 @@ import os -import pathlib import re import tempfile +from pathlib import Path from unittest import TestCase import evaluate import pytest import torch from datasets import Dataset, load_dataset +from pytest import LogCaptureFixture from sentence_transformers import losses from transformers import TrainerCallback +from transformers import TrainingArguments as TransformersTrainingArguments from transformers.testing_utils import require_optuna from transformers.utils.hp_naming import TrialShortNamer @@ -132,7 +134,7 @@ def test_trainer_raises_error_when_dataset_not_split(self): def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self): """Verify that a useful error is raised if we pass an unsplit dataset with only a `train` split to the trainer.""" with tempfile.TemporaryDirectory() as tmpdirname: - path = pathlib.Path(tmpdirname) / "test_dataset_dict_with_train.csv" + path = Path(tmpdirname) / "test_dataset_dict_with_train.csv" path.write_text("label,text\n1,good\n0,terrible\n") dataset = load_dataset("csv", data_files=str(path)) trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) @@ -534,20 +536,20 @@ def test_trainer_warn_freeze(model: SetFitModel): trainer.freeze() -def test_train_with_kwargs(model: SetFitModel): +def test_train_with_kwargs(model: SetFitModel) -> None: train_dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]}) trainer = Trainer(model, train_dataset=train_dataset) with pytest.warns(DeprecationWarning, match="`Trainer.train` does not accept keyword arguments anymore."): trainer.train(num_epochs=5) -def test_train_no_dataset(model: SetFitModel): +def test_train_no_dataset(model: SetFitModel) -> None: trainer = Trainer(model) with pytest.raises(ValueError, match="Training requires a `train_dataset` given to the `Trainer` initialization."): trainer.train() -def test_train_amp_save(model: SetFitModel, tmp_path): +def test_train_amp_save(model: SetFitModel, tmp_path: Path) -> None: args = TrainingArguments(output_dir=tmp_path, use_amp=True, save_steps=5, num_epochs=5) dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]}) trainer = Trainer(model, args=args, train_dataset=dataset, eval_dataset=dataset) @@ -556,7 +558,7 @@ def test_train_amp_save(model: SetFitModel, tmp_path): assert os.listdir(tmp_path) == ["step_5"] -def test_train_load_best(model: SetFitModel, tmp_path, caplog): +def test_train_load_best(model: SetFitModel, tmp_path: Path, caplog: LogCaptureFixture) -> None: args = TrainingArguments( output_dir=tmp_path, save_steps=5, @@ -571,3 +573,21 @@ def test_train_load_best(model: SetFitModel, tmp_path, caplog): trainer.train() assert any("Load pretrained SentenceTransformer" in text for _, _, text in caplog.record_tuples) + + +def test_evaluate_with_strings(model: SetFitModel) -> None: + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": ["positive", "positive", "negative"]}) + trainer = Trainer(model, train_dataset=dataset, eval_dataset=dataset) + trainer.train() + metrics = trainer.evaluate() + assert "accuracy" in metrics + + +def test_trainer_wrong_args(model: SetFitModel, tmp_path: Path) -> None: + args = TransformersTrainingArguments(output_dir=tmp_path) + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]}) + expected = "`args` must be a `TrainingArguments` instance imported from `setfit`." + with pytest.raises(ValueError, match=expected): + Trainer(model, args=args) + with pytest.raises(ValueError, match=expected): + Trainer(model, dataset)