Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
1acdd5c
Implement Trainer & TrainingArguments w. tests
tomaarsen Jan 11, 2023
89f4435
Readded support for hyperparameter tuning
tomaarsen Jan 11, 2023
5f2a6b3
Remove unused imports and reformat
tomaarsen Jan 11, 2023
622f33b
Preserve desired behaviour despite deprecation of keep_body_frozen pa…
tomaarsen Jan 11, 2023
ff59154
Ensure that DeprecationWarnings are displayed
tomaarsen Jan 11, 2023
3b4ef58
Set Trainer.freeze and Trainer.unfreeze methods normally
tomaarsen Jan 11, 2023
fd68274
Add TrainingArgument tests for num_epochs, batch_sizes, lr
tomaarsen Jan 11, 2023
14602ea
Convert trainer.train arguments into a softer deprecation
tomaarsen Jan 11, 2023
94106cc
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Jan 22, 2023
a39e772
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit; br…
tomaarsen Jan 23, 2023
9fc55a6
Use body/head_learning_rate instead of classifier/embedding_learning_…
tomaarsen Jan 23, 2023
7d4ad00
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Jan 23, 2023
aab2377
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 6, 2023
dee70b1
Reformat according to the newest black version
tomaarsen Feb 6, 2023
fb6547d
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 6, 2023
abbbb03
Remove "classifier" from var names in SetFitHead
tomaarsen Feb 6, 2023
12d326e
Update DeprecationWarnings to include timeline
tomaarsen Feb 6, 2023
70c0295
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 6, 2023
fc246cc
Convert training_argument imports to relative imports
tomaarsen Feb 6, 2023
57aa54f
Make conditional explicit
tomaarsen Feb 6, 2023
7ebdf93
Make conditional explicit
tomaarsen Feb 6, 2023
4695293
Use assertEqual rather than assert
tomaarsen Feb 6, 2023
4c6d0fd
Remove training_arguments from test func names
tomaarsen Feb 6, 2023
5937ec2
Replace loss_class on Trainer with loss on TrainArgs
tomaarsen Feb 6, 2023
f1e3de9
Removed dead class argument
tomaarsen Feb 6, 2023
6051095
Move SupConLoss to losses.py
tomaarsen Feb 6, 2023
bddd46a
Add deprecation to Trainer.(un)freeze
tomaarsen Feb 7, 2023
fa8a077
Prevent warning from always triggering
tomaarsen Feb 7, 2023
85a3684
Export TrainingArguments in __init__
tomaarsen Feb 7, 2023
ca625a2
Update & add important missing docstrings
tomaarsen Feb 7, 2023
868d7b7
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 7, 2023
68e9094
Use standard dataclass initialization for SetFitModel
tomaarsen Feb 8, 2023
19a6fc8
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 15, 2023
0b2efa1
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Feb 15, 2023
ca87c42
Remove duplicate space in DeprecationWarning
tomaarsen Feb 16, 2023
cc5282f
No longer require labeled data for DistillationTrainer
tomaarsen Mar 3, 2023
c6f5782
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Mar 3, 2023
36cbbfe
Update docs for v1.0.0
tomaarsen Mar 6, 2023
deb57ff
Remove references of SetFitTrainer
tomaarsen Mar 6, 2023
46922d5
Update expected test output
tomaarsen Mar 6, 2023
f43d5b2
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Apr 19, 2023
b0f9f58
Remove unused pipeline
tomaarsen Apr 19, 2023
339f332
Execute deprecations
tomaarsen Apr 19, 2023
9e0bf78
Stop importing now-removed function
tomaarsen Apr 19, 2023
ecabbcf
Initial setup for logging & callbacks
tomaarsen Jul 6, 2023
6e6720b
Move sentence-transformer training into trainer.py
tomaarsen Jul 6, 2023
826eb53
Add checkpointing, support EarlyStoppingCallback
tomaarsen Jul 28, 2023
019a971
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Jul 29, 2023
1930973
Run formatting
tomaarsen Jul 29, 2023
e4f3f76
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Jul 29, 2023
0f66109
Merge pull request #4 from tomaarsen/feat/logging_callbacks
tomaarsen Jul 29, 2023
a87cdc0
Add additional trainer tests
tomaarsen Jul 29, 2023
d418759
Use isinstance, required by flake8 release from 1hr ago
tomaarsen Jul 29, 2023
08892f6
sampler for refactor WIP
danstan5 Sep 14, 2023
0a2b664
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Oct 17, 2023
429de0f
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
173f084
Run formatters
tomaarsen Oct 17, 2023
c23959a
Remove tests from modeling.py
tomaarsen Oct 17, 2023
0fa3870
Add missing type hint
tomaarsen Oct 17, 2023
3969f38
Adjust test to still pass if W&B/Tensorboard are installed
tomaarsen Oct 17, 2023
567f1c9
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
851f0bb
The log/eval/save steps should be saved on the state instead
tomaarsen Oct 17, 2023
67ddedc
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
d37ee09
sampler logic fix "unique" strategy
danstan5 Oct 19, 2023
0ef8837
add sampler tests (not complete)
danstan5 Oct 19, 2023
131aa26
add sampling_strategy into TrainingArguments
danstan5 Oct 19, 2023
c6c6228
Merge branch 'refactor-sampling' of https://github.com/danstan5/setfi…
danstan5 Oct 19, 2023
7431005
num_iterations removed from TrainingArguments
danstan5 Oct 19, 2023
3bd2acc
run_fewshot compatible with <v.1.0.0
danstan5 Oct 20, 2023
3d07e6c
Run make style
tomaarsen Oct 25, 2023
978daee
Use "no" as the default evaluation_strategy
tomaarsen Oct 25, 2023
2802a3f
Move num_iterations back to TrainingArguments
tomaarsen Oct 25, 2023
391f991
Fix broken trainer tests due to new default sampling
tomaarsen Oct 25, 2023
f8b7253
Use the Contrastive Dataset for Distillation
tomaarsen Oct 25, 2023
38e9607
Set the default logging steps at 50
tomaarsen Oct 25, 2023
4ead15d
Add max_steps argument to TrainingArguments
tomaarsen Oct 25, 2023
eb70336
Change max_steps conditional
tomaarsen Oct 25, 2023
3478799
Merge pull request #5 from danstan5/refactor-sampling
tomaarsen Oct 27, 2023
d9c4a05
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen Nov 9, 2023
5b39f06
Seeds are now correctly applied for reproducibility
tomaarsen Nov 9, 2023
7c3feed
Don't scale gradients during evaluation
tomaarsen Nov 9, 2023
cdc8979
Use evaluation_strategy="steps" if eval_steps is set
tomaarsen Nov 9, 2023
e040167
Run formatting
tomaarsen Nov 9, 2023
d2f2489
Implement SetFit for ABSA from Intel Labs (#6)
tomaarsen Nov 9, 2023
5c4569d
Import optuna under TYPE_CHECKING
tomaarsen Nov 9, 2023
ceeb725
Remove unused import, reformat
tomaarsen Nov 9, 2023
5c669b5
Add MANIFEST.in with model_card_template
tomaarsen Nov 9, 2023
8e201e5
Don't require transformers TrainingArgs in tests
tomaarsen Nov 9, 2023
6ae5045
Update URLs in setup.py
tomaarsen Nov 9, 2023
ecaabb4
Increase min hf_hub version to 0.12.0 for SoftTemporaryDirectory
tomaarsen Nov 9, 2023
4e79397
Include MANIFEST.in data via `include_package_data=True`
tomaarsen Nov 9, 2023
65aff32
Use kwargs instead of args in super call
tomaarsen Nov 9, 2023
eeeac55
Use v0.13.0 as min. version as huggingface/huggingface_hub#1315
tomaarsen Nov 9, 2023
3214f1b
Use en_core_web_sm for tests
tomaarsen Nov 10, 2023
2b78bb0
Remove incorrect spacy_model from AspectModel/PolarityModel
tomaarsen Nov 10, 2023
b68f655
Rerun formatting
tomaarsen Nov 10, 2023
d85f0d9
Run CI on pre branch & workflow dispatch
tomaarsen Nov 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ on:
branches:
- main
- v*-release
- v*-pre
pull_request:
branches:
- main
- v*-pre
workflow_dispatch:

jobs:

Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ on:
branches:
- main
- v*-release
- v*-pre
pull_request:
branches:
- main
- v*-pre
workflow_dispatch:

jobs:

Expand Down Expand Up @@ -40,6 +43,8 @@ jobs:
run: |
python -m pip install --no-cache-dir --upgrade pip
python -m pip install --no-cache-dir ${{ matrix.requirements }}
python -m spacy download en_core_web_lg
python -m spacy download en_core_web_sm
if: steps.restore-cache.outputs.cache-hit != 'true'

- name: Install the checked-out setfit
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,7 @@ scripts/tfew/run_tmux.sh
# macOS
.DS_Store
.vscode/settings.json

# Common SetFit Trainer logging folders
wandb
runs/
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include src/setfit/span/model_card_template.md
120 changes: 56 additions & 64 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ The examples below provide a quick overview on the various features supported in
`setfit` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes:

* `SetFitModel`: a wrapper that combines a pretrained body from `sentence_transformers` and a classification head from either [`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) or [`SetFitHead`](https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py) (a differentiable head built upon `PyTorch` with similar APIs to `sentence_transformers`).
* `SetFitTrainer`: a helper class that wraps the fine-tuning process of SetFit.
* `Trainer`: a helper class that wraps the fine-tuning process of SetFit.

Here is an end-to-end example using a classification head from `scikit-learn`:


```python
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset


# Load a dataset from the Hugging Face Hub
Expand All @@ -61,17 +59,19 @@ eval_dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
args = TrainingArguments(
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1 # The number of epochs to use for contrastive learning
)

trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
Expand All @@ -81,7 +81,7 @@ metrics = trainer.evaluate()
# Push model to the Hub
trainer.push_to_hub("my-awesome-setfit-model")

# Download from Hub and run inference
# Download from Hub
model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model")
# Run inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
Expand All @@ -92,9 +92,7 @@ Here is an end-to-end example using `SetFitHead`:

```python
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset


# Load a dataset from the Hugging Face Hub
Expand All @@ -103,6 +101,7 @@ dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]
num_classes = 2

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
Expand All @@ -111,36 +110,26 @@ model = SetFitModel.from_pretrained(
head_params={"out_features": num_classes},
)

# Create trainer
trainer = SetFitTrainer(
args = TrainingArguments(
body_learning_rate=2e-5,
head_learning_rate=1e-2,
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=(1, 25), # For finetuning the embeddings and training the classifier, respectively
l2_weight=0.0,
end_to_end=False, # Don't train the classifier end-to-end, i.e. only train the head
)

trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
trainer.freeze() # Freeze the head
trainer.train() # Train only the body

# Unfreeze the head and freeze the body -> head-only training
trainer.unfreeze(keep_body_frozen=True)
# or
# Unfreeze the head and unfreeze the body -> end-to-end training
trainer.unfreeze(keep_body_frozen=False)

trainer.train(
num_epochs=25, # The number of epochs to train the head or the whole model (body and head)
batch_size=16,
body_learning_rate=1e-5, # The body's learning rate
learning_rate=1e-2, # The head's learning rate
l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01.
)
trainer.train()
metrics = trainer.evaluate()

# Push model to the Hub
Expand Down Expand Up @@ -175,7 +164,7 @@ This will initialise a multilabel classification head from `sklearn` - the follo
* `multi-output`: uses a `MultiOutputClassifier` head.
* `classifier-chain`: uses a `ClassifierChain` head.

From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual.
From here, you can instantiate a `Trainer` using the same example above, and train it as usual.

#### Example using the differentiable `SetFitHead`:

Expand All @@ -196,7 +185,6 @@ model = SetFitModel.from_pretrained(
SetFit can also be applied to scenarios where no labels are available. To do so, create a synthetic dataset of training examples:

```python
from datasets import Dataset
from setfit import get_templated_dataset

candidate_labels = ["negative", "positive"]
Expand All @@ -206,22 +194,22 @@ train_dataset = get_templated_dataset(candidate_labels=candidate_labels, sample_
This will create examples of the form `"This sentence is {}"`, where the `{}` is filled in with one of the candidate labels. From here you can train a SetFit model as usual:

```python
from setfit import SetFitModel, SetFitTrainer
from setfit import SetFitModel, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
trainer = SetFitTrainer(
trainer = Trainer(
model=model,
train_dataset=train_dataset
)
trainer.train()
```

We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with Bart), while being 5x faster to generate predictions with.
We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with BART), while being 5x faster to generate predictions with.


### Running hyperparameter search

`SetFitTrainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:
`Trainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:

```bash
python -m pip install setfit[optuna]
Expand Down Expand Up @@ -267,23 +255,23 @@ def hp_space(trial): # Training parameters

**Note:** In practice, we found `num_iterations` to be the most important hyperparameter for the contrastive learning process.

The next step is to instantiate a `SetFitTrainer` and call `hyperparameter_search()`:
The next step is to instantiate a `Trainer` and call `hyperparameter_search()`:

```python
from datasets import Dataset
from setfit import SetFitTrainer
from setfit import Trainer

dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)

