-
Notifications
You must be signed in to change notification settings - Fork 253
Prepare v1.0.0 release - Trainer, TrainingArguments, SetFitABSA, logging, evaluation during training, callbacks, docs
#439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
218 commits
Select commit
Hold shift + click to select a range
1acdd5c
Implement Trainer & TrainingArguments w. tests
tomaarsen 89f4435
Readded support for hyperparameter tuning
tomaarsen 5f2a6b3
Remove unused imports and reformat
tomaarsen 622f33b
Preserve desired behaviour despite deprecation of keep_body_frozen pa…
tomaarsen ff59154
Ensure that DeprecationWarnings are displayed
tomaarsen 3b4ef58
Set Trainer.freeze and Trainer.unfreeze methods normally
tomaarsen fd68274
Add TrainingArgument tests for num_epochs, batch_sizes, lr
tomaarsen 14602ea
Convert trainer.train arguments into a softer deprecation
tomaarsen 94106cc
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen a39e772
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit; br…
tomaarsen 9fc55a6
Use body/head_learning_rate instead of classifier/embedding_learning_…
tomaarsen 7d4ad00
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen aab2377
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen dee70b1
Reformat according to the newest black version
tomaarsen fb6547d
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen abbbb03
Remove "classifier" from var names in SetFitHead
tomaarsen 12d326e
Update DeprecationWarnings to include timeline
tomaarsen 70c0295
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen fc246cc
Convert training_argument imports to relative imports
tomaarsen 57aa54f
Make conditional explicit
tomaarsen 7ebdf93
Make conditional explicit
tomaarsen 4695293
Use assertEqual rather than assert
tomaarsen 4c6d0fd
Remove training_arguments from test func names
tomaarsen 5937ec2
Replace loss_class on Trainer with loss on TrainArgs
tomaarsen f1e3de9
Removed dead class argument
tomaarsen 6051095
Move SupConLoss to losses.py
tomaarsen bddd46a
Add deprecation to Trainer.(un)freeze
tomaarsen fa8a077
Prevent warning from always triggering
tomaarsen 85a3684
Export TrainingArguments in __init__
tomaarsen ca625a2
Update & add important missing docstrings
tomaarsen 868d7b7
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen 68e9094
Use standard dataclass initialization for SetFitModel
tomaarsen 19a6fc8
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen 0b2efa1
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen ca87c42
Remove duplicate space in DeprecationWarning
tomaarsen cc5282f
No longer require labeled data for DistillationTrainer
tomaarsen c6f5782
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen 36cbbfe
Update docs for v1.0.0
tomaarsen deb57ff
Remove references of SetFitTrainer
tomaarsen 46922d5
Update expected test output
tomaarsen f43d5b2
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen b0f9f58
Remove unused pipeline
tomaarsen 339f332
Execute deprecations
tomaarsen 9e0bf78
Stop importing now-removed function
tomaarsen ecabbcf
Initial setup for logging & callbacks
tomaarsen 6e6720b
Move sentence-transformer training into trainer.py
tomaarsen 826eb53
Add checkpointing, support EarlyStoppingCallback
tomaarsen 019a971
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen 1930973
Run formatting
tomaarsen e4f3f76
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen 0f66109
Merge pull request #4 from tomaarsen/feat/logging_callbacks
tomaarsen a87cdc0
Add additional trainer tests
tomaarsen d418759
Use isinstance, required by flake8 release from 1hr ago
tomaarsen 08892f6
sampler for refactor WIP
danstan5 0a2b664
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen 429de0f
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen 173f084
Run formatters
tomaarsen c23959a
Remove tests from modeling.py
tomaarsen 0fa3870
Add missing type hint
tomaarsen 3969f38
Adjust test to still pass if W&B/Tensorboard are installed
tomaarsen 567f1c9
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen 851f0bb
The log/eval/save steps should be saved on the state instead
tomaarsen 67ddedc
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen d37ee09
sampler logic fix "unique" strategy
danstan5 0ef8837
add sampler tests (not complete)
danstan5 131aa26
add sampling_strategy into TrainingArguments
danstan5 c6c6228
Merge branch 'refactor-sampling' of https://github.com/danstan5/setfi…
danstan5 7431005
num_iterations removed from TrainingArguments
danstan5 3bd2acc
run_fewshot compatible with <v.1.0.0
danstan5 3d07e6c
Run make style
tomaarsen 978daee
Use "no" as the default evaluation_strategy
tomaarsen 2802a3f
Move num_iterations back to TrainingArguments
tomaarsen 391f991
Fix broken trainer tests due to new default sampling
tomaarsen f8b7253
Use the Contrastive Dataset for Distillation
tomaarsen 38e9607
Set the default logging steps at 50
tomaarsen 4ead15d
Add max_steps argument to TrainingArguments
tomaarsen eb70336
Change max_steps conditional
tomaarsen 3478799
Merge pull request #5 from danstan5/refactor-sampling
tomaarsen d9c4a05
Merge branch 'main' of https://github.com/huggingface/setfit into ref…
tomaarsen 5b39f06
Seeds are now correctly applied for reproducibility
tomaarsen d8177db
Add files via upload
MosheWasserb 7c3feed
Don't scale gradients during evaluation
tomaarsen cdc8979
Use evaluation_strategy="steps" if eval_steps is set
tomaarsen e040167
Run formatting
tomaarsen d2f2489
Implement SetFit for ABSA from Intel Labs (#6)
tomaarsen 5c4569d
Import optuna under TYPE_CHECKING
tomaarsen ceeb725
Remove unused import, reformat
tomaarsen 5c669b5
Add MANIFEST.in with model_card_template
tomaarsen 8e201e5
Don't require transformers TrainingArgs in tests
tomaarsen 6ae5045
Update URLs in setup.py
tomaarsen ecaabb4
Increase min hf_hub version to 0.12.0 for SoftTemporaryDirectory
tomaarsen 4e79397
Include MANIFEST.in data via `include_package_data=True`
tomaarsen 65aff32
Use kwargs instead of args in super call
tomaarsen eeeac55
Use v0.13.0 as min. version as huggingface/huggingface_hub#1315
tomaarsen 3214f1b
Use en_core_web_sm for tests
tomaarsen 2b78bb0
Remove incorrect spacy_model from AspectModel/PolarityModel
tomaarsen b68f655
Rerun formatting
tomaarsen d85f0d9
Run CI on pre branch & workflow dispatch
tomaarsen b636cd7
Merge pull request #265 from tomaarsen/refactor_v2
tomaarsen 81952bf
Set development version to 1.0.0.dev0
tomaarsen 5b76361
Extend training argument tests
tomaarsen 54b5d55
Only create evaluation dataloader if eval_strat is set
tomaarsen 4788713
Run formatting
tomaarsen 74a5b7c
max_steps isn't optional
tomaarsen 7ef5bbc
Fix indentation of docstring
tomaarsen ca3030f
Apply fixes for HPO
tomaarsen f114572
Remove outdated tests
tomaarsen 8d118d5
Use SetFitModel as the model in CallbackHandler
tomaarsen b964238
Correctly set the total training steps based on args.max_steps
tomaarsen 2f06847
Add missing comma
tomaarsen fcb38fc
Capitalize first letter of sentence
tomaarsen 9fe6f0d
Run formatting
tomaarsen da338ad
Remove unused arguments in tests
tomaarsen be4c900
Initial documentation for SetFit v1.0.0
tomaarsen fb42dd7
Update the documentation related workflows
tomaarsen 04c45d7
Merge branch 'main' of https://github.com/huggingface/setfit into v1.…
tomaarsen bfe6ef6
Add figure to zero-shot how-to guide
tomaarsen 773b860
Add docs notebook building support
tomaarsen 883889c
Update broken, redirecting links
tomaarsen b4e5db0
polarity -> label
tomaarsen dbd707b
Mention extra download requirements for ABSA
tomaarsen 552cecc
Merge branch 'main' of https://github.com/huggingface/setfit into v1.…
tomaarsen 0d32dd1
Implement 'batch_size' on model.predict
tomaarsen 392cf0d
Add batch sizes to toctree
tomaarsen ee00c40
Merge pull request #443 from tomaarsen/feat/expose_batch_size
tomaarsen 17d6513
Save model head on CPU
tomaarsen dca6fd0
torch.Module -> torch.nn.Module
tomaarsen 4123609
Merge pull request #444 from tomaarsen/feat/cpu_load_diff_head
tomaarsen b5a6361
Add new top-level header to docs reference
tomaarsen 6ca989e
Update docs about return value of metric function
tomaarsen 93c52dd
Add "use_auth_token" to migration guide
tomaarsen 44daad4
Allow 'device' on SetFitModel.from_pretrained()
tomaarsen 6f06204
Add tests for SetFitABSA as well
tomaarsen c41b7c3
Merge pull request #445 from tomaarsen/feat/load_on_device
tomaarsen b8da4a3
Update which trainer methods are documented
tomaarsen 639750f
Link to the Hub in d ocstring
tomaarsen 9ffc262
Add scikit-learn API version of SetFit to related work
tomaarsen 2ef61bb
Batch Sizes + "for Inference"
tomaarsen b8b8417
Make first column bold in Sampling Strategy table
tomaarsen a2fa84f
Remove comment about Google Colab with Python 3.7
tomaarsen e2cf782
Rename file, remove distilBERT, fix typos
tomaarsen c5ea28d
Merge branch 'v1.0.0-pre' of https://github.com/huggingface/setfit in…
tomaarsen c7f49ad
Add ONNX tutorial to docs
tomaarsen 193f83f
Merge pull request #435 from huggingface/moshe
tomaarsen 8e0c55c
Update docstring of from_pretrained!
tomaarsen 19d6d9d
Revert "Update docstring of from_pretrained!"
tomaarsen 5058e31
Update docstring of from_pretrained!
tomaarsen 3e829ba
Update docstring edits of from_pretrained
tomaarsen d476ce0
Correctly format docstrings for API reference
tomaarsen dac5221
Also maybe log, evaluate & save at epoch end
tomaarsen 5edf540
Update README in preparation for documentation
tomaarsen c1b2f20
Link to scripts rather than scripts/setfit
tomaarsen 70bd935
Ensure correct device of "best model at the end"
tomaarsen c93b55a
Add "labels" in a configuration file
tomaarsen 4c0f152
Resolve flake issues
tomaarsen 1af337f
Add labels to migration guide
tomaarsen 3876d62
Update returns docstring for predict & __call__
tomaarsen 71be7a5
Use ndim rather than "multi_target_strategy is None"
tomaarsen 298fe39
Merge pull request #447 from tomaarsen/feat/configuration
tomaarsen cc97d10
Allow passing strings to model.predict
tomaarsen d85d537
Merge pull request #448 from tomaarsen/feat/predict_singular
tomaarsen 62f7eea
Allow partial column mappings
tomaarsen 6f226e5
Allow normalize_embeddings with diff head
tomaarsen f04e997
Merge pull request #449 from tomaarsen/feat/partial_col_mapping
tomaarsen f021e13
Merge pull request #450 from tomaarsen/fix/normalize_with_diff_head
tomaarsen 313bffc
Update phrasing in SetFit intro
tomaarsen 9976bb5
Heavily improve automatic model card generation
tomaarsen bbad20d
Rewrite first paragraph somewhat
tomaarsen 6cd51ed
Resolve issue with multi-label
tomaarsen 4a6852b
Set inference=False for multilabel models
tomaarsen 671611e
Add model card tests
tomaarsen 5f36d0e
Reformat
tomaarsen 086ee02
Satisfy flake8
tomaarsen 4990b09
Make model card generation more robust
tomaarsen 58d5815
Allow compute_metric to return a non-dict
tomaarsen 61cf947
Update tests as datasets are now column-mapped at init
tomaarsen 751ba80
Avoid bare except
tomaarsen 4a7255d
Avoid walrus operator for now for Python 3.7 compat
tomaarsen 8032131
Increase minimal datasets version for dataset inferring
tomaarsen 54d7127
Keep datasets version low, but skip test if < 2.14
tomaarsen 87420b3
Add reason to skipif
tomaarsen a73cb69
Always return dicts in id2label/label2id
tomaarsen 859691b
Introduce "no aspect", "aspect" labels for AspectModel
tomaarsen f9e6acb
Extend model card generation to ABSA + Tests
tomaarsen 4c4a9aa
Correctly use create_model_card in ABSA test
tomaarsen 0beedf2
Speed up model card tests for ABSA
tomaarsen 55e9380
Set default W&B project as "setfit" if not set via ENV var yet
tomaarsen 8ad41a8
Run formatting
tomaarsen a65a4e7
Remove the old ABSA model card template
tomaarsen d0cda23
Set fsspec<2023.12.0 due to breakages with older datasets
tomaarsen 7dcc35e
Make some model_card_data modifications for ABSA only once
tomaarsen 2cd004f
Reorder arguments
tomaarsen 3a5356e
Update absa models in docs
tomaarsen d513064
Move import
tomaarsen 1a60d09
Add model_card_data to from_pretrained
tomaarsen 681f8db
Remove useless brackets
tomaarsen d61ec69
Correct model_card docstring
tomaarsen 77aff7c
Only use the gold aspects/labels for training the polarity model
tomaarsen 2c09cfb
Merge branch 'v1.0.0-pre' of https://github.com/huggingface/setfit in…
tomaarsen 9dbca0b
Use text classification dataset examples
tomaarsen 368155c
Add model card generation documentation
tomaarsen 9c87685
Add spaCy version to ABSA model card
tomaarsen b257e82
Map to int to avoid potential warning
tomaarsen 3a6e23a
Store used spaCy model configuration in aspect/polarity model
tomaarsen 3bc0125
Correctly test against log
tomaarsen 35c7461
Reformat test imports
tomaarsen 5d04965
Try to resolve failing test on CI
tomaarsen 815e45a
debugging: test against trainer dataset size
tomaarsen 267e21d
Ignore log tests
tomaarsen 243fcb2
Add 'eval_max_steps', reduce load time before train
tomaarsen 8dd930c
Merge branch 'main' of https://github.com/huggingface/setfit into v1.…
tomaarsen c039e17
Also pass metric_kwargs to custom metric callable
tomaarsen 8d5fc46
Merge pull request #456 from tomaarsen/feat/use_metric_kwargs_with_cu…
tomaarsen 37592eb
Use gold aspects as True, and non-overlapping pred aspects as False
tomaarsen 522a420
Add missing +1 on edge case in Aspect Extractor
tomaarsen ebaf5a2
Update ABSA documentation slightly
tomaarsen 937c491
Specify AbsaTrainer methods
tomaarsen 3152e49
Update v1.0.0 migration; expand changelog
tomaarsen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| include src/setfit/model_card_template.md |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # docstyle-ignore | ||
| INSTALL_CONTENT = """ | ||
| # SetFit installation | ||
| ! pip install setfit | ||
| # To install from source instead of the last release, comment the command above and uncomment the following one. | ||
| # ! pip install git+https://github.com/huggingface/setfit.git | ||
| """ | ||
|
|
||
| notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
|
|
||
| # SetFit Sampling Strategies | ||
|
|
||
| SetFit supports various contrastive pair sampling strategies in [`TrainingArguments`]. In this conceptual guide, we will learn about the following four sampling strategies: | ||
|
|
||
| 1. `"oversampling"` (the default) | ||
| 2. `"undersampling"` | ||
| 3. `"unique"` | ||
| 4. `"num_iterations"` | ||
|
|
||
| Consider first reading the [SetFit conceptual guide](../setfit) for a background on contrastive learning and positive & negative pairs. | ||
|
|
||
| ## Running example | ||
|
|
||
| Throughout this conceptual guide, we will use to the following example scenario: | ||
|
|
||
| * 3 classes: "happy", "content", and "sad". | ||
| * 20 total samples: 8 "happy", 4 "content", and 8 "sad" samples. | ||
|
|
||
| Considering that a sentence pair of `(X, Y)` and `(Y, X)` result in the same embedding distance/loss, we only want to consider one of those two cases. Furthermore, we don't want pairs where both sentences are the same, e.g. no `(X, X)`. | ||
|
|
||
| The resulting positive and negative pairs can be visualized in a table like below. The `+` and `-` represent positive and negative pairs, respectively. Furthermore, `h-n` represents the n-th "happy" sentence, `c-n` the n-th "content" sentence, and `s-n` the n-th "sad" sentence. Note that the area below the diagonal is not used as `(X, Y)` and `(Y, X)` result in the same embedding distances, and that the diagonal is not used as we are not interested in pairs where both sentences are identical. | ||
|
|
||
| | |h-1|h-2|h-3|h-4|h-5|h-6|h-7|h-8|c-1|c-2|c-3|c-4|s-1|s-2|s-3|s-4|s-5|s-6|s-7|s-8| | ||
| |-------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | ||
| |**h-1**| | + | + | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-2**| | | + | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-3**| | | | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-4**| | | | | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-5**| | | | | | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-6**| | | | | | | + | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-7**| | | | | | | | + | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**h-8**| | | | | | | | | - | - | - | - | - | - | - | - | - | - | - | - | | ||
| |**c-1**| | | | | | | | | | + | + | + | - | - | - | - | - | - | - | - | | ||
| |**c-2**| | | | | | | | | | | + | + | - | - | - | - | - | - | - | - | | ||
| |**c-3**| | | | | | | | | | | | + | - | - | - | - | - | - | - | - | | ||
| |**c-4**| | | | | | | | | | | | | - | - | - | - | - | - | - | - | | ||
| |**s-1**| | | | | | | | | | | | | | + | + | + | + | + | + | + | | ||
| |**s-2**| | | | | | | | | | | | | | | + | + | + | + | + | + | | ||
| |**s-3**| | | | | | | | | | | | | | | | + | + | + | + | + | | ||
| |**s-4**| | | | | | | | | | | | | | | | | + | + | + | + | | ||
| |**s-5**| | | | | | | | | | | | | | | | | | + | + | + | | ||
| |**s-6**| | | | | | | | | | | | | | | | | | | + | + | | ||
| |**s-7**| | | | | | | | | | | | | | | | | | | | + | | ||
| |**s-8**| | | | | | | | | | | | | | | | | | | | | | ||
|
|
||
| As shown in the prior table, we have 28 positive pairs for "happy", 6 positive pairs for "content", and another 28 positive pairs for "sad". In total, this is 62 positive pairs. Also, we have 32 negative pairs between "happy" and "content", 64 negative pairs between "happy" and "sad", and 32 negative pairs between "content" and "sad". In total, this is 128 negative pairs. | ||
|
|
||
| ## Oversampling | ||
|
|
||
| By default, SetFit applies the oversampling strategy for its contrastive pairs. This strategy samples an equal amount of positive and negative training pairs, oversampling the minority pair type to match that of the majority pair type. As the number of negative pairs is generally larger than the number of positive pairs, this usually involves oversampling the positive pairs. | ||
|
|
||
| In our running example, this would involve oversampling the 62 positive pairs up to 128, resulting in one epoch of 128 + 128 = 256 pairs. In summary: | ||
|
|
||
| * ✅ An equal amount of positive and negative pairs are sampled. | ||
| * ✅ Every possible pair is used. | ||
| * ❌ There is some data duplication. | ||
|
|
||
| ## Undersampling | ||
|
|
||
| Like oversampling, this strategy samples an equal amount of positive and negative training pairs. However, it undersamples the majority pair type to match that of the minority pair type. This usually involves undersampling the negative pairs to match the positive pairs. | ||
|
|
||
| In our running example, this would involve undersampling the 128 negative pairs down to 62, resulting in one epoch of 62 + 62 = 124 pairs. In summary: | ||
|
|
||
| * ✅ An equal amount of positive and negative pairs are sampled. | ||
| * ❌ **Not** every possible pair is used. | ||
| * ✅ There is **no** data duplication. | ||
|
|
||
| ## Unique | ||
|
|
||
| Thirdly, the unique strategy does not sample an equal amount of positive and negative training pairs. Instead, it simply samples all possible pairs exactly once. No form of oversampling or undersampling is used here. | ||
|
|
||
| In our running example, this would involve sampling all negative and positive pairs, resulting in one epoch of 62 + 128 = 190 pairs. In summary: | ||
|
|
||
| * ❌ **Not** an equal amount of positive and negative pairs are sampled. | ||
| * ✅ Every possible pair is used. | ||
| * ✅ There is **no** data duplication. | ||
|
|
||
| ## `num_iterations` | ||
|
|
||
| Lastly, SetFit can still be used with a deprecated sampling strategy involving the `num_iterations` training argument. Unlike the other sampling strategies, this strategy does not involve the number of possible pairs. Instead, it samples `num_iterations` positive pairs and `num_iterations` negative pairs for each training sample. | ||
|
|
||
| In our running example, if we assume `num_iterations=20`, then we would sample 20 positive pairs and 20 negative pairs per training sample. Because there's 20 samples, this involves (20 + 20) * 20 = 800 pairs. Because there are only 190 unique pairs, this certainly involves some data duplication. However, it does not guarantee that every possible pair is used. In summary: | ||
|
|
||
| * ✅ **Not** an equal amount of positive and negative pairs are sampled. | ||
| * ❌ Not necessarily every possible pair is used. | ||
| * ❌ There is some data duplication. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
|
|
||
| # Sentence Transformers Finetuning (SetFit) | ||
|
|
||
| SetFit is a model framework to efficiently train text classification models with surprisingly little training data. For example, with only 8 labeled examples per class on the Customer Reviews (CR) sentiment dataset, SetFit is competitive with fine-tuning RoBERTa Large on the full training set of 3k examples. Furthermore, SetFit is fast to train and run inference with, and can easily support multilingual tasks. | ||
|
|
||
| Every SetFit model consists of two parts: a **sentence transformer** embedding model (the body) and a **classifier** (the head). These two parts are trained in two separate phases: the **embedding finetuning phase** and the **classifier training phase**. This conceptual guide will elaborate on the intuition between these phases, and why SetFit works so well. | ||
|
|
||
| ## Embedding finetuning phase | ||
|
|
||
| The first phase has one primary goal: finetune a sentence transformer embedding model to produce useful embeddings for *our* classification task. The [Hugging Face Hub](https://huggingface.co/models?library=sentence-transformers) already has thousands of sentence transformer available, many of which have been trained to very accurately group the embeddings of texts with similar semantic meaning. | ||
|
|
||
| However, models that are good at Semantic Textual Similarity (STS) are not necessarily immediately good at *our* classification task. For example, according to an embedding model, the sentence of 1) `"He biked to work."` will be much more similar to 2) `"He drove his car to work."` than to 3) `"Peter decided to take the bicycle to the beach party!"`. But if our classification task involves classifying texts into transportation modes, then we want our embedding model to place sentences 1 and 3 closely together, and 2 further away. | ||
|
|
||
| To do so, we can finetune the chosen sentence transformer embedding model. The goal here is to nudge the model to use its pretrained knowledge in a different way that better aligns with our classification task, rather than making the completely forget what it has learned. | ||
|
|
||
| For finetuning, SetFit uses **contrastive learning**. This training approach involves creating **positive and negative pairs** of sentences. A sentence pair will be positive if both of the sentences are of the same class, and negative otherwise. For example, in the case of binary "positive"-"negative" sentiment analysis, `("The movie was awesome", "I loved it")` is a positive pair, and `("The movie was awesome", "It was quite disappointing")` is a negative pair. | ||
|
|
||
| During training, the embedding model receives these pairs, and will convert the sentences to embeddings. If the pair is positive, then it will pull on the model weights such that the text embeddings will be more similar, and vice versa for a negative pair. Through this approach, sentences with the same label will be embedded more similarly, and sentences with different labels less similarly. | ||
|
|
||
| Conveniently, this contrastive learning works with pairs rather than individual samples, and we can create plenty of unique pairs from just a few samples. For example, given 8 positive sentences and 8 negative sentences, we can create 28 positive pairs and 64 negative pairs for 92 unique training pairs. This grows exponentially to the number of sentences and classes, and that is why SetFit can train with just a few examples and still correctly finetune the sentence transformer embedding model. However, we should still be wary of overfitting. | ||
|
|
||
| ## Classifier training phase | ||
|
|
||
| Once the sentence transformer embedding model has been finetuned for our task at hand, we can start training the classifier. This phase has one primary goal: create a good mapping from the sentence transformer embeddings to the classes. | ||
|
|
||
| Unlike with the first phase, training the classifier is done from scratch and using the labeled samples directly, rather than using pairs. By default, the classifier is a simple **logistic regression** classifier from scikit-learn. First, all training sentences are fed through the now-finetuned sentence transformer embedding model, and then the sentence embeddings and labels are used to fit the logistic regression classifier. The result is a strong and efficient classifier. | ||
|
|
||
| Using these two parts, SetFit models are efficient, performant and easy to train, even on CPU-only devices. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.