diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml
index 9ced4d45..b3cdcd6b 100644
--- a/.github/workflows/quality.yml
+++ b/.github/workflows/quality.yml
@@ -5,9 +5,12 @@ on:
branches:
- main
- v*-release
+ - v*-pre
pull_request:
branches:
- main
+ - v*-pre
+ workflow_dispatch:
jobs:
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index afdcf1ec..45dccb7f 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -5,9 +5,12 @@ on:
branches:
- main
- v*-release
+ - v*-pre
pull_request:
branches:
- main
+ - v*-pre
+ workflow_dispatch:
jobs:
@@ -40,6 +43,8 @@ jobs:
run: |
python -m pip install --no-cache-dir --upgrade pip
python -m pip install --no-cache-dir ${{ matrix.requirements }}
+ python -m spacy download en_core_web_lg
+ python -m spacy download en_core_web_sm
if: steps.restore-cache.outputs.cache-hit != 'true'
- name: Install the checked-out setfit
diff --git a/.gitignore b/.gitignore
index a13745c3..6e89ff50 100644
--- a/.gitignore
+++ b/.gitignore
@@ -149,3 +149,7 @@ scripts/tfew/run_tmux.sh
# macOS
.DS_Store
.vscode/settings.json
+
+# Common SetFit Trainer logging folders
+wandb
+runs/
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..69617566
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include src/setfit/span/model_card_template.md
\ No newline at end of file
diff --git a/README.md b/README.md
index f79a1001..4a8156e6 100644
--- a/README.md
+++ b/README.md
@@ -39,16 +39,14 @@ The examples below provide a quick overview on the various features supported in
`setfit` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes:
* `SetFitModel`: a wrapper that combines a pretrained body from `sentence_transformers` and a classification head from either [`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) or [`SetFitHead`](https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py) (a differentiable head built upon `PyTorch` with similar APIs to `sentence_transformers`).
-* `SetFitTrainer`: a helper class that wraps the fine-tuning process of SetFit.
+* `Trainer`: a helper class that wraps the fine-tuning process of SetFit.
Here is an end-to-end example using a classification head from `scikit-learn`:
```python
from datasets import load_dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-
-from setfit import SetFitModel, SetFitTrainer, sample_dataset
+from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
# Load a dataset from the Hugging Face Hub
@@ -61,17 +59,19 @@ eval_dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
-# Create trainer
-trainer = SetFitTrainer(
+args = TrainingArguments(
+ batch_size=16,
+ num_iterations=20, # The number of text pairs to generate for contrastive learning
+ num_epochs=1 # The number of epochs to use for contrastive learning
+)
+
+trainer = Trainer(
model=model,
+ args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
metric="accuracy",
- batch_size=16,
- num_iterations=20, # The number of text pairs to generate for contrastive learning
- num_epochs=1, # The number of epochs to use for contrastive learning
- column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
+ column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
@@ -81,7 +81,7 @@ metrics = trainer.evaluate()
# Push model to the Hub
trainer.push_to_hub("my-awesome-setfit-model")
-# Download from Hub and run inference
+# Download from Hub
model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model")
# Run inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
@@ -92,9 +92,7 @@ Here is an end-to-end example using `SetFitHead`:
```python
from datasets import load_dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-
-from setfit import SetFitModel, SetFitTrainer, sample_dataset
+from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
# Load a dataset from the Hugging Face Hub
@@ -103,6 +101,7 @@ dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]
+num_classes = 2
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
@@ -111,36 +110,26 @@ model = SetFitModel.from_pretrained(
head_params={"out_features": num_classes},
)
-# Create trainer
-trainer = SetFitTrainer(
+args = TrainingArguments(
+ body_learning_rate=2e-5,
+ head_learning_rate=1e-2,
+ batch_size=16,
+ num_iterations=20, # The number of text pairs to generate for contrastive learning
+ num_epochs=(1, 25), # For finetuning the embeddings and training the classifier, respectively
+ l2_weight=0.0,
+ end_to_end=False, # Don't train the classifier end-to-end, i.e. only train the head
+)
+
+trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
metric="accuracy",
- batch_size=16,
- num_iterations=20, # The number of text pairs to generate for contrastive learning
- num_epochs=1, # The number of epochs to use for contrastive learning
- column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
+ column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
-trainer.freeze() # Freeze the head
-trainer.train() # Train only the body
-
-# Unfreeze the head and freeze the body -> head-only training
-trainer.unfreeze(keep_body_frozen=True)
-# or
-# Unfreeze the head and unfreeze the body -> end-to-end training
-trainer.unfreeze(keep_body_frozen=False)
-
-trainer.train(
- num_epochs=25, # The number of epochs to train the head or the whole model (body and head)
- batch_size=16,
- body_learning_rate=1e-5, # The body's learning rate
- learning_rate=1e-2, # The head's learning rate
- l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01.
-)
+trainer.train()
metrics = trainer.evaluate()
# Push model to the Hub
@@ -175,7 +164,7 @@ This will initialise a multilabel classification head from `sklearn` - the follo
* `multi-output`: uses a `MultiOutputClassifier` head.
* `classifier-chain`: uses a `ClassifierChain` head.
-From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual.
+From here, you can instantiate a `Trainer` using the same example above, and train it as usual.
#### Example using the differentiable `SetFitHead`:
@@ -196,7 +185,6 @@ model = SetFitModel.from_pretrained(
SetFit can also be applied to scenarios where no labels are available. To do so, create a synthetic dataset of training examples:
```python
-from datasets import Dataset
from setfit import get_templated_dataset
candidate_labels = ["negative", "positive"]
@@ -206,22 +194,22 @@ train_dataset = get_templated_dataset(candidate_labels=candidate_labels, sample_
This will create examples of the form `"This sentence is {}"`, where the `{}` is filled in with one of the candidate labels. From here you can train a SetFit model as usual:
```python
-from setfit import SetFitModel, SetFitTrainer
+from setfit import SetFitModel, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
-trainer = SetFitTrainer(
+trainer = Trainer(
model=model,
train_dataset=train_dataset
)
trainer.train()
```
-We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with Bart), while being 5x faster to generate predictions with.
+We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with BART), while being 5x faster to generate predictions with.
### Running hyperparameter search
-`SetFitTrainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:
+`Trainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:
```bash
python -m pip install setfit[optuna]
@@ -267,23 +255,23 @@ def hp_space(trial): # Training parameters
**Note:** In practice, we found `num_iterations` to be the most important hyperparameter for the contrastive learning process.
-The next step is to instantiate a `SetFitTrainer` and call `hyperparameter_search()`:
+The next step is to instantiate a `Trainer` and call `hyperparameter_search()`:
```python
from datasets import Dataset
-from setfit import SetFitTrainer
+from setfit import Trainer
dataset = Dataset.from_dict(
- {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
- )
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+)
-trainer = SetFitTrainer(
+trainer = Trainer(
train_dataset=dataset,
eval_dataset=dataset,
model_init=model_init,
column_mapping={"text_new": "text", "label_new": "label"},
)
-best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=20)
+best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5)
```
Finally, you can apply the hyperparameters you found to the trainer, and lock in the optimal model, before training for
@@ -300,9 +288,8 @@ If you have access to unlabeled data, you can use knowledge distillation to comp
```python
from datasets import load_dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-
-from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer, sample_dataset
+from setfit import SetFitModel, Trainer, DistillationTrainer, sample_dataset
+from setfit.training_args import TrainingArguments
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
@@ -320,34 +307,37 @@ teacher_model = SetFitModel.from_pretrained(
)
# Create trainer for teacher model
-teacher_trainer = SetFitTrainer(
+teacher_trainer = Trainer(
model=teacher_model,
train_dataset=train_dataset_teacher,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
)
# Train teacher model
teacher_trainer.train()
+teacher_metrics = teacher_trainer.evaluate()
# Load small student model
student_model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2")
+args = TrainingArguments(
+ batch_size=16,
+ num_iterations=20,
+ num_epochs=1
+)
+
# Create trainer for knowledge distillation
-student_trainer = DistillationSetFitTrainer(
+student_trainer = DistillationTrainer(
teacher_model=teacher_model,
- train_dataset=train_dataset_student,
student_model=student_model,
+ args=args,
+ train_dataset=train_dataset_student,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
- metric="accuracy",
- batch_size=16,
- num_iterations=20,
- num_epochs=1,
)
# Train student with knowledge distillation
student_trainer.train()
+student_metrics = student_trainer.evaluate()
```
@@ -403,7 +393,8 @@ make style && make quality
## Citation
-```@misc{https://doi.org/10.48550/arxiv.2209.11055,
+```
+@misc{https://doi.org/10.48550/arxiv.2209.11055,
doi = {10.48550/ARXIV.2209.11055},
url = {https://arxiv.org/abs/2209.11055},
author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
@@ -411,5 +402,6 @@ make style && make quality
title = {Efficient Few-Shot Learning Without Prompts},
publisher = {arXiv},
year = {2022},
- copyright = {Creative Commons Attribution 4.0 International}}
+ copyright = {Creative Commons Attribution 4.0 International}
+}
```
diff --git a/docs/source/en/api/main.mdx b/docs/source/en/api/main.mdx
index ac2b77e4..a65b3db4 100644
--- a/docs/source/en/api/main.mdx
+++ b/docs/source/en/api/main.mdx
@@ -6,3 +6,7 @@
# SetFitHead
[[autodoc]] SetFitHead
+
+# AbsaModel
+
+[[autodoc]] AbsaModel
\ No newline at end of file
diff --git a/docs/source/en/api/trainer.mdx b/docs/source/en/api/trainer.mdx
index a51df833..3e3d39d1 100644
--- a/docs/source/en/api/trainer.mdx
+++ b/docs/source/en/api/trainer.mdx
@@ -1,8 +1,12 @@
-# SetFitTrainer
+# Trainer
-[[autodoc]] SetFitTrainer
+[[autodoc]] Trainer
-# DistillationSetFitTrainer
+# DistillationTrainer
-[[autodoc]] DistillationSetFitTrainer
\ No newline at end of file
+[[autodoc]] DistillationTrainer
+
+# AbsaTrainer
+
+[[autodoc]] AbsaTrainer
\ No newline at end of file
diff --git a/docs/source/en/quickstart.mdx b/docs/source/en/quickstart.mdx
index cc10ba5b..9e46933b 100644
--- a/docs/source/en/quickstart.mdx
+++ b/docs/source/en/quickstart.mdx
@@ -11,16 +11,14 @@ The examples below provide a quick overview on the various features supported in
`setfit` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes:
* `SetFitModel`: a wrapper that combines a pretrained body from `sentence_transformers` and a classification head from either [`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) or [`SetFitHead`](https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py) (a differentiable head built upon `PyTorch` with similar APIs to `sentence_transformers`).
-* `SetFitTrainer`: a helper class that wraps the fine-tuning process of SetFit.
+* `Trainer`: a helper class that wraps the fine-tuning process of SetFit.
Here is an end-to-end example using a classification head from `scikit-learn`:
```python
from datasets import load_dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-
-from setfit import SetFitModel, SetFitTrainer, sample_dataset
+from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
# Load a dataset from the Hugging Face Hub
@@ -33,17 +31,19 @@ eval_dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
-# Create trainer
-trainer = SetFitTrainer(
+args = TrainingArguments(
+ batch_size=16,
+ num_iterations=20, # The number of text pairs to generate for contrastive learning
+ num_epochs=1 # The number of epochs to use for contrastive learning
+)
+
+trainer = Trainer(
model=model,
+ args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
metric="accuracy",
- batch_size=16,
- num_iterations=20, # The number of text pairs to generate for contrastive learning
- num_epochs=1, # The number of epochs to use for contrastive learning
- column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
+ column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
@@ -53,7 +53,7 @@ metrics = trainer.evaluate()
# Push model to the Hub
trainer.push_to_hub("my-awesome-setfit-model")
-# Download from Hub and run inference
+# Download from Hub
model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model")
# Run inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
@@ -64,9 +64,7 @@ Here is an end-to-end example using `SetFitHead`:
```python
from datasets import load_dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-
-from setfit import SetFitModel, SetFitTrainer, sample_dataset
+from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
# Load a dataset from the Hugging Face Hub
@@ -75,6 +73,7 @@ dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]
+num_classes = 2
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
@@ -83,36 +82,26 @@ model = SetFitModel.from_pretrained(
head_params={"out_features": num_classes},
)
-# Create trainer
-trainer = SetFitTrainer(
+args = TrainingArguments(
+ body_learning_rate=2e-5,
+ head_learning_rate=1e-2,
+ batch_size=16,
+ num_iterations=20, # The number of text pairs to generate for contrastive learning
+ num_epochs=(1, 25), # For finetuning the embeddings and training the classifier, respectively
+ l2_weight=0.0,
+ end_to_end=False, # Don't train the classifier end-to-end, i.e. only train the head
+)
+
+trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
metric="accuracy",
- batch_size=16,
- num_iterations=20, # The number of text pairs to generate for contrastive learning
- num_epochs=1, # The number of epochs to use for contrastive learning
- column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
+ column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
-trainer.freeze() # Freeze the head
-trainer.train() # Train only the body
-
-# Unfreeze the head and freeze the body -> head-only training
-trainer.unfreeze(keep_body_frozen=True)
-# or
-# Unfreeze the head and unfreeze the body -> end-to-end training
-trainer.unfreeze(keep_body_frozen=False)
-
-trainer.train(
- num_epochs=25, # The number of epochs to train the head or the whole model (body and head)
- batch_size=16,
- body_learning_rate=1e-5, # The body's learning rate
- learning_rate=1e-2, # The head's learning rate
- l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01.
-)
+trainer.train()
metrics = trainer.evaluate()
# Push model to the Hub
@@ -147,7 +136,7 @@ This will initialise a multilabel classification head from `sklearn` - the follo
* `multi-output`: uses a `MultiOutputClassifier` head.
* `classifier-chain`: uses a `ClassifierChain` head.
-From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual.
+From here, you can instantiate a `Trainer` using the same example above, and train it as usual.
#### Example using the differentiable `SetFitHead`:
@@ -168,7 +157,6 @@ model = SetFitModel.from_pretrained(
SetFit can also be applied to scenarios where no labels are available. To do so, create a synthetic dataset of training examples:
```python
-from datasets import Dataset
from setfit import get_templated_dataset
candidate_labels = ["negative", "positive"]
@@ -178,22 +166,22 @@ train_dataset = get_templated_dataset(candidate_labels=candidate_labels, sample_
This will create examples of the form `"This sentence is {}"`, where the `{}` is filled in with one of the candidate labels. From here you can train a SetFit model as usual:
```python
-from setfit import SetFitModel, SetFitTrainer
+from setfit import SetFitModel, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
-trainer = SetFitTrainer(
+trainer = Trainer(
model=model,
train_dataset=train_dataset
)
trainer.train()
```
-We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with Bart), while being 5x faster to generate predictions with.
+We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with BART), while being 5x faster to generate predictions with.
### Running hyperparameter search
-`SetFitTrainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:
+`Trainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:
```bash
python -m pip install setfit[optuna]
@@ -239,23 +227,23 @@ def hp_space(trial): # Training parameters
**Note:** In practice, we found `num_iterations` to be the most important hyperparameter for the contrastive learning process.
-The next step is to instantiate a `SetFitTrainer` and call `hyperparameter_search()`:
+The next step is to instantiate a `Trainer` and call `hyperparameter_search()`:
```python
from datasets import Dataset
-from setfit import SetFitTrainer
+from setfit import Trainer
dataset = Dataset.from_dict(
- {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
- )
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+)
-trainer = SetFitTrainer(
+trainer = Trainer(
train_dataset=dataset,
eval_dataset=dataset,
model_init=model_init,
column_mapping={"text_new": "text", "label_new": "label"},
)
-best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=20)
+best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5)
```
Finally, you can apply the hyperparameters you found to the trainer, and lock in the optimal model, before training for
@@ -272,9 +260,8 @@ If you have access to unlabeled data, you can use knowledge distillation to comp
```python
from datasets import load_dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-
-from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer, sample_dataset
+from setfit import SetFitModel, Trainer, DistillationTrainer, sample_dataset
+from setfit.training_args import TrainingArguments
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
@@ -292,32 +279,35 @@ teacher_model = SetFitModel.from_pretrained(
)
# Create trainer for teacher model
-teacher_trainer = SetFitTrainer(
+teacher_trainer = Trainer(
model=teacher_model,
train_dataset=train_dataset_teacher,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
)
# Train teacher model
teacher_trainer.train()
+teacher_metrics = teacher_trainer.evaluate()
# Load small student model
student_model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2")
+args = TrainingArguments(
+ batch_size=16,
+ num_iterations=20,
+ num_epochs=1
+)
+
# Create trainer for knowledge distillation
-student_trainer = DistillationSetFitTrainer(
+student_trainer = DistillationTrainer(
teacher_model=teacher_model,
- train_dataset=train_dataset_student,
student_model=student_model,
+ args=args,
+ train_dataset=train_dataset_student,
eval_dataset=eval_dataset,
- loss_class=CosineSimilarityLoss,
- metric="accuracy",
- batch_size=16,
- num_iterations=20,
- num_epochs=1,
)
# Train student with knowledge distillation
student_trainer.train()
+student_metrics = student_trainer.evaluate()
```
\ No newline at end of file
diff --git a/scripts/setfit/run_fewshot.py b/scripts/setfit/run_fewshot.py
index 1248fddc..08f7023e 100644
--- a/scripts/setfit/run_fewshot.py
+++ b/scripts/setfit/run_fewshot.py
@@ -59,6 +59,7 @@ def parse_args():
parser.add_argument("--override_results", default=False, action="store_true")
parser.add_argument("--keep_body_frozen", default=False, action="store_true")
parser.add_argument("--add_data_augmentation", default=False)
+ parser.add_argument("--evaluation_strategy", default=False)
args = parser.parse_args()
@@ -148,6 +149,8 @@ def main():
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
)
+ if not args.evaluation_strategy:
+ trainer.args.evaluation_strategy = "no"
if args.classifier == "pytorch":
trainer.freeze()
trainer.train()
diff --git a/setup.py b/setup.py
index dcd5a8ea..bdc32252 100644
--- a/setup.py
+++ b/setup.py
@@ -10,11 +10,18 @@
MAINTAINER_EMAIL = "lewis@huggingface.co"
INTEGRATIONS_REQUIRE = ["optuna"]
-REQUIRED_PKGS = ["datasets>=2.3.0", "sentence-transformers>=2.2.1", "evaluate>=0.3.0"]
+REQUIRED_PKGS = [
+ "datasets>=2.3.0",
+ "sentence-transformers>=2.2.1",
+ "evaluate>=0.3.0",
+ "huggingface_hub>=0.13.0",
+ "scikit-learn",
+]
+ABSA_REQUIRE = ["spacy"]
QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]
ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"]
OPENVINO_REQUIRE = ["hummingbird-ml<0.4.9", "openvino==2022.3.0"]
-TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE
+TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE + ABSA_REQUIRE
DOCS_REQUIRE = ["hf-doc-builder>=0.3.0"]
EXTRAS_REQUIRE = {
"optuna": INTEGRATIONS_REQUIRE,
@@ -23,6 +30,7 @@
"onnx": ONNX_REQUIRE,
"openvino": ONNX_REQUIRE + OPENVINO_REQUIRE,
"docs": DOCS_REQUIRE,
+ "absa": ABSA_REQUIRE,
}
@@ -46,11 +54,12 @@ def combine_requirements(base_keys):
long_description_content_type="text/markdown",
maintainer=MAINTAINER,
maintainer_email=MAINTAINER_EMAIL,
- url="https://github.com/SetFit/setfit",
- download_url="https://github.com/SetFit/setfit/tags",
+ url="https://github.com/huggingface/setfit",
+ download_url="https://github.com/huggingface/setfit/tags",
license="Apache 2.0",
package_dir={"": "src"},
packages=find_packages("src"),
+ include_package_data=True,
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS_REQUIRE,
classifiers=[
diff --git a/src/setfit/__init__.py b/src/setfit/__init__.py
index 287d89c5..f131eee0 100644
--- a/src/setfit/__init__.py
+++ b/src/setfit/__init__.py
@@ -1,6 +1,15 @@
__version__ = "0.8.0.dev0"
-from .data import add_templated_examples, get_templated_dataset, sample_dataset
+import warnings
+
+from .data import get_templated_dataset, sample_dataset
from .modeling import SetFitHead, SetFitModel
-from .trainer import SetFitTrainer
-from .trainer_distillation import DistillationSetFitTrainer
+from .span import AbsaModel, AbsaTrainer, AspectExtractor, AspectModel, PolarityModel
+from .trainer import SetFitTrainer, Trainer
+from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer
+from .training_args import TrainingArguments
+
+
+# Ensure that DeprecationWarnings are shown by default, as recommended by
+# https://docs.python.org/3/library/warnings.html#overriding-the-default-filter
+warnings.filterwarnings("default", category=DeprecationWarning)
diff --git a/src/setfit/data.py b/src/setfit/data.py
index ff5a0c33..2d9cd5f8 100644
--- a/src/setfit/data.py
+++ b/src/setfit/data.py
@@ -1,4 +1,3 @@
-import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import pandas as pd
@@ -21,15 +20,6 @@
SAMPLE_SIZES = [2, 4, 8, 16, 32, 64]
-def get_augmented_samples(*args, **kwargs) -> None:
- warnings.warn(
- "`get_augmented_samples` has been deprecated and will be removed in v1.0.0 of SetFit. "
- "Please use `get_templated_dataset` instead.",
- DeprecationWarning,
- stacklevel=2,
- )
-
-
def get_templated_dataset(
dataset: Optional[Dataset] = None,
candidate_labels: Optional[List[str]] = None,
@@ -115,15 +105,6 @@ def get_templated_dataset(
return dataset
-def add_templated_examples(*args, **kwargs) -> None:
- warnings.warn(
- "`add_templated_examples` has been deprecated and will be removed in v1.0.0 of SetFit. "
- "Please use `get_templated_dataset` instead.",
- DeprecationWarning,
- stacklevel=2,
- )
-
-
def get_candidate_labels(dataset_name: str, label_names_column: str = "label_text") -> List[str]:
dataset = load_dataset(dataset_name, split="train")
@@ -170,7 +151,7 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i
df = df.groupby(label_column)
# sample num_samples, or at least as much as possible
- df = df.apply(lambda x: x.sample(min(num_samples, len(x))))
+ df = df.apply(lambda x: x.sample(min(num_samples, len(x)), random_state=seed))
df = df.reset_index(drop=True)
all_samples = Dataset.from_pandas(df, features=dataset.features)
diff --git a/src/setfit/integrations.py b/src/setfit/integrations.py
index 94d7161e..a847d753 100644
--- a/src/setfit/integrations.py
+++ b/src/setfit/integrations.py
@@ -5,10 +5,10 @@
if TYPE_CHECKING:
- from .trainer import SetFitTrainer
+ from .trainer import Trainer
-def is_optuna_available():
+def is_optuna_available() -> bool:
return importlib.util.find_spec("optuna") is not None
@@ -17,7 +17,7 @@ def default_hp_search_backend():
return "optuna"
-def run_hp_search_optuna(trainer: "SetFitTrainer", n_trials: int, direction: str, **kwargs) -> BestRun:
+def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
# Heavily inspired by transformers.integrations.run_hp_search_optuna
diff --git a/src/setfit/losses.py b/src/setfit/losses.py
new file mode 100644
index 00000000..369c8451
--- /dev/null
+++ b/src/setfit/losses.py
@@ -0,0 +1,100 @@
+import torch
+from torch import nn
+
+
+class SupConLoss(nn.Module):
+ """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
+
+ It also supports the unsupervised contrastive loss in SimCLR.
+ """
+
+ def __init__(self, model, temperature=0.07, contrast_mode="all", base_temperature=0.07):
+ super(SupConLoss, self).__init__()
+ self.model = model
+ self.temperature = temperature
+ self.contrast_mode = contrast_mode
+ self.base_temperature = base_temperature
+
+ def forward(self, sentence_features, labels=None, mask=None):
+ """Computes loss for model.
+
+ If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss:
+ https://arxiv.org/pdf/2002.05709.pdf
+
+ Args:
+ features: hidden vector of shape [bsz, n_views, ...].
+ labels: ground truth of shape [bsz].
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
+ has the same class as sample i. Can be asymmetric.
+
+ Returns:
+ A loss scalar.
+ """
+ features = self.model(sentence_features[0])["sentence_embedding"]
+
+ # Normalize embeddings
+ features = torch.nn.functional.normalize(features, p=2, dim=1)
+
+ # Add n_views dimension
+ features = torch.unsqueeze(features, 1)
+
+ device = features.device
+
+ if len(features.shape) < 3:
+ raise ValueError("`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required")
+ if len(features.shape) > 3:
+ features = features.view(features.shape[0], features.shape[1], -1)
+
+ batch_size = features.shape[0]
+ if labels is not None and mask is not None:
+ raise ValueError("Cannot define both `labels` and `mask`")
+ elif labels is None and mask is None:
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
+ elif labels is not None:
+ labels = labels.contiguous().view(-1, 1)
+ if labels.shape[0] != batch_size:
+ raise ValueError("Num of labels does not match num of features")
+ mask = torch.eq(labels, labels.T).float().to(device)
+ else:
+ mask = mask.float().to(device)
+
+ contrast_count = features.shape[1]
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
+ if self.contrast_mode == "one":
+ anchor_feature = features[:, 0]
+ anchor_count = 1
+ elif self.contrast_mode == "all":
+ anchor_feature = contrast_feature
+ anchor_count = contrast_count
+ else:
+ raise ValueError("Unknown mode: {}".format(self.contrast_mode))
+
+ # Compute logits
+ anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
+ # For numerical stability
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
+ logits = anchor_dot_contrast - logits_max.detach()
+
+ # Tile mask
+ mask = mask.repeat(anchor_count, contrast_count)
+ # Mask-out self-contrast cases
+ logits_mask = torch.scatter(
+ torch.ones_like(mask),
+ 1,
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
+ 0,
+ )
+ mask = mask * logits_mask
+
+ # Compute log_prob
+ exp_logits = torch.exp(logits) * logits_mask
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
+
+ # Compute mean of log-likelihood over positive
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
+
+ # Loss
+ loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
+ loss = loss.view(anchor_count, batch_size).mean()
+
+ return loss
diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py
index f971d952..793b2c72 100644
--- a/src/setfit/modeling.py
+++ b/src/setfit/modeling.py
@@ -1,8 +1,9 @@
import os
import tempfile
+import warnings
from dataclasses import dataclass
from pathlib import Path
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
# Google Colab runs on Python 3.7, so we need this to be compatible
@@ -15,23 +16,20 @@
import numpy as np
import requests
import torch
-import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
-from sentence_transformers import InputExample, SentenceTransformer, models
+from huggingface_hub.utils import validate_hf_hub_args
+from sentence_transformers import SentenceTransformer, models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
+from torch import nn
from torch.utils.data import DataLoader
-from tqdm.auto import trange
+from tqdm.auto import tqdm, trange
from . import logging
from .data import SetFitDataset
-if TYPE_CHECKING:
- from numpy import ndarray
-
-
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
@@ -77,28 +75,19 @@
```bibtex
@article{{https://doi.org/10.48550/arxiv.2209.11055,
-doi = {{10.48550/ARXIV.2209.11055}},
-url = {{https://arxiv.org/abs/2209.11055}},
-author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}},
-keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}},
-title = {{Efficient Few-Shot Learning Without Prompts}},
-publisher = {{arXiv}},
-year = {{2022}},
-copyright = {{Creative Commons Attribution 4.0 International}}
+ doi = {{10.48550/ARXIV.2209.11055}},
+ url = {{https://arxiv.org/abs/2209.11055}},
+ author = {{Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}},
+ keywords = {{Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences}},
+ title = {{Efficient Few-Shot Learning Without Prompts}},
+ publisher = {{arXiv}},
+ year = {{2022}},
+ copyright = {{Creative Commons Attribution 4.0 International}}
}}
```
"""
-class SetFitBaseModel:
- def __init__(self, model, max_seq_length: int, add_normalization_layer: bool) -> None:
- self.model = SentenceTransformer(model)
- self.model.max_seq_length = max_seq_length
-
- if add_normalization_layer:
- self.model._modules["2"] = models.Normalize()
-
-
class SetFitHead(models.Dense):
"""
A SetFit head that supports multi-class classification for end-to-end training.
@@ -217,7 +206,7 @@ def predict(self, x_test: torch.Tensor) -> torch.Tensor:
return torch.where(probs >= 0.5, 1, 0)
return torch.argmax(probs, dim=-1)
- def get_loss_fn(self):
+ def get_loss_fn(self) -> nn.Module:
if self.multitarget: # if sigmoid output
return torch.nn.BCEWithLogitsLoss()
return torch.nn.CrossEntropyLoss()
@@ -241,13 +230,13 @@ def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]:
}
@staticmethod
- def _init_weight(module):
+ def _init_weight(module) -> None:
if isinstance(module, nn.Linear):
- torch.nn.init.xavier_uniform_(module.weight)
+ nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
- torch.nn.init.constant_(module.bias, 1e-2)
+ nn.init.constant_(module.bias, 1e-2)
- def __repr__(self):
+ def __repr__(self) -> str:
return "SetFitHead({})".format(self.get_config_dict())
@@ -255,22 +244,10 @@ def __repr__(self):
class SetFitModel(PyTorchModelHubMixin):
"""A SetFit model with integration to the Hugging Face Hub."""
- def __init__(
- self,
- model_body: Optional[SentenceTransformer] = None,
- model_head: Optional[Union[SetFitHead, LogisticRegression]] = None,
- multi_target_strategy: Optional[str] = None,
- l2_weight: float = 1e-2,
- normalize_embeddings: bool = False,
- ) -> None:
- super(SetFitModel, self).__init__()
- self.model_body = model_body
- self.model_head = model_head
-
- self.multi_target_strategy = multi_target_strategy
- self.l2_weight = l2_weight
-
- self.normalize_embeddings = normalize_embeddings
+ model_body: Optional[SentenceTransformer] = (None,)
+ model_head: Optional[Union[SetFitHead, LogisticRegression]] = None
+ multi_target_strategy: Optional[str] = None
+ normalize_embeddings: bool = False
@property
def has_differentiable_head(self) -> bool:
@@ -283,23 +260,46 @@ def fit(
y_train: Union[List[int], List[List[int]]],
num_epochs: int,
batch_size: Optional[int] = None,
- learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
+ head_learning_rate: Optional[float] = None,
+ end_to_end: bool = False,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
- show_progress_bar: Optional[bool] = None,
+ show_progress_bar: bool = True,
) -> None:
+ """Train the classifier head, only used if a differentiable PyTorch head is used.
+
+ Args:
+ x_train (`List[str]`): A list of training sentences.
+ y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences.
+ num_epochs (`int`): The number of epochs to train for.
+ batch_size (`int`, *optional*): The batch size to use.
+ body_learning_rate (`float`, *optional*): The learning rate for the `SentenceTransformer` body
+ in the `AdamW` optimizer. Disregarded if `end_to_end=False`.
+ head_learning_rate (`float`, *optional*): The learning rate for the differentiable torch head
+ in the `AdamW` optimizer.
+ end_to_end (`bool`, defaults to `False`): If True, train the entire model end-to-end.
+ Otherwise, freeze the `SentenceTransformer` body and only train the head.
+ l2_weight (`float`, *optional*): The l2 weight for both the model body and head
+ in the `AdamW` optimizer.
+ max_length (`int`, *optional*): The maximum token length a tokenizer can generate. If not provided,
+ the maximum length for the `SentenceTransformer` body is used.
+ show_progress_bar (`bool`, defaults to `True`): Whether to display a progress bar for the training
+ epochs and iterations.
+ """
if self.has_differentiable_head: # train with pyTorch
device = self.model_body.device
self.model_body.train()
self.model_head.train()
+ if not end_to_end:
+ self.freeze("body")
dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length)
criterion = self.model_head.get_loss_fn()
- optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight)
+ optimizer = self._prepare_optimizer(head_learning_rate, body_learning_rate, l2_weight)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
for epoch_idx in trange(num_epochs, desc="Epoch", disable=not show_progress_bar):
- for batch in dataloader:
+ for batch in tqdm(dataloader, desc="Iteration", disable=not show_progress_bar, leave=False):
features, labels = batch
optimizer.zero_grad()
@@ -309,15 +309,18 @@ def fit(
outputs = self.model_body(features)
if self.normalize_embeddings:
- outputs = torch.nn.functional.normalize(outputs, p=2, dim=1)
+ outputs = nn.functional.normalize(outputs, p=2, dim=1)
outputs = self.model_head(outputs)
logits = outputs["logits"]
- loss = criterion(logits, labels)
+ loss: torch.Tensor = criterion(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
+
+ if not end_to_end:
+ self.unfreeze("body")
else: # train with sklearn
embeddings = self.model_body.encode(x_train, normalize_embeddings=self.normalize_embeddings)
self.model_head.fit(embeddings, y_train)
@@ -364,12 +367,12 @@ def _prepare_dataloader(
def _prepare_optimizer(
self,
- learning_rate: float,
+ head_learning_rate: float,
body_learning_rate: Optional[float],
l2_weight: float,
) -> torch.optim.Optimizer:
- body_learning_rate = body_learning_rate or learning_rate
- l2_weight = l2_weight or self.l2_weight
+ body_learning_rate = body_learning_rate or head_learning_rate
+ l2_weight = l2_weight or 1e-2
optimizer = torch.optim.AdamW(
[
{
@@ -377,37 +380,79 @@ def _prepare_optimizer(
"lr": body_learning_rate,
"weight_decay": l2_weight,
},
- {
- "params": self.model_head.parameters(),
- "lr": learning_rate,
- "weight_decay": l2_weight,
- },
+ {"params": self.model_head.parameters(), "lr": head_learning_rate, "weight_decay": l2_weight},
],
)
return optimizer
def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None:
+ """Freeze the model body and/or the head, preventing further training on that component until unfrozen.
+
+ Args:
+ component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to freeze that component.
+ If no component is provided, freeze both. Defaults to None.
+ """
if component is None or component == "body":
self._freeze_or_not(self.model_body, to_freeze=True)
- if component is None or component == "head":
+ if (component is None or component == "head") and self.has_differentiable_head:
self._freeze_or_not(self.model_head, to_freeze=True)
- def unfreeze(self, component: Optional[Literal["body", "head"]] = None) -> None:
+ def unfreeze(
+ self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None
+ ) -> None:
+ """Unfreeze the model body and/or the head, allowing further training on that component.
+
+ Args:
+ component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to unfreeze that component.
+ If no component is provided, unfreeze both. Defaults to None.
+ keep_body_frozen (`bool`, *optional*): Deprecated argument, use `component` instead.
+ """
+ if keep_body_frozen is not None:
+ warnings.warn(
+ "`keep_body_frozen` is deprecated and will be removed in v2.0.0 of SetFit. "
+ 'Please either pass "head", "body" or no arguments to unfreeze both.',
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ # If the body must stay frozen, only unfreeze the head. Eventually, this entire if-branch
+ # can be removed.
+ if keep_body_frozen and not component:
+ component = "head"
+
if component is None or component == "body":
self._freeze_or_not(self.model_body, to_freeze=False)
- if component is None or component == "head":
+ if (component is None or component == "head") and self.has_differentiable_head:
self._freeze_or_not(self.model_head, to_freeze=False)
- def _freeze_or_not(self, model: torch.nn.Module, to_freeze: bool) -> None:
+ def _freeze_or_not(self, model: nn.Module, to_freeze: bool) -> None:
+ """Set `requires_grad=not to_freeze` for all parameters in `model`"""
for param in model.parameters():
param.requires_grad = not to_freeze
+ def encode(self, inputs: List[str], show_progress_bar: Optional[bool] = None) -> Union[torch.Tensor, np.ndarray]:
+ """Convert input sentences to embeddings using the `SentenceTransformer` body.
+
+ Args:
+ inputs (`List[str]`): The input sentences to embed.
+ show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding.
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]: A matrix with shape [INPUT_LENGTH, EMBEDDING_SIZE], as a
+ torch Tensor if this model has a differentiable Torch head, or otherwise as a numpy array.
+ """
+ return self.model_body.encode(
+ inputs,
+ normalize_embeddings=self.normalize_embeddings,
+ convert_to_tensor=self.has_differentiable_head,
+ show_progress_bar=show_progress_bar,
+ )
+
def _output_type_conversion(
- self, outputs: Union[torch.Tensor, "ndarray"], as_numpy: bool = False
- ) -> Union[torch.Tensor, "ndarray"]:
+ self, outputs: Union[torch.Tensor, np.ndarray], as_numpy: bool = False
+ ) -> Union[torch.Tensor, np.ndarray]:
"""Return `outputs` in the desired type:
* Numpy array if no differentiable head is used.
* Torch tensor if a differentiable head is used.
@@ -427,37 +472,74 @@ def _output_type_conversion(
return outputs
def predict(
- self, x_test: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None
- ) -> Union[torch.Tensor, "ndarray"]:
- embeddings = self.model_body.encode(
- x_test,
- normalize_embeddings=self.normalize_embeddings,
- convert_to_tensor=self.has_differentiable_head,
- show_progress_bar=show_progress_bar,
- )
+ self, inputs: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None
+ ) -> Union[torch.Tensor, np.ndarray]:
+ """Predict the various classes.
+
+ Args:
+ inputs (`List[str]`): The input sentences to predict classes for.
+ as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead.
+ show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding.
+
+ Example:
+ >>> model = SetFitModel.from_pretrained(...)
+ >>> model.predict(["What a boring display", "Exhilarating through and through", "I'm wowed!"])
+ tensor([0, 1, 1], dtype=torch.int32)
+ Returns:
+ `Union[torch.Tensor, np.ndarray]`: A vector with equal length to the inputs, denoting
+ to which class each input is predicted to belong.
+ """
+ embeddings = self.encode(inputs, show_progress_bar=show_progress_bar)
outputs = self.model_head.predict(embeddings)
return self._output_type_conversion(outputs, as_numpy=as_numpy)
def predict_proba(
- self, x_test: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None
- ) -> Union[torch.Tensor, "ndarray"]:
- embeddings = self.model_body.encode(
- x_test,
- normalize_embeddings=self.normalize_embeddings,
- convert_to_tensor=self.has_differentiable_head,
- show_progress_bar=show_progress_bar,
- )
+ self, inputs: List[str], as_numpy: bool = False, show_progress_bar: Optional[bool] = None
+ ) -> Union[torch.Tensor, np.ndarray]:
+ """Predict the probabilities of the various classes.
+ Args:
+ inputs (`List[str]`): The input sentences to predict class probabilities for.
+ as_numpy (`bool`, defaults to `False`): Whether to output as numpy array instead.
+ show_progress_bar (`Optional[bool]`, defaults to `None`): Whether to show a progress bar while encoding.
+
+ Example:
+ >>> model = SetFitModel.from_pretrained(...)
+ >>> model.predict_proba(["What a boring display", "Exhilarating through and through", "I'm wowed!"])
+ tensor([[0.9367, 0.0633],
+ [0.0627, 0.9373],
+ [0.0890, 0.9110]], dtype=torch.float64)
+
+ Returns:
+ `Union[torch.Tensor, np.ndarray]`: A matrix with shape [INPUT_LENGTH, NUM_CLASSES] denoting
+ probabilities of predicting an input as a class.
+ """
+ embeddings = self.encode(inputs, show_progress_bar=show_progress_bar)
outputs = self.model_head.predict_proba(embeddings)
return self._output_type_conversion(outputs, as_numpy=as_numpy)
+ @property
+ def device(self) -> torch.device:
+ """Get the Torch device that this model is on.
+
+ Returns:
+ torch.device: The device that the model is on.
+ """
+ return self.model_body.device
+
def to(self, device: Union[str, torch.device]) -> "SetFitModel":
"""Move this SetFitModel to `device`, and then return `self`. This method does not copy.
Args:
device (Union[str, torch.device]): The identifier of the device to move the model to.
+ Example:
+
+ >>> model = SetFitModel.from_pretrained(...)
+ >>> model.to("cpu")
+ >>> model(["cats are cute", "dogs are loyal"])
+
Returns:
SetFitModel: Returns the original model, but now on the desired device.
"""
@@ -492,7 +574,21 @@ def create_model_card(self, path: str, model_name: Optional[str] = "SetFit Model
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
f.write(model_card_content)
- def __call__(self, inputs):
+ def __call__(self, inputs: List[str]) -> torch.Tensor:
+ """Predict the various classes.
+
+ Args:
+ inputs (`List[str]`): The input sentences to predict classes for.
+
+ Example:
+ >>> model = SetFitModel.from_pretrained(...)
+ >>> model(["What a boring display", "Exhilarating through and through", "I'm wowed!"])
+ tensor([0, 1, 1], dtype=torch.int32)
+
+ Returns:
+ `torch.Tensor`: A vector with equal length to the inputs, denoting to which class each
+ input is predicted to belong.
+ """
return self.predict(inputs)
def _save_pretrained(self, save_directory: Union[Path, str]) -> None:
@@ -502,6 +598,7 @@ def _save_pretrained(self, save_directory: Union[Path, str]) -> None:
joblib.dump(self.model_head, str(Path(save_directory) / MODEL_HEAD_NAME))
@classmethod
+ @validate_hf_hub_args
def _from_pretrained(
cls,
model_id: str,
@@ -511,13 +608,13 @@ def _from_pretrained(
proxies: Optional[Dict] = None,
resume_download: Optional[bool] = None,
local_files_only: Optional[bool] = None,
- use_auth_token: Optional[Union[bool, str]] = None,
+ token: Optional[Union[bool, str]] = None,
multi_target_strategy: Optional[str] = None,
use_differentiable_head: bool = False,
normalize_embeddings: bool = False,
**model_kwargs,
) -> "SetFitModel":
- model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=use_auth_token)
+ model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token)
target_device = model_body._target_device
model_body.to(target_device) # put `model_body` on the target device
@@ -541,7 +638,7 @@ def _from_pretrained(
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
- use_auth_token=use_auth_token,
+ token=token,
local_files_only=local_files_only,
)
except requests.exceptions.RequestException:
@@ -554,7 +651,7 @@ def _from_pretrained(
if model_head_file is not None:
model_head = joblib.load(model_head_file)
else:
- head_params = model_kwargs.get("head_params", {})
+ head_params = model_kwargs.pop("head_params", {})
if use_differentiable_head:
if multi_target_strategy is None:
use_multitarget = False
@@ -590,207 +687,12 @@ def _from_pretrained(
else:
model_head = clf
+ # Remove the `transformers` config
+ model_kwargs.pop("config", None)
return cls(
model_body=model_body,
model_head=model_head,
multi_target_strategy=multi_target_strategy,
normalize_embeddings=normalize_embeddings,
+ **model_kwargs,
)
-
-
-class SupConLoss(nn.Module):
- """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
-
- It also supports the unsupervised contrastive loss in SimCLR.
- """
-
- def __init__(self, model, temperature=0.07, contrast_mode="all", base_temperature=0.07):
- super(SupConLoss, self).__init__()
- self.model = model
- self.temperature = temperature
- self.contrast_mode = contrast_mode
- self.base_temperature = base_temperature
-
- def forward(self, sentence_features, labels=None, mask=None):
- """Computes loss for model.
-
- If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss:
- https://arxiv.org/pdf/2002.05709.pdf
-
- Args:
- features: hidden vector of shape [bsz, n_views, ...].
- labels: ground truth of shape [bsz].
- mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
- has the same class as sample i. Can be asymmetric.
-
- Returns:
- A loss scalar.
- """
- features = self.model(sentence_features[0])["sentence_embedding"]
-
- # Normalize embeddings
- features = torch.nn.functional.normalize(features, p=2, dim=1)
-
- # Add n_views dimension
- features = torch.unsqueeze(features, 1)
-
- device = features.device
-
- if len(features.shape) < 3:
- raise ValueError("`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required")
- if len(features.shape) > 3:
- features = features.view(features.shape[0], features.shape[1], -1)
-
- batch_size = features.shape[0]
- if labels is not None and mask is not None:
- raise ValueError("Cannot define both `labels` and `mask`")
- elif labels is None and mask is None:
- mask = torch.eye(batch_size, dtype=torch.float32).to(device)
- elif labels is not None:
- labels = labels.contiguous().view(-1, 1)
- if labels.shape[0] != batch_size:
- raise ValueError("Num of labels does not match num of features")
- mask = torch.eq(labels, labels.T).float().to(device)
- else:
- mask = mask.float().to(device)
-
- contrast_count = features.shape[1]
- contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
- if self.contrast_mode == "one":
- anchor_feature = features[:, 0]
- anchor_count = 1
- elif self.contrast_mode == "all":
- anchor_feature = contrast_feature
- anchor_count = contrast_count
- else:
- raise ValueError("Unknown mode: {}".format(self.contrast_mode))
-
- # Compute logits
- anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
- # For numerical stability
- logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
- logits = anchor_dot_contrast - logits_max.detach()
-
- # Tile mask
- mask = mask.repeat(anchor_count, contrast_count)
- # Mask-out self-contrast cases
- logits_mask = torch.scatter(
- torch.ones_like(mask),
- 1,
- torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
- 0,
- )
- mask = mask * logits_mask
-
- # Compute log_prob
- exp_logits = torch.exp(logits) * logits_mask
- log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
-
- # Compute mean of log-likelihood over positive
- mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
-
- # Loss
- loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
- loss = loss.view(anchor_count, batch_size).mean()
-
- return loss
-
-
-def sentence_pairs_generation(sentences, labels, pairs):
- # Initialize two empty lists to hold the (sentence, sentence) pairs and
- # labels to indicate if a pair is positive or negative
-
- num_classes = np.unique(labels)
- label_to_idx = {x: i for i, x in enumerate(num_classes)}
- positive_idxs = [np.where(labels == i)[0] for i in num_classes]
- negative_idxs = [np.where(labels != i)[0] for i in num_classes]
-
- for first_idx in range(len(sentences)):
- current_sentence = sentences[first_idx]
- label = labels[first_idx]
- second_idx = np.random.choice(positive_idxs[label_to_idx[label]])
- positive_sentence = sentences[second_idx]
- # Prepare a positive pair and update the sentences and labels
- # lists, respectively
- pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0))
-
- third_idx = np.random.choice(negative_idxs[label_to_idx[label]])
- negative_sentence = sentences[third_idx]
- # Prepare a negative pair of sentences and update our lists
- pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0))
- # Return a 2-tuple of our sentence pairs and labels
- return pairs
-
-
-def sentence_pairs_generation_multilabel(sentences, labels, pairs):
- # Initialize two empty lists to hold the (sentence, sentence) pairs and
- # labels to indicate if a pair is positive or negative
- for first_idx in range(len(sentences)):
- current_sentence = sentences[first_idx]
- sample_labels = np.where(labels[first_idx, :] == 1)[0]
- if len(np.where(labels.dot(labels[first_idx, :].T) == 0)[0]) == 0:
- continue
- else:
- for _label in sample_labels:
- second_idx = np.random.choice(np.where(labels[:, _label] == 1)[0])
- positive_sentence = sentences[second_idx]
- # Prepare a positive pair and update the sentences and labels
- # lists, respectively
- pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0))
-
- # Search for sample that don't have a label in common with current
- # sentence
- negative_idx = np.where(labels.dot(labels[first_idx, :].T) == 0)[0]
- negative_sentence = sentences[np.random.choice(negative_idx)]
- # Prepare a negative pair of sentences and update our lists
- pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0))
- # Return a 2-tuple of our sentence pairs and labels
- return pairs
-
-
-def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix):
- # initialize two empty lists to hold the (sentence, sentence) pairs and
- # labels to indicate if a pair is positive or negative
-
- idx = list(range(len(sentences)))
-
- for first_idx in range(len(sentences)):
- current_sentence = sentences[first_idx]
- second_idx = int(np.random.choice([x for x in idx if x != first_idx]))
-
- cos_sim = float(cos_sim_matrix[first_idx][second_idx])
- paired_sentence = sentences[second_idx]
- pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim))
-
- third_idx = np.random.choice([x for x in idx if x != first_idx])
- cos_sim = float(cos_sim_matrix[first_idx][third_idx])
- paired_sentence = sentences[third_idx]
- pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim))
-
- return pairs
-
-
-class SKLearnWrapper:
- def __init__(self, st_model=None, clf=None):
- self.st_model = st_model
- self.clf = clf
-
- def fit(self, x_train, y_train):
- embeddings = self.st_model.encode(x_train)
- self.clf.fit(embeddings, y_train)
-
- def predict(self, x_test):
- embeddings = self.st_model.encode(x_test)
- return self.clf.predict(embeddings)
-
- def predict_proba(self, x_test):
- embeddings = self.st_model.encode(x_test)
- return self.clf.predict_proba(embeddings)
-
- def save(self, path):
- self.st_model.save(path=path)
- joblib.dump(self.clf, f"{path}/setfit_head.pkl")
-
- def load(self, path):
- self.st_model = SentenceTransformer(model_name_or_path=path)
- self.clf = joblib.load(f"{path}/setfit_head.pkl")
diff --git a/src/setfit/pipeline.py b/src/setfit/pipeline.py
deleted file mode 100644
index 51e551ff..00000000
--- a/src/setfit/pipeline.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from .modeling import SKLearnWrapper
-
-
-class SetFitPipeline:
- def __init__(self, model_name_or_path) -> None:
- base_model = SKLearnWrapper()
- base_model.load(model_name_or_path)
- self.model = base_model
-
- def __call__(self, inputs, *args, **kwargs):
- model_outputs = self.model.predict(inputs)
- return model_outputs
diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py
new file mode 100644
index 00000000..0eceba1e
--- /dev/null
+++ b/src/setfit/sampler.py
@@ -0,0 +1,156 @@
+from itertools import zip_longest
+from typing import Generator, Iterable, List, Optional
+
+import numpy as np
+import torch
+from sentence_transformers import InputExample
+from torch.utils.data import IterableDataset
+
+from . import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
+ """Generates shuffled pair combinations for any iterable data provided.
+
+ Args:
+ iterable: data to generate pair combinations from
+ replacement: enable to include combinations of same samples,
+ equivalent to itertools.combinations_with_replacement
+
+ Returns:
+ Generator of shuffled pairs as a tuple
+ """
+ n = len(iterable)
+ k = 1 if not replacement else 0
+ idxs = np.stack(np.triu_indices(n, k), axis=-1)
+ for i in np.random.RandomState(seed=42).permutation(len(idxs)):
+ _idx, idx = idxs[i, :]
+ yield iterable[_idx], iterable[idx]
+
+
+class ContrastiveDataset(IterableDataset):
+ def __init__(
+ self,
+ examples: List[InputExample],
+ multilabel: bool,
+ num_iterations: Optional[None] = None,
+ sampling_strategy: str = "oversampling",
+ ) -> None:
+ """Generates positive and negative text pairs for contrastive learning.
+
+ Args:
+ examples (InputExample): text and labels in a text transformer dataclass
+ multilabel: set to process "multilabel" labels array
+ sampling_strategy: "unique", "oversampling", or "undersampling"
+ num_iterations: if provided explicitly sets the number of pairs to be generated
+ where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs)
+ """
+ super().__init__()
+ self.pos_index = 0
+ self.neg_index = 0
+ self.pos_pairs = []
+ self.neg_pairs = []
+ self.sentences = np.array([s.texts[0] for s in examples])
+ self.labels = np.array([s.label for s in examples])
+ self.sentence_labels = list(zip(self.sentences, self.labels))
+
+ if multilabel:
+ self.generate_multilabel_pairs()
+ else:
+ self.generate_pairs()
+
+ if num_iterations is not None and num_iterations > 0:
+ self.len_pos_pairs = num_iterations * len(self.sentences)
+ self.len_neg_pairs = num_iterations * len(self.sentences)
+
+ elif sampling_strategy == "unique":
+ self.len_pos_pairs = len(self.pos_pairs)
+ self.len_neg_pairs = len(self.neg_pairs)
+
+ elif sampling_strategy == "undersampling":
+ self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
+ self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
+
+ elif sampling_strategy == "oversampling":
+ self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
+ self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
+
+ else:
+ raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.")
+
+ def generate_pairs(self) -> None:
+ for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
+ if _label == label:
+ self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
+ else:
+ self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
+
+ def generate_multilabel_pairs(self) -> None:
+ for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
+ if any(np.logical_and(_label, label)):
+ # logical_and checks if labels are both set for each class
+ self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
+ else:
+ self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
+
+ def get_positive_pairs(self) -> List[InputExample]:
+ pairs = []
+ for _ in range(self.len_pos_pairs):
+ if self.pos_index >= len(self.pos_pairs):
+ self.pos_index = 0
+ pairs.append(self.pos_pairs[self.pos_index])
+ self.pos_index += 1
+ return pairs
+
+ def get_negative_pairs(self) -> List[InputExample]:
+ pairs = []
+ for _ in range(self.len_neg_pairs):
+ if self.neg_index >= len(self.neg_pairs):
+ self.neg_index = 0
+ pairs.append(self.neg_pairs[self.neg_index])
+ self.neg_index += 1
+ return pairs
+
+ def __iter__(self):
+ for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()):
+ if pos_pair is not None:
+ yield pos_pair
+ if neg_pair is not None:
+ yield neg_pair
+
+ def __len__(self) -> int:
+ return self.len_pos_pairs + self.len_neg_pairs
+
+
+class ContrastiveDistillationDataset(ContrastiveDataset):
+ def __init__(
+ self,
+ examples: List[InputExample],
+ cos_sim_matrix: torch.Tensor,
+ num_iterations: Optional[None] = None,
+ sampling_strategy: str = "oversampling",
+ ) -> None:
+ self.cos_sim_matrix = cos_sim_matrix
+ super().__init__(
+ examples,
+ multilabel=False,
+ num_iterations=num_iterations,
+ sampling_strategy=sampling_strategy,
+ )
+ # Internally we store all pairs in pos_pairs, regardless of sampling strategy.
+ # After all, without labels, there isn't much of a strategy.
+ self.sentence_labels = list(enumerate(self.sentences))
+
+ self.len_neg_pairs = 0
+ if num_iterations is not None and num_iterations > 0:
+ self.len_pos_pairs = num_iterations * len(self.sentences)
+ else:
+ self.len_pos_pairs = len(self.pos_pairs)
+
+ def generate_pairs(self) -> None:
+ for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
+ self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two]))
diff --git a/src/setfit/span/__init__.py b/src/setfit/span/__init__.py
new file mode 100644
index 00000000..7fc6f9db
--- /dev/null
+++ b/src/setfit/span/__init__.py
@@ -0,0 +1,3 @@
+from .aspect_extractor import AspectExtractor
+from .modeling import AbsaModel, AspectModel, PolarityModel
+from .trainer import AbsaTrainer
diff --git a/src/setfit/span/aspect_extractor.py b/src/setfit/span/aspect_extractor.py
new file mode 100644
index 00000000..096b9bb6
--- /dev/null
+++ b/src/setfit/span/aspect_extractor.py
@@ -0,0 +1,34 @@
+from typing import TYPE_CHECKING, List, Tuple
+
+
+if TYPE_CHECKING:
+ from spacy.tokens import Doc
+
+
+class AspectExtractor:
+ def __init__(self, spacy_model: str) -> None:
+ super().__init__()
+ import spacy
+
+ self.nlp = spacy.load(spacy_model)
+
+ def find_groups(self, aspect_mask: List[bool]):
+ start = None
+ for idx, flag in enumerate(aspect_mask):
+ if flag:
+ if start is None:
+ start = idx
+ else:
+ if start is not None:
+ yield slice(start, idx)
+ start = None
+ if start is not None:
+ yield slice(start, idx)
+
+ def __call__(self, texts: List[str]) -> Tuple[List["Doc"], List[slice]]:
+ aspects_list = []
+ docs = list(self.nlp.pipe(texts))
+ for doc in docs:
+ aspect_mask = [token.pos_ in ("NOUN", "PROPN") for token in doc]
+ aspects_list.append(list(self.find_groups(aspect_mask)))
+ return docs, aspects_list
diff --git a/src/setfit/span/model_card_template.md b/src/setfit/span/model_card_template.md
new file mode 100644
index 00000000..31ec618f
--- /dev/null
+++ b/src/setfit/span/model_card_template.md
@@ -0,0 +1,64 @@
+---
+license: apache-2.0
+tags:
+- setfit
+- sentence-transformers
+- absa
+- token-classification
+pipeline_tag: token-classification
+---
+
+# {{ model_name | default("SetFit ABSA Model", true) }}
+
+This is a [SetFit ABSA model](https://github.com/huggingface/setfit) that can be used for Aspect Based Sentiment Analysis (ABSA). \
+In particular, this model is in charge of {{ "filtering aspect span candidates" if is_aspect else "classifying aspect polarities"}}.
+It has been trained using SetFit, an efficient few-shot learning technique that involves:
+
+1. Fine-tuning a [Sentence Transformer](https://www.sbert.net) with contrastive learning.
+2. Training a classification head with features from the fine-tuned Sentence Transformer.
+
+This model was trained within the context of a larger system for ABSA, which looks like so:
+
+1. Use a spaCy model to select possible aspect span candidates.
+2. {{ "**" if is_aspect else "" }}Use {{ "this" if is_aspect else "a" }} SetFit model to filter these possible aspect span candidates.{{ "**" if is_aspect else "" }}
+3. {{ "**" if not is_aspect else "" }}Use {{ "this" if not is_aspect else "a" }} SetFit model to classify the filtered aspect span candidates.{{ "**" if not is_aspect else "" }}
+
+## Usage
+
+To use this model for inference, first install the SetFit library:
+
+```bash
+pip install setfit
+```
+
+You can then run inference as follows:
+
+```python
+from setfit import AbsaModel
+
+# Download from Hub and run inference
+model = AbsaModel.from_pretrained(
+ "{{ aspect_model }}",
+ "{{ polarity_model }}",
+)
+# Run inference
+preds = model([
+ "The best pizza outside of Italy and really tasty.",
+ "The food here is great but the service is terrible",
+])
+```
+
+## BibTeX entry and citation info
+
+```bibtex
+@article{https://doi.org/10.48550/arxiv.2209.11055,
+ doi = {10.48550/ARXIV.2209.11055},
+ url = {https://arxiv.org/abs/2209.11055},
+ author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
+ keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {Efficient Few-Shot Learning Without Prompts},
+ publisher = {arXiv},
+ year = {2022},
+ copyright = {Creative Commons Attribution 4.0 International}
+}
+```
\ No newline at end of file
diff --git a/src/setfit/span/modeling.py b/src/setfit/span/modeling.py
new file mode 100644
index 00000000..f25a72c1
--- /dev/null
+++ b/src/setfit/span/modeling.py
@@ -0,0 +1,292 @@
+import json
+import os
+import tempfile
+from dataclasses import dataclass
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import SoftTemporaryDirectory, validate_hf_hub_args
+from jinja2 import Environment, FileSystemLoader
+
+from .. import logging
+from ..modeling import SetFitModel
+from .aspect_extractor import AspectExtractor
+
+
+if TYPE_CHECKING:
+ from spacy.tokens import Doc
+
+logger = logging.get_logger(__name__)
+
+CONFIG_NAME = "config_span_setfit.json"
+
+
+@dataclass
+class SpanSetFitModel(SetFitModel):
+ span_context: int = 0
+
+ def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]:
+ for doc, aspects in zip(docs, aspects_list):
+ for aspect_slice in aspects:
+ aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context]
+ # TODO: Investigate performance difference of different formats
+ yield aspect.text + ":" + doc.text
+
+ def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
+ inputs_list = list(self.prepend_aspects(docs, aspects_list))
+ preds = self.predict(inputs_list, as_numpy=True)
+ iter_preds = iter(preds)
+ return [[next(iter_preds) for _ in aspects] for aspects in aspects_list]
+
+ @classmethod
+ @validate_hf_hub_args
+ def _from_pretrained(
+ cls,
+ model_id: str,
+ span_context: Optional[int] = None,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ force_download: Optional[bool] = None,
+ proxies: Optional[Dict] = None,
+ resume_download: Optional[bool] = None,
+ local_files_only: Optional[bool] = None,
+ token: Optional[Union[bool, str]] = None,
+ **model_kwargs,
+ ) -> "SpanSetFitModel":
+ config_file: Optional[str] = None
+ if os.path.isdir(model_id):
+ if CONFIG_NAME in os.listdir(model_id):
+ config_file = os.path.join(model_id, CONFIG_NAME)
+ else:
+ try:
+ config_file = hf_hub_download(
+ repo_id=model_id,
+ filename=CONFIG_NAME,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ token=token,
+ local_files_only=local_files_only,
+ )
+ except requests.exceptions.RequestException:
+ pass
+
+ if config_file is not None:
+ with open(config_file, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ model_kwargs.update(config)
+
+ if span_context is not None:
+ model_kwargs["span_context"] = span_context
+
+ return super(SpanSetFitModel, cls)._from_pretrained(
+ model_id=model_id,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ **model_kwargs,
+ )
+
+ def _save_pretrained(self, save_directory: Union[Path, str]) -> None:
+ path = os.path.join(save_directory, CONFIG_NAME)
+ with open(path, "w") as f:
+ json.dump({"span_context": self.span_context}, f, indent=2)
+
+ super()._save_pretrained(save_directory)
+
+ def create_model_card(self, path: str, model_name: Optional[str] = None) -> None:
+ """Creates and saves a model card for a SetFit model.
+
+ Args:
+ path (str): The path to save the model card to.
+ model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`.
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ # If the model_path is a folder that exists locally, i.e. when create_model_card is called
+ # via push_to_hub, and the path is in a temporary folder, then we only take the last two
+ # directories
+ model_path = Path(model_name)
+ if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents:
+ model_name = "/".join(model_path.parts[-2:])
+
+ environment = Environment(loader=FileSystemLoader(Path(__file__).parent))
+ template = environment.get_template("model_card_template.md")
+ is_aspect = isinstance(self, AspectModel)
+ aspect_model = "setfit-absa-aspect"
+ polarity_model = "setfit-absa-polarity"
+ if model_name is not None:
+ if is_aspect:
+ aspect_model = model_name
+ if model_name.endswith("-aspect"):
+ polarity_model = model_name[: -len("-aspect")] + "-polarity"
+ else:
+ polarity_model = model_name
+ if model_name.endswith("-polarity"):
+ aspect_model = model_name[: -len("-polarity")] + "-aspect"
+
+ model_card_content = template.render(
+ model_name=model_name, is_aspect=is_aspect, aspect_model=aspect_model, polarity_model=polarity_model
+ )
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
+ f.write(model_card_content)
+
+
+class AspectModel(SpanSetFitModel):
+ # TODO: Assumes binary SetFitModel with 0 == no aspect, 1 == aspect
+ def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
+ sentence_preds = super().__call__(docs, aspects_list)
+ return [
+ [aspect for aspect, pred in zip(aspects, preds) if pred == 1]
+ for aspects, preds in zip(aspects_list, sentence_preds)
+ ]
+
+
+@dataclass
+class PolarityModel(SpanSetFitModel):
+ span_context: int = 3
+
+
+@dataclass
+class AbsaModel:
+ aspect_extractor: AspectExtractor
+ aspect_model: AspectModel
+ polarity_model: PolarityModel
+
+ def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
+ is_str = isinstance(inputs, str)
+ inputs_list = [inputs] if is_str else inputs
+ docs, aspects_list = self.aspect_extractor(inputs_list)
+ if sum(aspects_list, []) == []:
+ return aspects_list
+
+ aspects_list = self.aspect_model(docs, aspects_list)
+ if sum(aspects_list, []) == []:
+ return aspects_list
+
+ polarity_list = self.polarity_model(docs, aspects_list)
+ outputs = []
+ for docs, aspects, polarities in zip(docs, aspects_list, polarity_list):
+ outputs.append(
+ [
+ {"span": docs[aspect_slice].text, "polarity": polarity}
+ for aspect_slice, polarity in zip(aspects, polarities)
+ ]
+ )
+ return outputs if not is_str else outputs[0]
+
+ @property
+ def device(self) -> torch.device:
+ return self.aspect_model.device
+
+ def to(self, device: Union[str, torch.device]) -> "AbsaModel":
+ self.aspect_model.to(device)
+ self.polarity_model.to(device)
+
+ def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
+ return self.predict(inputs)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, Path],
+ polarity_save_directory: Optional[Union[str, Path]] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ) -> None:
+ if polarity_save_directory is None:
+ base_save_directory = Path(save_directory)
+ save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect")
+ polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity")
+ self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+ self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_id: str,
+ polarity_model_id: Optional[str] = None,
+ spacy_model: Optional[str] = "en_core_web_lg",
+ span_contexts: Tuple[Optional[int], Optional[int]] = (None, None),
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict] = None,
+ token: Optional[Union[str, bool]] = None,
+ cache_dir: Optional[str] = None,
+ local_files_only: bool = False,
+ use_differentiable_head: bool = False,
+ normalize_embeddings: bool = False,
+ **model_kwargs,
+ ) -> "AbsaModel":
+ revision = None
+ if len(model_id.split("@")) == 2:
+ model_id, revision = model_id.split("@")
+ aspect_model = AspectModel.from_pretrained(
+ model_id,
+ span_context=span_contexts[0],
+ revision=revision,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ use_differentiable_head=use_differentiable_head,
+ normalize_embeddings=normalize_embeddings,
+ **model_kwargs,
+ )
+ if polarity_model_id:
+ model_id = polarity_model_id
+ revision = None
+ if len(model_id.split("@")) == 2:
+ model_id, revision = model_id.split("@")
+ polarity_model = PolarityModel.from_pretrained(
+ model_id,
+ span_context=span_contexts[1],
+ revision=revision,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ use_differentiable_head=use_differentiable_head,
+ normalize_embeddings=normalize_embeddings,
+ **model_kwargs,
+ )
+
+ aspect_extractor = AspectExtractor(spacy_model=spacy_model)
+
+ return cls(aspect_extractor, aspect_model, polarity_model)
+
+ def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None:
+ if "/" not in repo_id:
+ raise ValueError(
+ '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
+ )
+ if polarity_repo_id is not None and "/" not in polarity_repo_id:
+ raise ValueError(
+ '`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
+ )
+ commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model")
+
+ # Push the files to the repo in a single commit
+ with SoftTemporaryDirectory() as tmp_dir:
+ save_directory = Path(tmp_dir) / repo_id
+ polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id
+ self.save_pretrained(
+ save_directory=save_directory,
+ polarity_save_directory=polarity_save_directory,
+ push_to_hub=True,
+ commit_message=commit_message,
+ **kwargs,
+ )
diff --git a/src/setfit/span/trainer.py b/src/setfit/span/trainer.py
new file mode 100644
index 00000000..1d362616
--- /dev/null
+++ b/src/setfit/span/trainer.py
@@ -0,0 +1,324 @@
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+from datasets import Dataset
+from transformers.trainer_callback import TrainerCallback
+
+from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel
+from setfit.training_args import TrainingArguments
+
+from .. import logging
+from ..trainer import ColumnMappingMixin, Trainer
+
+
+if TYPE_CHECKING:
+ import optuna
+
+logger = logging.get_logger(__name__)
+
+
+class AbsaTrainer(ColumnMappingMixin):
+ """Trainer to train a SetFit ABSA model.
+
+ Args:
+ model (`AbsaModel`):
+ The AbsaModel model to train.
+ args (`TrainingArguments`, *optional*):
+ The training arguments to use. If `polarity_args` is not defined, then `args` is used for both
+ the aspect and the polarity model.
+ polarity_args (`TrainingArguments`, *optional*):
+ The training arguments to use for the polarity model. If not defined, `args` is used for both
+ the aspect and the polarity model.
+ train_dataset (`Dataset`):
+ The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns.
+ eval_dataset (`Dataset`, *optional*):
+ The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns.
+ metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
+ The metric to use for evaluation. If a string is provided, we treat it as the metric
+ name and load it with default settings.
+ If a callable is provided, it must take two arguments (`y_pred`, `y_test`).
+ metric_kwargs (`Dict[str, Any]`, *optional*):
+ Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
+ For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
+ callbacks: (`List[~transformers.TrainerCallback]`, *optional*):
+ A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback).
+ If you want to remove one of the default callbacks used, use the `Trainer.remove_callback()` method.
+ column_mapping (`Dict[str, str]`, *optional*):
+ A mapping from the column names in the dataset to the column names expected by the model.
+ The expected format is a dictionary with the following format:
+ `{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`.
+ """
+
+ _REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"}
+
+ def __init__(
+ self,
+ model: AbsaModel,
+ args: Optional[TrainingArguments] = None,
+ polarity_args: Optional[TrainingArguments] = None,
+ train_dataset: Optional["Dataset"] = None,
+ eval_dataset: Optional["Dataset"] = None,
+ metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
+ metric_kwargs: Optional[Dict[str, Any]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ column_mapping: Optional[Dict[str, str]] = None,
+ ) -> None:
+ self.model = model
+ self.aspect_extractor = model.aspect_extractor
+
+ if train_dataset is not None and column_mapping:
+ train_dataset = self._apply_column_mapping(train_dataset, column_mapping)
+ aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset(
+ model.aspect_model, model.polarity_model, train_dataset
+ )
+ if eval_dataset is not None and column_mapping:
+ eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping)
+ aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset(
+ model.aspect_model, model.polarity_model, eval_dataset
+ )
+
+ self.aspect_trainer = Trainer(
+ model.aspect_model,
+ args=args,
+ train_dataset=aspect_train_dataset,
+ eval_dataset=aspect_eval_dataset,
+ metric=metric,
+ metric_kwargs=metric_kwargs,
+ callbacks=callbacks,
+ )
+ self.aspect_trainer._set_logs_mapper(
+ {
+ "eval_embedding_loss": "eval_aspect_embedding_loss",
+ "embedding_loss": "aspect_embedding_loss",
+ }
+ )
+ self.polarity_trainer = Trainer(
+ model.polarity_model,
+ args=polarity_args or args,
+ train_dataset=polarity_train_dataset,
+ eval_dataset=polarity_eval_dataset,
+ metric=metric,
+ metric_kwargs=metric_kwargs,
+ callbacks=callbacks,
+ )
+ self.polarity_trainer._set_logs_mapper(
+ {
+ "eval_embedding_loss": "eval_polarity_embedding_loss",
+ "embedding_loss": "polarity_embedding_loss",
+ }
+ )
+
+ def preprocess_dataset(
+ self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset
+ ) -> Dataset:
+ if dataset is None:
+ return dataset, dataset
+
+ # Group by "text"
+ grouped_data = defaultdict(list)
+ for sample in dataset:
+ text = sample.pop("text")
+ grouped_data[text].append(sample)
+
+ def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]:
+ find_from = 0
+ for _ in range(ordinal + 1):
+ start_idx = text.index(target, find_from)
+ find_from = start_idx + 1
+ return start_idx, start_idx + len(target)
+
+ docs, aspects_list = self.aspect_extractor(grouped_data.keys())
+ intersected_aspect_list = []
+ polarity_labels = []
+ aspect_labels = []
+ for doc, aspects, text in zip(docs, aspects_list, grouped_data):
+ gold_aspects = []
+ gold_polarity_labels = []
+ for annotation in grouped_data[text]:
+ try:
+ start, end = index_ordinal(text, annotation["span"], annotation["ordinal"])
+ except ValueError:
+ logger.info(
+ f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. "
+ "Skipping this sample."
+ )
+ continue
+
+ gold_aspect_span = doc.char_span(start, end)
+ if gold_aspect_span is None:
+ continue
+ gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end))
+ gold_polarity_labels.append(annotation["label"])
+
+ # The Aspect model uses all predicted aspects, with labels depending on whether
+ # the predicted aspects are indeed true/gold aspects.
+ aspect_labels.extend([aspect in gold_aspects for aspect in aspects])
+
+ # The Polarity model uses the intersection of pred and gold aspects, with labels for the gold label.
+ intersected_aspects = []
+ for gold_aspect, gold_label in zip(gold_aspects, gold_polarity_labels):
+ if gold_aspect in aspects:
+ intersected_aspects.append(gold_aspect)
+ polarity_labels.append(gold_label)
+ intersected_aspect_list.append(intersected_aspects)
+
+ aspect_texts = list(aspect_model.prepend_aspects(docs, aspects_list))
+ polarity_texts = list(polarity_model.prepend_aspects(docs, intersected_aspect_list))
+ return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict(
+ {"text": polarity_texts, "label": polarity_labels}
+ )
+
+ def train(
+ self,
+ args: Optional[TrainingArguments] = None,
+ polarity_args: Optional[TrainingArguments] = None,
+ trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Main training entry point.
+
+ Args:
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the aspect training arguments for this training call.
+ polarity_args (`TrainingArguments`, *optional*):
+ Temporarily change the polarity training arguments for this training call.
+ trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ """
+ self.train_aspect(args=args, trial=trial, **kwargs)
+ self.train_polarity(args=polarity_args, trial=trial, **kwargs)
+
+ def train_aspect(
+ self,
+ args: Optional[TrainingArguments] = None,
+ trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Train the aspect model only.
+
+ Args:
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the aspect training arguments for this training call.
+ trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ """
+ self.aspect_trainer.train(args=args, trial=trial, **kwargs)
+
+ def train_polarity(
+ self,
+ args: Optional[TrainingArguments] = None,
+ trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Train the polarity model only.
+
+ Args:
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the aspect training arguments for this training call.
+ trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ """
+ self.polarity_trainer.train(args=args, trial=trial, **kwargs)
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of [`~transformer.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will instantiate a member of that class.
+ """
+ self.aspect_trainer.add_callback(callback)
+ self.polarity_trainer.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
+
+ If the callback is not found, returns `None` (and no error is raised).
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will pop the first member of that class found in the list of callbacks.
+
+ Returns:
+ [`Tuple[~transformer.TrainerCallback]`]: The callbacks removed from the aspect and polarity trainers, if found.
+ """
+ return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.aspect_trainer.remove_callback(callback)
+ self.polarity_trainer.remove_callback(callback)
+
+ def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None:
+ """Upload model checkpoint to the Hub using `huggingface_hub`.
+
+ See the full list of parameters for your `huggingface_hub` version in the\
+ [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub).
+
+ Args:
+ repo_id (`str`):
+ The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`.
+ repo_id (`str`):
+ The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`.
+ config (`dict`, *optional*):
+ Configuration object to be saved alongside the model weights.
+ commit_message (`str`, *optional*):
+ Message to commit while pushing.
+ private (`bool`, *optional*, defaults to `False`):
+ Whether the repository created should be private.
+ api_endpoint (`str`, *optional*):
+ The API endpoint to use when pushing the model to the hub.
+ token (`str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ If not set, will use the token set when logging in with
+ `transformers-cli login` (stored in `~/.huggingface`).
+ branch (`str`, *optional*):
+ The git branch on which to push the model. This defaults to
+ the default branch as specified in your repository, which
+ defaults to `"main"`.
+ create_pr (`boolean`, *optional*):
+ Whether or not to create a Pull Request from `branch` with that commit.
+ Defaults to `False`.
+ allow_patterns (`List[str]` or `str`, *optional*):
+ If provided, only files matching at least one pattern are pushed.
+ ignore_patterns (`List[str]` or `str`, *optional*):
+ If provided, files matching any of the patterns are not pushed.
+ """
+ return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs)
+
+ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]:
+ """
+ Computes the metrics for a given classifier.
+
+ Args:
+ dataset (`Dataset`, *optional*):
+ The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via
+ the `eval_dataset` argument at `Trainer` initialization.
+
+ Returns:
+ `Dict[str, Dict[str, float]]`: The evaluation metrics.
+ """
+ aspect_eval_dataset = polarity_eval_dataset = None
+ if dataset:
+ aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset(
+ self.model.aspect_model, self.model.polarity_model, dataset
+ )
+ return {
+ "aspect": self.aspect_trainer.evaluate(aspect_eval_dataset),
+ "polarity": self.polarity_trainer.evaluate(polarity_eval_dataset),
+ }
diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py
index 6304ce5b..baecd053 100644
--- a/src/setfit/trainer.py
+++ b/src/setfit/trainer.py
@@ -1,23 +1,56 @@
import math
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
+import shutil
+import time
+import warnings
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import evaluate
-import numpy as np
import torch
from datasets import Dataset, DatasetDict
-from sentence_transformers import InputExample, losses
+from sentence_transformers import InputExample, SentenceTransformer, losses
from sentence_transformers.datasets import SentenceLabelDataset
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
+from sentence_transformers.util import batch_to_device
+from sklearn.preprocessing import LabelEncoder
+from torch import nn
+from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
-from tqdm.auto import trange
-from transformers.trainer_utils import HPSearchBackend, default_compute_objective, number_of_arguments, set_seed
+from tqdm.autonotebook import tqdm
+from transformers.integrations import get_reporting_integration_callbacks
+from transformers.trainer_callback import (
+ CallbackHandler,
+ DefaultFlowCallback,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from transformers.trainer_utils import (
+ HPSearchBackend,
+ default_compute_objective,
+ number_of_arguments,
+ set_seed,
+ speed_metrics,
+)
+from transformers.utils.import_utils import is_in_notebook
from . import logging
from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna
-from .modeling import SupConLoss, sentence_pairs_generation, sentence_pairs_generation_multilabel
+from .losses import SupConLoss
+from .sampler import ContrastiveDataset
+from .training_args import TrainingArguments
from .utils import BestRun, default_hp_space_optuna
+# Google Colab runs on Python 3.7, so we need this to be compatible
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
if TYPE_CHECKING:
import optuna
@@ -27,123 +60,24 @@
logger = logging.get_logger(__name__)
-class SetFitTrainer:
- """Trainer to train a SetFit model.
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
- Args:
- model (`SetFitModel`, *optional*):
- The model to train. If not provided, a `model_init` must be passed.
- train_dataset (`Dataset`):
- The training dataset.
- eval_dataset (`Dataset`, *optional*):
- The evaluation dataset.
- model_init (`Callable[[], SetFitModel]`, *optional*):
- A function that instantiates the model to be used. If provided, each call to [`~SetFitTrainer.train`] will start
- from a new instance of the model as given by this function when a `trial` is passed.
- metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
- The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings.
- If a callable is provided, it must take two arguments (`y_pred`, `y_test`).
- metric_kwargs (`Dict[str, Any]`, *optional*):
- Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
- For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
- loss_class (`nn.Module`, *optional*, defaults to `CosineSimilarityLoss`):
- The loss function to use for contrastive training.
- num_iterations (`int`, *optional*, defaults to `20`):
- The number of iterations to generate sentence pairs for.
- This argument is ignored if triplet loss is used.
- It is only used in conjunction with `CosineSimilarityLoss`.
- num_epochs (`int`, *optional*, defaults to `1`):
- The number of epochs to train the Sentence Transformer body for.
- learning_rate (`float`, *optional*, defaults to `2e-5`):
- The learning rate to use for contrastive training.
- batch_size (`int`, *optional*, defaults to `16`):
- The batch size to use for contrastive training.
- seed (`int`, *optional*, defaults to 42):
- Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
- [`~SetTrainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
- column_mapping (`Dict[str, str]`, *optional*):
- A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "label_column_name: "label"}.
- use_amp (`bool`, *optional*, defaults to `False`):
- Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
- warmup_proportion (`float`, *optional*, defaults to `0.1`):
- Proportion of the warmup in the total training steps.
- Must be greater than or equal to 0.0 and less than or equal to 1.0.
- distance_metric (`Callable`, defaults to `BatchHardTripletLossDistanceFunction.cosine_distance`):
- Function that returns a distance between two embeddings.
- It is set for the triplet loss and
- is ignored for `CosineSimilarityLoss` and `SupConLoss`.
- margin (`float`, defaults to `0.25`): Margin for the triplet loss.
- Negative samples should be at least margin further apart from the anchor than the positive.
- This is ignored for `CosineSimilarityLoss`, `BatchHardSoftMarginTripletLoss` and `SupConLoss`.
- samples_per_label (`int`, defaults to `2`): Number of consecutive, random and unique samples drawn per label.
- This is only relevant for triplet loss and ignored for `CosineSimilarityLoss`.
- Batch size should be a multiple of samples_per_label.
- """
+if is_in_notebook():
+ from transformers.utils.notebook import NotebookProgressCallback
- def __init__(
- self,
- model: Optional["SetFitModel"] = None,
- train_dataset: Optional["Dataset"] = None,
- eval_dataset: Optional["Dataset"] = None,
- model_init: Optional[Callable[[], "SetFitModel"]] = None,
- metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
- metric_kwargs: Optional[Dict[str, Any]] = None,
- loss_class=losses.CosineSimilarityLoss,
- num_iterations: int = 20,
- num_epochs: int = 1,
- learning_rate: float = 2e-5,
- batch_size: int = 16,
- seed: int = 42,
- column_mapping: Optional[Dict[str, str]] = None,
- use_amp: bool = False,
- warmup_proportion: float = 0.1,
- distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance,
- margin: float = 0.25,
- samples_per_label: int = 2,
- ) -> None:
- if (warmup_proportion < 0.0) or (warmup_proportion > 1.0):
- raise ValueError(
- f"warmup_proportion must be greater than or equal to 0.0 and less than or equal to 1.0! But it was: {warmup_proportion}"
- )
+ DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
- self.train_dataset = train_dataset
- self.eval_dataset = eval_dataset
- self.model_init = model_init
- self.metric = metric
- self.metric_kwargs = metric_kwargs
- self.loss_class = loss_class
- self.num_iterations = num_iterations
- self.num_epochs = num_epochs
- self.learning_rate = learning_rate
- self.batch_size = batch_size
- self.seed = seed
- self.column_mapping = column_mapping
- self.use_amp = use_amp
- self.warmup_proportion = warmup_proportion
- self.distance_metric = distance_metric
- self.margin = margin
- self.samples_per_label = samples_per_label
- if model is None:
- if model_init is not None:
- model = self.call_model_init()
- else:
- raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument")
- else:
- if model_init is not None:
- raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument, but not both")
-
- self.model = model
- self.hp_search_backend = None
- self._freeze = True # If True, will train the body only; otherwise, train the body and head
+class ColumnMappingMixin:
+ _REQUIRED_COLUMNS = {"text", "label"}
def _validate_column_mapping(self, dataset: "Dataset") -> None:
"""
Validates the provided column mapping against the dataset.
"""
- required_columns = {"text", "label"}
column_names = set(dataset.column_names)
- if self.column_mapping is None and not required_columns.issubset(column_names):
+ if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names):
# Issue #226: load_dataset will automatically assign points to "train" if no split is specified
if column_names == {"train"} and isinstance(dataset, DatasetDict):
raise ValueError(
@@ -157,12 +91,12 @@ def _validate_column_mapping(self, dataset: "Dataset") -> None:
)
else:
raise ValueError(
- f"SetFit expected the dataset to have the columns {sorted(required_columns)}, "
+ f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, "
f"but only the columns {sorted(column_names)} were found. "
- "Either make sure these columns are present, or specify which columns to use with column_mapping in SetFitTrainer."
+ "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer."
)
if self.column_mapping is not None:
- missing_columns = required_columns.difference(self.column_mapping.values())
+ missing_columns = self._REQUIRED_COLUMNS.difference(self.column_mapping.values())
if missing_columns:
raise ValueError(
f"The following columns are missing from the column mapping: {missing_columns}. Please provide a mapping for all required columns."
@@ -181,7 +115,11 @@ def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, st
dataset = dataset.rename_columns(
{
**column_mapping,
- **{col: f"feat_{col}" for col in dataset.column_names if col not in column_mapping},
+ **{
+ col: f"feat_{col}"
+ for col in dataset.column_names
+ if col not in column_mapping and col not in self._REQUIRED_COLUMNS
+ },
}
)
dset_format = dataset.format
@@ -193,6 +131,128 @@ def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, st
)
return dataset
+
+class Trainer(ColumnMappingMixin):
+ """Trainer to train a SetFit model.
+
+ Args:
+ model (`SetFitModel`, *optional*):
+ The model to train. If not provided, a `model_init` must be passed.
+ args (`TrainingArguments`, *optional*):
+ The training arguments to use.
+ train_dataset (`Dataset`):
+ The training dataset.
+ eval_dataset (`Dataset`, *optional*):
+ The evaluation dataset.
+ model_init (`Callable[[], SetFitModel]`, *optional*):
+ A function that instantiates the model to be used. If provided, each call to
+ [`~Trainer.train`] will start from a new instance of the model as given by this
+ function when a `trial` is passed.
+ metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
+ The metric to use for evaluation. If a string is provided, we treat it as the metric
+ name and load it with default settings.
+ If a callable is provided, it must take two arguments (`y_pred`, `y_test`).
+ metric_kwargs (`Dict[str, Any]`, *optional*):
+ Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
+ For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
+ callbacks: (`List[~transformers.TrainerCallback]`, *optional*):
+ A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback).
+ If you want to remove one of the default callbacks used, use the `Trainer.remove_callback()` method.
+ column_mapping (`Dict[str, str]`, *optional*):
+ A mapping from the column names in the dataset to the column names expected by the model.
+ The expected format is a dictionary with the following format:
+ `{"text_column_name": "text", "label_column_name: "label"}`.
+ """
+
+ def __init__(
+ self,
+ model: Optional["SetFitModel"] = None,
+ args: Optional[TrainingArguments] = None,
+ train_dataset: Optional["Dataset"] = None,
+ eval_dataset: Optional["Dataset"] = None,
+ model_init: Optional[Callable[[], "SetFitModel"]] = None,
+ metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
+ metric_kwargs: Optional[Dict[str, Any]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ column_mapping: Optional[Dict[str, str]] = None,
+ ) -> None:
+ if args is not None and not isinstance(args, TrainingArguments):
+ raise ValueError("`args` must be a `TrainingArguments` instance imported from `setfit`.")
+ self.args = args or TrainingArguments()
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.model_init = model_init
+ self.metric = metric
+ self.metric_kwargs = metric_kwargs
+ self.column_mapping = column_mapping
+ self.logs_mapper = {}
+
+ # Seed must be set before instantiating the model when using model_init.
+ set_seed(12)
+
+ if model is None:
+ if model_init is not None:
+ model = self.call_model_init()
+ else:
+ raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument.")
+ else:
+ if model_init is not None:
+ raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument, but not both.")
+
+ self.model = model
+ self.hp_search_backend = None
+
+ # Setup the callbacks
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+ callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+ # TODO: Observe optimizer and scheduler by wrapping SentenceTransformer._get_scheduler
+ self.callback_handler = CallbackHandler(
+ callbacks, self.model.model_body, self.model.model_body.tokenizer, None, None
+ )
+ self.state = TrainerState()
+ self.control = TrainerControl()
+ self.add_callback(DEFAULT_PROGRESS_CALLBACK if self.args.show_progress_bar else PrinterCallback)
+ self.control = self.callback_handler.on_init_end(args, self.state, self.control)
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of [`~transformer.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will instantiate a member of that class.
+ """
+ self.callback_handler.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
+
+ If the callback is not found, returns `None` (and no error is raised).
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will pop the first member of that class found in the list of callbacks.
+
+ Returns:
+ [`~transformer.TrainerCallback`]: The callback removed, if found.
+ """
+ return self.callback_handler.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.callback_handler.remove_callback(callback)
+
def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = False) -> None:
"""Applies a dictionary of hyperparameters to both the trainer and the model
@@ -200,18 +260,11 @@ def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = Fals
params (`Dict[str, Any]`): The parameters, usually from `BestRun.hyperparameters`
final_model (`bool`, *optional*, defaults to `False`): If `True`, replace the `model_init()` function with a fixed model based on the parameters.
"""
- for key, value in params.items():
- if hasattr(self, key):
- old_attr = getattr(self, key, None)
- # Casting value to the proper type
- if old_attr is not None:
- value = type(old_attr)(value)
- setattr(self, key, value)
- elif number_of_arguments(self.model_init) == 0: # we do not warn if model_init could be using it
- logger.warning(
- f"Trying to set {key!r} in the hyperparameter search but there is no corresponding field in "
- "`SetFitTrainer`, and `model_init` does not take any arguments."
- )
+
+ if self.args is not None:
+ self.args = self.args.update(params, ignore_extra=True)
+ else:
+ self.args = TrainingArguments.from_dict(params, ignore_extra=True)
self.model = self.model_init(params)
if final_model:
@@ -248,172 +301,445 @@ def call_model_init(self, params: Optional[Dict[str, Any]] = None) -> "SetFitMod
return model
- def freeze(self) -> None:
- """
- Freeze SetFitModel's differentiable head.
- Note: call this function only when using the differentiable head.
- """
- if not self.model.has_differentiable_head:
- raise ValueError("Please use the differentiable head in `SetFitModel` when calling this function.")
+ def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None:
+ """Freeze the model body and/or the head, preventing further training on that component until unfrozen.
- self._freeze = True # Currently use self._freeze as a switch
- self.model.freeze("head")
+ This method is deprecated, use `SetFitModel.freeze` instead.
- def unfreeze(self, keep_body_frozen: bool = False) -> None:
+ Args:
+ component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to freeze that component.
+ If no component is provided, freeze both. Defaults to None.
"""
- Unfreeze SetFitModel's differentiable head.
- Note: call this function only when using the differentiable head.
+ warnings.warn(
+ f"`{self.__class__.__name__}.freeze` is deprecated and will be removed in v2.0.0 of SetFit. "
+ "Please use `SetFitModel.freeze` directly instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.model.freeze(component)
+
+ def unfreeze(
+ self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None
+ ) -> None:
+ """Unfreeze the model body and/or the head, allowing further training on that component.
+
+ This method is deprecated, use `SetFitModel.unfreeze` instead.
Args:
- keep_body_frozen (`bool`, *optional*, defaults to `False`):
- Whether to freeze the body when unfreeze the head.
+ component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to unfreeze that component.
+ If no component is provided, unfreeze both. Defaults to None.
+ keep_body_frozen (`bool`, *optional*): Deprecated argument, use `component` instead.
"""
- if not self.model.has_differentiable_head:
- raise ValueError("Please use the differentiable head in `SetFitModel` when calling this function.")
-
- self._freeze = False # Currently use self._freeze as a switch
- self.model.unfreeze("head")
- if keep_body_frozen:
- self.model.freeze("body")
- else: # ensure to unfreeze the body
- self.model.unfreeze("body")
+ warnings.warn(
+ f"`{self.__class__.__name__}.unfreeze` is deprecated and will be removed in v2.0.0 of SetFit. "
+ "Please use `SetFitModel.unfreeze` directly instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.model.unfreeze(component, keep_body_frozen=keep_body_frozen)
def train(
self,
- num_epochs: Optional[int] = None,
- batch_size: Optional[int] = None,
- learning_rate: Optional[float] = None,
- body_learning_rate: Optional[float] = None,
- l2_weight: Optional[float] = None,
- max_length: Optional[int] = None,
+ args: Optional[TrainingArguments] = None,
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
- show_progress_bar: bool = True,
+ **kwargs,
) -> None:
"""
Main training entry point.
Args:
- num_epochs (`int`, *optional*):
- Temporary change the number of epochs to train the Sentence Transformer body/head for.
- If ignore, will use the value given in initialization.
- batch_size (`int`, *optional*):
- Temporary change the batch size to use for contrastive training or logistic regression.
- If ignore, will use the value given in initialization.
- learning_rate (`float`, *optional*):
- Temporary change the learning rate to use for contrastive training or SetFitModel's head in logistic regression.
- If ignore, will use the value given in initialization.
- body_learning_rate (`float`, *optional*):
- Temporary change the learning rate to use for SetFitModel's body in logistic regression only.
- If ignore, will be the same as `learning_rate`.
- l2_weight (`float`, *optional*):
- Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression.
- max_length (int, *optional*, defaults to `None`):
- The maximum number of tokens for one data sample. Currently only for training the differentiable head.
- If `None`, will use the maximum number of tokens the model body can accept.
- If `max_length` is greater than the maximum number of acceptable tokens the model body can accept, it will be set to the maximum number of acceptable tokens.
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the training arguments for this training call.
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
The trial run or the hyperparameter dictionary for hyperparameter search.
- show_progress_bar (`bool`, *optional*, defaults to `True`):
- Whether to show a bar that indicates training progress.
"""
- set_seed(self.seed) # Seed must be set before instantiating the model when using model_init.
+ if len(kwargs):
+ warnings.warn(
+ f"`{self.__class__.__name__}.train` does not accept keyword arguments anymore. "
+ f"Please provide training arguments via a `TrainingArguments` instance to the `{self.__class__.__name__}` "
+ f"initialisation or the `{self.__class__.__name__}.train` method.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ args = args or self.args or TrainingArguments()
+
+ # Seed must be set before instantiating the model when using model_init.
+ set_seed(args.seed)
if trial: # Trial and model initialization
self._hp_search_setup(trial) # sets trainer parameters and initializes model
if self.train_dataset is None:
- raise ValueError("Training requires a `train_dataset` given to the `SetFitTrainer` initialization.")
+ raise ValueError(
+ f"Training requires a `train_dataset` given to the `{self.__class__.__name__}` initialization."
+ )
- self._validate_column_mapping(self.train_dataset)
- train_dataset = self.train_dataset
- if self.column_mapping is not None:
- logger.info("Applying column mapping to training dataset")
- train_dataset = self._apply_column_mapping(self.train_dataset, self.column_mapping)
-
- x_train = train_dataset["text"]
- y_train = train_dataset["label"]
- if self.loss_class is None:
- logger.warning("No `loss_class` detected! Using `CosineSimilarityLoss` as the default.")
- self.loss_class = losses.CosineSimilarityLoss
-
- num_epochs = num_epochs or self.num_epochs
- batch_size = batch_size or self.batch_size
- learning_rate = learning_rate or self.learning_rate
-
- if not self.model.has_differentiable_head or self._freeze:
- # sentence-transformers adaptation
- if self.loss_class in [
- losses.BatchAllTripletLoss,
- losses.BatchHardTripletLoss,
- losses.BatchSemiHardTripletLoss,
- losses.BatchHardSoftMarginTripletLoss,
- SupConLoss,
- ]:
- train_examples = [InputExample(texts=[text], label=label) for text, label in zip(x_train, y_train)]
- train_data_sampler = SentenceLabelDataset(train_examples, samples_per_label=self.samples_per_label)
-
- batch_size = min(batch_size, len(train_data_sampler))
- train_dataloader = DataLoader(train_data_sampler, batch_size=batch_size, drop_last=True)
-
- if self.loss_class is losses.BatchHardSoftMarginTripletLoss:
- train_loss = self.loss_class(
- model=self.model.model_body,
- distance_metric=self.distance_metric,
- )
- elif self.loss_class is SupConLoss:
- train_loss = self.loss_class(model=self.model.model_body)
+ parameters = []
+ for dataset, dataset_name in [(self.train_dataset, "training"), (self.eval_dataset, "evaluation")]:
+ if dataset is None:
+ continue
+
+ self._validate_column_mapping(dataset)
+ if self.column_mapping is not None:
+ logger.info(f"Applying column mapping to {dataset_name} dataset")
+ dataset = self._apply_column_mapping(dataset, self.column_mapping)
+
+ parameters.extend(self.dataset_to_parameters(dataset))
+
+ self.train_embeddings(*parameters, args=args)
+ training_parameters = parameters[: len(parameters) // 2] if self.eval_dataset else parameters
+ self.train_classifier(*training_parameters, args=args)
+
+ def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]:
+ return [dataset["text"], dataset["label"]]
+
+ def train_embeddings(
+ self,
+ x_train: List[str],
+ y_train: Optional[Union[List[int], List[List[int]]]] = None,
+ x_eval: Optional[List[str]] = None,
+ y_eval: Optional[Union[List[int], List[List[int]]]] = None,
+ args: Optional[TrainingArguments] = None,
+ ) -> None:
+ """
+ Method to perform the embedding phase: finetuning the `SentenceTransformer` body.
+
+ Args:
+ x_train (`List[str]`): A list of training sentences.
+ y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences.
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the training arguments for this training call.
+ """
+ args = args or self.args or TrainingArguments()
+ # Since transformers v4.32.0, the log/eval/save steps should be saved on the state instead
+ self.state.logging_steps = args.logging_steps
+ self.state.eval_steps = args.eval_steps
+ self.state.save_steps = args.save_steps
+
+ train_dataloader, loss_func, batch_size = self.get_dataloader(x_train, y_train, args=args)
+ if x_eval is not None:
+ eval_dataloader, _, _ = self.get_dataloader(x_eval, y_eval, args=args)
+ else:
+ eval_dataloader = None
+
+ total_train_steps = len(train_dataloader) * args.embedding_num_epochs
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataloader)}")
+ logger.info(f" Num epochs = {args.embedding_num_epochs}")
+ logger.info(f" Total optimization steps = {total_train_steps}")
+ logger.info(f" Total train batch size = {batch_size}")
+
+ warmup_steps = math.ceil(total_train_steps * args.warmup_proportion)
+ self._train_sentence_transformer(
+ self.model.model_body,
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ args=args,
+ loss_func=loss_func,
+ warmup_steps=warmup_steps,
+ )
+
+ def get_dataloader(
+ self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments
+ ) -> Tuple[DataLoader, nn.Module, int]:
+ # sentence-transformers adaptation
+ input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)]
+
+ if args.loss in [
+ losses.BatchAllTripletLoss,
+ losses.BatchHardTripletLoss,
+ losses.BatchSemiHardTripletLoss,
+ losses.BatchHardSoftMarginTripletLoss,
+ SupConLoss,
+ ]:
+ data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.samples_per_label)
+ batch_size = min(args.embedding_batch_size, len(data_sampler))
+ dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True)
+
+ if args.loss is losses.BatchHardSoftMarginTripletLoss:
+ loss = args.loss(
+ model=self.model.model_body,
+ distance_metric=args.distance_metric,
+ )
+ elif args.loss is SupConLoss:
+ loss = args.loss(model=self.model.model_body)
+ else:
+ loss = args.loss(
+ model=self.model.model_body,
+ distance_metric=args.distance_metric,
+ margin=args.margin,
+ )
+ else:
+ data_sampler = ContrastiveDataset(
+ input_data, self.model.multi_target_strategy, args.num_iterations, args.sampling_strategy
+ )
+ # shuffle_sampler = True can be dropped in for further 'randomising'
+ shuffle_sampler = True if args.sampling_strategy == "unique" else False
+ batch_size = min(args.embedding_batch_size, len(data_sampler))
+ dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
+ loss = args.loss(self.model.model_body)
+
+ return dataloader, loss, batch_size
+
+ def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training.
+
+ Subclass and override this method to inject custom behavior.
+
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ logs = {self.logs_mapper.get(key, key): value for key, value in logs.items()}
+ if self.state.epoch is not None:
+ logs["epoch"] = round(self.state.epoch, 2)
+
+ output = {**logs, **{"step": self.state.global_step}}
+ self.state.log_history.append(output)
+ return self.callback_handler.on_log(args, self.state, self.control, logs)
+
+ def _set_logs_mapper(self, logs_mapper: Dict[str, str]) -> None:
+ """Set the logging mapper.
+
+ Args:
+ logs_mapper (str): The logging mapper, e.g. {"eval_embedding_loss": "eval_aspect_embedding_loss"}.
+ """
+ self.logs_mapper = logs_mapper
+
+ def _train_sentence_transformer(
+ self,
+ model_body: SentenceTransformer,
+ train_dataloader: DataLoader,
+ eval_dataloader: Optional[DataLoader],
+ args: TrainingArguments,
+ loss_func: nn.Module,
+ warmup_steps: int = 10000,
+ ) -> None:
+ """
+ Train the model with the given training objective
+ Each training objective is sampled in turn for one batch.
+ We sample only as many batches from each objective as there are in the smallest one
+ to make sure of equal training with each dataset.
+ """
+ # TODO: args.gradient_accumulation_steps
+ # TODO: fp16/bf16, etc.
+ # TODO: Safetensors
+
+ # Hardcoded training arguments
+ max_grad_norm = 1
+ weight_decay = 0.01
+
+ self.state.epoch = 0
+ start_time = time.time()
+ if args.max_steps > 0:
+ self.state.max_steps = args.max_steps
+ else:
+ self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ if args.use_amp:
+ scaler = torch.cuda.amp.GradScaler()
+
+ model_body.to(model_body._target_device)
+ loss_func.to(model_body._target_device)
+
+ # Use smart batching
+ train_dataloader.collate_fn = model_body.smart_batching_collate
+ if eval_dataloader:
+ eval_dataloader.collate_fn = model_body.smart_batching_collate
+
+ steps_per_epoch = len(train_dataloader)
+ num_train_steps = int(steps_per_epoch * args.embedding_num_epochs)
+
+ # Prepare optimizers
+ param_optimizer = list(loss_func.named_parameters())
+
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
+ "weight_decay": weight_decay,
+ },
+ {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
+ ]
+
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **{"lr": args.body_embedding_learning_rate})
+ scheduler_obj = model_body._get_scheduler(
+ optimizer, scheduler="WarmupLinear", warmup_steps=warmup_steps, t_total=num_train_steps
+ )
+
+ data_iterator = iter(train_dataloader)
+ skip_scheduler = False
+ for epoch in range(args.embedding_num_epochs):
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ loss_func.zero_grad()
+ loss_func.train()
+
+ for step in range(steps_per_epoch):
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ try:
+ data = next(data_iterator)
+ except StopIteration:
+ data_iterator = iter(train_dataloader)
+ data = next(data_iterator)
+
+ features, labels = data
+ labels = labels.to(model_body._target_device)
+ features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
+
+ if args.use_amp:
+ with autocast():
+ loss_value = loss_func(features, labels)
+
+ scale_before_step = scaler.get_scale()
+ scaler.scale(loss_value).backward()
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm)
+ scaler.step(optimizer)
+ scaler.update()
+
+ skip_scheduler = scaler.get_scale() != scale_before_step
else:
- train_loss = self.loss_class(
- model=self.model.model_body,
- distance_metric=self.distance_metric,
- margin=self.margin,
+ loss_value = loss_func(features, labels)
+ loss_value.backward()
+ torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm)
+ optimizer.step()
+
+ optimizer.zero_grad()
+
+ if not skip_scheduler:
+ scheduler_obj.step()
+
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1) / steps_per_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ if self.control.should_log:
+ learning_rate = scheduler_obj.get_last_lr()[0]
+ metrics = {"embedding_loss": round(loss_value.item(), 4), "learning_rate": learning_rate}
+ self.control = self.log(args, metrics)
+
+ eval_loss = None
+ if self.control.should_evaluate and eval_dataloader is not None:
+ eval_loss = self._evaluate_with_loss(model_body, eval_dataloader, args, loss_func)
+ learning_rate = scheduler_obj.get_last_lr()[0]
+ metrics = {"eval_embedding_loss": round(eval_loss, 4), "learning_rate": learning_rate}
+ self.control = self.log(args, metrics)
+
+ self.control = self.callback_handler.on_evaluate(args, self.state, self.control, metrics)
+
+ loss_func.zero_grad()
+ loss_func.train()
+
+ if self.control.should_save:
+ checkpoint_dir = self._checkpoint(
+ self.args.output_dir, args.save_total_limit, self.state.global_step
)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ if eval_loss is not None and (
+ self.state.best_metric is None or eval_loss < self.state.best_metric
+ ):
+ self.state.best_metric = eval_loss
+ self.state.best_model_checkpoint = checkpoint_dir
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+
+ if self.control.should_training_stop:
+ break
+
+ if self.args.load_best_model_at_end and self.state.best_model_checkpoint:
+ dir_name = Path(self.state.best_model_checkpoint).name
+ if dir_name.startswith("step_"):
+ logger.info(f"Loading best SentenceTransformer model from step {dir_name[5:]}.")
+ self.model.model_body = SentenceTransformer(self.state.best_model_checkpoint, device=model_body.device)
+
+ # Ensure logging the speed metrics
+ num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps
+ metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
+ self.control.should_log = True
+ self.log(args, metrics)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ def _evaluate_with_loss(
+ self,
+ model_body: SentenceTransformer,
+ eval_dataloader: DataLoader,
+ args: TrainingArguments,
+ loss_func: nn.Module,
+ ) -> float:
+ model_body.eval()
+ losses = []
+ for data in tqdm(
+ iter(eval_dataloader), total=len(eval_dataloader), leave=False, disable=not args.show_progress_bar
+ ):
+ features, labels = data
+ labels = labels.to(model_body._target_device)
+ features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
+
+ if args.use_amp:
+ with autocast():
+ loss_value = loss_func(features, labels)
+
+ losses.append(loss_value.item())
else:
- train_examples = []
-
- for _ in trange(self.num_iterations, desc="Generating Training Pairs", disable=not show_progress_bar):
- if self.model.multi_target_strategy is not None:
- train_examples = sentence_pairs_generation_multilabel(
- np.array(x_train), np.array(y_train), train_examples
- )
- else:
- train_examples = sentence_pairs_generation(
- np.array(x_train), np.array(y_train), train_examples
- )
-
- train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
- train_loss = self.loss_class(self.model.model_body)
-
- total_train_steps = len(train_dataloader) * num_epochs
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {len(train_examples)}")
- logger.info(f" Num epochs = {num_epochs}")
- logger.info(f" Total optimization steps = {total_train_steps}")
- logger.info(f" Total train batch size = {batch_size}")
-
- warmup_steps = math.ceil(total_train_steps * self.warmup_proportion)
- self.model.model_body.fit(
- train_objectives=[(train_dataloader, train_loss)],
- epochs=num_epochs,
- optimizer_params={"lr": learning_rate},
- warmup_steps=warmup_steps,
- show_progress_bar=show_progress_bar,
- use_amp=self.use_amp,
- )
+ losses.append(loss_func(features, labels).item())
+
+ model_body.train()
+ return sum(losses) / len(losses)
+
+ def _checkpoint(self, checkpoint_path: str, checkpoint_save_total_limit: int, step: int) -> None:
+ # Delete old checkpoints
+ if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0:
+ old_checkpoints = []
+ for subdir in Path(checkpoint_path).glob("step_*"):
+ if subdir.name[5:].isdigit() and (
+ self.state.best_model_checkpoint is None or subdir != Path(self.state.best_model_checkpoint)
+ ):
+ old_checkpoints.append({"step": int(subdir.name[5:]), "path": str(subdir)})
+
+ if len(old_checkpoints) > checkpoint_save_total_limit - 1:
+ old_checkpoints = sorted(old_checkpoints, key=lambda x: x["step"])
+ shutil.rmtree(old_checkpoints[0]["path"])
+
+ checkpoint_file_path = str(Path(checkpoint_path) / f"step_{step}")
+ self.model.save_pretrained(checkpoint_file_path)
+ return checkpoint_file_path
+
+ def train_classifier(
+ self, x_train: List[str], y_train: Union[List[int], List[List[int]]], args: Optional[TrainingArguments] = None
+ ) -> None:
+ """
+ Method to perform the classifier phase: fitting a classifier head.
- if not self.model.has_differentiable_head or not self._freeze:
- # Train the final classifier
- self.model.fit(
- x_train,
- y_train,
- num_epochs=num_epochs,
- batch_size=batch_size,
- learning_rate=learning_rate,
- body_learning_rate=body_learning_rate,
- l2_weight=l2_weight,
- max_length=max_length,
- show_progress_bar=True,
- )
+ Args:
+ x_train (`List[str]`): A list of training sentences.
+ y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences.
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the training arguments for this training call.
+ """
+ args = args or self.args or TrainingArguments()
+
+ self.model.fit(
+ x_train,
+ y_train,
+ num_epochs=args.classifier_num_epochs,
+ batch_size=args.classifier_batch_size,
+ body_learning_rate=args.body_classifier_learning_rate,
+ head_learning_rate=args.head_learning_rate,
+ l2_weight=args.l2_weight,
+ max_length=args.max_length,
+ show_progress_bar=args.show_progress_bar,
+ end_to_end=args.end_to_end,
+ )
def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]:
"""
@@ -421,13 +747,16 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]:
Args:
dataset (`Dataset`, *optional*):
- The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed in the eval_dataset argument at `SetFitTrainer` initialization.
+ The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via
+ the `eval_dataset` argument at `Trainer` initialization.
Returns:
`Dict[str, float]`: The evaluation metrics.
"""
eval_dataset = dataset or self.eval_dataset
+ if eval_dataset is None:
+ raise ValueError("No evaluation dataset provided to `Trainer.evaluate` nor the `Trainer` initialzation.")
self._validate_column_mapping(eval_dataset)
if self.column_mapping is not None:
@@ -442,6 +771,13 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]:
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.cpu()
+ # Normalize string outputs
+ if y_test and isinstance(y_test[0], str):
+ encoder = LabelEncoder()
+ encoder.fit(list(y_test) + list(y_pred))
+ y_test = encoder.transform(y_test)
+ y_pred = encoder.transform(y_pred)
+
if isinstance(self.metric, str):
metric_config = "multilabel" if self.model.multi_target_strategy is not None else None
metric_fn = evaluate.load(self.metric, config_name=metric_config)
@@ -472,7 +808,7 @@ def hyperparameter_search(
- To use this method, you need to have provided a `model_init` when initializing your [`SetFitTrainer`]: we need to
+ To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
reinitialize the model at each new run.
@@ -507,7 +843,7 @@ def hyperparameter_search(
if backend is None:
backend = default_hp_search_backend()
if backend is None:
- raise RuntimeError("optuna should be installed. " "To install optuna run `pip install optuna`. ")
+ raise RuntimeError("optuna should be installed. To install optuna run `pip install optuna`.")
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
@@ -539,7 +875,7 @@ def push_to_hub(self, repo_id: str, **kwargs) -> str:
Args:
repo_id (`str`):
- The full repository ID to push to, e.g. `"tomaarsen/setfit_sst2"`.
+ The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`.
config (`dict`, *optional*):
Configuration object to be saved alongside the model weights.
commit_message (`str`, *optional*):
@@ -569,7 +905,66 @@ def push_to_hub(self, repo_id: str, **kwargs) -> str:
"""
if "/" not in repo_id:
raise ValueError(
- '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit_sst2".'
+ '`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-sst2".'
)
commit_message = kwargs.pop("commit_message", "Add SetFit model")
return self.model.push_to_hub(repo_id, commit_message=commit_message, **kwargs)
+
+
+class SetFitTrainer(Trainer):
+ """
+ `SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit.
+ Please use `Trainer` instead.
+ """
+
+ def __init__(
+ self,
+ model: Optional["SetFitModel"] = None,
+ train_dataset: Optional["Dataset"] = None,
+ eval_dataset: Optional["Dataset"] = None,
+ model_init: Optional[Callable[[], "SetFitModel"]] = None,
+ metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
+ metric_kwargs: Optional[Dict[str, Any]] = None,
+ loss_class=losses.CosineSimilarityLoss,
+ num_iterations: int = 20,
+ num_epochs: int = 1,
+ learning_rate: float = 2e-5,
+ batch_size: int = 16,
+ seed: int = 42,
+ column_mapping: Optional[Dict[str, str]] = None,
+ use_amp: bool = False,
+ warmup_proportion: float = 0.1,
+ distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance,
+ margin: float = 0.25,
+ samples_per_label: int = 2,
+ ):
+ warnings.warn(
+ "`SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. "
+ "Please use `Trainer` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ args = TrainingArguments(
+ num_iterations=num_iterations,
+ num_epochs=num_epochs,
+ body_learning_rate=learning_rate,
+ head_learning_rate=learning_rate,
+ batch_size=batch_size,
+ seed=seed,
+ use_amp=use_amp,
+ warmup_proportion=warmup_proportion,
+ distance_metric=distance_metric,
+ margin=margin,
+ samples_per_label=samples_per_label,
+ loss=loss_class,
+ )
+ super().__init__(
+ model=model,
+ args=args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ model_init=model_init,
+ metric=metric,
+ metric_kwargs=metric_kwargs,
+ column_mapping=column_mapping,
+ )
diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py
index ca194066..da5ec4b8 100644
--- a/src/setfit/trainer_distillation.py
+++ b/src/setfit/trainer_distillation.py
@@ -1,245 +1,162 @@
-import math
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
+import warnings
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union
-import numpy as np
import torch
+from datasets import Dataset
from sentence_transformers import InputExample, losses, util
-from sentence_transformers.datasets import SentenceLabelDataset
-from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
+from torch import nn
from torch.utils.data import DataLoader
-from transformers.trainer_utils import set_seed
-from . import SetFitTrainer, logging
-from .modeling import SupConLoss, sentence_pairs_generation_cos_sim
+from . import logging
+from .sampler import ContrastiveDistillationDataset
+from .trainer import Trainer
+from .training_args import TrainingArguments
if TYPE_CHECKING:
- import optuna
- from datasets import Dataset
-
from .modeling import SetFitModel
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
-class DistillationSetFitTrainer(SetFitTrainer):
+class DistillationTrainer(Trainer):
"""Trainer to compress a SetFit model with knowledge distillation.
Args:
teacher_model (`SetFitModel`):
The teacher model to mimic.
+ student_model (`SetFitModel`, *optional*):
+ The model to train. If not provided, a `model_init` must be passed.
+ args (`TrainingArguments`, *optional*):
+ The training arguments to use.
train_dataset (`Dataset`):
The training dataset.
- student_model (`SetFitModel`):
- The student model to train. If not provided, a `model_init` must be passed.
eval_dataset (`Dataset`, *optional*):
The evaluation dataset.
model_init (`Callable[[], SetFitModel]`, *optional*):
- A function that instantiates the model to be used. If provided, each call to [`~DistillationSetFitTrainer.train`] will start
- from a new instance of the model as given by this function when a `trial` is passed.
+ A function that instantiates the model to be used. If provided, each call to
+ [`~DistillationTrainer.train`] will start from a new instance of the model as given by this
+ function when a `trial` is passed.
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
- The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings.
+ The metric to use for evaluation. If a string is provided, we treat it as the metric
+ name and load it with default settings.
If a callable is provided, it must take two arguments (`y_pred`, `y_test`).
- loss_class (`nn.Module`, *optional*, defaults to `CosineSimilarityLoss`):
- The loss function to use for contrastive training.
- num_iterations (`int`, *optional*, defaults to `20`):
- The number of iterations to generate sentence pairs for.
- num_epochs (`int`, *optional*, defaults to `1`):
- The number of epochs to train the Sentence Transformer body for.
- learning_rate (`float`, *optional*, defaults to `2e-5`):
- The learning rate to use for contrastive training.
- batch_size (`int`, *optional*, defaults to `16`):
- The batch size to use for contrastive training.
- seed (`int`, *optional*, defaults to 42):
- Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
- [`~SetTrainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
column_mapping (`Dict[str, str]`, *optional*):
- A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "label_column_name: "label"}.
- use_amp (`bool`, *optional*, defaults to `False`):
- Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
- warmup_proportion (`float`, *optional*, defaults to `0.1`):
- Proportion of the warmup in the total training steps.
- Must be greater than or equal to 0.0 and less than or equal to 1.0.
+ A mapping from the column names in the dataset to the column names expected by the model.
+ The expected format is a dictionary with the following format:
+ `{"text_column_name": "text", "label_column_name: "label"}`.
"""
+ _REQUIRED_COLUMNS = {"text"}
+
def __init__(
self,
teacher_model: "SetFitModel",
student_model: Optional["SetFitModel"] = None,
+ args: TrainingArguments = None,
train_dataset: Optional["Dataset"] = None,
eval_dataset: Optional["Dataset"] = None,
model_init: Optional[Callable[[], "SetFitModel"]] = None,
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
- loss_class: torch.nn.Module = losses.CosineSimilarityLoss,
- num_iterations: int = 20,
- num_epochs: int = 1,
- learning_rate: float = 2e-5,
- batch_size: int = 16,
- seed: int = 42,
column_mapping: Optional[Dict[str, str]] = None,
- use_amp: bool = False,
- warmup_proportion: float = 0.1,
) -> None:
- super(DistillationSetFitTrainer, self).__init__(
+ super().__init__(
model=student_model,
+ args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model_init=model_init,
metric=metric,
- loss_class=loss_class,
- num_iterations=num_iterations,
- num_epochs=num_epochs,
- learning_rate=learning_rate,
- batch_size=batch_size,
- seed=seed,
column_mapping=column_mapping,
- use_amp=use_amp,
- warmup_proportion=warmup_proportion,
)
self.teacher_model = teacher_model
self.student_model = self.model
- def train(
- self,
- num_epochs: Optional[int] = None,
- batch_size: Optional[int] = None,
- learning_rate: Optional[float] = None,
- body_learning_rate: Optional[float] = None,
- l2_weight: Optional[float] = None,
- trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
- show_progress_bar: bool = True,
- ):
+ def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]:
+ return [dataset["text"]]
+
+ def get_dataloader(
+ self, x: List[str], y: Optional[Union[List[int], List[List[int]]]], args: TrainingArguments
+ ) -> Tuple[DataLoader, nn.Module, int]:
+ x_embd_student = self.teacher_model.model_body.encode(
+ x, convert_to_tensor=self.teacher_model.has_differentiable_head
+ )
+ cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student)
+
+ input_data = [InputExample(texts=[text]) for text in x]
+ data_sampler = ContrastiveDistillationDataset(
+ input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy
+ )
+ # shuffle_sampler = True can be dropped in for further 'randomising'
+ shuffle_sampler = True if args.sampling_strategy == "unique" else False
+ batch_size = min(args.embedding_batch_size, len(data_sampler))
+ dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
+ loss = args.loss(self.model.model_body)
+ return dataloader, loss, batch_size
+
+ def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None:
"""
- Main training entry point.
+ Method to perform the classifier phase: fitting the student classifier head.
Args:
- num_epochs (`int`, *optional*):
- Temporary change the number of epochs to train the Sentence Transformer body/head for.
- If ignore, will use the value given in initialization.
- batch_size (`int`, *optional*):
- Temporary change the batch size to use for contrastive training or logistic regression.
- If ignore, will use the value given in initialization.
- learning_rate (`float`, *optional*):
- Temporary change the learning rate to use for contrastive training or SetFitModel's head in logistic regression.
- If ignore, will use the value given in initialization.
- body_learning_rate (`float`, *optional*):
- Temporary change the learning rate to use for SetFitModel's body in logistic regression only.
- If ignore, will be the same as `learning_rate`.
- l2_weight (`float`, *optional*):
- Temporary change the weight of L2 regularization for SetFitModel's differentiable head in logistic regression.
- trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
- The trial run or the hyperparameter dictionary for hyperparameter search.
- show_progress_bar (`bool`, *optional*, defaults to `True`):
- Whether to show a bar that indicates training progress.
+ x_train (`List[str]`): A list of training sentences.
+ args (`TrainingArguments`, *optional*):
+ Temporarily change the training arguments for this training call.
"""
- set_seed(self.seed) # Seed must be set before instantiating the model when using model_init.
-
- if trial: # Trial and model initialization
- self._hp_search_setup(trial) # sets trainer parameters and initializes model
-
- if self.train_dataset is None:
- raise ValueError(
- "Training requires a `train_dataset` given to the `DistillationSetFitTrainer` initialization."
- )
-
- self._validate_column_mapping(self.train_dataset)
- train_dataset = self.train_dataset
- if self.column_mapping is not None:
- logger.info("Applying column mapping to training dataset")
- train_dataset = self._apply_column_mapping(self.train_dataset, self.column_mapping)
-
- x_train = train_dataset["text"]
- y_train = train_dataset["label"]
- if self.loss_class is None:
- logger.warning("No `loss_class` detected! Using `CosineSimilarityLoss` as the default.")
- self.loss_class = losses.CosineSimilarityLoss
-
- num_epochs = num_epochs or self.num_epochs
- batch_size = batch_size or self.batch_size
- learning_rate = learning_rate or self.learning_rate
-
- if not self.student_model.has_differentiable_head or self._freeze:
- # sentence-transformers adaptation
- if self.loss_class in [
- losses.BatchAllTripletLoss,
- losses.BatchHardTripletLoss,
- losses.BatchSemiHardTripletLoss,
- losses.BatchHardSoftMarginTripletLoss,
- SupConLoss,
- ]:
- train_examples = [InputExample(texts=[text], label=label) for text, label in zip(x_train, y_train)]
- train_data_sampler = SentenceLabelDataset(train_examples)
-
- batch_size = min(batch_size, len(train_data_sampler))
- train_dataloader = DataLoader(train_data_sampler, batch_size=batch_size, drop_last=True)
-
- if self.loss_class is losses.BatchHardSoftMarginTripletLoss:
- train_loss = self.loss_class(
- model=self.student_model,
- distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
- )
- elif self.loss_class is SupConLoss:
- train_loss = self.loss_class(model=self.student_model)
- else:
- train_loss = self.loss_class(
- model=self.student_model,
- distance_metric=BatchHardTripletLossDistanceFunction.cosine_distance,
- margin=0.25,
- )
- else:
- train_examples = []
-
- # **************** student training ****************
- x_train_embd_student = self.teacher_model.model_body.encode(
- x_train, convert_to_tensor=self.teacher_model.has_differentiable_head
- )
- y_train = self.teacher_model.model_head.predict(x_train_embd_student)
- if not self.teacher_model.has_differentiable_head and self.student_model.has_differentiable_head:
- y_train = torch.from_numpy(y_train)
- elif self.teacher_model.has_differentiable_head and not self.student_model.has_differentiable_head:
- y_train = y_train.detach().cpu().numpy()
-
- cos_sim_matrix = util.cos_sim(x_train_embd_student, x_train_embd_student)
-
- train_examples = []
- for _ in range(self.num_iterations):
- train_examples = sentence_pairs_generation_cos_sim(
- np.array(x_train), train_examples, cos_sim_matrix
- )
-
- # **************** student training END ****************
-
- train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
- train_loss = self.loss_class(self.student_model.model_body)
-
- total_train_steps = len(train_dataloader) * num_epochs
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {len(train_examples)}")
- logger.info(f" Num epochs = {num_epochs}")
- logger.info(f" Total optimization steps = {total_train_steps}")
- logger.info(f" Total train batch size = {batch_size}")
-
- warmup_steps = math.ceil(total_train_steps * self.warmup_proportion)
- self.student_model.model_body.fit(
- train_objectives=[(train_dataloader, train_loss)],
- epochs=num_epochs,
- optimizer_params={"lr": learning_rate},
- warmup_steps=warmup_steps,
- show_progress_bar=show_progress_bar,
- use_amp=self.use_amp,
- )
-
- if not self.student_model.has_differentiable_head or not self._freeze:
- # Train the final classifier
- self.student_model.fit(
- x_train,
- y_train,
- num_epochs=num_epochs,
- batch_size=batch_size,
- learning_rate=learning_rate,
- body_learning_rate=body_learning_rate,
- l2_weight=l2_weight,
- show_progress_bar=show_progress_bar,
- )
+ y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head)
+ return super().train_classifier(x_train, y_train, args)
+
+
+class DistillationSetFitTrainer(DistillationTrainer):
+ """
+ `DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit.
+ Please use `DistillationTrainer` instead.
+ """
+
+ def __init__(
+ self,
+ teacher_model: "SetFitModel",
+ student_model: Optional["SetFitModel"] = None,
+ train_dataset: Optional["Dataset"] = None,
+ eval_dataset: Optional["Dataset"] = None,
+ model_init: Optional[Callable[[], "SetFitModel"]] = None,
+ metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
+ loss_class: torch.nn.Module = losses.CosineSimilarityLoss,
+ num_iterations: int = 20,
+ num_epochs: int = 1,
+ learning_rate: float = 2e-5,
+ batch_size: int = 16,
+ seed: int = 42,
+ column_mapping: Optional[Dict[str, str]] = None,
+ use_amp: bool = False,
+ warmup_proportion: float = 0.1,
+ ) -> None:
+ warnings.warn(
+ "`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. "
+ "Please use `DistillationTrainer` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ args = TrainingArguments(
+ num_iterations=num_iterations,
+ num_epochs=num_epochs,
+ body_learning_rate=learning_rate,
+ head_learning_rate=learning_rate,
+ batch_size=batch_size,
+ seed=seed,
+ use_amp=use_amp,
+ warmup_proportion=warmup_proportion,
+ loss=loss_class,
+ )
+ super().__init__(
+ teacher_model=teacher_model,
+ student_model=student_model,
+ args=args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ model_init=model_init,
+ metric=metric,
+ column_mapping=column_mapping,
+ )
diff --git a/src/setfit/training_args.py b/src/setfit/training_args.py
new file mode 100644
index 00000000..9ed24fb7
--- /dev/null
+++ b/src/setfit/training_args.py
@@ -0,0 +1,344 @@
+from __future__ import annotations
+
+import inspect
+import json
+from copy import copy
+from dataclasses import dataclass, field, fields
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from sentence_transformers import losses
+from transformers import IntervalStrategy
+from transformers.integrations import get_available_reporting_integrations
+from transformers.training_args import default_logdir
+from transformers.utils import is_torch_available
+
+from . import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class TrainingArguments:
+ """
+ TrainingArguments is the subset of the arguments which relate to the training loop itself.
+
+ Parameters:
+ output_dir (`str`, defaults to `"checkpoints"`):
+ The output directory where the model predictions and checkpoints will be written.
+ batch_size (`Union[int, Tuple[int, int]]`, defaults to `(16, 2)`):
+ Set the batch sizes for the embedding and classifier training phases respectively,
+ or set both if an integer is provided.
+ Note that the batch size for the classifier is only used with a differentiable PyTorch head.
+ num_epochs (`Union[int, Tuple[int, int]]`, defaults to `(1, 16)`):
+ Set the number of epochs the embedding and classifier training phases respectively,
+ or set both if an integer is provided.
+ Note that the number of epochs for the classifier is only used with a differentiable PyTorch head.
+ max_steps (`int`, *optional*, defaults to `-1`):
+ If set to a positive number, the total number of training steps to perform. Overrides `num_epochs`.
+ The training may stop before reaching the set number of steps when all data is exhausted.
+ sampling_strategy (`str`, defaults to `"oversampling"`):
+ The sampling strategy of how to draw pairs in training. Possible values are:
+
+ - `"oversampling"`: Draws even number of positive/ negative sentence pairs until every
+ sentence pair has been drawn.
+ - `"undersampling"`: Draws the minimum number of positive/ negative sentence pairs until
+ every sentence pair in the minority class has been drawn.
+ - `"unique"`: Draws every sentence pair combination (likely resulting in unbalanced
+ number of positive/ negative sentence pairs).
+
+ The default is set to `"oversampling"`, ensuring all sentence pairs are drawn at least once.
+ Alternatively setting `num_iterations` will override this argument and determine the number
+ of generated sentence pairs.
+ num_iterations (`int`, *optional*):
+ If not set the `sampling_strategy` will determine the number of sentence pairs to generate.
+ This argument sets the number of iterations to generate sentence pairs for
+ and provides compatability with Setfit = 1.6.0
+ warmup_proportion (`float`, defaults to `0.1`):
+ Proportion of the warmup in the total training steps.
+ Must be greater than or equal to 0.0 and less than or equal to 1.0.
+ l2_weight (`float`, *optional*):
+ Optional l2 weight for both the model body and head, passed to the `AdamW` optimizer in the
+ classifier training phase if a differentiable PyTorch head is used.
+ max_length (`int`, *optional*):
+ The maximum token length a tokenizer can generate. If not provided, the maximum length for
+ the `SentenceTransformer` body is used.
+ samples_per_label (`int`, defaults to `2`): Number of consecutive, random and unique samples drawn per label.
+ This is only relevant for triplet loss and ignored for `CosineSimilarityLoss`.
+ Batch size should be a multiple of samples_per_label.
+ show_progress_bar (`bool`, defaults to `True`):
+ Whether to display a progress bar for the training epochs and iterations.
+ seed (`int`, defaults to `42`):
+ Random seed that will be set at the beginning of training. To ensure reproducibility across
+ runs, use the [`~SetTrainer.model_init`] function to instantiate the model if it has some
+ randomly initialized parameters.
+ report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
+ The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
+ `"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. Use `"all"` to report to
+ all integrations installed, `"none"` for no integrations.
+ run_name (`str`, *optional*):
+ A descriptor for the run. Typically used for [wandb](https://www.wandb.com/) and
+ [mlflow](https://www.mlflow.org/) logging.
+ logging_dir (`str`, *optional*):
+ [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
+ *runs/**CURRENT_DATETIME_HOSTNAME***.
+ logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The logging strategy to adopt during training. Possible values are:
+
+ - `"no"`: No logging is done during training.
+ - `"epoch"`: Logging is done at the end of each epoch.
+ - `"steps"`: Logging is done every `logging_steps`.
+
+ logging_first_step (`bool`, *optional*, defaults to `False`):
+ Whether to log and evaluate the first `global_step` or not.
+ logging_steps (`int`, *optional*, defaults to 50):
+ Number of update steps between two logs if `logging_strategy="steps"`.
+ evaluation_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+ The evaluation strategy to adopt during training. Possible values are:
+
+ - `"no"`: No evaluation is done during training.
+ - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
+ - `"epoch"`: Evaluation is done at the end of each epoch.
+
+ eval_steps (`int`, *optional*):
+ Number of update steps between two evaluations if `evaluation_strategy="steps"`. Will default to the same
+ value as `logging_steps` if not set.
+ eval_delay (`float`, *optional*):
+ Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
+ evaluation_strategy.
+
+ save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: No save is done during training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`.
+ save_steps (`int`, *optional*, defaults to 500):
+ Number of updates steps before two checkpoint saves if `save_strategy="steps"`.
+ save_total_limit (`int`, *optional*, defaults to `1`):
+ If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+ `output_dir`. Note, the best model is always preserved if the `evaluation_strategy` is not `"no"`.
+ load_best_model_at_end (`bool`, *optional*, defaults to `False`):
+ Whether or not to load the best model found during training at the end of training.
+
+
+
+ When set to `True`, the parameters `save_strategy` needs to be the same as `evaluation_strategy`, and in
+ the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
+
+
+ """
+
+ output_dir: str = "checkpoints"
+
+ # batch_size is only used to conveniently set `embedding_batch_size` and `classifier_batch_size`
+ # which are used in practice
+ batch_size: Union[int, Tuple[int, int]] = field(default=(16, 2), repr=False)
+ embedding_batch_size: int = None
+ classifier_batch_size: int = None
+
+ # num_epochs is only used to conveniently set `embedding_num_epochs` and `classifier_num_epochs`
+ # which are used in practice
+ num_epochs: Union[int, Tuple[int, int]] = field(default=(1, 16), repr=False)
+ embedding_num_epochs: int = None
+ classifier_num_epochs: int = None
+
+ max_steps: int = -1
+
+ sampling_strategy: str = "oversampling"
+ num_iterations: Optional[int] = None
+
+ # As with batch_size and num_epochs, the first value in the tuple is the learning rate
+ # for the embeddings step, while the second value is the learning rate for the classifier step.
+ body_learning_rate: Union[float, Tuple[float, float]] = field(default=(2e-5, 1e-5), repr=False)
+ body_embedding_learning_rate: float = None
+ body_classifier_learning_rate: float = None
+ head_learning_rate: float = 1e-2
+
+ # Loss-related arguments
+ loss: Callable = losses.CosineSimilarityLoss
+ distance_metric: Callable = losses.BatchHardTripletLossDistanceFunction.cosine_distance
+ margin: float = 0.25
+
+ end_to_end: bool = field(default=False)
+
+ use_amp: bool = False
+ warmup_proportion: float = 0.1
+ l2_weight: Optional[float] = None
+ max_length: Optional[int] = None
+ samples_per_label: int = 2
+
+ # Arguments that do not affect performance
+ show_progress_bar: bool = True
+ seed: int = 42
+
+ # Logging & callbacks
+ report_to: str = "all"
+ run_name: Optional[str] = None
+ logging_dir: Optional[str] = None
+ logging_strategy: str = "steps"
+ logging_first_step: bool = True
+ logging_steps: int = 50
+
+ evaluation_strategy: str = "no"
+ eval_steps: Optional[int] = None
+ eval_delay: int = 0
+
+ save_strategy: str = "steps"
+ save_steps: int = 500
+ save_total_limit: Optional[int] = 1
+
+ load_best_model_at_end: bool = False
+ metric_for_best_model: str = field(default="embedding_loss", repr=False)
+ greater_is_better: bool = field(default=False, repr=False)
+
+ def __post_init__(self) -> None:
+ # Set `self.embedding_batch_size` and `self.classifier_batch_size` using values from `self.batch_size`
+ if isinstance(self.batch_size, int):
+ self.batch_size = (self.batch_size, self.batch_size)
+ if self.embedding_batch_size is None:
+ self.embedding_batch_size = self.batch_size[0]
+ if self.classifier_batch_size is None:
+ self.classifier_batch_size = self.batch_size[1]
+
+ # Set `self.embedding_num_epochs` and `self.classifier_num_epochs` using values from `self.num_epochs`
+ if isinstance(self.num_epochs, int):
+ self.num_epochs = (self.num_epochs, self.num_epochs)
+ if self.embedding_num_epochs is None:
+ self.embedding_num_epochs = self.num_epochs[0]
+ if self.classifier_num_epochs is None:
+ self.classifier_num_epochs = self.num_epochs[1]
+
+ # Set `self.body_embedding_learning_rate` and `self.body_classifier_learning_rate` using
+ # values from `self.body_learning_rate`
+ if isinstance(self.body_learning_rate, float):
+ self.body_learning_rate = (self.body_learning_rate, self.body_learning_rate)
+ if self.body_embedding_learning_rate is None:
+ self.body_embedding_learning_rate = self.body_learning_rate[0]
+ if self.body_classifier_learning_rate is None:
+ self.body_classifier_learning_rate = self.body_learning_rate[1]
+
+ if self.warmup_proportion < 0.0 or self.warmup_proportion > 1.0:
+ raise ValueError(
+ f"warmup_proportion must be greater than or equal to 0.0 and less than or equal to 1.0! But it was: {self.warmup_proportion}"
+ )
+
+ if self.report_to in (None, "all", ["all"]):
+ self.report_to = get_available_reporting_integrations()
+ elif self.report_to in ("none", ["none"]):
+ self.report_to = []
+ elif not isinstance(self.report_to, list):
+ self.report_to = [self.report_to]
+
+ if self.logging_dir is None:
+ self.logging_dir = default_logdir()
+
+ self.logging_strategy = IntervalStrategy(self.logging_strategy)
+ self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
+
+ if self.eval_steps is not None and self.evaluation_strategy == IntervalStrategy.NO:
+ logger.info('Using `evaluation_strategy="steps"` as `eval_steps` is defined.')
+ self.evaluation_strategy = IntervalStrategy.STEPS
+
+ # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
+ if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
+ if self.logging_steps > 0:
+ self.eval_steps = self.logging_steps
+ else:
+ raise ValueError(
+ f"evaluation strategy {self.evaluation_strategy} requires either non-zero `eval_steps` or"
+ " `logging_steps`"
+ )
+
+ # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
+ if self.load_best_model_at_end:
+ if self.evaluation_strategy != self.save_strategy:
+ raise ValueError(
+ "`load_best_model_at_end` requires the save and eval strategy to match, but found\n- Evaluation "
+ f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}"
+ )
+ if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
+ if self.eval_steps < 1 or self.save_steps < 1:
+ if not (self.eval_steps < 1 and self.save_steps < 1):
+ raise ValueError(
+ "`load_best_model_at_end` requires the saving steps to be a multiple of the evaluation "
+ "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps"
+ f"{self.save_steps} and eval_steps {self.eval_steps}."
+ )
+ # Work around floating point precision issues
+ LARGE_MULTIPLIER = 1_000_000
+ if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0:
+ raise ValueError(
+ "`load_best_model_at_end` requires the saving steps to be a multiple of the evaluation "
+ f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}."
+ )
+ raise ValueError(
+ "`load_best_model_at_end` requires the saving steps to be a round multiple of the evaluation "
+ f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
+ )
+
+ # logging_steps must be non-zero for logging_strategy that is other than 'no'
+ if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0:
+ raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps")
+
+ def to_dict(self) -> Dict[str, Any]:
+ # filter out fields that are defined as field(init=False)
+ return {field.name: getattr(self, field.name) for field in fields(self) if field.init}
+
+ @classmethod
+ def from_dict(cls, arguments: Dict[str, Any], ignore_extra: bool = False) -> TrainingArguments:
+ if ignore_extra:
+ return cls(**{key: value for key, value in arguments.items() if key in inspect.signature(cls).parameters})
+ return cls(**arguments)
+
+ def copy(self) -> TrainingArguments:
+ return copy(self)
+
+ def update(self, arguments: Dict[str, Any], ignore_extra: bool = False) -> TrainingArguments:
+ return TrainingArguments.from_dict({**self.to_dict(), **arguments}, ignore_extra=ignore_extra)
+
+ def to_json_string(self):
+ """
+ Serializes this instance to a JSON string.
+ """
+ # TODO: This needs to be improved
+ return json.dumps({key: str(value) for key, value in self.to_dict().items()}, indent=2)
+
+ def to_sanitized_dict(self) -> Dict[str, Any]:
+ """
+ Sanitized serialization to use with TensorBoard’s hparams
+ """
+ d = self.to_dict()
+ d = {**d, **{"train_batch_size": self.embedding_batch_size, "eval_batch_size": self.embedding_batch_size}}
+
+ valid_types = [bool, int, float, str]
+ if is_torch_available():
+ valid_types.append(torch.Tensor)
+
+ return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
diff --git a/src/setfit/utils.py b/src/setfit/utils.py
index 57fb31d4..d75dc7cf 100644
--- a/src/setfit/utils.py
+++ b/src/setfit/utils.py
@@ -7,7 +7,7 @@
from sentence_transformers import losses
from .data import create_fewshot_splits, create_fewshot_splits_multilabel
-from .modeling import SupConLoss
+from .losses import SupConLoss
SEC_TO_NS_SCALE = 1000000000
@@ -135,7 +135,7 @@ def summary(self) -> None:
class BestRun(NamedTuple):
"""
- The best run found by a hyperparameter search (see [`~SetFitTrainer.hyperparameter_search`]).
+ The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
Parameters:
run_id (`str`):
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..f92a81d8
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,29 @@
+import pytest
+from datasets import Dataset
+
+from setfit import AbsaModel, SetFitModel
+
+
+@pytest.fixture()
+def model() -> SetFitModel:
+ return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
+
+
+@pytest.fixture()
+def absa_model() -> AbsaModel:
+ return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")
+
+
+@pytest.fixture()
+def absa_dataset() -> Dataset:
+ texts = [
+ "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.",
+ "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.",
+ "Food is great and inexpensive.",
+ "Good bagels and good cream cheese.",
+ "Good bagels and good cream cheese.",
+ ]
+ spans = ["food", "ambiance", "Food", "bagels", "cream cheese"]
+ labels = ["negative", "negative", "positive", "positive", "positive"]
+ ordinals = [0, 0, 0, 0, 0]
+ return Dataset.from_dict({"text": texts, "span": spans, "label": labels, "ordinal": ordinals})
diff --git a/tests/exporters/test_onnx.py b/tests/exporters/test_onnx.py
index 6c132d43..6e515d74 100644
--- a/tests/exporters/test_onnx.py
+++ b/tests/exporters/test_onnx.py
@@ -8,7 +8,8 @@
from setfit import SetFitModel
from setfit.data import get_templated_dataset
from setfit.exporters.onnx import export_onnx
-from setfit.trainer import SetFitTrainer
+from setfit.trainer import Trainer
+from setfit.training_args import TrainingArguments
@pytest.mark.parametrize(
@@ -71,25 +72,23 @@ def test_export_onnx_torch_head(out_features):
model_path, use_differentiable_head=True, head_params={"out_features": out_features}
)
- trainer = SetFitTrainer(
+ args = TrainingArguments(
+ num_iterations=15,
+ num_epochs=(1, 15),
+ batch_size=16,
+ body_learning_rate=(2e-5, 1e-5),
+ head_learning_rate=1e-2,
+ l2_weight=0.0,
+ end_to_end=True,
+ )
+ trainer = Trainer(
model=model,
+ args=args,
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=15,
column_mapping={"text": "text", "label": "label"},
)
- # Train and evaluate
- trainer.freeze() # Freeze the head
- trainer.train() # Train only the body
- # Unfreeze the head and unfreeze the body -> end-to-end training
- trainer.unfreeze(keep_body_frozen=False)
- trainer.train(
- num_epochs=15,
- batch_size=16,
- body_learning_rate=1e-5,
- learning_rate=1e-2,
- l2_weight=0.0,
- )
+ trainer.train()
# Export the sklearn based model
output_path = "model.onnx"
diff --git a/tests/span/test_modeling.py b/tests/span/test_modeling.py
new file mode 100644
index 00000000..0bc3ccb8
--- /dev/null
+++ b/tests/span/test_modeling.py
@@ -0,0 +1,86 @@
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+import torch
+
+from setfit import AbsaModel
+from setfit.span.aspect_extractor import AspectExtractor
+from setfit.span.modeling import AspectModel, PolarityModel
+
+
+def test_loading():
+ model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")
+ assert isinstance(model, AbsaModel)
+ assert isinstance(model.aspect_extractor, AspectExtractor)
+ assert isinstance(model.aspect_model, AspectModel)
+ assert isinstance(model.polarity_model, PolarityModel)
+
+ model = AbsaModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2@6c91e73a51599e35bd1145dfdcd3289215225009",
+ "sentence-transformers/paraphrase-albert-small-v2",
+ spacy_model="en_core_web_sm",
+ )
+ assert isinstance(model, AbsaModel)
+
+ model = AbsaModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2",
+ "sentence-transformers/paraphrase-albert-small-v2@6c91e73a51599e35bd1145dfdcd3289215225009",
+ spacy_model="en_core_web_sm",
+ )
+ assert isinstance(model, AbsaModel)
+
+ with pytest.raises(OSError):
+ model = AbsaModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2", spacy_model="not_a_spacy_model"
+ )
+
+ model = AbsaModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm", normalize_embeddings=True
+ )
+ assert model.aspect_model.normalize_embeddings
+ assert model.polarity_model.normalize_embeddings
+
+ aspect_model = AspectModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_context=12)
+ assert aspect_model.span_context == 12
+ polarity_model = PolarityModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", span_context=12)
+ assert polarity_model.span_context == 12
+
+ model = AbsaModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm", span_contexts=(12, None)
+ )
+ assert model.aspect_model.span_context == 12
+ assert model.polarity_model.span_context == 3 # <- default
+
+
+def test_save_load(absa_model: AbsaModel) -> None:
+ absa_model.polarity_model.span_context = 5
+
+ with TemporaryDirectory() as tmp_dir:
+ tmp_dir = str(Path(tmp_dir) / "model")
+ absa_model.save_pretrained(tmp_dir)
+ assert (Path(tmp_dir + "-aspect") / "config_span_setfit.json").exists()
+ assert (Path(tmp_dir + "-polarity") / "config_span_setfit.json").exists()
+
+ fresh_model = AbsaModel.from_pretrained(
+ tmp_dir + "-aspect", tmp_dir + "-polarity", spacy_model="en_core_web_sm"
+ )
+ assert fresh_model.polarity_model.span_context == 5
+
+ with TemporaryDirectory() as aspect_tmp_dir:
+ with TemporaryDirectory() as polarity_tmp_dir:
+ absa_model.save_pretrained(aspect_tmp_dir, polarity_tmp_dir)
+ assert (Path(aspect_tmp_dir) / "config_span_setfit.json").exists()
+ assert (Path(polarity_tmp_dir) / "config_span_setfit.json").exists()
+
+ fresh_model = AbsaModel.from_pretrained(aspect_tmp_dir, polarity_tmp_dir, spacy_model="en_core_web_sm")
+ assert fresh_model.polarity_model.span_context == 5
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to move a model between devices")
+def test_to(absa_model: AbsaModel) -> None:
+ assert absa_model.device.type == "cuda"
+ absa_model.to("cpu")
+ assert absa_model.device.type == "cpu"
+ assert absa_model.aspect_model.device.type == "cpu"
+ assert absa_model.polarity_model.device.type == "cpu"
diff --git a/tests/span/test_trainer.py b/tests/span/test_trainer.py
new file mode 100644
index 00000000..f89044dc
--- /dev/null
+++ b/tests/span/test_trainer.py
@@ -0,0 +1,75 @@
+from datasets import Dataset
+from transformers import TrainerCallback
+
+from setfit import AbsaTrainer
+from setfit.span.modeling import AbsaModel
+
+
+def test_trainer(absa_model: AbsaModel, absa_dataset: Dataset) -> None:
+ trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset, eval_dataset=absa_dataset)
+ trainer.train()
+
+ metrics = trainer.evaluate()
+ assert "aspect" in metrics
+ assert "polarity" in metrics
+ assert "accuracy" in metrics["aspect"]
+ assert "accuracy" in metrics["polarity"]
+ assert metrics["aspect"]["accuracy"] > 0.0
+ assert metrics["polarity"]["accuracy"] > 0.0
+ new_metrics = trainer.evaluate(absa_dataset)
+ assert metrics == new_metrics
+
+ predict = absa_model.predict("Best pizza outside of Italy and really tasty.")
+ assert {"span": "pizza", "polarity": "positive"} in predict
+ predict = absa_model.predict(["Best pizza outside of Italy and really tasty.", "This is another sentence"])
+ assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list)
+ predict = absa_model(["Best pizza outside of Italy and really tasty.", "This is another sentence"])
+ assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list)
+
+
+def test_trainer_callbacks(absa_model: AbsaModel) -> None:
+ trainer = AbsaTrainer(absa_model)
+ assert len(trainer.aspect_trainer.callback_handler.callbacks) >= 2
+ callback_names = {callback.__class__.__name__ for callback in trainer.aspect_trainer.callback_handler.callbacks}
+ assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names
+
+ class TestCallback(TrainerCallback):
+ pass
+
+ callback = TestCallback()
+ trainer.add_callback(callback)
+ assert len(trainer.aspect_trainer.callback_handler.callbacks) == len(callback_names) + 1
+ assert len(trainer.polarity_trainer.callback_handler.callbacks) == len(callback_names) + 1
+ assert trainer.aspect_trainer.callback_handler.callbacks[-1] == callback
+ assert trainer.polarity_trainer.callback_handler.callbacks[-1] == callback
+
+ assert trainer.pop_callback(callback) == (callback, callback)
+ trainer.add_callback(callback)
+ assert trainer.aspect_trainer.callback_handler.callbacks[-1] == callback
+ assert trainer.polarity_trainer.callback_handler.callbacks[-1] == callback
+ trainer.remove_callback(callback)
+ assert callback not in trainer.aspect_trainer.callback_handler.callbacks
+ assert callback not in trainer.polarity_trainer.callback_handler.callbacks
+
+
+def test_train_ordinal_too_high(absa_model: AbsaModel) -> None:
+ absa_dataset = Dataset.from_dict(
+ {
+ "text": [
+ "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine."
+ ],
+ "span": ["food"],
+ "label": ["negative"],
+ "ordinal": [1],
+ }
+ )
+ AbsaTrainer(absa_model, train_dataset=absa_dataset)
+ # TODO: Capture warning and test against it.
+
+
+def test_train_column_mapping(absa_model: AbsaModel, absa_dataset: Dataset) -> None:
+ absa_dataset = absa_dataset.rename_columns({"text": "sentence", "span": "aspect"})
+ trainer = AbsaTrainer(
+ absa_model, train_dataset=absa_dataset, column_mapping={"sentence": "text", "aspect": "span"}
+ )
+ trainer.train()
diff --git a/tests/test_deprecated_trainer.py b/tests/test_deprecated_trainer.py
new file mode 100644
index 00000000..8e1ce1d5
--- /dev/null
+++ b/tests/test_deprecated_trainer.py
@@ -0,0 +1,531 @@
+import pathlib
+import re
+import tempfile
+from unittest import TestCase
+
+import evaluate
+import pytest
+import torch
+from datasets import Dataset, load_dataset
+from sentence_transformers import losses
+from transformers.testing_utils import require_optuna
+from transformers.utils.hp_naming import TrialShortNamer
+
+from setfit import logging
+from setfit.losses import SupConLoss
+from setfit.modeling import SetFitModel
+from setfit.trainer import SetFitTrainer
+from setfit.utils import BestRun
+
+
+logging.set_verbosity_warning()
+logging.enable_propagation()
+
+
+class SetFitTrainerTest(TestCase):
+ def setUp(self):
+ self.model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
+ self.num_iterations = 1
+
+ def test_trainer_works_with_model_init(self):
+ def get_model():
+ model_name = "sentence-transformers/paraphrase-albert-small-v2"
+ return SetFitModel.from_pretrained(model_name)
+
+ dataset = Dataset.from_dict(
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ trainer = SetFitTrainer(
+ model_init=get_model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+ trainer.train()
+ metrics = trainer.evaluate()
+ self.assertEqual(metrics["accuracy"], 1.0)
+
+ def test_trainer_works_with_column_mapping(self):
+ dataset = Dataset.from_dict(
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+ trainer.train()
+ metrics = trainer.evaluate()
+ self.assertEqual(metrics["accuracy"], 1.0)
+
+ def test_trainer_works_with_default_columns(self):
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
+ trainer = SetFitTrainer(
+ model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
+ )
+ trainer.train()
+ metrics = trainer.evaluate()
+ self.assertEqual(metrics["accuracy"], 1.0)
+
+ def test_trainer_works_with_alternate_dataset_for_evaluate(self):
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
+ alternate_dataset = Dataset.from_dict(
+ {"text": ["x", "y", "z"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ trainer = SetFitTrainer(
+ model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
+ )
+ trainer.train()
+ metrics = trainer.evaluate(alternate_dataset)
+ self.assertNotEqual(metrics["accuracy"], 1.0)
+
+ def test_trainer_raises_error_with_missing_label(self):
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
+ trainer = SetFitTrainer(
+ model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
+ )
+ with pytest.raises(ValueError):
+ trainer.train()
+
+ def test_trainer_raises_error_with_missing_text(self):
+ """If the required columns are missing from the dataset, the library should throw an error and list the columns found."""
+ dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
+ trainer = SetFitTrainer(
+ model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
+ )
+ expected_message = re.escape(
+ "SetFit expected the dataset to have the columns ['label', 'text'], "
+ "but only the columns ['extra_column', 'label'] were found. "
+ "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer."
+ )
+ with pytest.raises(ValueError, match=expected_message):
+ trainer._validate_column_mapping(trainer.train_dataset)
+
+ def test_column_mapping_raises_error_when_mapped_columns_missing(self):
+ """If the columns specified in the column mapping are missing from the dataset, the library should throw an error and list the columns found."""
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+ expected_message = re.escape(
+ "The column mapping expected the columns ['label_new', 'text_new'] in the dataset, "
+ "but the dataset had the columns ['extra_column', 'text'].",
+ )
+ with pytest.raises(ValueError, match=expected_message):
+ trainer._validate_column_mapping(trainer.train_dataset)
+
+ def test_trainer_raises_error_when_dataset_not_split(self):
+ """Verify that an error is raised if we pass an unsplit dataset to the trainer."""
+ dataset = Dataset.from_dict({"text": ["a", "b", "c", "d"], "label": [0, 0, 1, 1]}).train_test_split(
+ test_size=0.5
+ )
+ trainer = SetFitTrainer(
+ model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
+ )
+ expected_message = re.escape(
+ "SetFit expected a Dataset, but it got a DatasetDict with the splits ['test', 'train']. "
+ "Did you mean to select one of these splits from the dataset?",
+ )
+ with pytest.raises(ValueError, match=expected_message):
+ trainer._validate_column_mapping(trainer.train_dataset)
+
+ def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self):
+ """Verify that a useful error is raised if we pass an unsplit dataset with only a `train` split to the trainer."""
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ path = pathlib.Path(tmpdirname) / "test_dataset_dict_with_train.csv"
+ path.write_text("label,text\n1,good\n0,terrible\n")
+ dataset = load_dataset("csv", data_files=str(path))
+ trainer = SetFitTrainer(
+ model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
+ )
+ expected_message = re.escape(
+ "SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. "
+ "Did you mean to select the training split with dataset['train']?",
+ )
+ with pytest.raises(ValueError, match=expected_message):
+ trainer._validate_column_mapping(trainer.train_dataset)
+
+ def test_column_mapping_multilabel(self):
+ dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]})
+
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer._validate_column_mapping(trainer.train_dataset)
+ formatted_dataset = trainer._apply_column_mapping(trainer.train_dataset, trainer.column_mapping)
+
+ assert formatted_dataset.column_names == ["text", "label"]
+
+ assert formatted_dataset[0]["text"] == "a"
+ assert formatted_dataset[0]["label"] == [0, 1]
+
+ assert formatted_dataset[1]["text"] == "b"
+
+ def test_trainer_support_callable_as_metric(self):
+ dataset = Dataset.from_dict(
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+
+ f1_metric = evaluate.load("f1")
+ accuracy_metric = evaluate.load("accuracy")
+
+ def compute_metrics(y_pred, y_test):
+ return {
+ "f1": f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
+ "accuracy": accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
+ }
+
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ metric=compute_metrics,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer.train()
+ metrics = trainer.evaluate()
+
+ self.assertEqual(
+ {
+ "f1": 1.0,
+ "accuracy": 1.0,
+ },
+ metrics,
+ )
+
+ def test_raise_when_metric_value_is_invalid(self):
+ dataset = Dataset.from_dict(
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ metric="this-metric-does-not-exist", # invalid metric value
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer.train()
+
+ with self.assertRaises(FileNotFoundError):
+ trainer.evaluate()
+
+ def test_trainer_raises_error_with_wrong_warmup_proportion(self):
+ # warmup_proportion must not be > 1.0
+ with pytest.raises(ValueError):
+ SetFitTrainer(warmup_proportion=1.1)
+
+ # warmup_proportion must not be < 0.0
+ with pytest.raises(ValueError):
+ SetFitTrainer(warmup_proportion=-0.1)
+
+
+class SetFitTrainerDifferentiableHeadTest(TestCase):
+ def setUp(self):
+ self.dataset = Dataset.from_dict(
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ self.model = SetFitModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2",
+ use_differentiable_head=True,
+ head_params={"out_features": 3},
+ )
+ self.num_iterations = 1
+
+ @pytest.mark.skip(reason="The `trainer.train` arguments are now ignored, causing this test to fail.")
+ def test_trainer_max_length_exceeds_max_acceptable_length(self):
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=self.dataset,
+ eval_dataset=self.dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+ trainer.unfreeze(keep_body_frozen=True)
+ with self.assertLogs(level=logging.WARNING) as cm:
+ max_length = 4096
+ max_acceptable_length = self.model.model_body.get_max_seq_length()
+ trainer.train(
+ num_epochs=1,
+ batch_size=3,
+ learning_rate=1e-2,
+ l2_weight=0.0,
+ max_length=max_length,
+ )
+ self.assertEqual(
+ cm.output,
+ [
+ (
+ f"WARNING:setfit.modeling:The specified `max_length`: {max_length} is greater than the maximum length "
+ f"of the current model body: {max_acceptable_length}. Using {max_acceptable_length} instead."
+ )
+ ],
+ )
+
+ @pytest.mark.skip(reason="The `trainer.train` arguments are now ignored, causing this test to fail.")
+ def test_trainer_max_length_is_smaller_than_max_acceptable_length(self):
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=self.dataset,
+ eval_dataset=self.dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+ trainer.unfreeze(keep_body_frozen=True)
+
+ # An alternative way of `assertNoLogs`, which is new in Python 3.10
+ try:
+ with self.assertLogs(level=logging.WARNING) as cm:
+ max_length = 32
+ trainer.train(
+ num_epochs=1,
+ batch_size=3,
+ learning_rate=1e-2,
+ l2_weight=0.0,
+ max_length=max_length,
+ )
+ self.assertEqual(cm.output, [])
+ except AssertionError as e:
+ if e.args[0] != "no logs of level WARNING or higher triggered on root":
+ raise AssertionError(e)
+
+
+class SetFitTrainerMultilabelTest(TestCase):
+ def setUp(self):
+ self.model = SetFitModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest"
+ )
+ self.num_iterations = 1
+
+ def test_trainer_multilabel_support_callable_as_metric(self):
+ dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]})
+
+ multilabel_f1_metric = evaluate.load("f1", "multilabel")
+ multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")
+
+ def compute_metrics(y_pred, y_test):
+ return {
+ "f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
+ "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
+ }
+
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ metric=compute_metrics,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer.train()
+ metrics = trainer.evaluate()
+
+ self.assertEqual(
+ {
+ "f1": 1.0,
+ "accuracy": 1.0,
+ },
+ metrics,
+ )
+
+
+@pytest.mark.skip(
+ reason=(
+ "The `trainer.freeze()` before `trainer.train()` now freezes the body as well as the head, "
+ "which means the backwards call from `trainer.train()` will fail."
+ )
+)
+class SetFitTrainerMultilabelDifferentiableTest(TestCase):
+ def setUp(self):
+ self.model = SetFitModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2",
+ multi_target_strategy="one-vs-rest",
+ use_differentiable_head=True,
+ head_params={"out_features": 2},
+ )
+ self.num_iterations = 1
+
+ def test_trainer_multilabel_support_callable_as_metric(self):
+ dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})
+
+ multilabel_f1_metric = evaluate.load("f1", "multilabel")
+ multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")
+
+ def compute_metrics(y_pred, y_test):
+ return {
+ "f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
+ "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
+ }
+
+ trainer = SetFitTrainer(
+ model=self.model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ metric=compute_metrics,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer.freeze()
+ trainer.train()
+
+ trainer.unfreeze(keep_body_frozen=False)
+ trainer.train(5)
+ metrics = trainer.evaluate()
+
+ self.assertEqual(
+ {
+ "f1": 1.0,
+ "accuracy": 1.0,
+ },
+ metrics,
+ )
+
+
+@require_optuna
+class TrainerHyperParameterOptunaIntegrationTest(TestCase):
+ def setUp(self):
+ self.dataset = Dataset.from_dict(
+ {"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ self.num_iterations = 1
+
+ def test_hyperparameter_search(self):
+ class MyTrialShortNamer(TrialShortNamer):
+ DEFAULTS = {"max_iter": 100, "solver": "liblinear"}
+
+ def hp_space(trial):
+ return {
+ "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
+ "batch_size": trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64]),
+ "max_iter": trial.suggest_int("max_iter", 50, 300),
+ "solver": trial.suggest_categorical("solver", ["newton-cg", "lbfgs", "liblinear"]),
+ }
+
+ def model_init(params):
+ params = params or {}
+ max_iter = params.get("max_iter", 100)
+ solver = params.get("solver", "liblinear")
+ params = {
+ "head_params": {
+ "max_iter": max_iter,
+ "solver": solver,
+ }
+ }
+ return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", **params)
+
+ def hp_name(trial):
+ return MyTrialShortNamer.shortname(trial.params)
+
+ trainer = SetFitTrainer(
+ train_dataset=self.dataset,
+ eval_dataset=self.dataset,
+ num_iterations=self.num_iterations,
+ model_init=model_init,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+ result = trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
+ assert isinstance(result, BestRun)
+ assert result.hyperparameters.keys() == {"learning_rate", "batch_size", "max_iter", "solver"}
+
+
+# regression test for https://github.com/huggingface/setfit/issues/153
+@pytest.mark.parametrize(
+ "loss_class",
+ [
+ losses.BatchAllTripletLoss,
+ losses.BatchHardTripletLoss,
+ losses.BatchSemiHardTripletLoss,
+ losses.BatchHardSoftMarginTripletLoss,
+ SupConLoss,
+ ],
+)
+def test_trainer_works_with_non_default_loss_class(loss_class):
+ dataset = Dataset.from_dict({"text": ["a 1", "b 1", "c 1", "a 2", "b 2", "c 2"], "label": [0, 1, 2, 0, 1, 2]})
+ model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
+ trainer = SetFitTrainer(
+ model=model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ num_iterations=1,
+ loss_class=loss_class,
+ )
+ trainer.train()
+ # no asserts here because this is a regression test - we only test if an exception is raised
+
+
+def test_trainer_evaluate_with_strings():
+ dataset = Dataset.from_dict(
+ {"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]}
+ )
+ model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
+ trainer = SetFitTrainer(
+ model=model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ num_iterations=1,
+ )
+ trainer.train()
+ # This used to fail due to "TypeError: can't convert np.ndarray of type numpy.str_.
+ # The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."
+ model.predict(["another positive sentence"])
+
+
+def test_trainer_evaluate_multilabel_f1():
+ dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})
+ model = SetFitModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest"
+ )
+
+ trainer = SetFitTrainer(
+ model=model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ metric="f1",
+ metric_kwargs={"average": "micro"},
+ num_iterations=5,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer.train()
+ metrics = trainer.evaluate()
+ assert metrics == {"f1": 1.0}
+
+
+def test_trainer_evaluate_on_cpu() -> None:
+ # This test used to fail if CUDA was available
+ dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]})
+ model = SetFitModel.from_pretrained(
+ "sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True
+ )
+
+ def compute_metric(y_pred, y_test) -> None:
+ assert y_pred.device == torch.device("cpu")
+ return 1.0
+
+ trainer = SetFitTrainer(
+ model=model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ metric=compute_metric,
+ num_iterations=5,
+ )
+ trainer.train()
+ trainer.evaluate()
diff --git a/tests/test_deprecated_trainer_distillation.py b/tests/test_deprecated_trainer_distillation.py
new file mode 100644
index 00000000..5d59c4f5
--- /dev/null
+++ b/tests/test_deprecated_trainer_distillation.py
@@ -0,0 +1,117 @@
+from unittest import TestCase
+
+import pytest
+from datasets import Dataset
+from sentence_transformers.losses import CosineSimilarityLoss
+
+from setfit import DistillationSetFitTrainer, SetFitTrainer
+from setfit.modeling import SetFitModel
+
+
+class DistillationSetFitTrainerTest(TestCase):
+ def setUp(self):
+ self.teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
+ self.student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
+ self.num_iterations = 1
+
+ def test_trainer_works_with_default_columns(self):
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
+ # train a teacher model
+ teacher_trainer = SetFitTrainer(
+ model=self.teacher_model,
+ train_dataset=dataset,
+ eval_dataset=dataset,
+ loss_class=CosineSimilarityLoss,
+ metric="accuracy",
+ )
+ # Teacher Train and evaluate
+ teacher_trainer.train()
+ teacher_model = teacher_trainer.model
+
+ student_trainer = DistillationSetFitTrainer(
+ teacher_model=teacher_model,
+ train_dataset=dataset,
+ student_model=self.student_model,
+ eval_dataset=dataset,
+ loss_class=CosineSimilarityLoss,
+ metric="accuracy",
+ )
+
+ # Student Train and evaluate
+ student_trainer.train()
+ metrics = student_trainer.evaluate()
+ print("Student results: ", metrics)
+ self.assertEqual(metrics["accuracy"], 1.0)
+
+ def test_trainer_raises_error_with_missing_label(self):
+ labeled_dataset = Dataset.from_dict(
+ {"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ # train a teacher model
+ teacher_trainer = SetFitTrainer(
+ model=self.teacher_model,
+ train_dataset=labeled_dataset,
+ eval_dataset=labeled_dataset,
+ metric="accuracy",
+ num_iterations=self.num_iterations,
+ )
+ # Teacher Train and evaluate
+ teacher_trainer.train()
+
+ unlabeled_dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
+ student_trainer = DistillationSetFitTrainer(
+ teacher_model=self.teacher_model,
+ student_model=self.student_model,
+ train_dataset=unlabeled_dataset,
+ eval_dataset=labeled_dataset,
+ num_iterations=self.num_iterations,
+ )
+ student_trainer.train()
+ metrics = student_trainer.evaluate()
+ print("Student results: ", metrics)
+ self.assertEqual(metrics["accuracy"], 1.0)
+
+ def test_trainer_raises_error_with_missing_text(self):
+ dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
+ trainer = DistillationSetFitTrainer(
+ teacher_model=self.teacher_model,
+ train_dataset=dataset,
+ student_model=self.student_model,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ )
+ with pytest.raises(ValueError):
+ trainer.train()
+
+ def test_column_mapping_with_missing_text(self):
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
+ trainer = DistillationSetFitTrainer(
+ teacher_model=self.teacher_model,
+ train_dataset=dataset,
+ student_model=self.student_model,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"label_new": "label"},
+ )
+ with pytest.raises(ValueError):
+ trainer._validate_column_mapping(trainer.train_dataset)
+
+ def test_column_mapping_multilabel(self):
+ dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]})
+
+ trainer = DistillationSetFitTrainer(
+ teacher_model=self.teacher_model,
+ train_dataset=dataset,
+ student_model=self.student_model,
+ eval_dataset=dataset,
+ num_iterations=self.num_iterations,
+ column_mapping={"text_new": "text", "label_new": "label"},
+ )
+
+ trainer._validate_column_mapping(trainer.train_dataset)
+ formatted_dataset = trainer._apply_column_mapping(trainer.train_dataset, trainer.column_mapping)
+
+ assert formatted_dataset.column_names == ["text", "label"]
+ assert formatted_dataset[0]["text"] == "a"
+ assert formatted_dataset[0]["label"] == [0, 1]
+ assert formatted_dataset[1]["text"] == "b"
diff --git a/tests/test_modeling.py b/tests/test_modeling.py
index c31417d2..a5e279f6 100644
--- a/tests/test_modeling.py
+++ b/tests/test_modeling.py
@@ -10,42 +10,12 @@
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
from setfit import SetFitHead, SetFitModel
-from setfit.modeling import MODEL_HEAD_NAME, sentence_pairs_generation, sentence_pairs_generation_multilabel
+from setfit.modeling import MODEL_HEAD_NAME
torch_cuda_available = pytest.mark.skipif(not torch.cuda.is_available(), reason="PyTorch must be compiled with CUDA")
-def test_sentence_pairs_generation():
- sentences = np.array(["sent 1", "sent 2", "sent 3"])
- labels = np.array(["label 1", "label 2", "label 3"])
-
- pairs = []
- n_iterations = 2
-
- for _ in range(n_iterations):
- pairs = sentence_pairs_generation(sentences, labels, pairs)
-
- assert len(pairs) == 12
- assert pairs[0].texts == ["sent 1", "sent 1"]
- assert pairs[0].label == 1.0
-
-
-def test_sentence_pairs_generation_multilabel():
- sentences = np.array(["sent 1", "sent 2", "sent 3"])
- labels = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]])
-
- pairs = []
- n_iterations = 2
-
- for _ in range(n_iterations):
- pairs = sentence_pairs_generation_multilabel(sentences, labels, pairs)
-
- assert len(pairs) == 12
- assert pairs[0].texts == ["sent 1", "sent 1"]
- assert pairs[0].label == 1.0
-
-
def test_setfit_model_body():
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
new file mode 100644
index 00000000..d8d37712
--- /dev/null
+++ b/tests/test_sampler.py
@@ -0,0 +1,49 @@
+import numpy as np
+import pytest
+from sentence_transformers import InputExample
+
+from setfit.sampler import ContrastiveDataset
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, expected_pos_pairs, expected_neg_pairs",
+ [("unique", 4, 2), ("undersampling", 2, 2), ("oversampling", 4, 4)],
+)
+def test_sentence_pairs_generation(sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int):
+ sentences = np.array(["sent 1", "sent 2", "sent 3"])
+ labels = np.array(["label 1", "label 1", "label 2"])
+
+ data = [InputExample(texts=[text], label=label) for text, label in zip(sentences, labels)]
+ multilabel = False
+
+ data_sampler = ContrastiveDataset(data, multilabel, sampling_strategy=sampling_strategy)
+
+ assert data_sampler.len_pos_pairs == expected_pos_pairs
+ assert data_sampler.len_neg_pairs == expected_neg_pairs
+
+ pairs = [i for i in data_sampler]
+
+ assert len(pairs) == expected_pos_pairs + expected_neg_pairs
+ assert pairs[0].texts == ["sent 1", "sent 1"]
+ assert pairs[0].label == 1.0
+
+
+@pytest.mark.parametrize(
+ "sampling_strategy, expected_pos_pairs, expected_neg_pairs",
+ [("unique", 6, 4), ("undersampling", 4, 4), ("oversampling", 6, 6)],
+)
+def test_sentence_pairs_generation_multilabel(
+ sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int
+):
+ sentences = np.array(["sent 1", "sent 2", "sent 3", "sent 4"])
+ labels = np.array([[1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
+
+ data = [InputExample(texts=[text], label=label) for text, label in zip(sentences, labels)]
+ multilabel = True
+
+ data_sampler = ContrastiveDataset(data, multilabel, sampling_strategy=sampling_strategy)
+ assert data_sampler.len_pos_pairs == expected_pos_pairs
+ assert data_sampler.len_neg_pairs == expected_neg_pairs
+
+ pairs = [i for i in data_sampler]
+ assert len(pairs) == expected_pos_pairs + expected_neg_pairs
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index af2c9a82..c5654524 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -1,19 +1,24 @@
-import pathlib
+import os
import re
import tempfile
+from pathlib import Path
from unittest import TestCase
import evaluate
import pytest
import torch
from datasets import Dataset, load_dataset
+from pytest import LogCaptureFixture
from sentence_transformers import losses
+from transformers import TrainerCallback
from transformers.testing_utils import require_optuna
from transformers.utils.hp_naming import TrialShortNamer
from setfit import logging
-from setfit.modeling import SetFitModel, SupConLoss
-from setfit.trainer import SetFitTrainer
+from setfit.losses import SupConLoss
+from setfit.modeling import SetFitModel
+from setfit.trainer import Trainer
+from setfit.training_args import TrainingArguments
from setfit.utils import BestRun
@@ -21,10 +26,10 @@
logging.enable_propagation()
-class SetFitTrainerTest(TestCase):
+class TrainerTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
- self.num_iterations = 1
+ self.args = TrainingArguments(num_iterations=1)
def test_trainer_works_with_model_init(self):
def get_model():
@@ -34,11 +39,11 @@ def get_model():
dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
- trainer = SetFitTrainer(
+ trainer = Trainer(
model_init=get_model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.train()
@@ -49,11 +54,11 @@ def test_trainer_works_with_column_mapping(self):
dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.train()
@@ -62,9 +67,7 @@ def test_trainer_works_with_column_mapping(self):
def test_trainer_works_with_default_columns(self):
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
- trainer = SetFitTrainer(
- model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
- )
+ trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
trainer.train()
metrics = trainer.evaluate()
self.assertEqual(metrics["accuracy"], 1.0)
@@ -74,31 +77,25 @@ def test_trainer_works_with_alternate_dataset_for_evaluate(self):
alternate_dataset = Dataset.from_dict(
{"text": ["x", "y", "z"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
- trainer = SetFitTrainer(
- model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
- )
+ trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
trainer.train()
metrics = trainer.evaluate(alternate_dataset)
self.assertNotEqual(metrics["accuracy"], 1.0)
def test_trainer_raises_error_with_missing_label(self):
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
- trainer = SetFitTrainer(
- model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
- )
+ trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
with pytest.raises(ValueError):
trainer.train()
def test_trainer_raises_error_with_missing_text(self):
"""If the required columns are missing from the dataset, the library should throw an error and list the columns found."""
dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
- trainer = SetFitTrainer(
- model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
- )
+ trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
expected_message = re.escape(
"SetFit expected the dataset to have the columns ['label', 'text'], "
"but only the columns ['extra_column', 'label'] were found. "
- "Either make sure these columns are present, or specify which columns to use with column_mapping in SetFitTrainer."
+ "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer."
)
with pytest.raises(ValueError, match=expected_message):
trainer._validate_column_mapping(trainer.train_dataset)
@@ -106,11 +103,11 @@ def test_trainer_raises_error_with_missing_text(self):
def test_column_mapping_raises_error_when_mapped_columns_missing(self):
"""If the columns specified in the column mapping are missing from the dataset, the library should throw an error and list the columns found."""
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
expected_message = re.escape(
@@ -125,9 +122,7 @@ def test_trainer_raises_error_when_dataset_not_split(self):
dataset = Dataset.from_dict({"text": ["a", "b", "c", "d"], "label": [0, 0, 1, 1]}).train_test_split(
test_size=0.5
)
- trainer = SetFitTrainer(
- model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
- )
+ trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
expected_message = re.escape(
"SetFit expected a Dataset, but it got a DatasetDict with the splits ['test', 'train']. "
"Did you mean to select one of these splits from the dataset?",
@@ -138,12 +133,10 @@ def test_trainer_raises_error_when_dataset_not_split(self):
def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self):
"""Verify that a useful error is raised if we pass an unsplit dataset with only a `train` split to the trainer."""
with tempfile.TemporaryDirectory() as tmpdirname:
- path = pathlib.Path(tmpdirname) / "test_dataset_dict_with_train.csv"
+ path = Path(tmpdirname) / "test_dataset_dict_with_train.csv"
path.write_text("label,text\n1,good\n0,terrible\n")
dataset = load_dataset("csv", data_files=str(path))
- trainer = SetFitTrainer(
- model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
- )
+ trainer = Trainer(model=self.model, args=self.args, train_dataset=dataset, eval_dataset=dataset)
expected_message = re.escape(
"SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. "
"Did you mean to select the training split with dataset['train']?",
@@ -154,11 +147,11 @@ def test_trainer_raises_error_when_dataset_is_dataset_dict_with_train(self):
def test_column_mapping_multilabel(self):
dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]})
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -186,12 +179,12 @@ def compute_metrics(y_pred, y_test):
"accuracy": accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
}
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metrics,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -211,12 +204,12 @@ def test_raise_when_metric_value_is_invalid(self):
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
metric="this-metric-does-not-exist", # invalid metric value
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -225,17 +218,8 @@ def test_raise_when_metric_value_is_invalid(self):
with self.assertRaises(FileNotFoundError):
trainer.evaluate()
- def test_trainer_raises_error_with_wrong_warmup_proportion(self):
- # warmup_proportion must not be > 1.0
- with pytest.raises(ValueError):
- SetFitTrainer(warmup_proportion=1.1)
-
- # warmup_proportion must not be < 0.0
- with pytest.raises(ValueError):
- SetFitTrainer(warmup_proportion=-0.1)
-
-class SetFitTrainerDifferentiableHeadTest(TestCase):
+class TrainerDifferentiableHeadTest(TestCase):
def setUp(self):
self.dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
@@ -245,27 +229,22 @@ def setUp(self):
use_differentiable_head=True,
head_params={"out_features": 3},
)
- self.num_iterations = 1
+ self.args = TrainingArguments(num_iterations=1)
def test_trainer_max_length_exceeds_max_acceptable_length(self):
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=self.dataset,
eval_dataset=self.dataset,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.unfreeze(keep_body_frozen=True)
with self.assertLogs(level=logging.WARNING) as cm:
max_length = 4096
max_acceptable_length = self.model.model_body.get_max_seq_length()
- trainer.train(
- num_epochs=1,
- batch_size=3,
- learning_rate=1e-2,
- l2_weight=0.0,
- max_length=max_length,
- )
+ args = TrainingArguments(num_iterations=1, max_length=max_length)
+ trainer.train(args)
self.assertEqual(
cm.output,
[
@@ -277,38 +256,32 @@ def test_trainer_max_length_exceeds_max_acceptable_length(self):
)
def test_trainer_max_length_is_smaller_than_max_acceptable_length(self):
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=self.dataset,
eval_dataset=self.dataset,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
- trainer.unfreeze(keep_body_frozen=True)
# An alternative way of `assertNoLogs`, which is new in Python 3.10
try:
with self.assertLogs(level=logging.WARNING) as cm:
max_length = 32
- trainer.train(
- num_epochs=1,
- batch_size=3,
- learning_rate=1e-2,
- l2_weight=0.0,
- max_length=max_length,
- )
+ args = TrainingArguments(num_iterations=1, max_length=max_length)
+ trainer.train(args)
self.assertEqual(cm.output, [])
except AssertionError as e:
if e.args[0] != "no logs of level WARNING or higher triggered on root":
raise AssertionError(e)
-class SetFitTrainerMultilabelTest(TestCase):
+class TrainerMultilabelTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest"
)
- self.num_iterations = 1
+ self.args = TrainingArguments(num_iterations=1)
def test_trainer_multilabel_support_callable_as_metric(self):
dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]})
@@ -322,12 +295,12 @@ def compute_metrics(y_pred, y_test):
"accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
}
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metrics,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -343,7 +316,7 @@ def compute_metrics(y_pred, y_test):
)
-class SetFitTrainerMultilabelDifferentiableTest(TestCase):
+class TrainerMultilabelDifferentiableTest(TestCase):
def setUp(self):
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
@@ -351,7 +324,7 @@ def setUp(self):
use_differentiable_head=True,
head_params={"out_features": 2},
)
- self.num_iterations = 1
+ self.args = TrainingArguments(num_iterations=1)
def test_trainer_multilabel_support_callable_as_metric(self):
dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})
@@ -365,20 +338,16 @@ def compute_metrics(y_pred, y_test):
"accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
}
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=self.model,
+ args=self.args,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metrics,
- num_iterations=self.num_iterations,
column_mapping={"text_new": "text", "label_new": "label"},
)
- trainer.freeze()
trainer.train()
-
- trainer.unfreeze(keep_body_frozen=False)
- trainer.train(5)
metrics = trainer.evaluate()
self.assertEqual(
@@ -396,7 +365,7 @@ def setUp(self):
self.dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
- self.num_iterations = 1
+ self.args = TrainingArguments(num_iterations=1)
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
@@ -425,10 +394,10 @@ def model_init(params):
def hp_name(trial):
return MyTrialShortNamer.shortname(trial.params)
- trainer = SetFitTrainer(
+ trainer = Trainer(
+ args=self.args,
train_dataset=self.dataset,
eval_dataset=self.dataset,
- num_iterations=self.num_iterations,
model_init=model_init,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -451,27 +420,26 @@ def hp_name(trial):
def test_trainer_works_with_non_default_loss_class(loss_class):
dataset = Dataset.from_dict({"text": ["a 1", "b 1", "c 1", "a 2", "b 2", "c 2"], "label": [0, 1, 2, 0, 1, 2]})
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
- trainer = SetFitTrainer(
+ args = TrainingArguments(num_iterations=1, loss=loss_class)
+ trainer = Trainer(
model=model,
+ args=args,
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=1,
- loss_class=loss_class,
)
trainer.train()
# no asserts here because this is a regression test - we only test if an exception is raised
-def test_trainer_evaluate_with_strings():
+def test_trainer_evaluate_with_strings(model: SetFitModel):
dataset = Dataset.from_dict(
{"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]}
)
- model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=model,
+ args=TrainingArguments(num_iterations=1),
train_dataset=dataset,
eval_dataset=dataset,
- num_iterations=1,
)
trainer.train()
# This used to fail due to "TypeError: can't convert np.ndarray of type numpy.str_.
@@ -485,13 +453,13 @@ def test_trainer_evaluate_multilabel_f1():
"sentence-transformers/paraphrase-albert-small-v2", multi_target_strategy="one-vs-rest"
)
- trainer = SetFitTrainer(
+ trainer = Trainer(
model=model,
+ args=TrainingArguments(num_iterations=5),
train_dataset=dataset,
eval_dataset=dataset,
metric="f1",
metric_kwargs={"average": "micro"},
- num_iterations=5,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -502,9 +470,7 @@ def test_trainer_evaluate_multilabel_f1():
def test_trainer_evaluate_on_cpu() -> None:
# This test used to fail if CUDA was available
- dataset = Dataset.from_dict(
- {"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]}
- )
+ dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]})
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True
)
@@ -513,12 +479,111 @@ def compute_metric(y_pred, y_test) -> None:
assert y_pred.device == torch.device("cpu")
return 1.0
- trainer = SetFitTrainer(
+ args = TrainingArguments(num_iterations=5)
+ trainer = Trainer(
model=model,
+ args=args,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metric,
- num_iterations=5,
)
trainer.train()
trainer.evaluate()
+
+
+def test_no_model_no_model_init():
+ with pytest.raises(RuntimeError, match="`Trainer` requires either a `model` or `model_init` argument."):
+ Trainer()
+
+
+def test_model_and_model_init(model: SetFitModel):
+ def model_init() -> SetFitModel:
+ return model
+
+ with pytest.raises(RuntimeError, match="`Trainer` requires either a `model` or `model_init` argument."):
+ Trainer(model=model, model_init=model_init)
+
+
+def test_trainer_callbacks(model: SetFitModel):
+ trainer = Trainer(model=model)
+ assert len(trainer.callback_handler.callbacks) >= 2
+ callback_names = {callback.__class__.__name__ for callback in trainer.callback_handler.callbacks}
+ assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names
+
+ class TestCallback(TrainerCallback):
+ pass
+
+ callback = TestCallback()
+ trainer.add_callback(callback)
+ assert len(trainer.callback_handler.callbacks) == len(callback_names) + 1
+ assert trainer.callback_handler.callbacks[-1] == callback
+
+ assert trainer.pop_callback(callback) == callback
+ trainer.add_callback(callback)
+ assert trainer.callback_handler.callbacks[-1] == callback
+ trainer.remove_callback(callback)
+ assert callback not in trainer.callback_handler.callbacks
+
+
+def test_trainer_warn_freeze(model: SetFitModel):
+ trainer = Trainer(model)
+ with pytest.warns(
+ DeprecationWarning,
+ match="Trainer.freeze` is deprecated and will be removed in v2.0.0 of SetFit. "
+ "Please use `SetFitModel.freeze` directly instead.",
+ ):
+ trainer.freeze()
+
+
+def test_train_with_kwargs(model: SetFitModel) -> None:
+ train_dataset = Dataset.from_dict({"text": ["positive sentence", "negative sentence"], "label": [1, 0]})
+ trainer = Trainer(model, train_dataset=train_dataset)
+ with pytest.warns(DeprecationWarning, match="`Trainer.train` does not accept keyword arguments anymore."):
+ trainer.train(num_epochs=5)
+
+
+def test_train_no_dataset(model: SetFitModel) -> None:
+ trainer = Trainer(model)
+ with pytest.raises(ValueError, match="Training requires a `train_dataset` given to the `Trainer` initialization."):
+ trainer.train()
+
+
+def test_train_amp_save(model: SetFitModel, tmp_path: Path) -> None:
+ args = TrainingArguments(output_dir=tmp_path, use_amp=True, save_steps=5, num_epochs=5)
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]})
+ trainer = Trainer(model, args=args, train_dataset=dataset, eval_dataset=dataset)
+ trainer.train()
+ assert trainer.evaluate() == {"accuracy": 1.0}
+ assert os.listdir(tmp_path) == ["step_5"]
+
+
+def test_train_load_best(model: SetFitModel, tmp_path: Path, caplog: LogCaptureFixture) -> None:
+ args = TrainingArguments(
+ output_dir=tmp_path,
+ save_steps=5,
+ eval_steps=5,
+ evaluation_strategy="steps",
+ load_best_model_at_end=True,
+ num_epochs=5,
+ )
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]})
+ trainer = Trainer(model, args=args, train_dataset=dataset, eval_dataset=dataset)
+ with caplog.at_level(logging.INFO):
+ trainer.train()
+
+ assert any("Load pretrained SentenceTransformer" in text for _, _, text in caplog.record_tuples)
+
+
+def test_evaluate_with_strings(model: SetFitModel) -> None:
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": ["positive", "positive", "negative"]})
+ trainer = Trainer(model, train_dataset=dataset, eval_dataset=dataset)
+ trainer.train()
+ metrics = trainer.evaluate()
+ assert "accuracy" in metrics
+
+
+def test_trainer_wrong_args(model: SetFitModel, tmp_path: Path) -> None:
+ dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]})
+ expected = "`args` must be a `TrainingArguments` instance imported from `setfit`."
+ with pytest.raises(ValueError, match=expected):
+ Trainer(model, dataset)
diff --git a/tests/test_trainer_distillation.py b/tests/test_trainer_distillation.py
index 82dd0b05..7ddb4ba3 100644
--- a/tests/test_trainer_distillation.py
+++ b/tests/test_trainer_distillation.py
@@ -2,39 +2,36 @@
import pytest
from datasets import Dataset
-from sentence_transformers.losses import CosineSimilarityLoss
-from setfit import DistillationSetFitTrainer, SetFitTrainer
+from setfit import DistillationTrainer, Trainer
from setfit.modeling import SetFitModel
+from setfit.training_args import TrainingArguments
-class DistillationSetFitTrainerTest(TestCase):
+class DistillationTrainerTest(TestCase):
def setUp(self):
self.teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
self.student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
- self.num_iterations = 1
+ self.args = TrainingArguments(num_iterations=1)
def test_trainer_works_with_default_columns(self):
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
# train a teacher model
- teacher_trainer = SetFitTrainer(
+ teacher_trainer = Trainer(
model=self.teacher_model,
train_dataset=dataset,
eval_dataset=dataset,
- loss_class=CosineSimilarityLoss,
metric="accuracy",
)
# Teacher Train and evaluate
teacher_trainer.train()
- metrics = teacher_trainer.evaluate()
teacher_model = teacher_trainer.model
- student_trainer = DistillationSetFitTrainer(
+ student_trainer = DistillationTrainer(
teacher_model=teacher_model,
train_dataset=dataset,
student_model=self.student_model,
eval_dataset=dataset,
- loss_class=CosineSimilarityLoss,
metric="accuracy",
)
@@ -45,37 +42,53 @@ def test_trainer_works_with_default_columns(self):
self.assertEqual(metrics["accuracy"], 1.0)
def test_trainer_raises_error_with_missing_label(self):
- dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
- trainer = DistillationSetFitTrainer(
+ labeled_dataset = Dataset.from_dict(
+ {"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}
+ )
+ # train a teacher model
+ teacher_trainer = Trainer(
+ model=self.teacher_model,
+ train_dataset=labeled_dataset,
+ eval_dataset=labeled_dataset,
+ metric="accuracy",
+ args=self.args,
+ )
+ # Teacher Train and evaluate
+ teacher_trainer.train()
+
+ unlabeled_dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
+ student_trainer = DistillationTrainer(
teacher_model=self.teacher_model,
- train_dataset=dataset,
student_model=self.student_model,
- eval_dataset=dataset,
- num_iterations=self.num_iterations,
+ train_dataset=unlabeled_dataset,
+ eval_dataset=labeled_dataset,
+ args=self.args,
)
- with pytest.raises(ValueError):
- trainer.train()
+ student_trainer.train()
+ metrics = student_trainer.evaluate()
+ print("Student results: ", metrics)
+ self.assertEqual(metrics["accuracy"], 1.0)
def test_trainer_raises_error_with_missing_text(self):
dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
- trainer = DistillationSetFitTrainer(
+ trainer = DistillationTrainer(
teacher_model=self.teacher_model,
train_dataset=dataset,
student_model=self.student_model,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
+ args=self.args,
)
with pytest.raises(ValueError):
trainer.train()
def test_column_mapping_with_missing_text(self):
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
- trainer = DistillationSetFitTrainer(
+ trainer = DistillationTrainer(
teacher_model=self.teacher_model,
train_dataset=dataset,
student_model=self.student_model,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
+ args=self.args,
column_mapping={"label_new": "label"},
)
with pytest.raises(ValueError):
@@ -84,12 +97,12 @@ def test_column_mapping_with_missing_text(self):
def test_column_mapping_multilabel(self):
dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]})
- trainer = DistillationSetFitTrainer(
+ trainer = DistillationTrainer(
teacher_model=self.teacher_model,
train_dataset=dataset,
student_model=self.student_model,
eval_dataset=dataset,
- num_iterations=self.num_iterations,
+ args=self.args,
column_mapping={"text_new": "text", "label_new": "label"},
)
@@ -102,22 +115,8 @@ def test_column_mapping_multilabel(self):
assert formatted_dataset[1]["text"] == "b"
-def train_diff(trainer: SetFitTrainer):
- # Teacher Train and evaluate
- trainer.freeze() # Freeze the head
- trainer.train() # Train only the body
-
- # Unfreeze the head and unfreeze the body -> end-to-end training
- trainer.unfreeze(keep_body_frozen=False)
-
- trainer.train(num_epochs=5)
-
-
-def train_lr(trainer: SetFitTrainer):
- trainer.train()
-
-
-@pytest.mark.parametrize(("teacher_diff", "student_diff"), [[True, False], [True, False]])
+@pytest.mark.parametrize("teacher_diff", [True, False])
+@pytest.mark.parametrize("student_diff", [True, False])
def test_differentiable_models(teacher_diff: bool, student_diff: bool) -> None:
if teacher_diff:
teacher_model = SetFitModel.from_pretrained(
@@ -125,34 +124,32 @@ def test_differentiable_models(teacher_diff: bool, student_diff: bool) -> None:
use_differentiable_head=True,
head_params={"out_features": 3},
)
- teacher_train_func = train_diff
else:
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
- teacher_train_func = train_lr
if student_diff:
student_model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-MiniLM-L3-v2",
use_differentiable_head=True,
head_params={"out_features": 3},
)
- student_train_func = train_diff
else:
student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
- student_train_func = train_lr
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
# train a teacher model
- teacher_trainer = SetFitTrainer(
+ teacher_trainer = Trainer(
model=teacher_model,
train_dataset=dataset,
eval_dataset=dataset,
metric="accuracy",
)
- teacher_train_func(teacher_trainer)
+ teacher_trainer.train()
metrics = teacher_trainer.evaluate()
+ print("Teacher results: ", metrics)
+ assert metrics["accuracy"] == 1.0
teacher_model = teacher_trainer.model
- student_trainer = DistillationSetFitTrainer(
+ student_trainer = DistillationTrainer(
teacher_model=teacher_model,
train_dataset=dataset,
student_model=student_model,
@@ -161,7 +158,7 @@ def test_differentiable_models(teacher_diff: bool, student_diff: bool) -> None:
)
# Student Train and evaluate
- student_train_func(student_trainer)
+ student_trainer.train()
metrics = student_trainer.evaluate()
print("Student results: ", metrics)
assert metrics["accuracy"] == 1.0
diff --git a/tests/test_training_args.py b/tests/test_training_args.py
new file mode 100644
index 00000000..a573e10e
--- /dev/null
+++ b/tests/test_training_args.py
@@ -0,0 +1,113 @@
+from unittest import TestCase
+
+import pytest
+
+from setfit.training_args import TrainingArguments
+
+
+class TestTrainingArguments(TestCase):
+ def test_raises_error_with_wrong_warmup_proportion(self):
+ # warmup_proportion must not be > 1.0
+ with pytest.raises(ValueError):
+ TrainingArguments(warmup_proportion=1.1)
+
+ # warmup_proportion must not be < 0.0
+ with pytest.raises(ValueError):
+ TrainingArguments(warmup_proportion=-0.1)
+
+ def test_batch_sizes(self):
+ batch_size_A = 12
+ batch_size_B = 4
+ batch_size_C = 6
+
+ args = TrainingArguments(batch_size=batch_size_A)
+ self.assertEqual(args.batch_size, (batch_size_A, batch_size_A))
+ self.assertEqual(args.embedding_batch_size, batch_size_A)
+ self.assertEqual(args.classifier_batch_size, batch_size_A)
+
+ args = TrainingArguments(batch_size=(batch_size_A, batch_size_B))
+ self.assertEqual(args.batch_size, (batch_size_A, batch_size_B))
+ self.assertEqual(args.embedding_batch_size, batch_size_A)
+ self.assertEqual(args.classifier_batch_size, batch_size_B)
+
+ args = TrainingArguments(batch_size=(batch_size_A, batch_size_B), embedding_batch_size=batch_size_C)
+ self.assertEqual(args.batch_size, (batch_size_A, batch_size_B))
+ self.assertEqual(args.embedding_batch_size, batch_size_C)
+ self.assertEqual(args.classifier_batch_size, batch_size_B)
+
+ args = TrainingArguments(batch_size=batch_size_A, embedding_batch_size=batch_size_C)
+ self.assertEqual(args.batch_size, (batch_size_A, batch_size_A))
+ self.assertEqual(args.embedding_batch_size, batch_size_C)
+ self.assertEqual(args.classifier_batch_size, batch_size_A)
+
+ def test_num_epochs(self):
+ num_epochs_A = 12
+ num_epochs_B = 4
+ num_epochs_C = 6
+
+ args = TrainingArguments(num_epochs=num_epochs_A)
+ self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_A))
+ self.assertEqual(args.embedding_num_epochs, num_epochs_A)
+ self.assertEqual(args.classifier_num_epochs, num_epochs_A)
+
+ args = TrainingArguments(num_epochs=(num_epochs_A, num_epochs_B))
+ self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_B))
+ self.assertEqual(args.embedding_num_epochs, num_epochs_A)
+ self.assertEqual(args.classifier_num_epochs, num_epochs_B)
+
+ args = TrainingArguments(num_epochs=(num_epochs_A, num_epochs_B), embedding_num_epochs=num_epochs_C)
+ self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_B))
+ self.assertEqual(args.embedding_num_epochs, num_epochs_C)
+ self.assertEqual(args.classifier_num_epochs, num_epochs_B)
+
+ args = TrainingArguments(num_epochs=num_epochs_A, embedding_num_epochs=num_epochs_C)
+ self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_A))
+ self.assertEqual(args.embedding_num_epochs, num_epochs_C)
+ self.assertEqual(args.classifier_num_epochs, num_epochs_A)
+
+ def test_learning_rates(self):
+ learning_rate_A = 1e-2
+ learning_rate_B = 1e-3
+ learning_rate_C = 1e-4
+
+ base = TrainingArguments()
+
+ args = TrainingArguments(body_learning_rate=learning_rate_A)
+ self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_A))
+ self.assertEqual(args.body_embedding_learning_rate, learning_rate_A)
+ self.assertEqual(args.body_classifier_learning_rate, learning_rate_A)
+ self.assertEqual(args.head_learning_rate, base.head_learning_rate)
+
+ args = TrainingArguments(body_learning_rate=(learning_rate_A, learning_rate_B))
+ self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_B))
+ self.assertEqual(args.body_embedding_learning_rate, learning_rate_A)
+ self.assertEqual(args.body_classifier_learning_rate, learning_rate_B)
+ self.assertEqual(args.head_learning_rate, base.head_learning_rate)
+
+ args = TrainingArguments(
+ body_learning_rate=(learning_rate_A, learning_rate_B), head_learning_rate=learning_rate_C
+ )
+ self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_B))
+ self.assertEqual(args.body_embedding_learning_rate, learning_rate_A)
+ self.assertEqual(args.body_classifier_learning_rate, learning_rate_B)
+ self.assertEqual(args.head_learning_rate, learning_rate_C)
+
+ args = TrainingArguments(
+ body_learning_rate=learning_rate_A,
+ body_embedding_learning_rate=learning_rate_B,
+ head_learning_rate=learning_rate_C,
+ )
+ # Perhaps not ideal, but body_learning_rate is never used directly:
+ self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_A))
+ self.assertEqual(args.body_embedding_learning_rate, learning_rate_B)
+ self.assertEqual(args.body_classifier_learning_rate, learning_rate_A)
+ self.assertEqual(args.head_learning_rate, learning_rate_C)
+
+ args = TrainingArguments(
+ body_classifier_learning_rate=learning_rate_A,
+ body_embedding_learning_rate=learning_rate_B,
+ head_learning_rate=learning_rate_C,
+ )
+ self.assertEqual(args.body_embedding_learning_rate, learning_rate_B)
+ self.assertEqual(args.body_classifier_learning_rate, learning_rate_A)
+ self.assertEqual(args.head_learning_rate, learning_rate_C)