Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions scripts/setfit/run_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pathlib
import sys
import time
import warnings
from collections import Counter
from shutil import copyfile
Expand Down Expand Up @@ -55,6 +56,8 @@ def parse_args():
parser.add_argument("--override_results", default=False, action="store_true")
parser.add_argument("--keep_body_frozen", default=False, action="store_true")
parser.add_argument("--add_data_augmentation", default=False)
parser.add_argument("--remove_duplicate_samples", type=bool, default=False)
parser.add_argument("--train_time", type=bool, default=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to add a training time argument here.


args = parser.parse_args()

Expand Down Expand Up @@ -134,6 +137,7 @@ def main():
model.model_body._modules["2"] = models.Normalize()

# Train on current split
st_time = time.time()
trainer = SetFitTrainer(
model=model,
train_dataset=train_data,
Expand All @@ -156,18 +160,19 @@ def main():
batch_size=args.batch_size,
)
else:
trainer.train()
trainer.train(remove_duplicate_samples=args.remove_duplicate_samples)
runtime = time.time() - st_time

# Evaluate the model on the test data
metrics = trainer.evaluate()
print(f"Metrics: {metrics}")

results = {"score": metrics[metric] * 100, "measure": metric}
if args.train_time:
results.update({"train_time": runtime})

with open(results_path, "w") as f_out:
json.dump(
{"score": metrics[metric] * 100, "measure": metric},
f_out,
sort_keys=True,
)
json.dump(results, f_out, sort_keys=True)


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,18 @@ def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix):
return pairs


def sentence_pairs_remove_duplicates(sentences: List[InputExample]):
key_pairs = set()
rm_duplicate_sentences = []
for s in sentences:
key = tuple(sorted(s.texts))
if key not in key_pairs:
key_pairs.add(key)
if key[0] != key[1]:
rm_duplicate_sentences.append(s)
return rm_duplicate_sentences
Comment on lines +686 to +695
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This differs slightly from my understanding of #258. In particular, this function:

  1. removes duplicate pairs
  2. removes pairs where the items are identical

I believe that #258 only mentions the former. Did you find the pairs where both items being identical to be a problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Tbh I wrote these changes ages ago but only got round to evaluating on public dataset/ PR yesterday.
For my use I was slashing as much data as possible to speed up training, this is a remnant from that I forgot I added !

I can run some tests again with 2. applied or not, and we can see what the analysis shows. My intuition would say labelling identical pairs wouldn't improve performance, but we can see..

But will hold off until have some resolution about it should be applied (based on your first comment @tomaarsen) 👍



class SKLearnWrapper:
def __init__(self, st_model=None, clf=None):
self.st_model = st_model
Expand Down
16 changes: 15 additions & 1 deletion src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from . import logging
from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna
from .modeling import SupConLoss, sentence_pairs_generation, sentence_pairs_generation_multilabel
from .modeling import (
SupConLoss,
sentence_pairs_generation,
sentence_pairs_generation_multilabel,
sentence_pairs_remove_duplicates,
)
from .utils import BestRun, default_hp_space_optuna


Expand Down Expand Up @@ -267,6 +272,7 @@ def train(
max_length: Optional[int] = None,
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,
show_progress_bar: bool = True,
remove_duplicate_samples: bool = False,
):
"""
Main training entry point.
Expand Down Expand Up @@ -294,6 +300,8 @@ def train(
The trial run or the hyperparameter dictionary for hyperparameter search.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Whether to show a bar that indicates training progress.
remove_duplicate_samples (`bool`, *optional*, defaults to `False`):
Removes duplicate samples from the training set, to improve training speeds on large datasets.
"""
set_seed(self.seed) # Seed must be set before instantiating the model when using model_init.

Expand Down Expand Up @@ -363,6 +371,12 @@ def train(
np.array(x_train), np.array(y_train), train_examples
)

if remove_duplicate_samples:
prv_train_len = len(train_examples)
train_examples = sentence_pairs_remove_duplicates(train_examples)
chg_train_perc = (prv_train_len - len(train_examples)) / prv_train_len
logger.info(f"{chg_train_perc:.1%} of training examples removed as duplicates.")

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = self.loss_class(self.model.model_body)
train_steps = len(train_dataloader) * num_epochs
Expand Down
23 changes: 22 additions & 1 deletion tests/test_modeling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import combinations
from unittest import TestCase

import numpy as np
Expand All @@ -9,7 +10,12 @@
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier

from setfit import SetFitHead, SetFitModel
from setfit.modeling import MODEL_HEAD_NAME, sentence_pairs_generation, sentence_pairs_generation_multilabel
from setfit.modeling import (
MODEL_HEAD_NAME,
sentence_pairs_generation,
sentence_pairs_generation_multilabel,
sentence_pairs_remove_duplicates,
)


def test_sentence_pairs_generation():
Expand Down Expand Up @@ -42,6 +48,21 @@ def test_sentence_pairs_generation_multilabel():
assert pairs[0].label == 1.0


def test_sentence_pairs_remove_duplicates():
sentences = np.array(["sent 1", "sent 2", "sent 3"])
labels = np.array(["label 1", "label 2", "label 3"])

pairs = []
n_iterations = 2

for _ in range(n_iterations):
pairs = sentence_pairs_generation(sentences, labels, pairs)
no_duplicate_pairs = sentence_pairs_remove_duplicates(pairs)

assert len(pairs) == len(sentences) * n_iterations * 2
assert len(no_duplicate_pairs) <= len(list(combinations(sentences, 2)))


def test_setfit_model_body():
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")

Expand Down