diff --git a/scripts/setfit/run_fewshot.py b/scripts/setfit/run_fewshot.py index 5963bda4..5e4aea8e 100644 --- a/scripts/setfit/run_fewshot.py +++ b/scripts/setfit/run_fewshot.py @@ -3,6 +3,7 @@ import os import pathlib import sys +import time import warnings from collections import Counter from shutil import copyfile @@ -55,6 +56,8 @@ def parse_args(): parser.add_argument("--override_results", default=False, action="store_true") parser.add_argument("--keep_body_frozen", default=False, action="store_true") parser.add_argument("--add_data_augmentation", default=False) + parser.add_argument("--remove_duplicate_samples", type=bool, default=False) + parser.add_argument("--train_time", type=bool, default=False) args = parser.parse_args() @@ -134,6 +137,7 @@ def main(): model.model_body._modules["2"] = models.Normalize() # Train on current split + st_time = time.time() trainer = SetFitTrainer( model=model, train_dataset=train_data, @@ -156,18 +160,19 @@ def main(): batch_size=args.batch_size, ) else: - trainer.train() + trainer.train(remove_duplicate_samples=args.remove_duplicate_samples) + runtime = time.time() - st_time # Evaluate the model on the test data metrics = trainer.evaluate() print(f"Metrics: {metrics}") + results = {"score": metrics[metric] * 100, "measure": metric} + if args.train_time: + results.update({"train_time": runtime}) + with open(results_path, "w") as f_out: - json.dump( - {"score": metrics[metric] * 100, "measure": metric}, - f_out, - sort_keys=True, - ) + json.dump(results, f_out, sort_keys=True) if __name__ == "__main__": diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 7bd9143a..f976e653 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -683,6 +683,18 @@ def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix): return pairs +def sentence_pairs_remove_duplicates(sentences: List[InputExample]): + key_pairs = set() + rm_duplicate_sentences = [] + for s in sentences: + key = tuple(sorted(s.texts)) + if key not in key_pairs: + key_pairs.add(key) + if key[0] != key[1]: + rm_duplicate_sentences.append(s) + return rm_duplicate_sentences + + class SKLearnWrapper: def __init__(self, st_model=None, clf=None): self.st_model = st_model diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 26dbc3af..c3cd8106 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -12,7 +12,12 @@ from . import logging from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna -from .modeling import SupConLoss, sentence_pairs_generation, sentence_pairs_generation_multilabel +from .modeling import ( + SupConLoss, + sentence_pairs_generation, + sentence_pairs_generation_multilabel, + sentence_pairs_remove_duplicates, +) from .utils import BestRun, default_hp_space_optuna @@ -267,6 +272,7 @@ def train( max_length: Optional[int] = None, trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, show_progress_bar: bool = True, + remove_duplicate_samples: bool = False, ): """ Main training entry point. @@ -294,6 +300,8 @@ def train( The trial run or the hyperparameter dictionary for hyperparameter search. show_progress_bar (`bool`, *optional*, defaults to `True`): Whether to show a bar that indicates training progress. + remove_duplicate_samples (`bool`, *optional*, defaults to `False`): + Removes duplicate samples from the training set, to improve training speeds on large datasets. """ set_seed(self.seed) # Seed must be set before instantiating the model when using model_init. @@ -363,6 +371,12 @@ def train( np.array(x_train), np.array(y_train), train_examples ) + if remove_duplicate_samples: + prv_train_len = len(train_examples) + train_examples = sentence_pairs_remove_duplicates(train_examples) + chg_train_perc = (prv_train_len - len(train_examples)) / prv_train_len + logger.info(f"{chg_train_perc:.1%} of training examples removed as duplicates.") + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) train_loss = self.loss_class(self.model.model_body) train_steps = len(train_dataloader) * num_epochs diff --git a/tests/test_modeling.py b/tests/test_modeling.py index f16b0452..a170909a 100644 --- a/tests/test_modeling.py +++ b/tests/test_modeling.py @@ -1,3 +1,4 @@ +from itertools import combinations from unittest import TestCase import numpy as np @@ -9,7 +10,12 @@ from sklearn.multioutput import ClassifierChain, MultiOutputClassifier from setfit import SetFitHead, SetFitModel -from setfit.modeling import MODEL_HEAD_NAME, sentence_pairs_generation, sentence_pairs_generation_multilabel +from setfit.modeling import ( + MODEL_HEAD_NAME, + sentence_pairs_generation, + sentence_pairs_generation_multilabel, + sentence_pairs_remove_duplicates, +) def test_sentence_pairs_generation(): @@ -42,6 +48,21 @@ def test_sentence_pairs_generation_multilabel(): assert pairs[0].label == 1.0 +def test_sentence_pairs_remove_duplicates(): + sentences = np.array(["sent 1", "sent 2", "sent 3"]) + labels = np.array(["label 1", "label 2", "label 3"]) + + pairs = [] + n_iterations = 2 + + for _ in range(n_iterations): + pairs = sentence_pairs_generation(sentences, labels, pairs) + no_duplicate_pairs = sentence_pairs_remove_duplicates(pairs) + + assert len(pairs) == len(sentences) * n_iterations * 2 + assert len(no_duplicate_pairs) <= len(list(combinations(sentences, 2))) + + def test_setfit_model_body(): model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")