diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 9ced4d45..b3cdcd6b 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -5,9 +5,12 @@ on: branches: - main - v*-release + - v*-pre pull_request: branches: - main + - v*-pre + workflow_dispatch: jobs: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index afdcf1ec..45dccb7f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,9 +5,12 @@ on: branches: - main - v*-release + - v*-pre pull_request: branches: - main + - v*-pre + workflow_dispatch: jobs: @@ -40,6 +43,8 @@ jobs: run: | python -m pip install --no-cache-dir --upgrade pip python -m pip install --no-cache-dir ${{ matrix.requirements }} + python -m spacy download en_core_web_lg + python -m spacy download en_core_web_sm if: steps.restore-cache.outputs.cache-hit != 'true' - name: Install the checked-out setfit diff --git a/.gitignore b/.gitignore index a13745c3..6e89ff50 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,7 @@ scripts/tfew/run_tmux.sh # macOS .DS_Store .vscode/settings.json + +# Common SetFit Trainer logging folders +wandb +runs/ \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..69617566 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include src/setfit/span/model_card_template.md \ No newline at end of file diff --git a/README.md b/README.md index f79a1001..4a8156e6 100644 --- a/README.md +++ b/README.md @@ -39,16 +39,14 @@ The examples below provide a quick overview on the various features supported in `setfit` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes: * `SetFitModel`: a wrapper that combines a pretrained body from `sentence_transformers` and a classification head from either [`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) or [`SetFitHead`](https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py) (a differentiable head built upon `PyTorch` with similar APIs to `sentence_transformers`). -* `SetFitTrainer`: a helper class that wraps the fine-tuning process of SetFit. +* `Trainer`: a helper class that wraps the fine-tuning process of SetFit. Here is an end-to-end example using a classification head from `scikit-learn`: ```python from datasets import load_dataset -from sentence_transformers.losses import CosineSimilarityLoss - -from setfit import SetFitModel, SetFitTrainer, sample_dataset +from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset # Load a dataset from the Hugging Face Hub @@ -61,17 +59,19 @@ eval_dataset = dataset["validation"] # Load a SetFit model from Hub model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") -# Create trainer -trainer = SetFitTrainer( +args = TrainingArguments( + batch_size=16, + num_iterations=20, # The number of text pairs to generate for contrastive learning + num_epochs=1 # The number of epochs to use for contrastive learning +) + +trainer = Trainer( model=model, + args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, metric="accuracy", - batch_size=16, - num_iterations=20, # The number of text pairs to generate for contrastive learning - num_epochs=1, # The number of epochs to use for contrastive learning - column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer + column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer ) # Train and evaluate @@ -81,7 +81,7 @@ metrics = trainer.evaluate() # Push model to the Hub trainer.push_to_hub("my-awesome-setfit-model") -# Download from Hub and run inference +# Download from Hub model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model") # Run inference preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]) @@ -92,9 +92,7 @@ Here is an end-to-end example using `SetFitHead`: ```python from datasets import load_dataset -from sentence_transformers.losses import CosineSimilarityLoss - -from setfit import SetFitModel, SetFitTrainer, sample_dataset +from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset # Load a dataset from the Hugging Face Hub @@ -103,6 +101,7 @@ dataset = load_dataset("sst2") # Simulate the few-shot regime by sampling 8 examples per class train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8) eval_dataset = dataset["validation"] +num_classes = 2 # Load a SetFit model from Hub model = SetFitModel.from_pretrained( @@ -111,36 +110,26 @@ model = SetFitModel.from_pretrained( head_params={"out_features": num_classes}, ) -# Create trainer -trainer = SetFitTrainer( +args = TrainingArguments( + body_learning_rate=2e-5, + head_learning_rate=1e-2, + batch_size=16, + num_iterations=20, # The number of text pairs to generate for contrastive learning + num_epochs=(1, 25), # For finetuning the embeddings and training the classifier, respectively + l2_weight=0.0, + end_to_end=False, # Don't train the classifier end-to-end, i.e. only train the head +) + +trainer = Trainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, metric="accuracy", - batch_size=16, - num_iterations=20, # The number of text pairs to generate for contrastive learning - num_epochs=1, # The number of epochs to use for contrastive learning - column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer + column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer ) # Train and evaluate -trainer.freeze() # Freeze the head -trainer.train() # Train only the body - -# Unfreeze the head and freeze the body -> head-only training -trainer.unfreeze(keep_body_frozen=True) -# or -# Unfreeze the head and unfreeze the body -> end-to-end training -trainer.unfreeze(keep_body_frozen=False) - -trainer.train( - num_epochs=25, # The number of epochs to train the head or the whole model (body and head) - batch_size=16, - body_learning_rate=1e-5, # The body's learning rate - learning_rate=1e-2, # The head's learning rate - l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01. -) +trainer.train() metrics = trainer.evaluate() # Push model to the Hub @@ -175,7 +164,7 @@ This will initialise a multilabel classification head from `sklearn` - the follo * `multi-output`: uses a `MultiOutputClassifier` head. * `classifier-chain`: uses a `ClassifierChain` head. -From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual. +From here, you can instantiate a `Trainer` using the same example above, and train it as usual. #### Example using the differentiable `SetFitHead`: @@ -196,7 +185,6 @@ model = SetFitModel.from_pretrained( SetFit can also be applied to scenarios where no labels are available. To do so, create a synthetic dataset of training examples: ```python -from datasets import Dataset from setfit import get_templated_dataset candidate_labels = ["negative", "positive"] @@ -206,22 +194,22 @@ train_dataset = get_templated_dataset(candidate_labels=candidate_labels, sample_ This will create examples of the form `"This sentence is {}"`, where the `{}` is filled in with one of the candidate labels. From here you can train a SetFit model as usual: ```python -from setfit import SetFitModel, SetFitTrainer +from setfit import SetFitModel, Trainer model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") -trainer = SetFitTrainer( +trainer = Trainer( model=model, train_dataset=train_dataset ) trainer.train() ``` -We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with Bart), while being 5x faster to generate predictions with. +We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with BART), while being 5x faster to generate predictions with. ### Running hyperparameter search -`SetFitTrainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend: +`Trainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend: ```bash python -m pip install setfit[optuna] @@ -267,23 +255,23 @@ def hp_space(trial): # Training parameters **Note:** In practice, we found `num_iterations` to be the most important hyperparameter for the contrastive learning process. -The next step is to instantiate a `SetFitTrainer` and call `hyperparameter_search()`: +The next step is to instantiate a `Trainer` and call `hyperparameter_search()`: ```python from datasets import Dataset -from setfit import SetFitTrainer +from setfit import Trainer dataset = Dataset.from_dict( - {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} - ) + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} +) -trainer = SetFitTrainer( +trainer = Trainer( train_dataset=dataset, eval_dataset=dataset, model_init=model_init, column_mapping={"text_new": "text", "label_new": "label"}, ) -best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=20) +best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5) ``` Finally, you can apply the hyperparameters you found to the trainer, and lock in the optimal model, before training for @@ -300,9 +288,8 @@ If you have access to unlabeled data, you can use knowledge distillation to comp ```python from datasets import load_dataset -from sentence_transformers.losses import CosineSimilarityLoss - -from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer, sample_dataset +from setfit import SetFitModel, Trainer, DistillationTrainer, sample_dataset +from setfit.training_args import TrainingArguments # Load a dataset from the Hugging Face Hub dataset = load_dataset("ag_news") @@ -320,34 +307,37 @@ teacher_model = SetFitModel.from_pretrained( ) # Create trainer for teacher model -teacher_trainer = SetFitTrainer( +teacher_trainer = Trainer( model=teacher_model, train_dataset=train_dataset_teacher, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, ) # Train teacher model teacher_trainer.train() +teacher_metrics = teacher_trainer.evaluate() # Load small student model student_model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2") +args = TrainingArguments( + batch_size=16, + num_iterations=20, + num_epochs=1 +) + # Create trainer for knowledge distillation -student_trainer = DistillationSetFitTrainer( +student_trainer = DistillationTrainer( teacher_model=teacher_model, - train_dataset=train_dataset_student, student_model=student_model, + args=args, + train_dataset=train_dataset_student, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, - metric="accuracy", - batch_size=16, - num_iterations=20, - num_epochs=1, ) # Train student with knowledge distillation student_trainer.train() +student_metrics = student_trainer.evaluate() ``` @@ -403,7 +393,8 @@ make style && make quality ## Citation -```@misc{https://doi.org/10.48550/arxiv.2209.11055, +``` +@misc{https://doi.org/10.48550/arxiv.2209.11055, doi = {10.48550/ARXIV.2209.11055}, url = {https://arxiv.org/abs/2209.11055}, author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}, @@ -411,5 +402,6 @@ make style && make quality title = {Efficient Few-Shot Learning Without Prompts}, publisher = {arXiv}, year = {2022}, - copyright = {Creative Commons Attribution 4.0 International}} + copyright = {Creative Commons Attribution 4.0 International} +} ``` diff --git a/docs/source/en/api/main.mdx b/docs/source/en/api/main.mdx index ac2b77e4..a65b3db4 100644 --- a/docs/source/en/api/main.mdx +++ b/docs/source/en/api/main.mdx @@ -6,3 +6,7 @@ # SetFitHead [[autodoc]] SetFitHead + +# AbsaModel + +[[autodoc]] AbsaModel \ No newline at end of file diff --git a/docs/source/en/api/trainer.mdx b/docs/source/en/api/trainer.mdx index a51df833..3e3d39d1 100644 --- a/docs/source/en/api/trainer.mdx +++ b/docs/source/en/api/trainer.mdx @@ -1,8 +1,12 @@ -# SetFitTrainer +# Trainer -[[autodoc]] SetFitTrainer +[[autodoc]] Trainer -# DistillationSetFitTrainer +# DistillationTrainer -[[autodoc]] DistillationSetFitTrainer \ No newline at end of file +[[autodoc]] DistillationTrainer + +# AbsaTrainer + +[[autodoc]] AbsaTrainer \ No newline at end of file diff --git a/docs/source/en/quickstart.mdx b/docs/source/en/quickstart.mdx index cc10ba5b..9e46933b 100644 --- a/docs/source/en/quickstart.mdx +++ b/docs/source/en/quickstart.mdx @@ -11,16 +11,14 @@ The examples below provide a quick overview on the various features supported in `setfit` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes: * `SetFitModel`: a wrapper that combines a pretrained body from `sentence_transformers` and a classification head from either [`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) or [`SetFitHead`](https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py) (a differentiable head built upon `PyTorch` with similar APIs to `sentence_transformers`). -* `SetFitTrainer`: a helper class that wraps the fine-tuning process of SetFit. +* `Trainer`: a helper class that wraps the fine-tuning process of SetFit. Here is an end-to-end example using a classification head from `scikit-learn`: ```python from datasets import load_dataset -from sentence_transformers.losses import CosineSimilarityLoss - -from setfit import SetFitModel, SetFitTrainer, sample_dataset +from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset # Load a dataset from the Hugging Face Hub @@ -33,17 +31,19 @@ eval_dataset = dataset["validation"] # Load a SetFit model from Hub model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") -# Create trainer -trainer = SetFitTrainer( +args = TrainingArguments( + batch_size=16, + num_iterations=20, # The number of text pairs to generate for contrastive learning + num_epochs=1 # The number of epochs to use for contrastive learning +) + +trainer = Trainer( model=model, + args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, metric="accuracy", - batch_size=16, - num_iterations=20, # The number of text pairs to generate for contrastive learning - num_epochs=1, # The number of epochs to use for contrastive learning - column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer + column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer ) # Train and evaluate @@ -53,7 +53,7 @@ metrics = trainer.evaluate() # Push model to the Hub trainer.push_to_hub("my-awesome-setfit-model") -# Download from Hub and run inference +# Download from Hub model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model") # Run inference preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]) @@ -64,9 +64,7 @@ Here is an end-to-end example using `SetFitHead`: ```python from datasets import load_dataset -from sentence_transformers.losses import CosineSimilarityLoss - -from setfit import SetFitModel, SetFitTrainer, sample_dataset +from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset # Load a dataset from the Hugging Face Hub @@ -75,6 +73,7 @@ dataset = load_dataset("sst2") # Simulate the few-shot regime by sampling 8 examples per class train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8) eval_dataset = dataset["validation"] +num_classes = 2 # Load a SetFit model from Hub model = SetFitModel.from_pretrained( @@ -83,36 +82,26 @@ model = SetFitModel.from_pretrained( head_params={"out_features": num_classes}, ) -# Create trainer -trainer = SetFitTrainer( +args = TrainingArguments( + body_learning_rate=2e-5, + head_learning_rate=1e-2, + batch_size=16, + num_iterations=20, # The number of text pairs to generate for contrastive learning + num_epochs=(1, 25), # For finetuning the embeddings and training the classifier, respectively + l2_weight=0.0, + end_to_end=False, # Don't train the classifier end-to-end, i.e. only train the head +) + +trainer = Trainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, metric="accuracy", - batch_size=16, - num_iterations=20, # The number of text pairs to generate for contrastive learning - num_epochs=1, # The number of epochs to use for contrastive learning - column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer + column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer ) # Train and evaluate -trainer.freeze() # Freeze the head -trainer.train() # Train only the body - -# Unfreeze the head and freeze the body -> head-only training -trainer.unfreeze(keep_body_frozen=True) -# or -# Unfreeze the head and unfreeze the body -> end-to-end training -trainer.unfreeze(keep_body_frozen=False) - -trainer.train( - num_epochs=25, # The number of epochs to train the head or the whole model (body and head) - batch_size=16, - body_learning_rate=1e-5, # The body's learning rate - learning_rate=1e-2, # The head's learning rate - l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01. -) +trainer.train() metrics = trainer.evaluate() # Push model to the Hub @@ -147,7 +136,7 @@ This will initialise a multilabel classification head from `sklearn` - the follo * `multi-output`: uses a `MultiOutputClassifier` head. * `classifier-chain`: uses a `ClassifierChain` head. -From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual. +From here, you can instantiate a `Trainer` using the same example above, and train it as usual. #### Example using the differentiable `SetFitHead`: @@ -168,7 +157,6 @@ model = SetFitModel.from_pretrained( SetFit can also be applied to scenarios where no labels are available. To do so, create a synthetic dataset of training examples: ```python -from datasets import Dataset from setfit import get_templated_dataset candidate_labels = ["negative", "positive"] @@ -178,22 +166,22 @@ train_dataset = get_templated_dataset(candidate_labels=candidate_labels, sample_ This will create examples of the form `"This sentence is {}"`, where the `{}` is filled in with one of the candidate labels. From here you can train a SetFit model as usual: ```python -from setfit import SetFitModel, SetFitTrainer +from setfit import SetFitModel, Trainer model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") -trainer = SetFitTrainer( +trainer = Trainer( model=model, train_dataset=train_dataset ) trainer.train() ``` -We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with Bart), while being 5x faster to generate predictions with. +We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with BART), while being 5x faster to generate predictions with. ### Running hyperparameter search -`SetFitTrainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend: +`Trainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend: ```bash python -m pip install setfit[optuna] @@ -239,23 +227,23 @@ def hp_space(trial): # Training parameters **Note:** In practice, we found `num_iterations` to be the most important hyperparameter for the contrastive learning process. -The next step is to instantiate a `SetFitTrainer` and call `hyperparameter_search()`: +The next step is to instantiate a `Trainer` and call `hyperparameter_search()`: ```python from datasets import Dataset -from setfit import SetFitTrainer +from setfit import Trainer dataset = Dataset.from_dict( - {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} - ) + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} +) -trainer = SetFitTrainer( +trainer = Trainer( train_dataset=dataset, eval_dataset=dataset, model_init=model_init, column_mapping={"text_new": "text", "label_new": "label"}, ) -best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=20) +best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5) ``` Finally, you can apply the hyperparameters you found to the trainer, and lock in the optimal model, before training for @@ -272,9 +260,8 @@ If you have access to unlabeled data, you can use knowledge distillation to comp ```python from datasets import load_dataset -from sentence_transformers.losses import CosineSimilarityLoss - -from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer, sample_dataset +from setfit import SetFitModel, Trainer, DistillationTrainer, sample_dataset +from setfit.training_args import TrainingArguments # Load a dataset from the Hugging Face Hub dataset = load_dataset("ag_news") @@ -292,32 +279,35 @@ teacher_model = SetFitModel.from_pretrained( ) # Create trainer for teacher model -teacher_trainer = SetFitTrainer( +teacher_trainer = Trainer( model=teacher_model, train_dataset=train_dataset_teacher, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, ) # Train teacher model teacher_trainer.train() +teacher_metrics = teacher_trainer.evaluate() # Load small student model student_model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2") +args = TrainingArguments( + batch_size=16, + num_iterations=20, + num_epochs=1 +) + # Create trainer for knowledge distillation -student_trainer = DistillationSetFitTrainer( +student_trainer = DistillationTrainer( teacher_model=teacher_model, - train_dataset=train_dataset_student, student_model=student_model, + args=args, + train_dataset=train_dataset_student, eval_dataset=eval_dataset, - loss_class=CosineSimilarityLoss, - metric="accuracy", - batch_size=16, - num_iterations=20, - num_epochs=1, ) # Train student with knowledge distillation student_trainer.train() +student_metrics = student_trainer.evaluate() ``` \ No newline at end of file diff --git a/scripts/setfit/run_fewshot.py b/scripts/setfit/run_fewshot.py index 1248fddc..08f7023e 100644 --- a/scripts/setfit/run_fewshot.py +++ b/scripts/setfit/run_fewshot.py @@ -59,6 +59,7 @@ def parse_args(): parser.add_argument("--override_results", default=False, action="store_true") parser.add_argument("--keep_body_frozen", default=False, action="store_true") parser.add_argument("--add_data_augmentation", default=False) + parser.add_argument("--evaluation_strategy", default=False) args = parser.parse_args() @@ -148,6 +149,8 @@ def main(): num_epochs=args.num_epochs, num_iterations=args.num_iterations, ) + if not args.evaluation_strategy: + trainer.args.evaluation_strategy = "no" if args.classifier == "pytorch": trainer.freeze() trainer.train() diff --git a/setup.py b/setup.py index dcd5a8ea..bdc32252 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,18 @@ MAINTAINER_EMAIL = "lewis@huggingface.co" INTEGRATIONS_REQUIRE = ["optuna"] -REQUIRED_PKGS = ["datasets>=2.3.0", "sentence-transformers>=2.2.1", "evaluate>=0.3.0"] +REQUIRED_PKGS = [ + "datasets>=2.3.0", + "sentence-transformers>=2.2.1", + "evaluate>=0.3.0", + "huggingface_hub>=0.13.0", + "scikit-learn", +] +ABSA_REQUIRE = ["spacy"] QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"] ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"] OPENVINO_REQUIRE = ["hummingbird-ml<0.4.9", "openvino==2022.3.0"] -TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE +TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE + ABSA_REQUIRE DOCS_REQUIRE = ["hf-doc-builder>=0.3.0"] EXTRAS_REQUIRE = { "optuna": INTEGRATIONS_REQUIRE, @@ -23,6 +30,7 @@ "onnx": ONNX_REQUIRE, "openvino": ONNX_REQUIRE + OPENVINO_REQUIRE, "docs": DOCS_REQUIRE, + "absa": ABSA_REQUIRE, } @@ -46,11 +54,12 @@ def combine_requirements(base_keys): long_description_content_type="text/markdown", maintainer=MAINTAINER, maintainer_email=MAINTAINER_EMAIL, - url="https://github.com/SetFit/setfit", - download_url="https://github.com/SetFit/setfit/tags", + url="https://github.com/huggingface/setfit", + download_url="https://github.com/huggingface/setfit/tags", license="Apache 2.0", package_dir={"": "src"}, packages=find_packages("src"), + include_package_data=True, install_requires=REQUIRED_PKGS, extras_require=EXTRAS_REQUIRE, classifiers=[ diff --git a/src/setfit/__init__.py b/src/setfit/__init__.py index 287d89c5..f131eee0 100644 --- a/src/setfit/__init__.py +++ b/src/setfit/__init__.py @@ -1,6 +1,15 @@ __version__ = "0.8.0.dev0" -from .data import add_templated_examples, get_templated_dataset, sample_dataset +import warnings + +from .data import get_templated_dataset, sample_dataset from .modeling import SetFitHead, SetFitModel -from .trainer import SetFitTrainer -from .trainer_distillation import DistillationSetFitTrainer +from .span import AbsaModel, AbsaTrainer, AspectExtractor, AspectModel, PolarityModel +from .trainer import SetFitTrainer, Trainer +from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer +from .training_args import TrainingArguments + + +# Ensure that DeprecationWarnings are shown by default, as recommended by +# https://docs.python.org/3/library/warnings.html#overriding-the-default-filter +warnings.filterwarnings("default", category=DeprecationWarning) diff --git a/src/setfit/data.py b/src/setfit/data.py index ff5a0c33..2d9cd5f8 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -1,4 +1,3 @@ -import warnings from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import pandas as pd @@ -21,15 +20,6 @@ SAMPLE_SIZES = [2, 4, 8, 16, 32, 64] -def get_augmented_samples(*args, **kwargs) -> None: - warnings.warn( - "`get_augmented_samples` has been deprecated and will be removed in v1.0.0 of SetFit. " - "Please use `get_templated_dataset` instead.", - DeprecationWarning, - stacklevel=2, - ) - - def get_templated_dataset( dataset: Optional[Dataset] = None, candidate_labels: Optional[List[str]] = None, @@ -115,15 +105,6 @@ def get_templated_dataset( return dataset -def add_templated_examples(*args, **kwargs) -> None: - warnings.warn( - "`add_templated_examples` has been deprecated and will be removed in v1.0.0 of SetFit. " - "Please use `get_templated_dataset` instead.", - DeprecationWarning, - stacklevel=2, - ) - - def get_candidate_labels(dataset_name: str, label_names_column: str = "label_text") -> List[str]: dataset = load_dataset(dataset_name, split="train") @@ -170,7 +151,7 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i df = df.groupby(label_column) # sample num_samples, or at least as much as possible - df = df.apply(lambda x: x.sample(min(num_samples, len(x)))) + df = df.apply(lambda x: x.sample(min(num_samples, len(x)), random_state=seed)) df = df.reset_index(drop=True) all_samples = Dataset.from_pandas(df, features=dataset.features) diff --git a/src/setfit/integrations.py b/src/setfit/integrations.py index 94d7161e..a847d753 100644 --- a/src/setfit/integrations.py +++ b/src/setfit/integrations.py @@ -5,10 +5,10 @@ if TYPE_CHECKING: - from .trainer import SetFitTrainer + from .trainer import Trainer -def is_optuna_available(): +def is_optuna_available() -> bool: return importlib.util.find_spec("optuna") is not None @@ -17,7 +17,7 @@ def default_hp_search_backend(): return "optuna" -def run_hp_search_optuna(trainer: "SetFitTrainer", n_trials: int, direction: str, **kwargs) -> BestRun: +def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun: import optuna # Heavily inspired by transformers.integrations.run_hp_search_optuna diff --git a/src/setfit/losses.py b/src/setfit/losses.py new file mode 100644 index 00000000..369c8451 --- /dev/null +++ b/src/setfit/losses.py @@ -0,0 +1,100 @@ +import torch +from torch import nn + + +class SupConLoss(nn.Module): + """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + + It also supports the unsupervised contrastive loss in SimCLR. + """ + + def __init__(self, model, temperature=0.07, contrast_mode="all", base_temperature=0.07): + super(SupConLoss, self).__init__() + self.model = model + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, sentence_features, labels=None, mask=None): + """Computes loss for model. + + If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + + Returns: + A loss scalar. + """ + features = self.model(sentence_features[0])["sentence_embedding"] + + # Normalize embeddings + features = torch.nn.functional.normalize(features, p=2, dim=1) + + # Add n_views dimension + features = torch.unsqueeze(features, 1) + + device = features.device + + if len(features.shape) < 3: + raise ValueError("`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required") + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError("Cannot define both `labels` and `mask`") + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError("Num of labels does not match num of features") + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == "one": + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == "all": + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError("Unknown mode: {}".format(self.contrast_mode)) + + # Compute logits + anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) + # For numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # Tile mask + mask = mask.repeat(anchor_count, contrast_count) + # Mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0, + ) + mask = mask * logits_mask + + # Compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # Compute mean of log-likelihood over positive + mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + + # Loss + loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index f971d952..793b2c72 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -1,8 +1,9 @@ import os import tempfile +import warnings from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union # Google Colab runs on Python 3.7, so we need this to be compatible @@ -15,23 +16,20 @@ import numpy as np import requests import torch -import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin, hf_hub_download -from sentence_transformers import InputExample, SentenceTransformer, models +from huggingface_hub.utils import validate_hf_hub_args +from sentence_transformers import SentenceTransformer, models from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier from sklearn.multioutput import ClassifierChain, MultiOutputClassifier +from torch import nn from torch.utils.data import DataLoader -from tqdm.auto import trange +from tqdm.auto import tqdm, trange from . import logging from .data import SetFitDataset -if TYPE_CHECKING: - from numpy import ndarray - - logging.set_verbosity_info() logger = logging.get_logger(__name__) @@ -77,28 +75,19 @@ ```bibtex @article{{https://doi.org/10.48550/arxiv.2209.11055, -doi = {{10.48550/ARXIV.2209.11055}}, -url = {{https://arxiv.org/abs/2209.11055}}, -author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}}, -keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}}, -title = {{Efficient Few-Shot Learning Without Prompts}}, -publisher = {{arXiv}}, -year = {{2022}}, -copyright = {{Creative Commons Attribution 4.0 International}} + doi = {{10.48550/ARXIV.2209.11055}}, + url = {{https://arxiv.org/abs/2209.11055}}, + author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}}, + keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}}, + title = {{Efficient Few-Shot Learning Without Prompts}}, + publisher = {{arXiv}}, + year = {{2022}}, + copyright = {{Creative Commons Attribution 4.0 International}} }} ``` """ -class SetFitBaseModel: - def __init__(self, model, max_seq_length: int, add_normalization_layer: bool) -> None: - self.model = SentenceTransformer(model) - self.model.max_seq_length = max_seq_length - - if add_normalization_layer: - self.model._modules["2"] = models.Normalize() - - class SetFitHead(models.Dense): """ A SetFit head that supports multi-class classification for end-to-end training. @@ -217,7 +206,7 @@ def predict(self, x_test: torch.Tensor) -> torch.Tensor: return torch.where(probs >= 0.5, 1, 0) return torch.argmax(probs, dim=-1) - def get_loss_fn(self): + def get_loss_fn(self) -> nn.Module: if self.multitarget: # if sigmoid output return torch.nn.BCEWithLogitsLoss() return torch.nn.CrossEntropyLoss() @@ -241,13 +230,13 @@ def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]: } @staticmethod - def _init_weight(module): + def _init_weight(module) -> None: if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) + nn.init.xavier_uniform_(module.weight) if module.bias is not None: - torch.nn.init.constant_(module.bias, 1e-2) + nn.init.constant_(module.bias, 1e-2) - def __repr__(self): + def __repr__(self) -> str: return "SetFitHead({})".format(self.get_config_dict()) @@ -255,22 +244,10 @@ def __repr__(self): class SetFitModel(PyTorchModelHubMixin): """A SetFit model with integration to the Hugging Face Hub.""" - def __init__( - self, - model_body: Optional[SentenceTransformer] = None, - model_head: Optional[Union[SetFitHead, LogisticRegression]] = None, - multi_target_strategy: Optional[str] = None, - l2_weight: float = 1e-2, - normalize_embeddings: bool = False, - ) -> None: - super(SetFitModel, self).__init__() - self.model_body = model_body - self.model_head = model_head - - self.multi_target_strategy = multi_target_strategy - self.l2_weight = l2_weight - - self.normalize_embeddings = normalize_embeddings + model_body: Optional[SentenceTransformer] = (None,) + model_head: Optional[Union[SetFitHead, LogisticRegression]] = None + multi_target_strategy: Optional[str] = None + normalize_embeddings: bool = False @property def has_differentiable_head(self) -> bool: @@ -283,23 +260,46 @@ def fit( y_train: Union[List[int], List[List[int]]], num_epochs: int, batch_size: Optional[int] = None, - learning_rate: Optional[float] = None, body_learning_rate: Optional[float] = None, + head_learning_rate: Optional[float] = None, + end_to_end: bool = False, l2_weight: Optional[float] = None, max_length: Optional[int] = None, - show_progress_bar: Optional[bool] = None, + show_progress_bar: bool = True, ) -> None: + """Train the classifier head, only used if a differentiable PyTorch head is used. + + Args: + x_train (`List[str]`): A list of training sentences. + y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences. + num_epochs (`int`): The number of epochs to train for. + batch_size (`int`, *optional*): The batch size to use. + body_learning_rate (`float`, *optional*): The learning rate for the `SentenceTransformer` body + in the `AdamW` optimizer. Disregarded if `end_to_end=False`. + head_learning_rate (`float`, *optional*): The learning rate for the differentiable torch head + in the `AdamW` optimizer. + end_to_end (`bool`, defaults to `False`): If True, train the entire model end-to-end. + Otherwise, freeze the `SentenceTransformer` body and only train the head. + l2_weight (`float`, *optional*): The l2 weight for both the model body and head + in the `AdamW` optimizer. + max_length (`int`, *optional*): The maximum token length a tokenizer can generate. If not provided, + the maximum length for the `SentenceTransformer` body is used. + show_progress_bar (`bool`, defaults to `True`): Whether to display a progress bar for the training + epochs and iterations. + """ if self.has_differentiable_head: # train with pyTorch device = self.model_body.device self.model_body.train() self.model_head.train() + if not end_to_end: + self.freeze("body") dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length) criterion = self.model_head.get_loss_fn() - optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight) + optimizer = self._prepare_optimizer(head_learning_rate, body_learning_rate, l2_weight) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) for epoch_idx in trange(num_epochs, desc="Epoch", disable=not show_progress_bar): - for batch in dataloader: + for batch in tqdm(dataloader, desc="Iteration", disable=not show_progress_bar, leave=False): features, labels = batch optimizer.zero_grad() @@ -309,15 +309,18 @@ def fit( outputs = self.model_body(features) if self.normalize_embeddings: - outputs = torch.nn.functional.normalize(outputs, p=2, dim=1) + outputs = nn.functional.normalize(outputs, p=2, dim=1) outputs = self.model_head(outputs) logits = outputs["logits"] - loss = criterion(logits, labels) + loss: torch.Tensor = criterion(logits, labels) loss.backward() optimizer.step() scheduler.step() + + if not end_to_end: + self.unfreeze("body") else: # train with sklearn embeddings = self.model_body.encode(x_train, normalize_embeddings=self.normalize_embeddings) self.model_head.fit(embeddings, y_train) @@ -364,12 +367,12 @@ def _prepare_dataloader( def _prepare_optimizer( self, - learning_rate: float, + head_learning_rate: float, body_learning_rate: Optional[float], l2_weight: float, ) -> torch.optim.Optimizer: - body_learning_rate = body_learning_rate or learning_rate - l2_weight = l2_weight or self.l2_weight + body_learning_rate = body_learning_rate or head_learning_rate + l2_weight = l2_weight or 1e-2 optimizer = torch.optim.AdamW( [ { @@ -377,37 +380,79 @@ def _prepare_optimizer( "lr": body_learning_rate, "weight_decay": l2_weight, }, - { - "params": self.model_head.parameters(), - "lr": learning_rate, - "weight_decay": l2_weight, - }, + {"params": self.model_head.parameters(), "lr": head_learning_rate, "weight_decay": l2_weight}, ], ) return optimizer def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None: + """Freeze the model body and/or the head, preventing further training on that component until unfrozen. + + Args: + component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to freeze that component. + If no component is provided, freeze both. Defaults to None. + """ if component is None or component == "body": self._freeze_or_not(self.model_body, to_freeze=True) - if component is None or component == "head": + if (component is None or component == "head") and self.has_differentiable_head: self._freeze_or_not(self.model_head, to_freeze=True) - def unfreeze(self, component: Optional[Literal["body", "head"]] = None) -> None: + def unfreeze( + self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None + ) -> None: + """Unfreeze the model body and/or the head, allowing further training on that component. + + Args: + component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to unfreeze that component. + If no component is provided, unfreeze both. Defaults to None. + keep_body_frozen (`bool`, *optional*): Deprecated argument, use `component` instead. + """ + if keep_body_frozen is not None: + warnings.warn( + "`keep_body_frozen` is deprecated and will be removed in v2.0.0 of SetFit. " + 'Please either pass "head", "body" or no arguments to unfreeze both.', + DeprecationWarning, + stacklevel=2, + ) + # If the body must stay frozen, only unfreeze the head. Eventually, this entire if-branch + # can be removed. + if keep_body_frozen and not component: + component = "head" + if component is None or component == "body": self._freeze_or_not(self.model_body, to_freeze=False) - if component is None or component == "head": + if (component is None or component == "head") and self.has_differentiable_head: self._freeze_or_not(self.model_head, to_freeze=False) - def _freeze_or_not(self, model: torch.nn.Module, to_freeze: bool) -> None: + def _freeze_or_not(self, model: nn.Module, to_freeze: bool) -> None: + """Set `requires_grad=not to_freeze` for all parameters in `model`""" for param in model.parameters(): param.requires_grad = not to_freeze + def encode(self, inputs: List[str], show_progress_bar: Optional[bool] = None) -> Union[torch.Tensor, np.ndarray]: + """Convert input sentences to embeddings using the `SentenceTransformer` body. + + Args: + inputs (`List[str]`): The input sentences to embed. + show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. + + Returns: + Union[torch.Tensor, np.ndarray]: A matrix with shape [INPUT_LENGTH, EMBEDDING_SIZE], as a + torch Tensor if this model has a differentiable Torch head, or otherwise as a numpy array. + """ + return self.model_body.encode( + inputs, + normalize_embeddings=self.normalize_embeddings, + convert_to_tensor=self.has_differentiable_head, + show_progress_bar=show_progress_bar, + ) + def _output_type_conversion( - self, outputs: Union[torch.Tensor, "ndarray"], as_numpy: bool = False - ) -> Union[torch.Tensor, "ndarray"]: + self, outputs: Union[torch.Tensor, np.ndarray], as_numpy: bool = False + ) -> Union[torch.Tensor, np.ndarray]: """Return `outputs` in the desired type: * Numpy array if no differentiable head is used. * Torch tensor if a differentiable head is used. @@ -427,37 +472,74 @@ def _output_type_conversion( return outputs def predict( - self, x_test: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None - ) -> Union[torch.Tensor, "ndarray"]: - embeddings = self.model_body.encode( - x_test, - normalize_embeddings=self.normalize_embeddings, - convert_to_tensor=self.has_differentiable_head, - show_progress_bar=show_progress_bar, - ) + self, inputs: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None + ) -> Union[torch.Tensor, np.ndarray]: + """Predict the various classes. + + Args: + inputs (`List[str]`): The input sentences to predict classes for. + as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead. + show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. + + Example: + >>> model = SetFitModel.from_pretrained(...) + >>> model.predict(["What a boring display", "Exhilarating through and through", "I'm wowed!"]) + tensor([0, 1, 1], dtype=torch.int32) + Returns: + `Union[torch.Tensor, np.ndarray]`: A vector with equal length to the inputs, denoting + to which class each input is predicted to belong. + """ + embeddings = self.encode(inputs, show_progress_bar=show_progress_bar) outputs = self.model_head.predict(embeddings) return self._output_type_conversion(outputs, as_numpy=as_numpy) def predict_proba( - self, x_test: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None - ) -> Union[torch.Tensor, "ndarray"]: - embeddings = self.model_body.encode( - x_test, - normalize_embeddings=self.normalize_embeddings, - convert_to_tensor=self.has_differentiable_head, - show_progress_bar=show_progress_bar, - ) + self, inputs: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None + ) -> Union[torch.Tensor, np.ndarray]: + """Predict the probabilities of the various classes. + Args: + inputs (`List[str]`): The input sentences to predict class probabilities for. + as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead. + show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding. + + Example: + >>> model = SetFitModel.from_pretrained(...) + >>> model.predict_proba(["What a boring display", "Exhilarating through and through", "I'm wowed!"]) + tensor([[0.9367, 0.0633], + [0.0627, 0.9373], + [0.0890, 0.9110]], dtype=torch.float64) + + Returns: + `Union[torch.Tensor, np.ndarray]`: A matrix with shape [INPUT_LENGTH, NUM_CLASSES] denoting + probabilities of predicting an input as a class. + """ + embeddings = self.encode(inputs, show_progress_bar=show_progress_bar) outputs = self.model_head.predict_proba(embeddings) return self._output_type_conversion(outputs, as_numpy=as_numpy) + @property + def device(self) -> torch.device: + """Get the Torch device that this model is on. + + Returns: + torch.device: The device that the model is on. + """ + return self.model_body.device + def to(self, device: Union[str, torch.device]) -> "SetFitModel": """Move this SetFitModel to `device`, and then return `self`. This method does not copy. Args: device (Union[str, torch.device]): The identifier of the device to move the model to. + Example: + + >>> model = SetFitModel.from_pretrained(...) + >>> model.to("cpu") + >>> model(["cats are cute", "dogs are loyal"]) + Returns: SetFitModel: Returns the original model, but now on the desired device. """ @@ -492,7 +574,21 @@ def create_model_card(self, path: str, model_name: Optional[str] = "SetFit Model with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: f.write(model_card_content) - def __call__(self, inputs): + def __call__(self, inputs: List[str]) -> torch.Tensor: + """Predict the various classes. + + Args: + inputs (`List[str]`): The input sentences to predict classes for. + + Example: + >>> model = SetFitModel.from_pretrained(...) + >>> model(["What a boring display", "Exhilarating through and through", "I'm wowed!"]) + tensor([0, 1, 1], dtype=torch.int32) + + Returns: + `torch.Tensor`: A vector with equal length to the inputs, denoting to which class each + input is predicted to belong. + """ return self.predict(inputs) def _save_pretrained(self, save_directory: Union[Path, str]) -> None: @@ -502,6 +598,7 @@ def _save_pretrained(self, save_directory: Union[Path, str]) -> None: joblib.dump(self.model_head, str(Path(save_directory) / MODEL_HEAD_NAME)) @classmethod + @validate_hf_hub_args def _from_pretrained( cls, model_id: str, @@ -511,13 +608,13 @@ def _from_pretrained( proxies: Optional[Dict] = None, resume_download: Optional[bool] = None, local_files_only: Optional[bool] = None, - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, multi_target_strategy: Optional[str] = None, use_differentiable_head: bool = False, normalize_embeddings: bool = False, **model_kwargs, ) -> "SetFitModel": - model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=use_auth_token) + model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token) target_device = model_body._target_device model_body.to(target_device) # put `model_body` on the target device @@ -541,7 +638,7 @@ def _from_pretrained( force_download=force_download, proxies=proxies, resume_download=resume_download, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, ) except requests.exceptions.RequestException: @@ -554,7 +651,7 @@ def _from_pretrained( if model_head_file is not None: model_head = joblib.load(model_head_file) else: - head_params = model_kwargs.get("head_params", {}) + head_params = model_kwargs.pop("head_params", {}) if use_differentiable_head: if multi_target_strategy is None: use_multitarget = False @@ -590,207 +687,12 @@ def _from_pretrained( else: model_head = clf + # Remove the `transformers` config + model_kwargs.pop("config", None) return cls( model_body=model_body, model_head=model_head, multi_target_strategy=multi_target_strategy, normalize_embeddings=normalize_embeddings, + **model_kwargs, ) - - -class SupConLoss(nn.Module): - """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. - - It also supports the unsupervised contrastive loss in SimCLR. - """ - - def __init__(self, model, temperature=0.07, contrast_mode="all", base_temperature=0.07): - super(SupConLoss, self).__init__() - self.model = model - self.temperature = temperature - self.contrast_mode = contrast_mode - self.base_temperature = base_temperature - - def forward(self, sentence_features, labels=None, mask=None): - """Computes loss for model. - - If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: - https://arxiv.org/pdf/2002.05709.pdf - - Args: - features: hidden vector of shape [bsz, n_views, ...]. - labels: ground truth of shape [bsz]. - mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j - has the same class as sample i. Can be asymmetric. - - Returns: - A loss scalar. - """ - features = self.model(sentence_features[0])["sentence_embedding"] - - # Normalize embeddings - features = torch.nn.functional.normalize(features, p=2, dim=1) - - # Add n_views dimension - features = torch.unsqueeze(features, 1) - - device = features.device - - if len(features.shape) < 3: - raise ValueError("`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required") - if len(features.shape) > 3: - features = features.view(features.shape[0], features.shape[1], -1) - - batch_size = features.shape[0] - if labels is not None and mask is not None: - raise ValueError("Cannot define both `labels` and `mask`") - elif labels is None and mask is None: - mask = torch.eye(batch_size, dtype=torch.float32).to(device) - elif labels is not None: - labels = labels.contiguous().view(-1, 1) - if labels.shape[0] != batch_size: - raise ValueError("Num of labels does not match num of features") - mask = torch.eq(labels, labels.T).float().to(device) - else: - mask = mask.float().to(device) - - contrast_count = features.shape[1] - contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) - if self.contrast_mode == "one": - anchor_feature = features[:, 0] - anchor_count = 1 - elif self.contrast_mode == "all": - anchor_feature = contrast_feature - anchor_count = contrast_count - else: - raise ValueError("Unknown mode: {}".format(self.contrast_mode)) - - # Compute logits - anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) - # For numerical stability - logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) - logits = anchor_dot_contrast - logits_max.detach() - - # Tile mask - mask = mask.repeat(anchor_count, contrast_count) - # Mask-out self-contrast cases - logits_mask = torch.scatter( - torch.ones_like(mask), - 1, - torch.arange(batch_size * anchor_count).view(-1, 1).to(device), - 0, - ) - mask = mask * logits_mask - - # Compute log_prob - exp_logits = torch.exp(logits) * logits_mask - log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) - - # Compute mean of log-likelihood over positive - mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) - - # Loss - loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos - loss = loss.view(anchor_count, batch_size).mean() - - return loss - - -def sentence_pairs_generation(sentences, labels, pairs): - # Initialize two empty lists to hold the (sentence, sentence) pairs and - # labels to indicate if a pair is positive or negative - - num_classes = np.unique(labels) - label_to_idx = {x: i for i, x in enumerate(num_classes)} - positive_idxs = [np.where(labels == i)[0] for i in num_classes] - negative_idxs = [np.where(labels != i)[0] for i in num_classes] - - for first_idx in range(len(sentences)): - current_sentence = sentences[first_idx] - label = labels[first_idx] - second_idx = np.random.choice(positive_idxs[label_to_idx[label]]) - positive_sentence = sentences[second_idx] - # Prepare a positive pair and update the sentences and labels - # lists, respectively - pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0)) - - third_idx = np.random.choice(negative_idxs[label_to_idx[label]]) - negative_sentence = sentences[third_idx] - # Prepare a negative pair of sentences and update our lists - pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0)) - # Return a 2-tuple of our sentence pairs and labels - return pairs - - -def sentence_pairs_generation_multilabel(sentences, labels, pairs): - # Initialize two empty lists to hold the (sentence, sentence) pairs and - # labels to indicate if a pair is positive or negative - for first_idx in range(len(sentences)): - current_sentence = sentences[first_idx] - sample_labels = np.where(labels[first_idx, :] == 1)[0] - if len(np.where(labels.dot(labels[first_idx, :].T) == 0)[0]) == 0: - continue - else: - for _label in sample_labels: - second_idx = np.random.choice(np.where(labels[:, _label] == 1)[0]) - positive_sentence = sentences[second_idx] - # Prepare a positive pair and update the sentences and labels - # lists, respectively - pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0)) - - # Search for sample that don't have a label in common with current - # sentence - negative_idx = np.where(labels.dot(labels[first_idx, :].T) == 0)[0] - negative_sentence = sentences[np.random.choice(negative_idx)] - # Prepare a negative pair of sentences and update our lists - pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0)) - # Return a 2-tuple of our sentence pairs and labels - return pairs - - -def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix): - # initialize two empty lists to hold the (sentence, sentence) pairs and - # labels to indicate if a pair is positive or negative - - idx = list(range(len(sentences))) - - for first_idx in range(len(sentences)): - current_sentence = sentences[first_idx] - second_idx = int(np.random.choice([x for x in idx if x != first_idx])) - - cos_sim = float(cos_sim_matrix[first_idx][second_idx]) - paired_sentence = sentences[second_idx] - pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim)) - - third_idx = np.random.choice([x for x in idx if x != first_idx]) - cos_sim = float(cos_sim_matrix[first_idx][third_idx]) - paired_sentence = sentences[third_idx] - pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim)) - - return pairs - - -class SKLearnWrapper: - def __init__(self, st_model=None, clf=None): - self.st_model = st_model - self.clf = clf - - def fit(self, x_train, y_train): - embeddings = self.st_model.encode(x_train) - self.clf.fit(embeddings, y_train) - - def predict(self, x_test): - embeddings = self.st_model.encode(x_test) - return self.clf.predict(embeddings) - - def predict_proba(self, x_test): - embeddings = self.st_model.encode(x_test) - return self.clf.predict_proba(embeddings) - - def save(self, path): - self.st_model.save(path=path) - joblib.dump(self.clf, f"{path}/setfit_head.pkl") - - def load(self, path): - self.st_model = SentenceTransformer(model_name_or_path=path) - self.clf = joblib.load(f"{path}/setfit_head.pkl") diff --git a/src/setfit/pipeline.py b/src/setfit/pipeline.py deleted file mode 100644 index 51e551ff..00000000 --- a/src/setfit/pipeline.py +++ /dev/null @@ -1,12 +0,0 @@ -from .modeling import SKLearnWrapper - - -class SetFitPipeline: - def __init__(self, model_name_or_path) -> None: - base_model = SKLearnWrapper() - base_model.load(model_name_or_path) - self.model = base_model - - def __call__(self, inputs, *args, **kwargs): - model_outputs = self.model.predict(inputs) - return model_outputs diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py new file mode 100644 index 00000000..0eceba1e --- /dev/null +++ b/src/setfit/sampler.py @@ -0,0 +1,156 @@ +from itertools import zip_longest +from typing import Generator, Iterable, List, Optional + +import numpy as np +import torch +from sentence_transformers import InputExample +from torch.utils.data import IterableDataset + +from . import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: + """Generates shuffled pair combinations for any iterable data provided. + + Args: + iterable: data to generate pair combinations from + replacement: enable to include combinations of same samples, + equivalent to itertools.combinations_with_replacement + + Returns: + Generator of shuffled pairs as a tuple + """ + n = len(iterable) + k = 1 if not replacement else 0 + idxs = np.stack(np.triu_indices(n, k), axis=-1) + for i in np.random.RandomState(seed=42).permutation(len(idxs)): + _idx, idx = idxs[i, :] + yield iterable[_idx], iterable[idx] + + +class ContrastiveDataset(IterableDataset): + def __init__( + self, + examples: List[InputExample], + multilabel: bool, + num_iterations: Optional[None] = None, + sampling_strategy: str = "oversampling", + ) -> None: + """Generates positive and negative text pairs for contrastive learning. + + Args: + examples (InputExample): text and labels in a text transformer dataclass + multilabel: set to process "multilabel" labels array + sampling_strategy: "unique", "oversampling", or "undersampling" + num_iterations: if provided explicitly sets the number of pairs to be generated + where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs) + """ + super().__init__() + self.pos_index = 0 + self.neg_index = 0 + self.pos_pairs = [] + self.neg_pairs = [] + self.sentences = np.array([s.texts[0] for s in examples]) + self.labels = np.array([s.label for s in examples]) + self.sentence_labels = list(zip(self.sentences, self.labels)) + + if multilabel: + self.generate_multilabel_pairs() + else: + self.generate_pairs() + + if num_iterations is not None and num_iterations > 0: + self.len_pos_pairs = num_iterations * len(self.sentences) + self.len_neg_pairs = num_iterations * len(self.sentences) + + elif sampling_strategy == "unique": + self.len_pos_pairs = len(self.pos_pairs) + self.len_neg_pairs = len(self.neg_pairs) + + elif sampling_strategy == "undersampling": + self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) + self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) + + elif sampling_strategy == "oversampling": + self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) + self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) + + else: + raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") + + def generate_pairs(self) -> None: + for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): + if _label == label: + self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) + else: + self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) + + def generate_multilabel_pairs(self) -> None: + for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): + if any(np.logical_and(_label, label)): + # logical_and checks if labels are both set for each class + self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) + else: + self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) + + def get_positive_pairs(self) -> List[InputExample]: + pairs = [] + for _ in range(self.len_pos_pairs): + if self.pos_index >= len(self.pos_pairs): + self.pos_index = 0 + pairs.append(self.pos_pairs[self.pos_index]) + self.pos_index += 1 + return pairs + + def get_negative_pairs(self) -> List[InputExample]: + pairs = [] + for _ in range(self.len_neg_pairs): + if self.neg_index >= len(self.neg_pairs): + self.neg_index = 0 + pairs.append(self.neg_pairs[self.neg_index]) + self.neg_index += 1 + return pairs + + def __iter__(self): + for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()): + if pos_pair is not None: + yield pos_pair + if neg_pair is not None: + yield neg_pair + + def __len__(self) -> int: + return self.len_pos_pairs + self.len_neg_pairs + + +class ContrastiveDistillationDataset(ContrastiveDataset): + def __init__( + self, + examples: List[InputExample], + cos_sim_matrix: torch.Tensor, + num_iterations: Optional[None] = None, + sampling_strategy: str = "oversampling", + ) -> None: + self.cos_sim_matrix = cos_sim_matrix + super().__init__( + examples, + multilabel=False, + num_iterations=num_iterations, + sampling_strategy=sampling_strategy, + ) + # Internally we store all pairs in pos_pairs, regardless of sampling strategy. + # After all, without labels, there isn't much of a strategy. + self.sentence_labels = list(enumerate(self.sentences)) + + self.len_neg_pairs = 0 + if num_iterations is not None and num_iterations > 0: + self.len_pos_pairs = num_iterations * len(self.sentences) + else: + self.len_pos_pairs = len(self.pos_pairs) + + def generate_pairs(self) -> None: + for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels): + self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two])) diff --git a/src/setfit/span/__init__.py b/src/setfit/span/__init__.py new file mode 100644 index 00000000..7fc6f9db --- /dev/null +++ b/src/setfit/span/__init__.py @@ -0,0 +1,3 @@ +from .aspect_extractor import AspectExtractor +from .modeling import AbsaModel, AspectModel, PolarityModel +from .trainer import AbsaTrainer diff --git a/src/setfit/span/aspect_extractor.py b/src/setfit/span/aspect_extractor.py new file mode 100644 index 00000000..096b9bb6 --- /dev/null +++ b/src/setfit/span/aspect_extractor.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, List, Tuple + + +if TYPE_CHECKING: + from spacy.tokens import Doc + + +class AspectExtractor: + def __init__(self, spacy_model: str) -> None: + super().__init__() + import spacy + + self.nlp = spacy.load(spacy_model) + + def find_groups(self, aspect_mask: List[bool]): + start = None + for idx, flag in enumerate(aspect_mask): + if flag: + if start is None: + start = idx + else: + if start is not None: + yield slice(start, idx) + start = None + if start is not None: + yield slice(start, idx) + + def __call__(self, texts: List[str]) -> Tuple[List["Doc"], List[slice]]: + aspects_list = [] + docs = list(self.nlp.pipe(texts)) + for doc in docs: + aspect_mask = [token.pos_ in ("NOUN", "PROPN") for token in doc] + aspects_list.append(list(self.find_groups(aspect_mask))) + return docs, aspects_list diff --git a/src/setfit/span/model_card_template.md b/src/setfit/span/model_card_template.md new file mode 100644 index 00000000..31ec618f --- /dev/null +++ b/src/setfit/span/model_card_template.md @@ -0,0 +1,64 @@ +--- +license: apache-2.0 +tags: +- setfit +- sentence-transformers +- absa +- token-classification +pipeline_tag: token-classification +--- + +# {{ model_name | default("SetFit ABSA Model", true) }} + +This is a [SetFit ABSA model](https://github.com/huggingface/setfit) that can be used for Aspect Based Sentiment Analysis (ABSA). \ +In particular, this model is in charge of {{ "filtering aspect span candidates" if is_aspect else "classifying aspect polarities"}}. +It has been trained using SetFit, an efficient few-shot learning technique that involves: + +1. Fine-tuning a [Sentence Transformer](https://www.sbert.net) with contrastive learning. +2. Training a classification head with features from the fine-tuned Sentence Transformer. + +This model was trained within the context of a larger system for ABSA, which looks like so: + +1. Use a spaCy model to select possible aspect span candidates. +2. {{ "**" if is_aspect else "" }}Use {{ "this" if is_aspect else "a" }} SetFit model to filter these possible aspect span candidates.{{ "**" if is_aspect else "" }} +3. {{ "**" if not is_aspect else "" }}Use {{ "this" if not is_aspect else "a" }} SetFit model to classify the filtered aspect span candidates.{{ "**" if not is_aspect else "" }} + +## Usage + +To use this model for inference, first install the SetFit library: + +```bash +pip install setfit +``` + +You can then run inference as follows: + +```python +from setfit import AbsaModel + +# Download from Hub and run inference +model = AbsaModel.from_pretrained( + "{{ aspect_model }}", + "{{ polarity_model }}", +) +# Run inference +preds = model([ + "The best pizza outside of Italy and really tasty.", + "The food here is great but the service is terrible", +]) +``` + +## BibTeX entry and citation info + +```bibtex +@article{https://doi.org/10.48550/arxiv.2209.11055, + doi = {10.48550/ARXIV.2209.11055}, + url = {https://arxiv.org/abs/2209.11055}, + author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}, + keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Efficient Few-Shot Learning Without Prompts}, + publisher = {arXiv}, + year = {2022}, + copyright = {Creative Commons Attribution 4.0 International} +} +``` \ No newline at end of file diff --git a/src/setfit/span/modeling.py b/src/setfit/span/modeling.py new file mode 100644 index 00000000..f25a72c1 --- /dev/null +++ b/src/setfit/span/modeling.py @@ -0,0 +1,292 @@ +import json +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import requests +import torch +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import SoftTemporaryDirectory, validate_hf_hub_args +from jinja2 import Environment, FileSystemLoader + +from .. import logging +from ..modeling import SetFitModel +from .aspect_extractor import AspectExtractor + + +if TYPE_CHECKING: + from spacy.tokens import Doc + +logger = logging.get_logger(__name__) + +CONFIG_NAME = "config_span_setfit.json" + + +@dataclass +class SpanSetFitModel(SetFitModel): + span_context: int = 0 + + def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]: + for doc, aspects in zip(docs, aspects_list): + for aspect_slice in aspects: + aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context] + # TODO: Investigate performance difference of different formats + yield aspect.text + ":" + doc.text + + def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]: + inputs_list = list(self.prepend_aspects(docs, aspects_list)) + preds = self.predict(inputs_list, as_numpy=True) + iter_preds = iter(preds) + return [[next(iter_preds) for _ in aspects] for aspects in aspects_list] + + @classmethod + @validate_hf_hub_args + def _from_pretrained( + cls, + model_id: str, + span_context: Optional[int] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: Optional[bool] = None, + proxies: Optional[Dict] = None, + resume_download: Optional[bool] = None, + local_files_only: Optional[bool] = None, + token: Optional[Union[bool, str]] = None, + **model_kwargs, + ) -> "SpanSetFitModel": + config_file: Optional[str] = None + if os.path.isdir(model_id): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except requests.exceptions.RequestException: + pass + + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + model_kwargs.update(config) + + if span_context is not None: + model_kwargs["span_context"] = span_context + + return super(SpanSetFitModel, cls)._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + **model_kwargs, + ) + + def _save_pretrained(self, save_directory: Union[Path, str]) -> None: + path = os.path.join(save_directory, CONFIG_NAME) + with open(path, "w") as f: + json.dump({"span_context": self.span_context}, f, indent=2) + + super()._save_pretrained(save_directory) + + def create_model_card(self, path: str, model_name: Optional[str] = None) -> None: + """Creates and saves a model card for a SetFit model. + + Args: + path (str): The path to save the model card to. + model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`. + """ + if not os.path.exists(path): + os.makedirs(path) + + # If the model_path is a folder that exists locally, i.e. when create_model_card is called + # via push_to_hub, and the path is in a temporary folder, then we only take the last two + # directories + model_path = Path(model_name) + if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents: + model_name = "/".join(model_path.parts[-2:]) + + environment = Environment(loader=FileSystemLoader(Path(__file__).parent)) + template = environment.get_template("model_card_template.md") + is_aspect = isinstance(self, AspectModel) + aspect_model = "setfit-absa-aspect" + polarity_model = "setfit-absa-polarity" + if model_name is not None: + if is_aspect: + aspect_model = model_name + if model_name.endswith("-aspect"): + polarity_model = model_name[: -len("-aspect")] + "-polarity" + else: + polarity_model = model_name + if model_name.endswith("-polarity"): + aspect_model = model_name[: -len("-polarity")] + "-aspect" + + model_card_content = template.render( + model_name=model_name, is_aspect=is_aspect, aspect_model=aspect_model, polarity_model=polarity_model + ) + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + +class AspectModel(SpanSetFitModel): + # TODO: Assumes binary SetFitModel with 0 == no aspect, 1 == aspect + def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]: + sentence_preds = super().__call__(docs, aspects_list) + return [ + [aspect for aspect, pred in zip(aspects, preds) if pred == 1] + for aspects, preds in zip(aspects_list, sentence_preds) + ] + + +@dataclass +class PolarityModel(SpanSetFitModel): + span_context: int = 3 + + +@dataclass +class AbsaModel: + aspect_extractor: AspectExtractor + aspect_model: AspectModel + polarity_model: PolarityModel + + def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: + is_str = isinstance(inputs, str) + inputs_list = [inputs] if is_str else inputs + docs, aspects_list = self.aspect_extractor(inputs_list) + if sum(aspects_list, []) == []: + return aspects_list + + aspects_list = self.aspect_model(docs, aspects_list) + if sum(aspects_list, []) == []: + return aspects_list + + polarity_list = self.polarity_model(docs, aspects_list) + outputs = [] + for docs, aspects, polarities in zip(docs, aspects_list, polarity_list): + outputs.append( + [ + {"span": docs[aspect_slice].text, "polarity": polarity} + for aspect_slice, polarity in zip(aspects, polarities) + ] + ) + return outputs if not is_str else outputs[0] + + @property + def device(self) -> torch.device: + return self.aspect_model.device + + def to(self, device: Union[str, torch.device]) -> "AbsaModel": + self.aspect_model.to(device) + self.polarity_model.to(device) + + def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]: + return self.predict(inputs) + + def save_pretrained( + self, + save_directory: Union[str, Path], + polarity_save_directory: Optional[Union[str, Path]] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + if polarity_save_directory is None: + base_save_directory = Path(save_directory) + save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect") + polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity") + self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: str, + polarity_model_id: Optional[str] = None, + spacy_model: Optional[str] = "en_core_web_lg", + span_contexts: Tuple[Optional[int], Optional[int]] = (None, None), + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict] = None, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[str] = None, + local_files_only: bool = False, + use_differentiable_head: bool = False, + normalize_embeddings: bool = False, + **model_kwargs, + ) -> "AbsaModel": + revision = None + if len(model_id.split("@")) == 2: + model_id, revision = model_id.split("@") + aspect_model = AspectModel.from_pretrained( + model_id, + span_context=span_contexts[0], + revision=revision, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + use_differentiable_head=use_differentiable_head, + normalize_embeddings=normalize_embeddings, + **model_kwargs, + ) + if polarity_model_id: + model_id = polarity_model_id + revision = None + if len(model_id.split("@")) == 2: + model_id, revision = model_id.split("@") + polarity_model = PolarityModel.from_pretrained( + model_id, + span_context=span_contexts[1], + revision=revision, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + use_differentiable_head=use_differentiable_head, + normalize_embeddings=normalize_embeddings, + **model_kwargs, + ) + + aspect_extractor = AspectExtractor(spacy_model=spacy_model) + + return cls(aspect_extractor, aspect_model, polarity_model) + + def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: + if "/" not in repo_id: + raise ValueError( + '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".' + ) + if polarity_repo_id is not None and "/" not in polarity_repo_id: + raise ValueError( + '`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".' + ) + commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model") + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp_dir: + save_directory = Path(tmp_dir) / repo_id + polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id + self.save_pretrained( + save_directory=save_directory, + polarity_save_directory=polarity_save_directory, + push_to_hub=True, + commit_message=commit_message, + **kwargs, + ) diff --git a/src/setfit/span/trainer.py b/src/setfit/span/trainer.py new file mode 100644 index 00000000..1d362616 --- /dev/null +++ b/src/setfit/span/trainer.py @@ -0,0 +1,324 @@ +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from datasets import Dataset +from transformers.trainer_callback import TrainerCallback + +from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel +from setfit.training_args import TrainingArguments + +from .. import logging +from ..trainer import ColumnMappingMixin, Trainer + + +if TYPE_CHECKING: + import optuna + +logger = logging.get_logger(__name__) + + +class AbsaTrainer(ColumnMappingMixin): + """Trainer to train a SetFit ABSA model. + + Args: + model (`AbsaModel`): + The AbsaModel model to train. + args (`TrainingArguments`, *optional*): + The training arguments to use. If `polarity_args` is not defined, then `args` is used for both + the aspect and the polarity model. + polarity_args (`TrainingArguments`, *optional*): + The training arguments to use for the polarity model. If not defined, `args` is used for both + the aspect and the polarity model. + train_dataset (`Dataset`): + The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns. + eval_dataset (`Dataset`, *optional*): + The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns. + metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): + The metric to use for evaluation. If a string is provided, we treat it as the metric + name and load it with default settings. + If a callable is provided, it must take two arguments (`y_pred`, `y_test`). + metric_kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". + For example useful for providing an averaging strategy for computing f1 in a multi-label setting. + callbacks: (`List[~transformers.TrainerCallback]`, *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). + If you want to remove one of the default callbacks used, use the `Trainer.remove_callback()` method. + column_mapping (`Dict[str, str]`, *optional*): + A mapping from the column names in the dataset to the column names expected by the model. + The expected format is a dictionary with the following format: + `{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`. + """ + + _REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"} + + def __init__( + self, + model: AbsaModel, + args: Optional[TrainingArguments] = None, + polarity_args: Optional[TrainingArguments] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", + metric_kwargs: Optional[Dict[str, Any]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + self.model = model + self.aspect_extractor = model.aspect_extractor + + if train_dataset is not None and column_mapping: + train_dataset = self._apply_column_mapping(train_dataset, column_mapping) + aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset( + model.aspect_model, model.polarity_model, train_dataset + ) + if eval_dataset is not None and column_mapping: + eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping) + aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( + model.aspect_model, model.polarity_model, eval_dataset + ) + + self.aspect_trainer = Trainer( + model.aspect_model, + args=args, + train_dataset=aspect_train_dataset, + eval_dataset=aspect_eval_dataset, + metric=metric, + metric_kwargs=metric_kwargs, + callbacks=callbacks, + ) + self.aspect_trainer._set_logs_mapper( + { + "eval_embedding_loss": "eval_aspect_embedding_loss", + "embedding_loss": "aspect_embedding_loss", + } + ) + self.polarity_trainer = Trainer( + model.polarity_model, + args=polarity_args or args, + train_dataset=polarity_train_dataset, + eval_dataset=polarity_eval_dataset, + metric=metric, + metric_kwargs=metric_kwargs, + callbacks=callbacks, + ) + self.polarity_trainer._set_logs_mapper( + { + "eval_embedding_loss": "eval_polarity_embedding_loss", + "embedding_loss": "polarity_embedding_loss", + } + ) + + def preprocess_dataset( + self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset + ) -> Dataset: + if dataset is None: + return dataset, dataset + + # Group by "text" + grouped_data = defaultdict(list) + for sample in dataset: + text = sample.pop("text") + grouped_data[text].append(sample) + + def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]: + find_from = 0 + for _ in range(ordinal + 1): + start_idx = text.index(target, find_from) + find_from = start_idx + 1 + return start_idx, start_idx + len(target) + + docs, aspects_list = self.aspect_extractor(grouped_data.keys()) + intersected_aspect_list = [] + polarity_labels = [] + aspect_labels = [] + for doc, aspects, text in zip(docs, aspects_list, grouped_data): + gold_aspects = [] + gold_polarity_labels = [] + for annotation in grouped_data[text]: + try: + start, end = index_ordinal(text, annotation["span"], annotation["ordinal"]) + except ValueError: + logger.info( + f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. " + "Skipping this sample." + ) + continue + + gold_aspect_span = doc.char_span(start, end) + if gold_aspect_span is None: + continue + gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end)) + gold_polarity_labels.append(annotation["label"]) + + # The Aspect model uses all predicted aspects, with labels depending on whether + # the predicted aspects are indeed true/gold aspects. + aspect_labels.extend([aspect in gold_aspects for aspect in aspects]) + + # The Polarity model uses the intersection of pred and gold aspects, with labels for the gold label. + intersected_aspects = [] + for gold_aspect, gold_label in zip(gold_aspects, gold_polarity_labels): + if gold_aspect in aspects: + intersected_aspects.append(gold_aspect) + polarity_labels.append(gold_label) + intersected_aspect_list.append(intersected_aspects) + + aspect_texts = list(aspect_model.prepend_aspects(docs, aspects_list)) + polarity_texts = list(polarity_model.prepend_aspects(docs, intersected_aspect_list)) + return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict( + {"text": polarity_texts, "label": polarity_labels} + ) + + def train( + self, + args: Optional[TrainingArguments] = None, + polarity_args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Main training entry point. + + Args: + args (`TrainingArguments`, *optional*): + Temporarily change the aspect training arguments for this training call. + polarity_args (`TrainingArguments`, *optional*): + Temporarily change the polarity training arguments for this training call. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + """ + self.train_aspect(args=args, trial=trial, **kwargs) + self.train_polarity(args=polarity_args, trial=trial, **kwargs) + + def train_aspect( + self, + args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Train the aspect model only. + + Args: + args (`TrainingArguments`, *optional*): + Temporarily change the aspect training arguments for this training call. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + """ + self.aspect_trainer.train(args=args, trial=trial, **kwargs) + + def train_polarity( + self, + args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Train the polarity model only. + + Args: + args (`TrainingArguments`, *optional*): + Temporarily change the aspect training arguments for this training call. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + """ + self.polarity_trainer.train(args=args, trial=trial, **kwargs) + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.aspect_trainer.add_callback(callback) + self.polarity_trainer.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`Tuple[~transformer.TrainerCallback]`]: The callbacks removed from the aspect and polarity trainers, if found. + """ + return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.aspect_trainer.remove_callback(callback) + self.polarity_trainer.remove_callback(callback) + + def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: + """Upload model checkpoint to the Hub using `huggingface_hub`. + + See the full list of parameters for your `huggingface_hub` version in the\ + [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub). + + Args: + repo_id (`str`): + The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`. + repo_id (`str`): + The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*, defaults to `False`): + Whether the repository created should be private. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. + If not set, will use the token set when logging in with + `transformers-cli login` (stored in `~/.huggingface`). + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to + the default branch as specified in your repository, which + defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. + Defaults to `False`. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + """ + return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs) + + def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]: + """ + Computes the metrics for a given classifier. + + Args: + dataset (`Dataset`, *optional*): + The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via + the `eval_dataset` argument at `Trainer` initialization. + + Returns: + `Dict[str, Dict[str, float]]`: The evaluation metrics. + """ + aspect_eval_dataset = polarity_eval_dataset = None + if dataset: + aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( + self.model.aspect_model, self.model.polarity_model, dataset + ) + return { + "aspect": self.aspect_trainer.evaluate(aspect_eval_dataset), + "polarity": self.polarity_trainer.evaluate(polarity_eval_dataset), + } diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 6304ce5b..baecd053 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -1,23 +1,56 @@ import math -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +import shutil +import time +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import evaluate -import numpy as np import torch from datasets import Dataset, DatasetDict -from sentence_transformers import InputExample, losses +from sentence_transformers import InputExample, SentenceTransformer, losses from sentence_transformers.datasets import SentenceLabelDataset from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction +from sentence_transformers.util import batch_to_device +from sklearn.preprocessing import LabelEncoder +from torch import nn +from torch.cuda.amp import autocast from torch.utils.data import DataLoader -from tqdm.auto import trange -from transformers.trainer_utils import HPSearchBackend, default_compute_objective, number_of_arguments, set_seed +from tqdm.autonotebook import tqdm +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.trainer_utils import ( + HPSearchBackend, + default_compute_objective, + number_of_arguments, + set_seed, + speed_metrics, +) +from transformers.utils.import_utils import is_in_notebook from . import logging from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna -from .modeling import SupConLoss, sentence_pairs_generation, sentence_pairs_generation_multilabel +from .losses import SupConLoss +from .sampler import ContrastiveDataset +from .training_args import TrainingArguments from .utils import BestRun, default_hp_space_optuna +# Google Colab runs on Python 3.7, so we need this to be compatible +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + + if TYPE_CHECKING: import optuna @@ -27,123 +60,24 @@ logger = logging.get_logger(__name__) -class SetFitTrainer: - """Trainer to train a SetFit model. +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback - Args: - model (`SetFitModel`, *optional*): - The model to train. If not provided, a `model_init` must be passed. - train_dataset (`Dataset`): - The training dataset. - eval_dataset (`Dataset`, *optional*): - The evaluation dataset. - model_init (`Callable[[], SetFitModel]`, *optional*): - A function that instantiates the model to be used. If provided, each call to [`~SetFitTrainer.train`] will start - from a new instance of the model as given by this function when a `trial` is passed. - metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): - The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. - If a callable is provided, it must take two arguments (`y_pred`, `y_test`). - metric_kwargs (`Dict[str, Any]`, *optional*): - Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". - For example useful for providing an averaging strategy for computing f1 in a multi-label setting. - loss_class (`nn.Module`, *optional*, defaults to `CosineSimilarityLoss`): - The loss function to use for contrastive training. - num_iterations (`int`, *optional*, defaults to `20`): - The number of iterations to generate sentence pairs for. - This argument is ignored if triplet loss is used. - It is only used in conjunction with `CosineSimilarityLoss`. - num_epochs (`int`, *optional*, defaults to `1`): - The number of epochs to train the Sentence Transformer body for. - learning_rate (`float`, *optional*, defaults to `2e-5`): - The learning rate to use for contrastive training. - batch_size (`int`, *optional*, defaults to `16`): - The batch size to use for contrastive training. - seed (`int`, *optional*, defaults to 42): - Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the - [`~SetTrainer.model_init`] function to instantiate the model if it has some randomly initialized parameters. - column_mapping (`Dict[str, str]`, *optional*): - A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "label_column_name: "label"}. - use_amp (`bool`, *optional*, defaults to `False`): - Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0 - warmup_proportion (`float`, *optional*, defaults to `0.1`): - Proportion of the warmup in the total training steps. - Must be greater than or equal to 0.0 and less than or equal to 1.0. - distance_metric (`Callable`, defaults to `BatchHardTripletLossDistanceFunction.cosine_distance`): - Function that returns a distance between two embeddings. - It is set for the triplet loss and - is ignored for `CosineSimilarityLoss` and `SupConLoss`. - margin (`float`, defaults to `0.25`): Margin for the triplet loss. - Negative samples should be at least margin further apart from the anchor than the positive. - This is ignored for `CosineSimilarityLoss`, `BatchHardSoftMarginTripletLoss` and `SupConLoss`. - samples_per_label (`int`, defaults to `2`): Number of consecutive, random and unique samples drawn per label. - This is only relevant for triplet loss and ignored for `CosineSimilarityLoss`. - Batch size should be a multiple of samples_per_label. - """ +if is_in_notebook(): + from transformers.utils.notebook import NotebookProgressCallback - def __init__( - self, - model: Optional["SetFitModel"] = None, - train_dataset: Optional["Dataset"] = None, - eval_dataset: Optional["Dataset"] = None, - model_init: Optional[Callable[[], "SetFitModel"]] = None, - metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", - metric_kwargs: Optional[Dict[str, Any]] = None, - loss_class=losses.CosineSimilarityLoss, - num_iterations: int = 20, - num_epochs: int = 1, - learning_rate: float = 2e-5, - batch_size: int = 16, - seed: int = 42, - column_mapping: Optional[Dict[str, str]] = None, - use_amp: bool = False, - warmup_proportion: float = 0.1, - distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance, - margin: float = 0.25, - samples_per_label: int = 2, - ) -> None: - if (warmup_proportion < 0.0) or (warmup_proportion > 1.0): - raise ValueError( - f"warmup_proportion must be greater than or equal to 0.0 and less than or equal to 1.0! But it was: {warmup_proportion}" - ) + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.model_init = model_init - self.metric = metric - self.metric_kwargs = metric_kwargs - self.loss_class = loss_class - self.num_iterations = num_iterations - self.num_epochs = num_epochs - self.learning_rate = learning_rate - self.batch_size = batch_size - self.seed = seed - self.column_mapping = column_mapping - self.use_amp = use_amp - self.warmup_proportion = warmup_proportion - self.distance_metric = distance_metric - self.margin = margin - self.samples_per_label = samples_per_label - if model is None: - if model_init is not None: - model = self.call_model_init() - else: - raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument") - else: - if model_init is not None: - raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument, but not both") - - self.model = model - self.hp_search_backend = None - self._freeze = True # If True, will train the body only; otherwise, train the body and head +class ColumnMappingMixin: + _REQUIRED_COLUMNS = {"text", "label"} def _validate_column_mapping(self, dataset: "Dataset") -> None: """ Validates the provided column mapping against the dataset. """ - required_columns = {"text", "label"} column_names = set(dataset.column_names) - if self.column_mapping is None and not required_columns.issubset(column_names): + if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names): # Issue #226: load_dataset will automatically assign points to "train" if no split is specified if column_names == {"train"} and isinstance(dataset, DatasetDict): raise ValueError( @@ -157,12 +91,12 @@ def _validate_column_mapping(self, dataset: "Dataset") -> None: ) else: raise ValueError( - f"SetFit expected the dataset to have the columns {sorted(required_columns)}, " + f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, " f"but only the columns {sorted(column_names)} were found. " - "Either make sure these columns are present, or specify which columns to use with column_mapping in SetFitTrainer." + "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer." ) if self.column_mapping is not None: - missing_columns = required_columns.difference(self.column_mapping.values()) + missing_columns = self._REQUIRED_COLUMNS.difference(self.column_mapping.values()) if missing_columns: raise ValueError( f"The following columns are missing from the column mapping: {missing_columns}. Please provide a mapping for all required columns." @@ -181,7 +115,11 @@ def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, st dataset = dataset.rename_columns( { **column_mapping, - **{col: f"feat_{col}" for col in dataset.column_names if col not in column_mapping}, + **{ + col: f"feat_{col}" + for col in dataset.column_names + if col not in column_mapping and col not in self._REQUIRED_COLUMNS + }, } ) dset_format = dataset.format @@ -193,6 +131,128 @@ def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, st ) return dataset + +class Trainer(ColumnMappingMixin): + """Trainer to train a SetFit model. + + Args: + model (`SetFitModel`, *optional*): + The model to train. If not provided, a `model_init` must be passed. + args (`TrainingArguments`, *optional*): + The training arguments to use. + train_dataset (`Dataset`): + The training dataset. + eval_dataset (`Dataset`, *optional*): + The evaluation dataset. + model_init (`Callable[[], SetFitModel]`, *optional*): + A function that instantiates the model to be used. If provided, each call to + [`~Trainer.train`] will start from a new instance of the model as given by this + function when a `trial` is passed. + metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): + The metric to use for evaluation. If a string is provided, we treat it as the metric + name and load it with default settings. + If a callable is provided, it must take two arguments (`y_pred`, `y_test`). + metric_kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". + For example useful for providing an averaging strategy for computing f1 in a multi-label setting. + callbacks: (`List[~transformers.TrainerCallback]`, *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). + If you want to remove one of the default callbacks used, use the `Trainer.remove_callback()` method. + column_mapping (`Dict[str, str]`, *optional*): + A mapping from the column names in the dataset to the column names expected by the model. + The expected format is a dictionary with the following format: + `{"text_column_name": "text", "label_column_name: "label"}`. + """ + + def __init__( + self, + model: Optional["SetFitModel"] = None, + args: Optional[TrainingArguments] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, + metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", + metric_kwargs: Optional[Dict[str, Any]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + if args is not None and not isinstance(args, TrainingArguments): + raise ValueError("`args` must be a `TrainingArguments` instance imported from `setfit`.") + self.args = args or TrainingArguments() + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.model_init = model_init + self.metric = metric + self.metric_kwargs = metric_kwargs + self.column_mapping = column_mapping + self.logs_mapper = {} + + # Seed must be set before instantiating the model when using model_init. + set_seed(12) + + if model is None: + if model_init is not None: + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument.") + else: + if model_init is not None: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument, but not both.") + + self.model = model + self.hp_search_backend = None + + # Setup the callbacks + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + # TODO: Observe optimizer and scheduler by wrapping SentenceTransformer._get_scheduler + self.callback_handler = CallbackHandler( + callbacks, self.model.model_body, self.model.model_body.tokenizer, None, None + ) + self.state = TrainerState() + self.control = TrainerControl() + self.add_callback(DEFAULT_PROGRESS_CALLBACK if self.args.show_progress_bar else PrinterCallback) + self.control = self.callback_handler.on_init_end(args, self.state, self.control) + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`~transformer.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = False) -> None: """Applies a dictionary of hyperparameters to both the trainer and the model @@ -200,18 +260,11 @@ def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = Fals params (`Dict[str, Any]`): The parameters, usually from `BestRun.hyperparameters` final_model (`bool`, *optional*, defaults to `False`): If `True`, replace the `model_init()` function with a fixed model based on the parameters. """ - for key, value in params.items(): - if hasattr(self, key): - old_attr = getattr(self, key, None) - # Casting value to the proper type - if old_attr is not None: - value = type(old_attr)(value) - setattr(self, key, value) - elif number_of_arguments(self.model_init) == 0: # we do not warn if model_init could be using it - logger.warning( - f"Trying to set {key!r} in the hyperparameter search but there is no corresponding field in " - "`SetFitTrainer`, and `model_init` does not take any arguments." - ) + + if self.args is not None: + self.args = self.args.update(params, ignore_extra=True) + else: + self.args = TrainingArguments.from_dict(params, ignore_extra=True) self.model = self.model_init(params) if final_model: @@ -248,172 +301,445 @@ def call_model_init(self, params: Optional[Dict[str, Any]] = None) -> "SetFitMod return model - def freeze(self) -> None: - """ - Freeze SetFitModel's differentiable head. - Note: call this function only when using the differentiable head. - """ - if not self.model.has_differentiable_head: - raise ValueError("Please use the differentiable head in `SetFitModel` when calling this function.") + def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None: + """Freeze the model body and/or the head, preventing further training on that component until unfrozen. - self._freeze = True # Currently use self._freeze as a switch - self.model.freeze("head") + This method is deprecated, use `SetFitModel.freeze` instead. - def unfreeze(self, keep_body_frozen: bool = False) -> None: + Args: + component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to freeze that component. + If no component is provided, freeze both. Defaults to None. """ - Unfreeze SetFitModel's differentiable head. - Note: call this function only when using the differentiable head. + warnings.warn( + f"`{self.__class__.__name__}.freeze` is deprecated and will be removed in v2.0.0 of SetFit. " + "Please use `SetFitModel.freeze` directly instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.model.freeze(component) + + def unfreeze( + self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None + ) -> None: + """Unfreeze the model body and/or the head, allowing further training on that component. + + This method is deprecated, use `SetFitModel.unfreeze` instead. Args: - keep_body_frozen (`bool`, *optional*, defaults to `False`): - Whether to freeze the body when unfreeze the head. + component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to unfreeze that component. + If no component is provided, unfreeze both. Defaults to None. + keep_body_frozen (`bool`, *optional*): Deprecated argument, use `component` instead. """ - if not self.model.has_differentiable_head: - raise ValueError("Please use the differentiable head in `SetFitModel` when calling this function.") - - self._freeze = False # Currently use self._freeze as a switch - self.model.unfreeze("head") - if keep_body_frozen: - self.model.freeze("body") - else: # ensure to unfreeze the body - self.model.unfreeze("body") + warnings.warn( + f"`{self.__class__.__name__}.unfreeze` is deprecated and will be removed in v2.0.0 of SetFit. " + "Please use `SetFitModel.unfreeze` directly instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.model.unfreeze(component, keep_body_frozen=keep_body_frozen) def train( self, - num_epochs: Optional[int] = None, - batch_size: Optional[int] = None, - learning_rate: Optional[float] = None, - body_learning_rate: Optional[float] = None, - l2_weight: Optional[float] = None, - max_length: Optional[int] = None, + args: Optional[TrainingArguments] = None, trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, - show_progress_bar: bool = True, + **kwargs, ) -> None: """ Main training entry point. Args: - num_epochs (`int`, *optional*): - Temporary change the number of epochs to train the Sentence Transformer body/head for. - If ignore, will use the value given in initialization. - batch_size (`int`, *optional*): - Temporary change the batch size to use for contrastive training or logistic regression. - If ignore, will use the value given in initialization. - learning_rate (`float`, *optional*): - Temporary change the learning rate to use for contrastive training or SetFitModel's head in logistic regression. - If ignore, will use the value given in initialization. - body_learning_rate (`float`, *optional*): - Temporary change the learning rate to use for SetFitModel's body in logistic regression only. - If ignore, will be the same as `learning_rate`. - l2_weight (`float`, *optional*): - Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression. - max_length (int, *optional*, defaults to `None`): - The maximum number of tokens for one data sample. Currently only for training the differentiable head. - If `None`, will use the maximum number of tokens the model body can accept. - If `max_length` is greater than the maximum number of acceptable tokens the model body can accept, it will be set to the maximum number of acceptable tokens. + args (`TrainingArguments`, *optional*): + Temporarily change the training arguments for this training call. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. - show_progress_bar (`bool`, *optional*, defaults to `True`): - Whether to show a bar that indicates training progress. """ - set_seed(self.seed) # Seed must be set before instantiating the model when using model_init. + if len(kwargs): + warnings.warn( + f"`{self.__class__.__name__}.train` does not accept keyword arguments anymore. " + f"Please provide training arguments via a `TrainingArguments` instance to the `{self.__class__.__name__}` " + f"initialisation or the `{self.__class__.__name__}.train` method.", + DeprecationWarning, + stacklevel=2, + ) + + args = args or self.args or TrainingArguments() + + # Seed must be set before instantiating the model when using model_init. + set_seed(args.seed) if trial: # Trial and model initialization self._hp_search_setup(trial) # sets trainer parameters and initializes model if self.train_dataset is None: - raise ValueError("Training requires a `train_dataset` given to the `SetFitTrainer` initialization.") + raise ValueError( + f"Training requires a `train_dataset` given to the `{self.__class__.__name__}` initialization." + ) - self._validate_column_mapping(self.train_dataset) - train_dataset = self.train_dataset - if self.column_mapping is not None: - logger.info("Applying column mapping to training dataset") - train_dataset = self._apply_column_mapping(self.train_dataset, self.column_mapping) - - x_train = train_dataset["text"] - y_train = train_dataset["label"] - if self.loss_class is None: - logger.warning("No `loss_class` detected! Using `CosineSimilarityLoss` as the default.") - self.loss_class = losses.CosineSimilarityLoss - - num_epochs = num_epochs or self.num_epochs - batch_size = batch_size or self.batch_size - learning_rate = learning_rate or self.learning_rate - - if not self.model.has_differentiable_head or self._freeze: - # sentence-transformers adaptation - if self.loss_class in [ - losses.BatchAllTripletLoss, - losses.BatchHardTripletLoss, - losses.BatchSemiHardTripletLoss, - losses.BatchHardSoftMarginTripletLoss, - SupConLoss, - ]: - train_examples = [InputExample(texts=[text], label=label) for text, label in zip(x_train, y_train)] - train_data_sampler = SentenceLabelDataset(train_examples, samples_per_label=self.samples_per_label) - - batch_size = min(batch_size, len(train_data_sampler)) - train_dataloader = DataLoader(train_data_sampler, batch_size=batch_size, drop_last=True) - - if self.loss_class is losses.BatchHardSoftMarginTripletLoss: - train_loss = self.loss_class( - model=self.model.model_body, - distance_metric=self.distance_metric, - ) - elif self.loss_class is SupConLoss: - train_loss = self.loss_class(model=self.model.model_body) + parameters = [] + for dataset, dataset_name in [(self.train_dataset, "training"), (self.eval_dataset, "evaluation")]: + if dataset is None: + continue + + self._validate_column_mapping(dataset) + if self.column_mapping is not None: + logger.info(f"Applying column mapping to {dataset_name} dataset") + dataset = self._apply_column_mapping(dataset, self.column_mapping) + + parameters.extend(self.dataset_to_parameters(dataset)) + + self.train_embeddings(*parameters, args=args) + training_parameters = parameters[: len(parameters) // 2] if self.eval_dataset else parameters + self.train_classifier(*training_parameters, args=args) + + def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: + return [dataset["text"], dataset["label"]] + + def train_embeddings( + self, + x_train: List[str], + y_train: Optional[Union[List[int], List[List[int]]]] = None, + x_eval: Optional[List[str]] = None, + y_eval: Optional[Union[List[int], List[List[int]]]] = None, + args: Optional[TrainingArguments] = None, + ) -> None: + """ + Method to perform the embedding phase: finetuning the `SentenceTransformer` body. + + Args: + x_train (`List[str]`): A list of training sentences. + y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences. + args (`TrainingArguments`, *optional*): + Temporarily change the training arguments for this training call. + """ + args = args or self.args or TrainingArguments() + # Since transformers v4.32.0, the log/eval/save steps should be saved on the state instead + self.state.logging_steps = args.logging_steps + self.state.eval_steps = args.eval_steps + self.state.save_steps = args.save_steps + + train_dataloader, loss_func, batch_size = self.get_dataloader(x_train, y_train, args=args) + if x_eval is not None: + eval_dataloader, _, _ = self.get_dataloader(x_eval, y_eval, args=args) + else: + eval_dataloader = None + + total_train_steps = len(train_dataloader) * args.embedding_num_epochs + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataloader)}") + logger.info(f" Num epochs = {args.embedding_num_epochs}") + logger.info(f" Total optimization steps = {total_train_steps}") + logger.info(f" Total train batch size = {batch_size}") + + warmup_steps = math.ceil(total_train_steps * args.warmup_proportion) + self._train_sentence_transformer( + self.model.model_body, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + args=args, + loss_func=loss_func, + warmup_steps=warmup_steps, + ) + + def get_dataloader( + self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments + ) -> Tuple[DataLoader, nn.Module, int]: + # sentence-transformers adaptation + input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)] + + if args.loss in [ + losses.BatchAllTripletLoss, + losses.BatchHardTripletLoss, + losses.BatchSemiHardTripletLoss, + losses.BatchHardSoftMarginTripletLoss, + SupConLoss, + ]: + data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.samples_per_label) + batch_size = min(args.embedding_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True) + + if args.loss is losses.BatchHardSoftMarginTripletLoss: + loss = args.loss( + model=self.model.model_body, + distance_metric=args.distance_metric, + ) + elif args.loss is SupConLoss: + loss = args.loss(model=self.model.model_body) + else: + loss = args.loss( + model=self.model.model_body, + distance_metric=args.distance_metric, + margin=args.margin, + ) + else: + data_sampler = ContrastiveDataset( + input_data, self.model.multi_target_strategy, args.num_iterations, args.sampling_strategy + ) + # shuffle_sampler = True can be dropped in for further 'randomising' + shuffle_sampler = True if args.sampling_strategy == "unique" else False + batch_size = min(args.embedding_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) + loss = args.loss(self.model.model_body) + + return dataloader, loss, batch_size + + def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + logs = {self.logs_mapper.get(key, key): value for key, value in logs.items()} + if self.state.epoch is not None: + logs["epoch"] = round(self.state.epoch, 2) + + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + return self.callback_handler.on_log(args, self.state, self.control, logs) + + def _set_logs_mapper(self, logs_mapper: Dict[str, str]) -> None: + """Set the logging mapper. + + Args: + logs_mapper (str): The logging mapper, e.g. {"eval_embedding_loss": "eval_aspect_embedding_loss"}. + """ + self.logs_mapper = logs_mapper + + def _train_sentence_transformer( + self, + model_body: SentenceTransformer, + train_dataloader: DataLoader, + eval_dataloader: Optional[DataLoader], + args: TrainingArguments, + loss_func: nn.Module, + warmup_steps: int = 10000, + ) -> None: + """ + Train the model with the given training objective + Each training objective is sampled in turn for one batch. + We sample only as many batches from each objective as there are in the smallest one + to make sure of equal training with each dataset. + """ + # TODO: args.gradient_accumulation_steps + # TODO: fp16/bf16, etc. + # TODO: Safetensors + + # Hardcoded training arguments + max_grad_norm = 1 + weight_decay = 0.01 + + self.state.epoch = 0 + start_time = time.time() + if args.max_steps > 0: + self.state.max_steps = args.max_steps + else: + self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + + model_body.to(model_body._target_device) + loss_func.to(model_body._target_device) + + # Use smart batching + train_dataloader.collate_fn = model_body.smart_batching_collate + if eval_dataloader: + eval_dataloader.collate_fn = model_body.smart_batching_collate + + steps_per_epoch = len(train_dataloader) + num_train_steps = int(steps_per_epoch * args.embedding_num_epochs) + + # Prepare optimizers + param_optimizer = list(loss_func.named_parameters()) + + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + "weight_decay": weight_decay, + }, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **{"lr": args.body_embedding_learning_rate}) + scheduler_obj = model_body._get_scheduler( + optimizer, scheduler="WarmupLinear", warmup_steps=warmup_steps, t_total=num_train_steps + ) + + data_iterator = iter(train_dataloader) + skip_scheduler = False + for epoch in range(args.embedding_num_epochs): + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + loss_func.zero_grad() + loss_func.train() + + for step in range(steps_per_epoch): + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + try: + data = next(data_iterator) + except StopIteration: + data_iterator = iter(train_dataloader) + data = next(data_iterator) + + features, labels = data + labels = labels.to(model_body._target_device) + features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features)) + + if args.use_amp: + with autocast(): + loss_value = loss_func(features, labels) + + scale_before_step = scaler.get_scale() + scaler.scale(loss_value).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm) + scaler.step(optimizer) + scaler.update() + + skip_scheduler = scaler.get_scale() != scale_before_step else: - train_loss = self.loss_class( - model=self.model.model_body, - distance_metric=self.distance_metric, - margin=self.margin, + loss_value = loss_func(features, labels) + loss_value.backward() + torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm) + optimizer.step() + + optimizer.zero_grad() + + if not skip_scheduler: + scheduler_obj.step() + + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_per_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + if self.control.should_log: + learning_rate = scheduler_obj.get_last_lr()[0] + metrics = {"embedding_loss": round(loss_value.item(), 4), "learning_rate": learning_rate} + self.control = self.log(args, metrics) + + eval_loss = None + if self.control.should_evaluate and eval_dataloader is not None: + eval_loss = self._evaluate_with_loss(model_body, eval_dataloader, args, loss_func) + learning_rate = scheduler_obj.get_last_lr()[0] + metrics = {"eval_embedding_loss": round(eval_loss, 4), "learning_rate": learning_rate} + self.control = self.log(args, metrics) + + self.control = self.callback_handler.on_evaluate(args, self.state, self.control, metrics) + + loss_func.zero_grad() + loss_func.train() + + if self.control.should_save: + checkpoint_dir = self._checkpoint( + self.args.output_dir, args.save_total_limit, self.state.global_step ) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + if eval_loss is not None and ( + self.state.best_metric is None or eval_loss < self.state.best_metric + ): + self.state.best_metric = eval_loss + self.state.best_model_checkpoint = checkpoint_dir + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + + if self.control.should_training_stop: + break + + if self.args.load_best_model_at_end and self.state.best_model_checkpoint: + dir_name = Path(self.state.best_model_checkpoint).name + if dir_name.startswith("step_"): + logger.info(f"Loading best SentenceTransformer model from step {dir_name[5:]}.") + self.model.model_body = SentenceTransformer(self.state.best_model_checkpoint, device=model_body.device) + + # Ensure logging the speed metrics + num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + self.control.should_log = True + self.log(args, metrics) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + def _evaluate_with_loss( + self, + model_body: SentenceTransformer, + eval_dataloader: DataLoader, + args: TrainingArguments, + loss_func: nn.Module, + ) -> float: + model_body.eval() + losses = [] + for data in tqdm( + iter(eval_dataloader), total=len(eval_dataloader), leave=False, disable=not args.show_progress_bar + ): + features, labels = data + labels = labels.to(model_body._target_device) + features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features)) + + if args.use_amp: + with autocast(): + loss_value = loss_func(features, labels) + + losses.append(loss_value.item()) else: - train_examples = [] - - for _ in trange(self.num_iterations, desc="Generating Training Pairs", disable=not show_progress_bar): - if self.model.multi_target_strategy is not None: - train_examples = sentence_pairs_generation_multilabel( - np.array(x_train), np.array(y_train), train_examples - ) - else: - train_examples = sentence_pairs_generation( - np.array(x_train), np.array(y_train), train_examples - ) - - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) - train_loss = self.loss_class(self.model.model_body) - - total_train_steps = len(train_dataloader) * num_epochs - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_examples)}") - logger.info(f" Num epochs = {num_epochs}") - logger.info(f" Total optimization steps = {total_train_steps}") - logger.info(f" Total train batch size = {batch_size}") - - warmup_steps = math.ceil(total_train_steps * self.warmup_proportion) - self.model.model_body.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=num_epochs, - optimizer_params={"lr": learning_rate}, - warmup_steps=warmup_steps, - show_progress_bar=show_progress_bar, - use_amp=self.use_amp, - ) + losses.append(loss_func(features, labels).item()) + + model_body.train() + return sum(losses) / len(losses) + + def _checkpoint(self, checkpoint_path: str, checkpoint_save_total_limit: int, step: int) -> None: + # Delete old checkpoints + if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0: + old_checkpoints = [] + for subdir in Path(checkpoint_path).glob("step_*"): + if subdir.name[5:].isdigit() and ( + self.state.best_model_checkpoint is None or subdir != Path(self.state.best_model_checkpoint) + ): + old_checkpoints.append({"step": int(subdir.name[5:]), "path": str(subdir)}) + + if len(old_checkpoints) > checkpoint_save_total_limit - 1: + old_checkpoints = sorted(old_checkpoints, key=lambda x: x["step"]) + shutil.rmtree(old_checkpoints[0]["path"]) + + checkpoint_file_path = str(Path(checkpoint_path) / f"step_{step}") + self.model.save_pretrained(checkpoint_file_path) + return checkpoint_file_path + + def train_classifier( + self, x_train: List[str], y_train: Union[List[int], List[List[int]]], args: Optional[TrainingArguments] = None + ) -> None: + """ + Method to perform the classifier phase: fitting a classifier head. - if not self.model.has_differentiable_head or not self._freeze: - # Train the final classifier - self.model.fit( - x_train, - y_train, - num_epochs=num_epochs, - batch_size=batch_size, - learning_rate=learning_rate, - body_learning_rate=body_learning_rate, - l2_weight=l2_weight, - max_length=max_length, - show_progress_bar=True, - ) + Args: + x_train (`List[str]`): A list of training sentences. + y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences. + args (`TrainingArguments`, *optional*): + Temporarily change the training arguments for this training call. + """ + args = args or self.args or TrainingArguments() + + self.model.fit( + x_train, + y_train, + num_epochs=args.classifier_num_epochs, + batch_size=args.classifier_batch_size, + body_learning_rate=args.body_classifier_learning_rate, + head_learning_rate=args.head_learning_rate, + l2_weight=args.l2_weight, + max_length=args.max_length, + show_progress_bar=args.show_progress_bar, + end_to_end=args.end_to_end, + ) def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]: """ @@ -421,13 +747,16 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]: Args: dataset (`Dataset`, *optional*): - The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed in the eval_dataset argument at `SetFitTrainer` initialization. + The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via + the `eval_dataset` argument at `Trainer` initialization. Returns: `Dict[str, float]`: The evaluation metrics. """ eval_dataset = dataset or self.eval_dataset + if eval_dataset is None: + raise ValueError("No evaluation dataset provided to `Trainer.evaluate` nor the `Trainer` initialzation.") self._validate_column_mapping(eval_dataset) if self.column_mapping is not None: @@ -442,6 +771,13 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]: if isinstance(y_pred, torch.Tensor): y_pred = y_pred.cpu() + # Normalize string outputs + if y_test and isinstance(y_test[0], str): + encoder = LabelEncoder() + encoder.fit(list(y_test) + list(y_pred)) + y_test = encoder.transform(y_test) + y_pred = encoder.transform(y_pred) + if isinstance(self.metric, str): metric_config = "multilabel" if self.model.multi_target_strategy is not None else None metric_fn = evaluate.load(self.metric, config_name=metric_config) @@ -472,7 +808,7 @@ def hyperparameter_search( - To use this method, you need to have provided a `model_init` when initializing your [`SetFitTrainer`]: we need to + To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to reinitialize the model at each new run. @@ -507,7 +843,7 @@ def hyperparameter_search( if backend is None: backend = default_hp_search_backend() if backend is None: - raise RuntimeError("optuna should be installed. " "To install optuna run `pip install optuna`. ") + raise RuntimeError("optuna should be installed. To install optuna run `pip install optuna`.") backend = HPSearchBackend(backend) if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") @@ -539,7 +875,7 @@ def push_to_hub(self, repo_id: str, **kwargs) -> str: Args: repo_id (`str`): - The full repository ID to push to, e.g. `"tomaarsen/setfit_sst2"`. + The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. config (`dict`, *optional*): Configuration object to be saved alongside the model weights. commit_message (`str`, *optional*): @@ -569,7 +905,66 @@ def push_to_hub(self, repo_id: str, **kwargs) -> str: """ if "/" not in repo_id: raise ValueError( - '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit_sst2".' + '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-sst2".' ) commit_message = kwargs.pop("commit_message", "Add SetFit model") return self.model.push_to_hub(repo_id, commit_message=commit_message, **kwargs) + + +class SetFitTrainer(Trainer): + """ + `SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. + Please use `Trainer` instead. + """ + + def __init__( + self, + model: Optional["SetFitModel"] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, + metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", + metric_kwargs: Optional[Dict[str, Any]] = None, + loss_class=losses.CosineSimilarityLoss, + num_iterations: int = 20, + num_epochs: int = 1, + learning_rate: float = 2e-5, + batch_size: int = 16, + seed: int = 42, + column_mapping: Optional[Dict[str, str]] = None, + use_amp: bool = False, + warmup_proportion: float = 0.1, + distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance, + margin: float = 0.25, + samples_per_label: int = 2, + ): + warnings.warn( + "`SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. " + "Please use `Trainer` instead.", + DeprecationWarning, + stacklevel=2, + ) + args = TrainingArguments( + num_iterations=num_iterations, + num_epochs=num_epochs, + body_learning_rate=learning_rate, + head_learning_rate=learning_rate, + batch_size=batch_size, + seed=seed, + use_amp=use_amp, + warmup_proportion=warmup_proportion, + distance_metric=distance_metric, + margin=margin, + samples_per_label=samples_per_label, + loss=loss_class, + ) + super().__init__( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model_init=model_init, + metric=metric, + metric_kwargs=metric_kwargs, + column_mapping=column_mapping, + ) diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index ca194066..da5ec4b8 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -1,245 +1,162 @@ -import math -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +import warnings +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union -import numpy as np import torch +from datasets import Dataset from sentence_transformers import InputExample, losses, util -from sentence_transformers.datasets import SentenceLabelDataset -from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction +from torch import nn from torch.utils.data import DataLoader -from transformers.trainer_utils import set_seed -from . import SetFitTrainer, logging -from .modeling import SupConLoss, sentence_pairs_generation_cos_sim +from . import logging +from .sampler import ContrastiveDistillationDataset +from .trainer import Trainer +from .training_args import TrainingArguments if TYPE_CHECKING: - import optuna - from datasets import Dataset - from .modeling import SetFitModel logging.set_verbosity_info() logger = logging.get_logger(__name__) -class DistillationSetFitTrainer(SetFitTrainer): +class DistillationTrainer(Trainer): """Trainer to compress a SetFit model with knowledge distillation. Args: teacher_model (`SetFitModel`): The teacher model to mimic. + student_model (`SetFitModel`, *optional*): + The model to train. If not provided, a `model_init` must be passed. + args (`TrainingArguments`, *optional*): + The training arguments to use. train_dataset (`Dataset`): The training dataset. - student_model (`SetFitModel`): - The student model to train. If not provided, a `model_init` must be passed. eval_dataset (`Dataset`, *optional*): The evaluation dataset. model_init (`Callable[[], SetFitModel]`, *optional*): - A function that instantiates the model to be used. If provided, each call to [`~DistillationSetFitTrainer.train`] will start - from a new instance of the model as given by this function when a `trial` is passed. + A function that instantiates the model to be used. If provided, each call to + [`~DistillationTrainer.train`] will start from a new instance of the model as given by this + function when a `trial` is passed. metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): - The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. + The metric to use for evaluation. If a string is provided, we treat it as the metric + name and load it with default settings. If a callable is provided, it must take two arguments (`y_pred`, `y_test`). - loss_class (`nn.Module`, *optional*, defaults to `CosineSimilarityLoss`): - The loss function to use for contrastive training. - num_iterations (`int`, *optional*, defaults to `20`): - The number of iterations to generate sentence pairs for. - num_epochs (`int`, *optional*, defaults to `1`): - The number of epochs to train the Sentence Transformer body for. - learning_rate (`float`, *optional*, defaults to `2e-5`): - The learning rate to use for contrastive training. - batch_size (`int`, *optional*, defaults to `16`): - The batch size to use for contrastive training. - seed (`int`, *optional*, defaults to 42): - Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the - [`~SetTrainer.model_init`] function to instantiate the model if it has some randomly initialized parameters. column_mapping (`Dict[str, str]`, *optional*): - A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "label_column_name: "label"}. - use_amp (`bool`, *optional*, defaults to `False`): - Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0 - warmup_proportion (`float`, *optional*, defaults to `0.1`): - Proportion of the warmup in the total training steps. - Must be greater than or equal to 0.0 and less than or equal to 1.0. + A mapping from the column names in the dataset to the column names expected by the model. + The expected format is a dictionary with the following format: + `{"text_column_name": "text", "label_column_name: "label"}`. """ + _REQUIRED_COLUMNS = {"text"} + def __init__( self, teacher_model: "SetFitModel", student_model: Optional["SetFitModel"] = None, + args: TrainingArguments = None, train_dataset: Optional["Dataset"] = None, eval_dataset: Optional["Dataset"] = None, model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", - loss_class: torch.nn.Module = losses.CosineSimilarityLoss, - num_iterations: int = 20, - num_epochs: int = 1, - learning_rate: float = 2e-5, - batch_size: int = 16, - seed: int = 42, column_mapping: Optional[Dict[str, str]] = None, - use_amp: bool = False, - warmup_proportion: float = 0.1, ) -> None: - super(DistillationSetFitTrainer, self).__init__( + super().__init__( model=student_model, + args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, model_init=model_init, metric=metric, - loss_class=loss_class, - num_iterations=num_iterations, - num_epochs=num_epochs, - learning_rate=learning_rate, - batch_size=batch_size, - seed=seed, column_mapping=column_mapping, - use_amp=use_amp, - warmup_proportion=warmup_proportion, ) self.teacher_model = teacher_model self.student_model = self.model - def train( - self, - num_epochs: Optional[int] = None, - batch_size: Optional[int] = None, - learning_rate: Optional[float] = None, - body_learning_rate: Optional[float] = None, - l2_weight: Optional[float] = None, - trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, - show_progress_bar: bool = True, - ): + def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: + return [dataset["text"]] + + def get_dataloader( + self, x: List[str], y: Optional[Union[List[int], List[List[int]]]], args: TrainingArguments + ) -> Tuple[DataLoader, nn.Module, int]: + x_embd_student = self.teacher_model.model_body.encode( + x, convert_to_tensor=self.teacher_model.has_differentiable_head + ) + cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student) + + input_data = [InputExample(texts=[text]) for text in x] + data_sampler = ContrastiveDistillationDataset( + input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy + ) + # shuffle_sampler = True can be dropped in for further 'randomising' + shuffle_sampler = True if args.sampling_strategy == "unique" else False + batch_size = min(args.embedding_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) + loss = args.loss(self.model.model_body) + return dataloader, loss, batch_size + + def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None: """ - Main training entry point. + Method to perform the classifier phase: fitting the student classifier head. Args: - num_epochs (`int`, *optional*): - Temporary change the number of epochs to train the Sentence Transformer body/head for. - If ignore, will use the value given in initialization. - batch_size (`int`, *optional*): - Temporary change the batch size to use for contrastive training or logistic regression. - If ignore, will use the value given in initialization. - learning_rate (`float`, *optional*): - Temporary change the learning rate to use for contrastive training or SetFitModel's head in logistic regression. - If ignore, will use the value given in initialization. - body_learning_rate (`float`, *optional*): - Temporary change the learning rate to use for SetFitModel's body in logistic regression only. - If ignore, will be the same as `learning_rate`. - l2_weight (`float`, *optional*): - Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression. - trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): - The trial run or the hyperparameter dictionary for hyperparameter search. - show_progress_bar (`bool`, *optional*, defaults to `True`): - Whether to show a bar that indicates training progress. + x_train (`List[str]`): A list of training sentences. + args (`TrainingArguments`, *optional*): + Temporarily change the training arguments for this training call. """ - set_seed(self.seed) # Seed must be set before instantiating the model when using model_init. - - if trial: # Trial and model initialization - self._hp_search_setup(trial) # sets trainer parameters and initializes model - - if self.train_dataset is None: - raise ValueError( - "Training requires a `train_dataset` given to the `DistillationSetFitTrainer` initialization." - ) - - self._validate_column_mapping(self.train_dataset) - train_dataset = self.train_dataset - if self.column_mapping is not None: - logger.info("Applying column mapping to training dataset") - train_dataset = self._apply_column_mapping(self.train_dataset, self.column_mapping) - - x_train = train_dataset["text"] - y_train = train_dataset["label"] - if self.loss_class is None: - logger.warning("No `loss_class` detected! Using `CosineSimilarityLoss` as the default.") - self.loss_class = losses.CosineSimilarityLoss - - num_epochs = num_epochs or self.num_epochs - batch_size = batch_size or self.batch_size - learning_rate = learning_rate or self.learning_rate - - if not self.student_model.has_differentiable_head or self._freeze: - # sentence-transformers adaptation - if self.loss_class in [ - losses.BatchAllTripletLoss, - losses.BatchHardTripletLoss, - losses.BatchSemiHardTripletLoss, - losses.BatchHardSoftMarginTripletLoss, - SupConLoss, - ]: - train_examples = [InputExample(texts=[text], label=label) for text, label in zip(x_train, y_train)] - train_data_sampler = SentenceLabelDataset(train_examples) - - batch_size = min(batch_size, len(train_data_sampler)) - train_dataloader = DataLoader(train_data_sampler, batch_size=batch_size, drop_last=True) - - if self.loss_class is losses.BatchHardSoftMarginTripletLoss: - train_loss = self.loss_class( - model=self.student_model, - distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance, - ) - elif self.loss_class is SupConLoss: - train_loss = self.loss_class(model=self.student_model) - else: - train_loss = self.loss_class( - model=self.student_model, - distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance, - margin=0.25, - ) - else: - train_examples = [] - - # **************** student training **************** - x_train_embd_student = self.teacher_model.model_body.encode( - x_train, convert_to_tensor=self.teacher_model.has_differentiable_head - ) - y_train = self.teacher_model.model_head.predict(x_train_embd_student) - if not self.teacher_model.has_differentiable_head and self.student_model.has_differentiable_head: - y_train = torch.from_numpy(y_train) - elif self.teacher_model.has_differentiable_head and not self.student_model.has_differentiable_head: - y_train = y_train.detach().cpu().numpy() - - cos_sim_matrix = util.cos_sim(x_train_embd_student, x_train_embd_student) - - train_examples = [] - for _ in range(self.num_iterations): - train_examples = sentence_pairs_generation_cos_sim( - np.array(x_train), train_examples, cos_sim_matrix - ) - - # **************** student training END **************** - - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) - train_loss = self.loss_class(self.student_model.model_body) - - total_train_steps = len(train_dataloader) * num_epochs - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_examples)}") - logger.info(f" Num epochs = {num_epochs}") - logger.info(f" Total optimization steps = {total_train_steps}") - logger.info(f" Total train batch size = {batch_size}") - - warmup_steps = math.ceil(total_train_steps * self.warmup_proportion) - self.student_model.model_body.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=num_epochs, - optimizer_params={"lr": learning_rate}, - warmup_steps=warmup_steps, - show_progress_bar=show_progress_bar, - use_amp=self.use_amp, - ) - - if not self.student_model.has_differentiable_head or not self._freeze: - # Train the final classifier - self.student_model.fit( - x_train, - y_train, - num_epochs=num_epochs, - batch_size=batch_size, - learning_rate=learning_rate, - body_learning_rate=body_learning_rate, - l2_weight=l2_weight, - show_progress_bar=show_progress_bar, - ) + y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head) + return super().train_classifier(x_train, y_train, args) + + +class DistillationSetFitTrainer(DistillationTrainer): + """ + `DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. + Please use `DistillationTrainer` instead. + """ + + def __init__( + self, + teacher_model: "SetFitModel", + student_model: Optional["SetFitModel"] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, + metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", + loss_class: torch.nn.Module = losses.CosineSimilarityLoss, + num_iterations: int = 20, + num_epochs: int = 1, + learning_rate: float = 2e-5, + batch_size: int = 16, + seed: int = 42, + column_mapping: Optional[Dict[str, str]] = None, + use_amp: bool = False, + warmup_proportion: float = 0.1, + ) -> None: + warnings.warn( + "`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. " + "Please use `DistillationTrainer` instead.", + DeprecationWarning, + stacklevel=2, + ) + args = TrainingArguments( + num_iterations=num_iterations, + num_epochs=num_epochs, + body_learning_rate=learning_rate, + head_learning_rate=learning_rate, + batch_size=batch_size, + seed=seed, + use_amp=use_amp, + warmup_proportion=warmup_proportion, + loss=loss_class, + ) + super().__init__( + teacher_model=teacher_model, + student_model=student_model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model_init=model_init, + metric=metric, + column_mapping=column_mapping, + ) diff --git a/src/setfit/training_args.py b/src/setfit/training_args.py new file mode 100644 index 00000000..9ed24fb7 --- /dev/null +++ b/src/setfit/training_args.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import inspect +import json +from copy import copy +from dataclasses import dataclass, field, fields +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from sentence_transformers import losses +from transformers import IntervalStrategy +from transformers.integrations import get_available_reporting_integrations +from transformers.training_args import default_logdir +from transformers.utils import is_torch_available + +from . import logging + + +logger = logging.get_logger(__name__) + + +@dataclass +class TrainingArguments: + """ + TrainingArguments is the subset of the arguments which relate to the training loop itself. + + Parameters: + output_dir (`str`, defaults to `"checkpoints"`): + The output directory where the model predictions and checkpoints will be written. + batch_size (`Union[int, Tuple[int, int]]`, defaults to `(16, 2)`): + Set the batch sizes for the embedding and classifier training phases respectively, + or set both if an integer is provided. + Note that the batch size for the classifier is only used with a differentiable PyTorch head. + num_epochs (`Union[int, Tuple[int, int]]`, defaults to `(1, 16)`): + Set the number of epochs the embedding and classifier training phases respectively, + or set both if an integer is provided. + Note that the number of epochs for the classifier is only used with a differentiable PyTorch head. + max_steps (`int`, *optional*, defaults to `-1`): + If set to a positive number, the total number of training steps to perform. Overrides `num_epochs`. + The training may stop before reaching the set number of steps when all data is exhausted. + sampling_strategy (`str`, defaults to `"oversampling"`): + The sampling strategy of how to draw pairs in training. Possible values are: + + - `"oversampling"`: Draws even number of positive/ negative sentence pairs until every + sentence pair has been drawn. + - `"undersampling"`: Draws the minimum number of positive/ negative sentence pairs until + every sentence pair in the minority class has been drawn. + - `"unique"`: Draws every sentence pair combination (likely resulting in unbalanced + number of positive/ negative sentence pairs). + + The default is set to `"oversampling"`, ensuring all sentence pairs are drawn at least once. + Alternatively setting `num_iterations` will override this argument and determine the number + of generated sentence pairs. + num_iterations (`int`, *optional*): + If not set the `sampling_strategy` will determine the number of sentence pairs to generate. + This argument sets the number of iterations to generate sentence pairs for + and provides compatability with Setfit = 1.6.0 + warmup_proportion (`float`, defaults to `0.1`): + Proportion of the warmup in the total training steps. + Must be greater than or equal to 0.0 and less than or equal to 1.0. + l2_weight (`float`, *optional*): + Optional l2 weight for both the model body and head, passed to the `AdamW` optimizer in the + classifier training phase if a differentiable PyTorch head is used. + max_length (`int`, *optional*): + The maximum token length a tokenizer can generate. If not provided, the maximum length for + the `SentenceTransformer` body is used. + samples_per_label (`int`, defaults to `2`): Number of consecutive, random and unique samples drawn per label. + This is only relevant for triplet loss and ignored for `CosineSimilarityLoss`. + Batch size should be a multiple of samples_per_label. + show_progress_bar (`bool`, defaults to `True`): + Whether to display a progress bar for the training epochs and iterations. + seed (`int`, defaults to `42`): + Random seed that will be set at the beginning of training. To ensure reproducibility across + runs, use the [`~SetTrainer.model_init`] function to instantiate the model if it has some + randomly initialized parameters. + report_to (`str` or `List[str]`, *optional*, defaults to `"all"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. Use `"all"` to report to + all integrations installed, `"none"` for no integrations. + run_name (`str`, *optional*): + A descriptor for the run. Typically used for [wandb](https://www.wandb.com/) and + [mlflow](https://www.mlflow.org/) logging. + logging_dir (`str`, *optional*): + [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to + *runs/**CURRENT_DATETIME_HOSTNAME***. + logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + logging_first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + logging_steps (`int`, *optional*, defaults to 50): + Number of update steps between two logs if `logging_strategy="steps"`. + evaluation_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `eval_steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + eval_steps (`int`, *optional*): + Number of update steps between two evaluations if `evaluation_strategy="steps"`. Will default to the same + value as `logging_steps` if not set. + eval_delay (`float`, *optional*): + Number of epochs or steps to wait for before the first evaluation can be performed, depending on the + evaluation_strategy. + + save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + save_steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `save_strategy="steps"`. + save_total_limit (`int`, *optional*, defaults to `1`): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. Note, the best model is always preserved if the `evaluation_strategy` is not `"no"`. + load_best_model_at_end (`bool`, *optional*, defaults to `False`): + Whether or not to load the best model found during training at the end of training. + + + + When set to `True`, the parameters `save_strategy` needs to be the same as `evaluation_strategy`, and in + the case it is "steps", `save_steps` must be a round multiple of `eval_steps`. + + + """ + + output_dir: str = "checkpoints" + + # batch_size is only used to conveniently set `embedding_batch_size` and `classifier_batch_size` + # which are used in practice + batch_size: Union[int, Tuple[int, int]] = field(default=(16, 2), repr=False) + embedding_batch_size: int = None + classifier_batch_size: int = None + + # num_epochs is only used to conveniently set `embedding_num_epochs` and `classifier_num_epochs` + # which are used in practice + num_epochs: Union[int, Tuple[int, int]] = field(default=(1, 16), repr=False) + embedding_num_epochs: int = None + classifier_num_epochs: int = None + + max_steps: int = -1 + + sampling_strategy: str = "oversampling" + num_iterations: Optional[int] = None + + # As with batch_size and num_epochs, the first value in the tuple is the learning rate + # for the embeddings step, while the second value is the learning rate for the classifier step. + body_learning_rate: Union[float, Tuple[float, float]] = field(default=(2e-5, 1e-5), repr=False) + body_embedding_learning_rate: float = None + body_classifier_learning_rate: float = None + head_learning_rate: float = 1e-2 + + # Loss-related arguments + loss: Callable = losses.CosineSimilarityLoss + distance_metric: Callable = losses.BatchHardTripletLossDistanceFunction.cosine_distance + margin: float = 0.25 + + end_to_end: bool = field(default=False) + + use_amp: bool = False + warmup_proportion: float = 0.1 + l2_weight: Optional[float] = None + max_length: Optional[int] = None + samples_per_label: int = 2 + + # Arguments that do not affect performance + show_progress_bar: bool = True + seed: int = 42 + + # Logging & callbacks + report_to: str = "all" + run_name: Optional[str] = None + logging_dir: Optional[str] = None + logging_strategy: str = "steps" + logging_first_step: bool = True + logging_steps: int = 50 + + evaluation_strategy: str = "no" + eval_steps: Optional[int] = None + eval_delay: int = 0 + + save_strategy: str = "steps" + save_steps: int = 500 + save_total_limit: Optional[int] = 1 + + load_best_model_at_end: bool = False + metric_for_best_model: str = field(default="embedding_loss", repr=False) + greater_is_better: bool = field(default=False, repr=False) + + def __post_init__(self) -> None: + # Set `self.embedding_batch_size` and `self.classifier_batch_size` using values from `self.batch_size` + if isinstance(self.batch_size, int): + self.batch_size = (self.batch_size, self.batch_size) + if self.embedding_batch_size is None: + self.embedding_batch_size = self.batch_size[0] + if self.classifier_batch_size is None: + self.classifier_batch_size = self.batch_size[1] + + # Set `self.embedding_num_epochs` and `self.classifier_num_epochs` using values from `self.num_epochs` + if isinstance(self.num_epochs, int): + self.num_epochs = (self.num_epochs, self.num_epochs) + if self.embedding_num_epochs is None: + self.embedding_num_epochs = self.num_epochs[0] + if self.classifier_num_epochs is None: + self.classifier_num_epochs = self.num_epochs[1] + + # Set `self.body_embedding_learning_rate` and `self.body_classifier_learning_rate` using + # values from `self.body_learning_rate` + if isinstance(self.body_learning_rate, float): + self.body_learning_rate = (self.body_learning_rate, self.body_learning_rate) + if self.body_embedding_learning_rate is None: + self.body_embedding_learning_rate = self.body_learning_rate[0] + if self.body_classifier_learning_rate is None: + self.body_classifier_learning_rate = self.body_learning_rate[1] + + if self.warmup_proportion < 0.0 or self.warmup_proportion > 1.0: + raise ValueError( + f"warmup_proportion must be greater than or equal to 0.0 and less than or equal to 1.0! But it was: {self.warmup_proportion}" + ) + + if self.report_to in (None, "all", ["all"]): + self.report_to = get_available_reporting_integrations() + elif self.report_to in ("none", ["none"]): + self.report_to = [] + elif not isinstance(self.report_to, list): + self.report_to = [self.report_to] + + if self.logging_dir is None: + self.logging_dir = default_logdir() + + self.logging_strategy = IntervalStrategy(self.logging_strategy) + self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) + + if self.eval_steps is not None and self.evaluation_strategy == IntervalStrategy.NO: + logger.info('Using `evaluation_strategy="steps"` as `eval_steps` is defined.') + self.evaluation_strategy = IntervalStrategy.STEPS + + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero + if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): + if self.logging_steps > 0: + self.eval_steps = self.logging_steps + else: + raise ValueError( + f"evaluation strategy {self.evaluation_strategy} requires either non-zero `eval_steps` or" + " `logging_steps`" + ) + + # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. + if self.load_best_model_at_end: + if self.evaluation_strategy != self.save_strategy: + raise ValueError( + "`load_best_model_at_end` requires the save and eval strategy to match, but found\n- Evaluation " + f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" + ) + if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_steps < 1 or self.save_steps < 1: + if not (self.eval_steps < 1 and self.save_steps < 1): + raise ValueError( + "`load_best_model_at_end` requires the saving steps to be a multiple of the evaluation " + "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps" + f"{self.save_steps} and eval_steps {self.eval_steps}." + ) + # Work around floating point precision issues + LARGE_MULTIPLIER = 1_000_000 + if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: + raise ValueError( + "`load_best_model_at_end` requires the saving steps to be a multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." + ) + raise ValueError( + "`load_best_model_at_end` requires the saving steps to be a round multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." + ) + + # logging_steps must be non-zero for logging_strategy that is other than 'no' + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: + raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") + + def to_dict(self) -> Dict[str, Any]: + # filter out fields that are defined as field(init=False) + return {field.name: getattr(self, field.name) for field in fields(self) if field.init} + + @classmethod + def from_dict(cls, arguments: Dict[str, Any], ignore_extra: bool = False) -> TrainingArguments: + if ignore_extra: + return cls(**{key: value for key, value in arguments.items() if key in inspect.signature(cls).parameters}) + return cls(**arguments) + + def copy(self) -> TrainingArguments: + return copy(self) + + def update(self, arguments: Dict[str, Any], ignore_extra: bool = False) -> TrainingArguments: + return TrainingArguments.from_dict({**self.to_dict(), **arguments}, ignore_extra=ignore_extra) + + def to_json_string(self): + """ + Serializes this instance to a JSON string. + """ + # TODO: This needs to be improved + return json.dumps({key: str(value) for key, value in self.to_dict().items()}, indent=2) + + def to_sanitized_dict(self) -> Dict[str, Any]: + """ + Sanitized serialization to use with TensorBoard’s hparams + """ + d = self.to_dict() + d = {**d, **{"train_batch_size": self.embedding_batch_size, "eval_batch_size": self.embedding_batch_size}} + + valid_types = [bool, int, float, str] + if is_torch_available(): + valid_types.append(torch.Tensor) + + return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} diff --git a/src/setfit/utils.py b/src/setfit/utils.py index 57fb31d4..d75dc7cf 100644 --- a/src/setfit/utils.py +++ b/src/setfit/utils.py @@ -7,7 +7,7 @@ from sentence_transformers import losses from .data import create_fewshot_splits, create_fewshot_splits_multilabel -from .modeling import SupConLoss +from .losses import SupConLoss SEC_TO_NS_SCALE = 1000000000 @@ -135,7 +135,7 @@ def summary(self) -> None: class BestRun(NamedTuple): """ - The best run found by a hyperparameter search (see [`~SetFitTrainer.hyperparameter_search`]). + The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]). Parameters: run_id (`str`): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..f92a81d8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,29 @@ +import pytest +from datasets import Dataset + +from setfit import AbsaModel, SetFitModel + + +@pytest.fixture() +def model() -> SetFitModel: + return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + + +@pytest.fixture() +def absa_model() -> AbsaModel: + return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm") + + +@pytest.fixture() +def absa_dataset() -> Dataset: + texts = [ + "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.", + "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.", + "Food is great and inexpensive.", + "Good bagels and good cream cheese.", + "Good bagels and good cream cheese.", + ] + spans = ["food", "ambiance", "Food", "bagels", "cream cheese"] + labels = ["negative", "negative", "positive", "positive", "positive"] + ordinals = [0, 0, 0, 0, 0] + return Dataset.from_dict({"text": texts, "span": spans, "label": labels, "ordinal": ordinals}) diff --git a/tests/exporters/test_onnx.py b/tests/exporters/test_onnx.py index 6c132d43..6e515d74 100644 --- a/tests/exporters/test_onnx.py +++ b/tests/exporters/test_onnx.py @@ -8,7 +8,8 @@ from setfit import SetFitModel from setfit.data import get_templated_dataset from setfit.exporters.onnx import export_onnx -from setfit.trainer import SetFitTrainer +from setfit.trainer import Trainer +from setfit.training_args import TrainingArguments @pytest.mark.parametrize( @@ -71,25 +72,23 @@ def test_export_onnx_torch_head(out_features): model_path, use_differentiable_head=True, head_params={"out_features": out_features} ) - trainer = SetFitTrainer( + args = TrainingArguments( + num_iterations=15, + num_epochs=(1, 15), + batch_size=16, + body_learning_rate=(2e-5, 1e-5), + head_learning_rate=1e-2, + l2_weight=0.0, + end_to_end=True, + ) + trainer = Trainer( model=model, + args=args, train_dataset=dataset, eval_dataset=dataset, - num_iterations=15, column_mapping={"text": "text", "label": "label"}, ) - # Train and evaluate - trainer.freeze() # Freeze the head - trainer.train() # Train only the body - # Unfreeze the head and unfreeze the body -> end-to-end training - trainer.unfreeze(keep_body_frozen=False) - trainer.train( - num_epochs=15, - batch_size=16, - body_learning_rate=1e-5, - learning_rate=1e-2, - l2_weight=0.0, - ) + trainer.train() # Export the sklearn based model output_path = "model.onnx" diff --git a/tests/span/test_modeling.py b/tests/span/test_modeling.py new file mode 100644 index 00000000..0bc3ccb8 --- /dev/null +++ b/tests/span/test_modeling.py @@ -0,0 +1,86 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +import torch + +from setfit import AbsaModel +from setfit.span.aspect_extractor import AspectExtractor +from setfit.span.modeling import AspectModel, PolarityModel + + +def test_loading(): + model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm") + assert isinstance(model, AbsaModel) + assert isinstance(model.aspect_extractor, AspectExtractor) + assert isinstance(model.aspect_model, AspectModel) + assert isinstance(model.polarity_model, PolarityModel) + + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2@6c91e73a51599e35bd1145dfdcd3289215225009", + "sentence-transformers/paraphrase-albert-small-v2", + spacy_model="en_core_web_sm", + ) + assert isinstance(model, AbsaModel) + + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", + "sentence-transformers/paraphrase-albert-small-v2@6c91e73a51599e35bd1145dfdcd3289215225009", + spacy_model="en_core_web_sm", + ) + assert isinstance(model, AbsaModel) + + with pytest.raises(OSError): + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", spacy_model="not_a_spacy_model" + ) + + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm", normalize_embeddings=True + ) + assert model.aspect_model.normalize_embeddings + assert model.polarity_model.normalize_embeddings + + aspect_model = AspectModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_context=12) + assert aspect_model.span_context == 12 + polarity_model = PolarityModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_context=12) + assert polarity_model.span_context == 12 + + model = AbsaModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm", span_contexts=(12, None) + ) + assert model.aspect_model.span_context == 12 + assert model.polarity_model.span_context == 3 # <- default + + +def test_save_load(absa_model: AbsaModel) -> None: + absa_model.polarity_model.span_context = 5 + + with TemporaryDirectory() as tmp_dir: + tmp_dir = str(Path(tmp_dir) / "model") + absa_model.save_pretrained(tmp_dir) + assert (Path(tmp_dir + "-aspect") / "config_span_setfit.json").exists() + assert (Path(tmp_dir + "-polarity") / "config_span_setfit.json").exists() + + fresh_model = AbsaModel.from_pretrained( + tmp_dir + "-aspect", tmp_dir + "-polarity", spacy_model="en_core_web_sm" + ) + assert fresh_model.polarity_model.span_context == 5 + + with TemporaryDirectory() as aspect_tmp_dir: + with TemporaryDirectory() as polarity_tmp_dir: + absa_model.save_pretrained(aspect_tmp_dir, polarity_tmp_dir) + assert (Path(aspect_tmp_dir) / "config_span_setfit.json").exists() + assert (Path(polarity_tmp_dir) / "config_span_setfit.json").exists() + + fresh_model = AbsaModel.from_pretrained(aspect_tmp_dir, polarity_tmp_dir, spacy_model="en_core_web_sm") + assert fresh_model.polarity_model.span_context == 5 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to move a model between devices") +def test_to(absa_model: AbsaModel) -> None: + assert absa_model.device.type == "cuda" + absa_model.to("cpu") + assert absa_model.device.type == "cpu" + assert absa_model.aspect_model.device.type == "cpu" + assert absa_model.polarity_model.device.type == "cpu" diff --git a/tests/span/test_trainer.py b/tests/span/test_trainer.py new file mode 100644 index 00000000..f89044dc --- /dev/null +++ b/tests/span/test_trainer.py @@ -0,0 +1,75 @@ +from datasets import Dataset +from transformers import TrainerCallback + +from setfit import AbsaTrainer +from setfit.span.modeling import AbsaModel + + +def test_trainer(absa_model: AbsaModel, absa_dataset: Dataset) -> None: + trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset, eval_dataset=absa_dataset) + trainer.train() + + metrics = trainer.evaluate() + assert "aspect" in metrics + assert "polarity" in metrics + assert "accuracy" in metrics["aspect"] + assert "accuracy" in metrics["polarity"] + assert metrics["aspect"]["accuracy"] > 0.0 + assert metrics["polarity"]["accuracy"] > 0.0 + new_metrics = trainer.evaluate(absa_dataset) + assert metrics == new_metrics + + predict = absa_model.predict("Best pizza outside of Italy and really tasty.") + assert {"span": "pizza", "polarity": "positive"} in predict + predict = absa_model.predict(["Best pizza outside of Italy and really tasty.", "This is another sentence"]) + assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list) + predict = absa_model(["Best pizza outside of Italy and really tasty.", "This is another sentence"]) + assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list) + + +def test_trainer_callbacks(absa_model: AbsaModel) -> None: + trainer = AbsaTrainer(absa_model) + assert len(trainer.aspect_trainer.callback_handler.callbacks) >= 2 + callback_names = {callback.__class__.__name__ for callback in trainer.aspect_trainer.callback_handler.callbacks} + assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names + + class TestCallback(TrainerCallback): + pass + + callback = TestCallback() + trainer.add_callback(callback) + assert len(trainer.aspect_trainer.callback_handler.callbacks) == len(callback_names) + 1 + assert len(trainer.polarity_trainer.callback_handler.callbacks) == len(callback_names) + 1 + assert trainer.aspect_trainer.callback_handler.callbacks[-1] == callback + assert trainer.polarity_trainer.callback_handler.callbacks[-1] == callback + + assert trainer.pop_callback(callback) == (callback, callback) + trainer.add_callback(callback) + assert trainer.aspect_trainer.callback_handler.callbacks[-1] == callback + assert trainer.polarity_trainer.callback_handler.callbacks[-1] == callback + trainer.remove_callback(callback) + assert callback not in trainer.aspect_trainer.callback_handler.callbacks + assert callback not in trainer.polarity_trainer.callback_handler.callbacks + + +def test_train_ordinal_too_high(absa_model: AbsaModel) -> None: + absa_dataset = Dataset.from_dict( + { + "text": [ + "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine." + ], + "span": ["food"], + "label": ["negative"], + "ordinal": [1], + } + ) + AbsaTrainer(absa_model, train_dataset=absa_dataset) + # TODO: Capture warning and test against it. + + +def test_train_column_mapping(absa_model: AbsaModel, absa_dataset: Dataset) -> None: + absa_dataset = absa_dataset.rename_columns({"text": "sentence", "span": "aspect"}) + trainer = AbsaTrainer( + absa_model, train_dataset=absa_dataset, column_mapping={"sentence": "text", "aspect": "span"} + ) + trainer.train() diff --git a/tests/test_deprecated_trainer.py b/tests/test_deprecated_trainer.py new file mode 100644 index 00000000..8e1ce1d5 --- /dev/null +++ b/tests/test_deprecated_trainer.py @@ -0,0 +1,531 @@ +import pathlib +import re +import tempfile +from unittest import TestCase + +import evaluate +import pytest +import torch +from datasets import Dataset, load_dataset +from sentence_transformers import losses +from transformers.testing_utils import require_optuna +from transformers.utils.hp_naming import TrialShortNamer + +from setfit import logging +from setfit.losses import SupConLoss +from setfit.modeling import SetFitModel +from setfit.trainer import SetFitTrainer +from setfit.utils import BestRun + + +logging.set_verbosity_warning() +logging.enable_propagation() + + +class SetFitTrainerTest(TestCase): + def setUp(self): + self.model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + self.num_iterations = 1 + + def test_trainer_works_with_model_init(self): + def get_model(): + model_name = "sentence-transformers/paraphrase-albert-small-v2" + return SetFitModel.from_pretrained(model_name) + + dataset = Dataset.from_dict( + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + trainer = SetFitTrainer( + model_init=get_model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + trainer.train() + metrics = trainer.evaluate() + self.assertEqual(metrics["accuracy"], 1.0) + + def test_trainer_works_with_column_mapping(self): + dataset = Dataset.from_dict( + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + trainer.train() + metrics = trainer.evaluate() + self.assertEqual(metrics["accuracy"], 1.0) + + def test_trainer_works_with_default_columns(self): + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) + trainer = SetFitTrainer( + model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations + ) + trainer.train() + metrics = trainer.evaluate() + self.assertEqual(metrics["accuracy"], 1.0) + + def test_trainer_works_with_alternate_dataset_for_evaluate(self): + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) + alternate_dataset = Dataset.from_dict( + {"text": ["x", "y", "z"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + trainer = SetFitTrainer( + model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations + ) + trainer.train() + metrics = trainer.evaluate(alternate_dataset) + self.assertNotEqual(metrics["accuracy"], 1.0) + + def test_trainer_raises_error_with_missing_label(self): + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) + trainer = SetFitTrainer( + model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations + ) + with pytest.raises(ValueError): + trainer.train() + + def test_trainer_raises_error_with_missing_text(self): + """If the required columns are missing from the dataset, the library should throw an error and list the columns found.""" + dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) + trainer = SetFitTrainer( + model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations + ) + expected_message = re.escape( + "SetFit expected the dataset to have the columns ['label', 'text'], " + "but only the columns ['extra_column', 'label'] were found. " + "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer." + ) + with pytest.raises(ValueError, match=expected_message): + trainer._validate_column_mapping(trainer.train_dataset) + + def test_column_mapping_raises_error_when_mapped_columns_missing(self): + """If the columns specified in the column mapping are missing from the dataset, the library should throw an error and list the columns found.""" + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + expected_message = re.escape( + "The column mapping expected the columns ['label_new', 'text_new'] in the dataset, " + "but the dataset had the columns ['extra_column', 'text'].", + ) + with pytest.raises(ValueError, match=expected_message): + trainer._validate_column_mapping(trainer.train_dataset) + + def test_trainer_raises_error_when_dataset_not_split(self): + """Verify that an error is raised if we pass an unsplit dataset to the trainer.""" + dataset = Dataset.from_dict({"text": ["a", "b", "c", "d"], "label": [0, 0, 1, 1]}).train_test_split( + test_size=0.5 + ) + trainer = SetFitTrainer( + model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations + ) + expected_message = re.escape( + "SetFit expected a Dataset, but it got a DatasetDict with the splits ['test', 'train']. " + "Did you mean to select one of these splits from the dataset?", + ) + with pytest.raises(ValueError, match=expected_message): + trainer._validate_column_mapping(trainer.train_dataset) + + def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self): + """Verify that a useful error is raised if we pass an unsplit dataset with only a `train` split to the trainer.""" + with tempfile.TemporaryDirectory() as tmpdirname: + path = pathlib.Path(tmpdirname) / "test_dataset_dict_with_train.csv" + path.write_text("label,text\n1,good\n0,terrible\n") + dataset = load_dataset("csv", data_files=str(path)) + trainer = SetFitTrainer( + model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations + ) + expected_message = re.escape( + "SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. " + "Did you mean to select the training split with dataset['train']?", + ) + with pytest.raises(ValueError, match=expected_message): + trainer._validate_column_mapping(trainer.train_dataset) + + def test_column_mapping_multilabel(self): + dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]}) + + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer._validate_column_mapping(trainer.train_dataset) + formatted_dataset = trainer._apply_column_mapping(trainer.train_dataset, trainer.column_mapping) + + assert formatted_dataset.column_names == ["text", "label"] + + assert formatted_dataset[0]["text"] == "a" + assert formatted_dataset[0]["label"] == [0, 1] + + assert formatted_dataset[1]["text"] == "b" + + def test_trainer_support_callable_as_metric(self): + dataset = Dataset.from_dict( + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + + f1_metric = evaluate.load("f1") + accuracy_metric = evaluate.load("accuracy") + + def compute_metrics(y_pred, y_test): + return { + "f1": f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"], + "accuracy": accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"], + } + + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + metric=compute_metrics, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer.train() + metrics = trainer.evaluate() + + self.assertEqual( + { + "f1": 1.0, + "accuracy": 1.0, + }, + metrics, + ) + + def test_raise_when_metric_value_is_invalid(self): + dataset = Dataset.from_dict( + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + metric="this-metric-does-not-exist", # invalid metric value + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer.train() + + with self.assertRaises(FileNotFoundError): + trainer.evaluate() + + def test_trainer_raises_error_with_wrong_warmup_proportion(self): + # warmup_proportion must not be > 1.0 + with pytest.raises(ValueError): + SetFitTrainer(warmup_proportion=1.1) + + # warmup_proportion must not be < 0.0 + with pytest.raises(ValueError): + SetFitTrainer(warmup_proportion=-0.1) + + +class SetFitTrainerDifferentiableHeadTest(TestCase): + def setUp(self): + self.dataset = Dataset.from_dict( + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + self.model = SetFitModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", + use_differentiable_head=True, + head_params={"out_features": 3}, + ) + self.num_iterations = 1 + + @pytest.mark.skip(reason="The `trainer.train` arguments are now ignored, causing this test to fail.") + def test_trainer_max_length_exceeds_max_acceptable_length(self): + trainer = SetFitTrainer( + model=self.model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + trainer.unfreeze(keep_body_frozen=True) + with self.assertLogs(level=logging.WARNING) as cm: + max_length = 4096 + max_acceptable_length = self.model.model_body.get_max_seq_length() + trainer.train( + num_epochs=1, + batch_size=3, + learning_rate=1e-2, + l2_weight=0.0, + max_length=max_length, + ) + self.assertEqual( + cm.output, + [ + ( + f"WARNING:setfit.modeling:The specified `max_length`: {max_length} is greater than the maximum length " + f"of the current model body: {max_acceptable_length}. Using {max_acceptable_length} instead." + ) + ], + ) + + @pytest.mark.skip(reason="The `trainer.train` arguments are now ignored, causing this test to fail.") + def test_trainer_max_length_is_smaller_than_max_acceptable_length(self): + trainer = SetFitTrainer( + model=self.model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + trainer.unfreeze(keep_body_frozen=True) + + # An alternative way of `assertNoLogs`, which is new in Python 3.10 + try: + with self.assertLogs(level=logging.WARNING) as cm: + max_length = 32 + trainer.train( + num_epochs=1, + batch_size=3, + learning_rate=1e-2, + l2_weight=0.0, + max_length=max_length, + ) + self.assertEqual(cm.output, []) + except AssertionError as e: + if e.args[0] != "no logs of level WARNING or higher triggered on root": + raise AssertionError(e) + + +class SetFitTrainerMultilabelTest(TestCase): + def setUp(self): + self.model = SetFitModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest" + ) + self.num_iterations = 1 + + def test_trainer_multilabel_support_callable_as_metric(self): + dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]}) + + multilabel_f1_metric = evaluate.load("f1", "multilabel") + multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel") + + def compute_metrics(y_pred, y_test): + return { + "f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"], + "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"], + } + + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + metric=compute_metrics, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer.train() + metrics = trainer.evaluate() + + self.assertEqual( + { + "f1": 1.0, + "accuracy": 1.0, + }, + metrics, + ) + + +@pytest.mark.skip( + reason=( + "The `trainer.freeze()` before `trainer.train()` now freezes the body as well as the head, " + "which means the backwards call from `trainer.train()` will fail." + ) +) +class SetFitTrainerMultilabelDifferentiableTest(TestCase): + def setUp(self): + self.model = SetFitModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", + multi_target_strategy="one-vs-rest", + use_differentiable_head=True, + head_params={"out_features": 2}, + ) + self.num_iterations = 1 + + def test_trainer_multilabel_support_callable_as_metric(self): + dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]}) + + multilabel_f1_metric = evaluate.load("f1", "multilabel") + multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel") + + def compute_metrics(y_pred, y_test): + return { + "f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"], + "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"], + } + + trainer = SetFitTrainer( + model=self.model, + train_dataset=dataset, + eval_dataset=dataset, + metric=compute_metrics, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer.freeze() + trainer.train() + + trainer.unfreeze(keep_body_frozen=False) + trainer.train(5) + metrics = trainer.evaluate() + + self.assertEqual( + { + "f1": 1.0, + "accuracy": 1.0, + }, + metrics, + ) + + +@require_optuna +class TrainerHyperParameterOptunaIntegrationTest(TestCase): + def setUp(self): + self.dataset = Dataset.from_dict( + {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + self.num_iterations = 1 + + def test_hyperparameter_search(self): + class MyTrialShortNamer(TrialShortNamer): + DEFAULTS = {"max_iter": 100, "solver": "liblinear"} + + def hp_space(trial): + return { + "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), + "batch_size": trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64]), + "max_iter": trial.suggest_int("max_iter", 50, 300), + "solver": trial.suggest_categorical("solver", ["newton-cg", "lbfgs", "liblinear"]), + } + + def model_init(params): + params = params or {} + max_iter = params.get("max_iter", 100) + solver = params.get("solver", "liblinear") + params = { + "head_params": { + "max_iter": max_iter, + "solver": solver, + } + } + return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", **params) + + def hp_name(trial): + return MyTrialShortNamer.shortname(trial.params) + + trainer = SetFitTrainer( + train_dataset=self.dataset, + eval_dataset=self.dataset, + num_iterations=self.num_iterations, + model_init=model_init, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + result = trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4) + assert isinstance(result, BestRun) + assert result.hyperparameters.keys() == {"learning_rate", "batch_size", "max_iter", "solver"} + + +# regression test for https://github.com/huggingface/setfit/issues/153 +@pytest.mark.parametrize( + "loss_class", + [ + losses.BatchAllTripletLoss, + losses.BatchHardTripletLoss, + losses.BatchSemiHardTripletLoss, + losses.BatchHardSoftMarginTripletLoss, + SupConLoss, + ], +) +def test_trainer_works_with_non_default_loss_class(loss_class): + dataset = Dataset.from_dict({"text": ["a 1", "b 1", "c 1", "a 2", "b 2", "c 2"], "label": [0, 1, 2, 0, 1, 2]}) + model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + trainer = SetFitTrainer( + model=model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=1, + loss_class=loss_class, + ) + trainer.train() + # no asserts here because this is a regression test - we only test if an exception is raised + + +def test_trainer_evaluate_with_strings(): + dataset = Dataset.from_dict( + {"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]} + ) + model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + trainer = SetFitTrainer( + model=model, + train_dataset=dataset, + eval_dataset=dataset, + num_iterations=1, + ) + trainer.train() + # This used to fail due to "TypeError: can't convert np.ndarray of type numpy.str_. + # The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." + model.predict(["another positive sentence"]) + + +def test_trainer_evaluate_multilabel_f1(): + dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]}) + model = SetFitModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest" + ) + + trainer = SetFitTrainer( + model=model, + train_dataset=dataset, + eval_dataset=dataset, + metric="f1", + metric_kwargs={"average": "micro"}, + num_iterations=5, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer.train() + metrics = trainer.evaluate() + assert metrics == {"f1": 1.0} + + +def test_trainer_evaluate_on_cpu() -> None: + # This test used to fail if CUDA was available + dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]}) + model = SetFitModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True + ) + + def compute_metric(y_pred, y_test) -> None: + assert y_pred.device == torch.device("cpu") + return 1.0 + + trainer = SetFitTrainer( + model=model, + train_dataset=dataset, + eval_dataset=dataset, + metric=compute_metric, + num_iterations=5, + ) + trainer.train() + trainer.evaluate() diff --git a/tests/test_deprecated_trainer_distillation.py b/tests/test_deprecated_trainer_distillation.py new file mode 100644 index 00000000..5d59c4f5 --- /dev/null +++ b/tests/test_deprecated_trainer_distillation.py @@ -0,0 +1,117 @@ +from unittest import TestCase + +import pytest +from datasets import Dataset +from sentence_transformers.losses import CosineSimilarityLoss + +from setfit import DistillationSetFitTrainer, SetFitTrainer +from setfit.modeling import SetFitModel + + +class DistillationSetFitTrainerTest(TestCase): + def setUp(self): + self.teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") + self.student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2") + self.num_iterations = 1 + + def test_trainer_works_with_default_columns(self): + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) + # train a teacher model + teacher_trainer = SetFitTrainer( + model=self.teacher_model, + train_dataset=dataset, + eval_dataset=dataset, + loss_class=CosineSimilarityLoss, + metric="accuracy", + ) + # Teacher Train and evaluate + teacher_trainer.train() + teacher_model = teacher_trainer.model + + student_trainer = DistillationSetFitTrainer( + teacher_model=teacher_model, + train_dataset=dataset, + student_model=self.student_model, + eval_dataset=dataset, + loss_class=CosineSimilarityLoss, + metric="accuracy", + ) + + # Student Train and evaluate + student_trainer.train() + metrics = student_trainer.evaluate() + print("Student results: ", metrics) + self.assertEqual(metrics["accuracy"], 1.0) + + def test_trainer_raises_error_with_missing_label(self): + labeled_dataset = Dataset.from_dict( + {"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + # train a teacher model + teacher_trainer = SetFitTrainer( + model=self.teacher_model, + train_dataset=labeled_dataset, + eval_dataset=labeled_dataset, + metric="accuracy", + num_iterations=self.num_iterations, + ) + # Teacher Train and evaluate + teacher_trainer.train() + + unlabeled_dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) + student_trainer = DistillationSetFitTrainer( + teacher_model=self.teacher_model, + student_model=self.student_model, + train_dataset=unlabeled_dataset, + eval_dataset=labeled_dataset, + num_iterations=self.num_iterations, + ) + student_trainer.train() + metrics = student_trainer.evaluate() + print("Student results: ", metrics) + self.assertEqual(metrics["accuracy"], 1.0) + + def test_trainer_raises_error_with_missing_text(self): + dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) + trainer = DistillationSetFitTrainer( + teacher_model=self.teacher_model, + train_dataset=dataset, + student_model=self.student_model, + eval_dataset=dataset, + num_iterations=self.num_iterations, + ) + with pytest.raises(ValueError): + trainer.train() + + def test_column_mapping_with_missing_text(self): + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) + trainer = DistillationSetFitTrainer( + teacher_model=self.teacher_model, + train_dataset=dataset, + student_model=self.student_model, + eval_dataset=dataset, + num_iterations=self.num_iterations, + column_mapping={"label_new": "label"}, + ) + with pytest.raises(ValueError): + trainer._validate_column_mapping(trainer.train_dataset) + + def test_column_mapping_multilabel(self): + dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]}) + + trainer = DistillationSetFitTrainer( + teacher_model=self.teacher_model, + train_dataset=dataset, + student_model=self.student_model, + eval_dataset=dataset, + num_iterations=self.num_iterations, + column_mapping={"text_new": "text", "label_new": "label"}, + ) + + trainer._validate_column_mapping(trainer.train_dataset) + formatted_dataset = trainer._apply_column_mapping(trainer.train_dataset, trainer.column_mapping) + + assert formatted_dataset.column_names == ["text", "label"] + assert formatted_dataset[0]["text"] == "a" + assert formatted_dataset[0]["label"] == [0, 1] + assert formatted_dataset[1]["text"] == "b" diff --git a/tests/test_modeling.py b/tests/test_modeling.py index c31417d2..a5e279f6 100644 --- a/tests/test_modeling.py +++ b/tests/test_modeling.py @@ -10,42 +10,12 @@ from sklearn.multioutput import ClassifierChain, MultiOutputClassifier from setfit import SetFitHead, SetFitModel -from setfit.modeling import MODEL_HEAD_NAME, sentence_pairs_generation, sentence_pairs_generation_multilabel +from setfit.modeling import MODEL_HEAD_NAME torch_cuda_available = pytest.mark.skipif(not torch.cuda.is_available(), reason="PyTorch must be compiled with CUDA") -def test_sentence_pairs_generation(): - sentences = np.array(["sent 1", "sent 2", "sent 3"]) - labels = np.array(["label 1", "label 2", "label 3"]) - - pairs = [] - n_iterations = 2 - - for _ in range(n_iterations): - pairs = sentence_pairs_generation(sentences, labels, pairs) - - assert len(pairs) == 12 - assert pairs[0].texts == ["sent 1", "sent 1"] - assert pairs[0].label == 1.0 - - -def test_sentence_pairs_generation_multilabel(): - sentences = np.array(["sent 1", "sent 2", "sent 3"]) - labels = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]) - - pairs = [] - n_iterations = 2 - - for _ in range(n_iterations): - pairs = sentence_pairs_generation_multilabel(sentences, labels, pairs) - - assert len(pairs) == 12 - assert pairs[0].texts == ["sent 1", "sent 1"] - assert pairs[0].label == 1.0 - - def test_setfit_model_body(): model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 00000000..d8d37712 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +from sentence_transformers import InputExample + +from setfit.sampler import ContrastiveDataset + + +@pytest.mark.parametrize( + "sampling_strategy, expected_pos_pairs, expected_neg_pairs", + [("unique", 4, 2), ("undersampling", 2, 2), ("oversampling", 4, 4)], +) +def test_sentence_pairs_generation(sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int): + sentences = np.array(["sent 1", "sent 2", "sent 3"]) + labels = np.array(["label 1", "label 1", "label 2"]) + + data = [InputExample(texts=[text], label=label) for text, label in zip(sentences, labels)] + multilabel = False + + data_sampler = ContrastiveDataset(data, multilabel, sampling_strategy=sampling_strategy) + + assert data_sampler.len_pos_pairs == expected_pos_pairs + assert data_sampler.len_neg_pairs == expected_neg_pairs + + pairs = [i for i in data_sampler] + + assert len(pairs) == expected_pos_pairs + expected_neg_pairs + assert pairs[0].texts == ["sent 1", "sent 1"] + assert pairs[0].label == 1.0 + + +@pytest.mark.parametrize( + "sampling_strategy, expected_pos_pairs, expected_neg_pairs", + [("unique", 6, 4), ("undersampling", 4, 4), ("oversampling", 6, 6)], +) +def test_sentence_pairs_generation_multilabel( + sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int +): + sentences = np.array(["sent 1", "sent 2", "sent 3", "sent 4"]) + labels = np.array([[1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + + data = [InputExample(texts=[text], label=label) for text, label in zip(sentences, labels)] + multilabel = True + + data_sampler = ContrastiveDataset(data, multilabel, sampling_strategy=sampling_strategy) + assert data_sampler.len_pos_pairs == expected_pos_pairs + assert data_sampler.len_neg_pairs == expected_neg_pairs + + pairs = [i for i in data_sampler] + assert len(pairs) == expected_pos_pairs + expected_neg_pairs diff --git a/tests/test_trainer.py b/tests/test_trainer.py index af2c9a82..c5654524 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,19 +1,24 @@ -import pathlib +import os import re import tempfile +from pathlib import Path from unittest import TestCase import evaluate import pytest import torch from datasets import Dataset, load_dataset +from pytest import LogCaptureFixture from sentence_transformers import losses +from transformers import TrainerCallback from transformers.testing_utils import require_optuna from transformers.utils.hp_naming import TrialShortNamer from setfit import logging -from setfit.modeling import SetFitModel, SupConLoss -from setfit.trainer import SetFitTrainer +from setfit.losses import SupConLoss +from setfit.modeling import SetFitModel +from setfit.trainer import Trainer +from setfit.training_args import TrainingArguments from setfit.utils import BestRun @@ -21,10 +26,10 @@ logging.enable_propagation() -class SetFitTrainerTest(TestCase): +class TrainerTest(TestCase): def setUp(self): self.model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") - self.num_iterations = 1 + self.args = TrainingArguments(num_iterations=1) def test_trainer_works_with_model_init(self): def get_model(): @@ -34,11 +39,11 @@ def get_model(): dataset = Dataset.from_dict( {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} ) - trainer = SetFitTrainer( + trainer = Trainer( model_init=get_model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) trainer.train() @@ -49,11 +54,11 @@ def test_trainer_works_with_column_mapping(self): dataset = Dataset.from_dict( {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} ) - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) trainer.train() @@ -62,9 +67,7 @@ def test_trainer_works_with_column_mapping(self): def test_trainer_works_with_default_columns(self): dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) - trainer = SetFitTrainer( - model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations - ) + trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) trainer.train() metrics = trainer.evaluate() self.assertEqual(metrics["accuracy"], 1.0) @@ -74,31 +77,25 @@ def test_trainer_works_with_alternate_dataset_for_evaluate(self): alternate_dataset = Dataset.from_dict( {"text": ["x", "y", "z"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]} ) - trainer = SetFitTrainer( - model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations - ) + trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) trainer.train() metrics = trainer.evaluate(alternate_dataset) self.assertNotEqual(metrics["accuracy"], 1.0) def test_trainer_raises_error_with_missing_label(self): dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) - trainer = SetFitTrainer( - model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations - ) + trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) with pytest.raises(ValueError): trainer.train() def test_trainer_raises_error_with_missing_text(self): """If the required columns are missing from the dataset, the library should throw an error and list the columns found.""" dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) - trainer = SetFitTrainer( - model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations - ) + trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) expected_message = re.escape( "SetFit expected the dataset to have the columns ['label', 'text'], " "but only the columns ['extra_column', 'label'] were found. " - "Either make sure these columns are present, or specify which columns to use with column_mapping in SetFitTrainer." + "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer." ) with pytest.raises(ValueError, match=expected_message): trainer._validate_column_mapping(trainer.train_dataset) @@ -106,11 +103,11 @@ def test_trainer_raises_error_with_missing_text(self): def test_column_mapping_raises_error_when_mapped_columns_missing(self): """If the columns specified in the column mapping are missing from the dataset, the library should throw an error and list the columns found.""" dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) expected_message = re.escape( @@ -125,9 +122,7 @@ def test_trainer_raises_error_when_dataset_not_split(self): dataset = Dataset.from_dict({"text": ["a", "b", "c", "d"], "label": [0, 0, 1, 1]}).train_test_split( test_size=0.5 ) - trainer = SetFitTrainer( - model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations - ) + trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) expected_message = re.escape( "SetFit expected a Dataset, but it got a DatasetDict with the splits ['test', 'train']. " "Did you mean to select one of these splits from the dataset?", @@ -138,12 +133,10 @@ def test_trainer_raises_error_when_dataset_not_split(self): def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self): """Verify that a useful error is raised if we pass an unsplit dataset with only a `train` split to the trainer.""" with tempfile.TemporaryDirectory() as tmpdirname: - path = pathlib.Path(tmpdirname) / "test_dataset_dict_with_train.csv" + path = Path(tmpdirname) / "test_dataset_dict_with_train.csv" path.write_text("label,text\n1,good\n0,terrible\n") dataset = load_dataset("csv", data_files=str(path)) - trainer = SetFitTrainer( - model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations - ) + trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset) expected_message = re.escape( "SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. " "Did you mean to select the training split with dataset['train']?", @@ -154,11 +147,11 @@ def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self): def test_column_mapping_multilabel(self): dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]}) - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -186,12 +179,12 @@ def compute_metrics(y_pred, y_test): "accuracy": accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"], } - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, metric=compute_metrics, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -211,12 +204,12 @@ def test_raise_when_metric_value_is_invalid(self): {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} ) - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, metric="this-metric-does-not-exist", # invalid metric value - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -225,17 +218,8 @@ def test_raise_when_metric_value_is_invalid(self): with self.assertRaises(FileNotFoundError): trainer.evaluate() - def test_trainer_raises_error_with_wrong_warmup_proportion(self): - # warmup_proportion must not be > 1.0 - with pytest.raises(ValueError): - SetFitTrainer(warmup_proportion=1.1) - - # warmup_proportion must not be < 0.0 - with pytest.raises(ValueError): - SetFitTrainer(warmup_proportion=-0.1) - -class SetFitTrainerDifferentiableHeadTest(TestCase): +class TrainerDifferentiableHeadTest(TestCase): def setUp(self): self.dataset = Dataset.from_dict( {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} @@ -245,27 +229,22 @@ def setUp(self): use_differentiable_head=True, head_params={"out_features": 3}, ) - self.num_iterations = 1 + self.args = TrainingArguments(num_iterations=1) def test_trainer_max_length_exceeds_max_acceptable_length(self): - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=self.dataset, eval_dataset=self.dataset, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) trainer.unfreeze(keep_body_frozen=True) with self.assertLogs(level=logging.WARNING) as cm: max_length = 4096 max_acceptable_length = self.model.model_body.get_max_seq_length() - trainer.train( - num_epochs=1, - batch_size=3, - learning_rate=1e-2, - l2_weight=0.0, - max_length=max_length, - ) + args = TrainingArguments(num_iterations=1, max_length=max_length) + trainer.train(args) self.assertEqual( cm.output, [ @@ -277,38 +256,32 @@ def test_trainer_max_length_exceeds_max_acceptable_length(self): ) def test_trainer_max_length_is_smaller_than_max_acceptable_length(self): - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=self.dataset, eval_dataset=self.dataset, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) - trainer.unfreeze(keep_body_frozen=True) # An alternative way of `assertNoLogs`, which is new in Python 3.10 try: with self.assertLogs(level=logging.WARNING) as cm: max_length = 32 - trainer.train( - num_epochs=1, - batch_size=3, - learning_rate=1e-2, - l2_weight=0.0, - max_length=max_length, - ) + args = TrainingArguments(num_iterations=1, max_length=max_length) + trainer.train(args) self.assertEqual(cm.output, []) except AssertionError as e: if e.args[0] != "no logs of level WARNING or higher triggered on root": raise AssertionError(e) -class SetFitTrainerMultilabelTest(TestCase): +class TrainerMultilabelTest(TestCase): def setUp(self): self.model = SetFitModel.from_pretrained( "sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest" ) - self.num_iterations = 1 + self.args = TrainingArguments(num_iterations=1) def test_trainer_multilabel_support_callable_as_metric(self): dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]}) @@ -322,12 +295,12 @@ def compute_metrics(y_pred, y_test): "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"], } - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, metric=compute_metrics, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -343,7 +316,7 @@ def compute_metrics(y_pred, y_test): ) -class SetFitTrainerMultilabelDifferentiableTest(TestCase): +class TrainerMultilabelDifferentiableTest(TestCase): def setUp(self): self.model = SetFitModel.from_pretrained( "sentence-transformers/paraphrase-albert-small-v2", @@ -351,7 +324,7 @@ def setUp(self): use_differentiable_head=True, head_params={"out_features": 2}, ) - self.num_iterations = 1 + self.args = TrainingArguments(num_iterations=1) def test_trainer_multilabel_support_callable_as_metric(self): dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]}) @@ -365,20 +338,16 @@ def compute_metrics(y_pred, y_test): "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"], } - trainer = SetFitTrainer( + trainer = Trainer( model=self.model, + args=self.args, train_dataset=dataset, eval_dataset=dataset, metric=compute_metrics, - num_iterations=self.num_iterations, column_mapping={"text_new": "text", "label_new": "label"}, ) - trainer.freeze() trainer.train() - - trainer.unfreeze(keep_body_frozen=False) - trainer.train(5) metrics = trainer.evaluate() self.assertEqual( @@ -396,7 +365,7 @@ def setUp(self): self.dataset = Dataset.from_dict( {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]} ) - self.num_iterations = 1 + self.args = TrainingArguments(num_iterations=1) def test_hyperparameter_search(self): class MyTrialShortNamer(TrialShortNamer): @@ -425,10 +394,10 @@ def model_init(params): def hp_name(trial): return MyTrialShortNamer.shortname(trial.params) - trainer = SetFitTrainer( + trainer = Trainer( + args=self.args, train_dataset=self.dataset, eval_dataset=self.dataset, - num_iterations=self.num_iterations, model_init=model_init, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -451,27 +420,26 @@ def hp_name(trial): def test_trainer_works_with_non_default_loss_class(loss_class): dataset = Dataset.from_dict({"text": ["a 1", "b 1", "c 1", "a 2", "b 2", "c 2"], "label": [0, 1, 2, 0, 1, 2]}) model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") - trainer = SetFitTrainer( + args = TrainingArguments(num_iterations=1, loss=loss_class) + trainer = Trainer( model=model, + args=args, train_dataset=dataset, eval_dataset=dataset, - num_iterations=1, - loss_class=loss_class, ) trainer.train() # no asserts here because this is a regression test - we only test if an exception is raised -def test_trainer_evaluate_with_strings(): +def test_trainer_evaluate_with_strings(model: SetFitModel): dataset = Dataset.from_dict( {"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]} ) - model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") - trainer = SetFitTrainer( + trainer = Trainer( model=model, + args=TrainingArguments(num_iterations=1), train_dataset=dataset, eval_dataset=dataset, - num_iterations=1, ) trainer.train() # This used to fail due to "TypeError: can't convert np.ndarray of type numpy.str_. @@ -485,13 +453,13 @@ def test_trainer_evaluate_multilabel_f1(): "sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest" ) - trainer = SetFitTrainer( + trainer = Trainer( model=model, + args=TrainingArguments(num_iterations=5), train_dataset=dataset, eval_dataset=dataset, metric="f1", metric_kwargs={"average": "micro"}, - num_iterations=5, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -502,9 +470,7 @@ def test_trainer_evaluate_multilabel_f1(): def test_trainer_evaluate_on_cpu() -> None: # This test used to fail if CUDA was available - dataset = Dataset.from_dict( - {"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]} - ) + dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]}) model = SetFitModel.from_pretrained( "sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True ) @@ -513,12 +479,111 @@ def compute_metric(y_pred, y_test) -> None: assert y_pred.device == torch.device("cpu") return 1.0 - trainer = SetFitTrainer( + args = TrainingArguments(num_iterations=5) + trainer = Trainer( model=model, + args=args, train_dataset=dataset, eval_dataset=dataset, metric=compute_metric, - num_iterations=5, ) trainer.train() trainer.evaluate() + + +def test_no_model_no_model_init(): + with pytest.raises(RuntimeError, match="`Trainer` requires either a `model` or `model_init` argument."): + Trainer() + + +def test_model_and_model_init(model: SetFitModel): + def model_init() -> SetFitModel: + return model + + with pytest.raises(RuntimeError, match="`Trainer` requires either a `model` or `model_init` argument."): + Trainer(model=model, model_init=model_init) + + +def test_trainer_callbacks(model: SetFitModel): + trainer = Trainer(model=model) + assert len(trainer.callback_handler.callbacks) >= 2 + callback_names = {callback.__class__.__name__ for callback in trainer.callback_handler.callbacks} + assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names + + class TestCallback(TrainerCallback): + pass + + callback = TestCallback() + trainer.add_callback(callback) + assert len(trainer.callback_handler.callbacks) == len(callback_names) + 1 + assert trainer.callback_handler.callbacks[-1] == callback + + assert trainer.pop_callback(callback) == callback + trainer.add_callback(callback) + assert trainer.callback_handler.callbacks[-1] == callback + trainer.remove_callback(callback) + assert callback not in trainer.callback_handler.callbacks + + +def test_trainer_warn_freeze(model: SetFitModel): + trainer = Trainer(model) + with pytest.warns( + DeprecationWarning, + match="Trainer.freeze` is deprecated and will be removed in v2.0.0 of SetFit. " + "Please use `SetFitModel.freeze` directly instead.", + ): + trainer.freeze() + + +def test_train_with_kwargs(model: SetFitModel) -> None: + train_dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]}) + trainer = Trainer(model, train_dataset=train_dataset) + with pytest.warns(DeprecationWarning, match="`Trainer.train` does not accept keyword arguments anymore."): + trainer.train(num_epochs=5) + + +def test_train_no_dataset(model: SetFitModel) -> None: + trainer = Trainer(model) + with pytest.raises(ValueError, match="Training requires a `train_dataset` given to the `Trainer` initialization."): + trainer.train() + + +def test_train_amp_save(model: SetFitModel, tmp_path: Path) -> None: + args = TrainingArguments(output_dir=tmp_path, use_amp=True, save_steps=5, num_epochs=5) + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]}) + trainer = Trainer(model, args=args, train_dataset=dataset, eval_dataset=dataset) + trainer.train() + assert trainer.evaluate() == {"accuracy": 1.0} + assert os.listdir(tmp_path) == ["step_5"] + + +def test_train_load_best(model: SetFitModel, tmp_path: Path, caplog: LogCaptureFixture) -> None: + args = TrainingArguments( + output_dir=tmp_path, + save_steps=5, + eval_steps=5, + evaluation_strategy="steps", + load_best_model_at_end=True, + num_epochs=5, + ) + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]}) + trainer = Trainer(model, args=args, train_dataset=dataset, eval_dataset=dataset) + with caplog.at_level(logging.INFO): + trainer.train() + + assert any("Load pretrained SentenceTransformer" in text for _, _, text in caplog.record_tuples) + + +def test_evaluate_with_strings(model: SetFitModel) -> None: + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": ["positive", "positive", "negative"]}) + trainer = Trainer(model, train_dataset=dataset, eval_dataset=dataset) + trainer.train() + metrics = trainer.evaluate() + assert "accuracy" in metrics + + +def test_trainer_wrong_args(model: SetFitModel, tmp_path: Path) -> None: + dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]}) + expected = "`args` must be a `TrainingArguments` instance imported from `setfit`." + with pytest.raises(ValueError, match=expected): + Trainer(model, dataset) diff --git a/tests/test_trainer_distillation.py b/tests/test_trainer_distillation.py index 82dd0b05..7ddb4ba3 100644 --- a/tests/test_trainer_distillation.py +++ b/tests/test_trainer_distillation.py @@ -2,39 +2,36 @@ import pytest from datasets import Dataset -from sentence_transformers.losses import CosineSimilarityLoss -from setfit import DistillationSetFitTrainer, SetFitTrainer +from setfit import DistillationTrainer, Trainer from setfit.modeling import SetFitModel +from setfit.training_args import TrainingArguments -class DistillationSetFitTrainerTest(TestCase): +class DistillationTrainerTest(TestCase): def setUp(self): self.teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") self.student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2") - self.num_iterations = 1 + self.args = TrainingArguments(num_iterations=1) def test_trainer_works_with_default_columns(self): dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) # train a teacher model - teacher_trainer = SetFitTrainer( + teacher_trainer = Trainer( model=self.teacher_model, train_dataset=dataset, eval_dataset=dataset, - loss_class=CosineSimilarityLoss, metric="accuracy", ) # Teacher Train and evaluate teacher_trainer.train() - metrics = teacher_trainer.evaluate() teacher_model = teacher_trainer.model - student_trainer = DistillationSetFitTrainer( + student_trainer = DistillationTrainer( teacher_model=teacher_model, train_dataset=dataset, student_model=self.student_model, eval_dataset=dataset, - loss_class=CosineSimilarityLoss, metric="accuracy", ) @@ -45,37 +42,53 @@ def test_trainer_works_with_default_columns(self): self.assertEqual(metrics["accuracy"], 1.0) def test_trainer_raises_error_with_missing_label(self): - dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) - trainer = DistillationSetFitTrainer( + labeled_dataset = Dataset.from_dict( + {"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]} + ) + # train a teacher model + teacher_trainer = Trainer( + model=self.teacher_model, + train_dataset=labeled_dataset, + eval_dataset=labeled_dataset, + metric="accuracy", + args=self.args, + ) + # Teacher Train and evaluate + teacher_trainer.train() + + unlabeled_dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) + student_trainer = DistillationTrainer( teacher_model=self.teacher_model, - train_dataset=dataset, student_model=self.student_model, - eval_dataset=dataset, - num_iterations=self.num_iterations, + train_dataset=unlabeled_dataset, + eval_dataset=labeled_dataset, + args=self.args, ) - with pytest.raises(ValueError): - trainer.train() + student_trainer.train() + metrics = student_trainer.evaluate() + print("Student results: ", metrics) + self.assertEqual(metrics["accuracy"], 1.0) def test_trainer_raises_error_with_missing_text(self): dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) - trainer = DistillationSetFitTrainer( + trainer = DistillationTrainer( teacher_model=self.teacher_model, train_dataset=dataset, student_model=self.student_model, eval_dataset=dataset, - num_iterations=self.num_iterations, + args=self.args, ) with pytest.raises(ValueError): trainer.train() def test_column_mapping_with_missing_text(self): dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) - trainer = DistillationSetFitTrainer( + trainer = DistillationTrainer( teacher_model=self.teacher_model, train_dataset=dataset, student_model=self.student_model, eval_dataset=dataset, - num_iterations=self.num_iterations, + args=self.args, column_mapping={"label_new": "label"}, ) with pytest.raises(ValueError): @@ -84,12 +97,12 @@ def test_column_mapping_with_missing_text(self): def test_column_mapping_multilabel(self): dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]}) - trainer = DistillationSetFitTrainer( + trainer = DistillationTrainer( teacher_model=self.teacher_model, train_dataset=dataset, student_model=self.student_model, eval_dataset=dataset, - num_iterations=self.num_iterations, + args=self.args, column_mapping={"text_new": "text", "label_new": "label"}, ) @@ -102,22 +115,8 @@ def test_column_mapping_multilabel(self): assert formatted_dataset[1]["text"] == "b" -def train_diff(trainer: SetFitTrainer): - # Teacher Train and evaluate - trainer.freeze() # Freeze the head - trainer.train() # Train only the body - - # Unfreeze the head and unfreeze the body -> end-to-end training - trainer.unfreeze(keep_body_frozen=False) - - trainer.train(num_epochs=5) - - -def train_lr(trainer: SetFitTrainer): - trainer.train() - - -@pytest.mark.parametrize(("teacher_diff", "student_diff"), [[True, False], [True, False]]) +@pytest.mark.parametrize("teacher_diff", [True, False]) +@pytest.mark.parametrize("student_diff", [True, False]) def test_differentiable_models(teacher_diff: bool, student_diff: bool) -> None: if teacher_diff: teacher_model = SetFitModel.from_pretrained( @@ -125,34 +124,32 @@ def test_differentiable_models(teacher_diff: bool, student_diff: bool) -> None: use_differentiable_head=True, head_params={"out_features": 3}, ) - teacher_train_func = train_diff else: teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") - teacher_train_func = train_lr if student_diff: student_model = SetFitModel.from_pretrained( "sentence-transformers/paraphrase-MiniLM-L3-v2", use_differentiable_head=True, head_params={"out_features": 3}, ) - student_train_func = train_diff else: student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2") - student_train_func = train_lr dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) # train a teacher model - teacher_trainer = SetFitTrainer( + teacher_trainer = Trainer( model=teacher_model, train_dataset=dataset, eval_dataset=dataset, metric="accuracy", ) - teacher_train_func(teacher_trainer) + teacher_trainer.train() metrics = teacher_trainer.evaluate() + print("Teacher results: ", metrics) + assert metrics["accuracy"] == 1.0 teacher_model = teacher_trainer.model - student_trainer = DistillationSetFitTrainer( + student_trainer = DistillationTrainer( teacher_model=teacher_model, train_dataset=dataset, student_model=student_model, @@ -161,7 +158,7 @@ def test_differentiable_models(teacher_diff: bool, student_diff: bool) -> None: ) # Student Train and evaluate - student_train_func(student_trainer) + student_trainer.train() metrics = student_trainer.evaluate() print("Student results: ", metrics) assert metrics["accuracy"] == 1.0 diff --git a/tests/test_training_args.py b/tests/test_training_args.py new file mode 100644 index 00000000..a573e10e --- /dev/null +++ b/tests/test_training_args.py @@ -0,0 +1,113 @@ +from unittest import TestCase + +import pytest + +from setfit.training_args import TrainingArguments + + +class TestTrainingArguments(TestCase): + def test_raises_error_with_wrong_warmup_proportion(self): + # warmup_proportion must not be > 1.0 + with pytest.raises(ValueError): + TrainingArguments(warmup_proportion=1.1) + + # warmup_proportion must not be < 0.0 + with pytest.raises(ValueError): + TrainingArguments(warmup_proportion=-0.1) + + def test_batch_sizes(self): + batch_size_A = 12 + batch_size_B = 4 + batch_size_C = 6 + + args = TrainingArguments(batch_size=batch_size_A) + self.assertEqual(args.batch_size, (batch_size_A, batch_size_A)) + self.assertEqual(args.embedding_batch_size, batch_size_A) + self.assertEqual(args.classifier_batch_size, batch_size_A) + + args = TrainingArguments(batch_size=(batch_size_A, batch_size_B)) + self.assertEqual(args.batch_size, (batch_size_A, batch_size_B)) + self.assertEqual(args.embedding_batch_size, batch_size_A) + self.assertEqual(args.classifier_batch_size, batch_size_B) + + args = TrainingArguments(batch_size=(batch_size_A, batch_size_B), embedding_batch_size=batch_size_C) + self.assertEqual(args.batch_size, (batch_size_A, batch_size_B)) + self.assertEqual(args.embedding_batch_size, batch_size_C) + self.assertEqual(args.classifier_batch_size, batch_size_B) + + args = TrainingArguments(batch_size=batch_size_A, embedding_batch_size=batch_size_C) + self.assertEqual(args.batch_size, (batch_size_A, batch_size_A)) + self.assertEqual(args.embedding_batch_size, batch_size_C) + self.assertEqual(args.classifier_batch_size, batch_size_A) + + def test_num_epochs(self): + num_epochs_A = 12 + num_epochs_B = 4 + num_epochs_C = 6 + + args = TrainingArguments(num_epochs=num_epochs_A) + self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_A)) + self.assertEqual(args.embedding_num_epochs, num_epochs_A) + self.assertEqual(args.classifier_num_epochs, num_epochs_A) + + args = TrainingArguments(num_epochs=(num_epochs_A, num_epochs_B)) + self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_B)) + self.assertEqual(args.embedding_num_epochs, num_epochs_A) + self.assertEqual(args.classifier_num_epochs, num_epochs_B) + + args = TrainingArguments(num_epochs=(num_epochs_A, num_epochs_B), embedding_num_epochs=num_epochs_C) + self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_B)) + self.assertEqual(args.embedding_num_epochs, num_epochs_C) + self.assertEqual(args.classifier_num_epochs, num_epochs_B) + + args = TrainingArguments(num_epochs=num_epochs_A, embedding_num_epochs=num_epochs_C) + self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_A)) + self.assertEqual(args.embedding_num_epochs, num_epochs_C) + self.assertEqual(args.classifier_num_epochs, num_epochs_A) + + def test_learning_rates(self): + learning_rate_A = 1e-2 + learning_rate_B = 1e-3 + learning_rate_C = 1e-4 + + base = TrainingArguments() + + args = TrainingArguments(body_learning_rate=learning_rate_A) + self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_A)) + self.assertEqual(args.body_embedding_learning_rate, learning_rate_A) + self.assertEqual(args.body_classifier_learning_rate, learning_rate_A) + self.assertEqual(args.head_learning_rate, base.head_learning_rate) + + args = TrainingArguments(body_learning_rate=(learning_rate_A, learning_rate_B)) + self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_B)) + self.assertEqual(args.body_embedding_learning_rate, learning_rate_A) + self.assertEqual(args.body_classifier_learning_rate, learning_rate_B) + self.assertEqual(args.head_learning_rate, base.head_learning_rate) + + args = TrainingArguments( + body_learning_rate=(learning_rate_A, learning_rate_B), head_learning_rate=learning_rate_C + ) + self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_B)) + self.assertEqual(args.body_embedding_learning_rate, learning_rate_A) + self.assertEqual(args.body_classifier_learning_rate, learning_rate_B) + self.assertEqual(args.head_learning_rate, learning_rate_C) + + args = TrainingArguments( + body_learning_rate=learning_rate_A, + body_embedding_learning_rate=learning_rate_B, + head_learning_rate=learning_rate_C, + ) + # Perhaps not ideal, but body_learning_rate is never used directly: + self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_A)) + self.assertEqual(args.body_embedding_learning_rate, learning_rate_B) + self.assertEqual(args.body_classifier_learning_rate, learning_rate_A) + self.assertEqual(args.head_learning_rate, learning_rate_C) + + args = TrainingArguments( + body_classifier_learning_rate=learning_rate_A, + body_embedding_learning_rate=learning_rate_B, + head_learning_rate=learning_rate_C, + ) + self.assertEqual(args.body_embedding_learning_rate, learning_rate_B) + self.assertEqual(args.body_classifier_learning_rate, learning_rate_A) + self.assertEqual(args.head_learning_rate, learning_rate_C)