trainer = SetFitTrainer(
trainer = Trainer(
train_dataset=dataset,
eval_dataset=dataset,
model_init=model_init,
column_mapping={"text_new": "text", "label_new": "label"},
)
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=20)
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5)
```

Finally, you can apply the hyperparameters you found to the trainer, and lock in the optimal model, before training for
Expand All @@ -300,9 +288,8 @@ If you have access to unlabeled data, you can use knowledge distillation to comp

```python
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer, sample_dataset
from setfit import SetFitModel, Trainer, DistillationTrainer, sample_dataset
from setfit.training_args import TrainingArguments

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
Expand All @@ -320,34 +307,37 @@ teacher_model = SetFitModel.from_pretrained(
)

# Create trainer for teacher model
teacher_trainer = SetFitTrainer(
teacher_trainer = Trainer(
model=teacher_model,
train_dataset=train_dataset_teacher,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()

# Load small student model
student_model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
batch_size=16,
num_iterations=20,
num_epochs=1
)

# Create trainer for knowledge distillation
student_trainer = DistillationSetFitTrainer(
student_trainer = DistillationTrainer(
teacher_model=teacher_model,
train_dataset=train_dataset_student,
student_model=student_model,
args=args,
train_dataset=train_dataset_student,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20,
num_epochs=1,
)

# Train student with knowledge distillation
student_trainer.train()
student_metrics = student_trainer.evaluate()
```


Expand Down Expand Up @@ -403,13 +393,15 @@ make style && make quality

## Citation

```@misc{https://doi.org/10.48550/arxiv.2209.11055,
```
@misc{https://doi.org/10.48550/arxiv.2209.11055,
doi = {10.48550/ARXIV.2209.11055},
url = {https://arxiv.org/abs/2209.11055},
author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Efficient Few-Shot Learning Without Prompts},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}}
copyright = {Creative Commons Attribution 4.0 International}
}
```
4 changes: 4 additions & 0 deletions docs/source/en/api/main.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
# SetFitHead

[[autodoc]] SetFitHead

# AbsaModel

[[autodoc]] AbsaModel
12 changes: 8 additions & 4 deletions docs/source/en/api/trainer.mdx
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

# SetFitTrainer
# Trainer

[[autodoc]] SetFitTrainer
[[autodoc]] Trainer

# DistillationSetFitTrainer
# DistillationTrainer

[[autodoc]] DistillationSetFitTrainer
[[autodoc]] DistillationTrainer

# AbsaTrainer

[[autodoc]] AbsaTrainer
Loading