Skip to content

Add training option unique_pairs#268

Closed
danstan5 wants to merge 15 commits intohuggingface:mainfrom
danstan5:training-unique-pairs
Closed

Add training option unique_pairs#268
danstan5 wants to merge 15 commits intohuggingface:mainfrom
danstan5:training-unique-pairs

Conversation

@danstan5
Copy link
Contributor

@danstan5 danstan5 commented Jan 12, 2023

Derived from #259

Changes:

  • Updates the existing sentence pair generator: rather than random sampling an iterator returns new +/- combinations.
  • Pairs are generated upto a max_pairs count (set by num_iterations). If all unique pairs have been generated and unique_pairs=True it will stop at this point (logging a warning).
  • Pos/ neg pairs are generated separately to ensure positive pairs are not under-sampled (in case of high no. of classes). When unique_pairs=True neg/ pos pair counts will be remain balanced until all unique positive pairs have been added, then only unique neg pairs will continue to be added. Note: if concerned about an imbalance of negative pairs (in case of high no. of classes) num_iterations can simply be decreased to balance this as desired.
  • Positive pair generator iterates through each class in turn extracting pairs. This ensures maximum variety of samples across different classes.

Overall these sampler changes would improve reproducibility + provides more representative sampling of the dataset, useful in cases of strong class imbalance or samples >> no. iterations training.

unique_pairs adds an option for efficient embedding training by maximising only the available data for quicker training (testing for confirmation on this to follow..)

Still todo:

  • Add tests
  • Update for multi-label
  • Run performance evaluation on test_set - runtime/ accuracy comparisons

@tomaarsen keen to get your thoughts on this so far 👍

Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to look at it more closely some other time, but the writeup looks good! I'm glad my concerns are noted.

I do have some small nitpicks already, see my other comments. These are causing test failures.

danstan5 and others added 3 commits January 12, 2023 17:39
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
@danstan5 danstan5 marked this pull request as draft January 13, 2023 10:22
@danstan5
Copy link
Contributor Author

  • I've rolled multi-label into the sentence_pairs_generation function. I think this is cleaner as can now see the logic to generate pairs is exactly the same in both cases. The only difference is just for formatting labels → pos/ neg pairs.
  • Original tests have been fixed so passing again + some new ones added for unique_pairs.

@tomaarsen tomaarsen added the enhancement New feature or request label Jan 14, 2023
@tomaarsen
Copy link
Member

tomaarsen commented Jan 14, 2023

First of all, these changes look great! I way prefer simply having

train_examples = sentence_pairs_generation(...)

rather than

train_examples = []
for _ in range(...):
    train_examples = sentence_pairs_generation(..., train_examples)

I ran some quick experiments locally, and I noticed an interesting quirk. Before I stress you out: this quirk also exists on main. If I increase num_epochs from 1 to 2, then it runs two epochs instead of one like expected, but it also increments the number of iterations per epoch.

So, whereas I might get 34 training steps times 1 epoch (so, 34 steps total) with num_epochs=1, I'll get 68 training steps times 2 epochs for num_epochs=2 (so, 136 steps total).

Simple script to reproduce the quirk
from datasets import load_dataset

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]

# Load a SetFit model from Hub
model: SetFitModel = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    column_mapping={"sentence": "text", "label": "label"},  # Map dataset columns to text/label expected by trainer
    num_epochs=1, # <- Change 1 to see the quirk
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

It's caused by this line:

steps_per_epoch=train_steps,

The train_steps is defined like so:
train_steps = len(train_dataloader) * num_epochs

Which explains the behaviour. In my opinion, the number of steps per epoch should be equivalent to the number of datapoints in the dataloader. That is exactly the behaviour if steps_per_epoch is not supplied. So, could you remove supplying the steps_per_epoch parameter in your PR? That way, we can prevent this quirk.

Experiments on sst2

As mentioned, I ran some experiments. In particular, I used unique_pairs=True and unique_pairs=False while keeping the number of total optimization steps roughly equivalent between the two scripts. I also applied the fix of the quirk mentioned above.

unique_pairs = False w. 3 epochs of 640 steps (40 batches) each (1920 total optimization steps) (Baseline)
from datasets import load_dataset

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]

# Load a SetFit model from Hub
model: SetFitModel = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    column_mapping={"sentence": "text", "label": "label"},  # Map dataset columns to text/label expected by trainer
    num_epochs=3,
    unique_pairs=False,
    seed=1,
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
unique_pairs = True w. 14 epochs of 136 steps (9 batches) each (1904 total optimization steps)
from datasets import load_dataset

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]

# Load a SetFit model from Hub
model: SetFitModel = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    column_mapping={"sentence": "text", "label": "label"},  # Map dataset columns to text/label expected by trainer
    num_epochs=14,
    unique_pairs=True,
    seed=1,
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

Results

Experiment Evaluation Accuracy
unique_pairs = False w. 3 epochs (Baseline) 86.697% (1.593)
unique_pairs = True w. 14 epochs (This PR) 87.461% (0.593)

Notes:

  1. Evaluation accuracy is displayed as mean and standard deviation between 12 executions.
  2. The seed on the SetFitTrainer is updated for every execution, but the seed on the sample_dataset calls stays the same. This way, all executions get the same input data.
  3. Each of the experiments took roughly 1 minute to train, give or take 3 seconds.

The results are almost statistically significant at a p-value of .074996 (assuming p < .05 for statistical significance), but certainly very promising. Especially the difference in standard deviation is very promising for this PR, as it displays the expected behaviour that non-random sampling produces more stable and consistent results. I'm quite excited about it.

I'm considering running more experiments. Feel free to run your own ones, too!

  • Tom Aarsen

@tomaarsen
Copy link
Member

tomaarsen commented Jan 14, 2023

Experiments on emotion

I ran some more experiments. In particular, I used unique_pairs=True and unique_pairs=False while keeping the number of total optimization exactly equivalent between the two scripts. I also applied the fix of the quirk mentioned in my previous comment.

These experiments differ from the previous experiments on sst2 in that each epoch ran the same number of examples, and that I used the same number of epochs between the experiments. Furthermore, sst2 is binary, while the emotion dataset contains 6 classes.

unique_pairs = False w. num_iterations=10 and 2 epochs of 960 steps (60 batches) each (1920 total optimization steps) (Baseline)
from datasets import load_dataset

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("SetFit/emotion")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]

for seed in range(12):
    # Load a SetFit model from Hub
    model: SetFitModel = SetFitModel.from_pretrained(
        "sentence-transformers/paraphrase-mpnet-base-v2",
    )

    # Create trainer
    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        num_epochs=2,
        num_iterations=10,
        unique_pairs=False,
        seed=seed
    )

    # Train and evaluate
    trainer.train()
    metrics = trainer.evaluate()

    print(metrics)
unique_pairs = True w. num_iterations=10 and 2 epochs of 960 steps (60 batches) each (1920 total optimization steps)
from datasets import load_dataset

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("SetFit/emotion")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]

for seed in range(12):
    # Load a SetFit model from Hub
    model: SetFitModel = SetFitModel.from_pretrained(
        "sentence-transformers/paraphrase-mpnet-base-v2",
    )

    # Create trainer
    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        num_epochs=2,
        num_iterations=10,
        unique_pairs=True,
        seed=seed
    )

    # Train and evaluate
    trainer.train()
    metrics = trainer.evaluate()

    print(metrics)

Results

Experiment Evaluation Accuracy
main branch w. num_iterations=10 and 2 epochs (Baseline) 44.850% (1.648)
unique_pairs = False w. num_iterations=10 and 2 epochs (Baseline) 42.946% (1.615)
unique_pairs = True w. num_iterations=10 and 2 epochs (This PR) 48.717% (1.261)

Notes:

  1. Evaluation accuracy is displayed as mean and standard deviation between 12 executions.
  2. The seed on the SetFitTrainer is updated for every execution, but the seed on the sample_dataset calls stays the same. This way, all executions get the same input data.
  3. Each of the experiments took roughly 1:45 minutes to train, give or take 3 seconds.
  4. I included the first Baseline test to see whether unique_pairs=False would be equivalent to main.

image

These results are very much statistically significant! These seem like serious improvements. And again, we can note an increase in stability as unique_pairs is set to True through the decreased standard deviation.

However, from these results it is also clear that the performance of unique_pairs=False seems degraded relative to the main branch. The two sets of results are statistically significantly different, which is somewhat worrisome. I was expecting the behaviour of the main branch and unique_pairs=False to be equivalent. Could you elaborate on the differences?

I'm very positive about these results.

  • Tom Aarsen

@tomaarsen
Copy link
Member

Dan and I had a discussion on this topic, which I will share for others reading through this interesting PR:

We theorize that the reduction in accuracy when unique_pairs=False is due to how the new pairs are generated in a natural order e.g. (0,0), (0,1), (0,2), (0,3), (0,4), (1,1), ..., causing to non-representative sampling. This could be fixed by randomising the order of the combinations generated, but we'll have to figure out how to do this in a memory-efficient generator approach.

Additionally, Dan tested equal representation of positive pairs versus negative pairs compared to the unique_pairs approach from this PR. The unique_pairs approach often results in much more negative pairs due to how there are usually less possible positive pairs. He discovered that the 50:50 approach is generally much better than when the training samples are skewed towards negative.
One possible solution here is to adapt the unique_pairs approach to compute num_sample = min(max_num_positive_pairs, max_num_negative_pairs, user_specified_max_samples), and sample that many positive and that many negative pairs.

Furthermore, I'm very interested in "freshly" generating samples per epoch. For example, if there are 20.000 possible negative pairs, and we sample 2.000 in a 20-epoch training scenario, then currently we use the same 2,000 samples for each of the 20 epochs. We might get an increase in performance if we can re-sample 2.000 from the 20.000 for every epoch again.

  • Tom Aarsen

@danstan5
Copy link
Contributor Author

danstan5 commented Jan 29, 2023

Non-representative sampling fix - shuffle_combinations function is used instead of itertools.combinations to shuffle generated pairs. This is based on numpy generated indicies & indexing and looks to be the fastest method (I've tested generating upto ~10^8 combinations with no issues).

Overall the new sampler improves reproduciblity and samples labelled datasets more evenly, with slight accuracy & reducing std dev as a result:

Emotion dataset (8 samples per class) Main Branch New sampler
7 iterations (672 pairs, 42 optimization steps) 44.9 (4.3) 45.8 (3.1)
13 iterations (1,248 pairs, 78 optimization steps) 47.4 (4.8) 47.9 (3.7)
26 iterations (2,496 pairs, 156 optimization steps) 47.9 (5.3) 50.1 (4.9)

Setting up unique_pairs - forcing use of equal pos/ neg samples shows better results but requires sampling "unique pairs" differently...

  • 1 option is generate all positive pairs then same no. of negative. This (named min_unique_pairs as a reference) can result in small no. of training steps with v fast runtimes but sub-optimum accuracy
    num_sample = min(unique_positive_pairs, unique_negative_pairs, user_specified_max_samples)
  • Another alternative here would set to max_unique_pairs - which would get all negative pair combinations, then oversample positive pairs to balance. This uses the available data better, but isn't strictly "unique pairs" anymore
    num_sample = min(max(unique_positive_pairs, unique_negative_pairs), user_specified_max_samples)

Results on the multi-class (imbalanced) test datasets (opt_steps are a proxy for runtime), default 1 epoch:

8 samples per class emotion opt_steps emotion accuracy sst5 opt_steps sst5 accuracy ag_news opt_steps ag_news accuracy
Unbalanced unique_pairs (all) 74 46.4 (3.9) 51 40.2 (2.8) 33 82.2 (3.1)
min_unique_pairs (undersample) 27 42.4 (3.9) 23 40.7 (2.5) 18 81.2 (2.8)
3x epoch min_unique_pairs (undersample) 81 47.7 (4.7) 69 43.4 (2.6) 54 83.0 (3.1)
max_unique_pairs (oversample) 120 50.9 (3.7) 80 43.9 (2.4) 48 83.5 (2.8)

Results show increasing no. of epochs improves accuracy in roughly similar runtime, but not as effective as more sampling. As noted in the last comment in this PR, generating samples for each epoch run would be nice addition, however epochs are handled by the Model class and torch dataloader, probably beyond the scope of this PR.
Within the current constraints I would recommend generating higher num iterations (for more samples) instead of increased num epochs to benefit from all of the labelled data available.

@danstan5
Copy link
Contributor Author

danstan5 commented Jan 29, 2023

  • shuffle sample pair combinations(ca18e69) shows min unique pairs logic
  • update unique_pairs logic(ced309c) shows the max unique pairs logic, with a warning added to note oversampling to balance samples

I'm lending towards the latter due to the higher accuracy and provides a natural limit to the no. of samples worth used in training. The num_iterations parameter can be reduced to speed up training further if needed (now with better sampling). @tomaarsen keen to get your thoughts on all the above

@danstan5 danstan5 marked this pull request as ready for review February 14, 2023 17:05
@danstan5 danstan5 requested a review from tomaarsen February 14, 2023 17:06
@tomaarsen
Copy link
Member

tomaarsen commented Feb 14, 2023

@danstan5
My apologies, I lost sight of this PR. Let me dive back in.

My first hunch is to prefer oversampling rather than undersampling, for two reasons:

  1. In few-shot it is crucial to take advantage of as much of our data as possible.
  2. We can set an upper bound on oversampling using a user-provided sample maximum, but we can't as easily set a lower bound on undersampling.

That said, unique_pairs will be a poor name in this case.
I also want to point out that I intend for the next release to be v1.0.0, which means that I'm open to more breaking changes if we believe that they are beneficial for users.

As for the merge conflicts, no need to worry about them.

  • Tom Aarsen

@danstan5
Copy link
Contributor Author

Nice one. Strong agree on the oversampling.. struggling for good names though (set_max_pairs or something?)

Sure, I think it's best placed in a more major release as it will ultimately change existing results if your re-train setfit models.
Also v1 is looking great so far! I will update this PR as needed after the refactor goes in. Let me know if have any further suggestions on it till then.

Thanks, Dan

@danstan5 danstan5 mentioned this pull request Feb 14, 2023
@tomaarsen
Copy link
Member

Sounds good. We'll have to think of a strong name.

Alternatively, we can consider more major changes: Only supporting the oversampling. After all, how much does that really differ from the current behaviour? We could also remove num_iterations, which is just a really unintuitive parameter for someone unfamiliar with SetFit, and implement a max_pairs parameter instead, which defaults to -1 for "no limit". That would result in num_sample = min(max(unique_positive_pairs, unique_negative_pairs), max_pairs), i.e. your oversampling solution.

I'm trying to think if there are any scenarios that would be possible now that won't be possible if we take that potential approach, and whether the performance would be degraded in any of the scenarios.

Another approach is to take it a step further and take new pairs at the start of every epoch. That said, if we indeed go with the oversampling solution, then either the positive or the negative pairs will simply be all of the unique pairs (assuming max_pairs=-1), and "freshly generating" them every epoch is just a waste of time. Only the oversampled section of the other set of pairs would differ between epochs with that approach. It may still be worth regenerating these every epoch for a more robust model, though.

I'm curious about your thoughts on this.

  • Tom Aarsen

@tomaarsen
Copy link
Member

tomaarsen commented Feb 17, 2023

Hello!

I've ran some more (much more thorough) experiments on this PR, as myself and others suspected that this PR might help with an ubiquitous overfitting problem on SetFit models. After all, better data gives better models, and conceptually this PR should result in better data.

Experiment details

I have ran three sets of experiments:

  • A: From the main branch at 7885128:

    python .\scripts\setfit\run_fewshot.py --is_dev_set=True --sample_sizes 2 4 8 16 --batch_size 32
  • B: From this PR at 3ce733f, using unique_pairs=False:

    python .\scripts\setfit\run_fewshot.py --is_dev_set=True --sample_sizes 2 4 8 16 --batch_size 32 --unique_pairs=False
  • C: From this PR at 3ce733f, using unique_pairs=True, i.e. the oversampling approach that you described:

    python .\scripts\setfit\run_fewshot.py --is_dev_set=True --sample_sizes 2 4 8 16 --batch_size 32 --unique_pairs=True

    With the following changes in place:

    Click to see the diff
    diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py
    index 0978061..d5f96a4 100644
    --- a/src/setfit/modeling.py
    +++ b/src/setfit/modeling.py
    @@ -771,6 +771,24 @@ def negative_sentence_pairs_generate(
         return pairs
     
     
    +def bounded_cycle(iterable, N):
    +    # bounded_cycle('ABCD', 10) --> A B C D A B C D A B
    +    n_yielded = 0
    +    saved = []
    +    for element in iterable:
    +        yield element
    +        n_yielded += 1
    +        if n_yielded >= N:
    +            return None
    +        saved.append(element)
    +    while saved:
    +        for element in saved:
    +            yield element
    +            n_yielded += 1
    +            if n_yielded >= N:
    +                return None
    +
    +
     def sentence_pairs_generation(
         sentences: np.ndarray,
         labels: np.ndarray,
    @@ -803,7 +821,7 @@ def sentence_pairs_generation(
                 logger.warning("** Oversampling positive pairs to balance contrastive training samples.")
                 positive_pairs += positive_sentence_pairs_generate(sentences, labels, extra_pairs, False, multilabel)
     
    -    return positive_pairs + negative_pairs
    +    return list(bounded_cycle(positive_pairs + negative_pairs, max_pairs * 2))
     
     
     def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix):

    These changes ensure that all of the experiments run with exactly the same amount of training samples, even if less pairs would normally be selected when unique_pairs is set to True (which would normally be compensated with a higher num_epochs).

Goals

I have also written a script to plot the results from the three separate results folders that were created from these three experiments. Ideally, we would like to see a situation where approach B is equivalent to approach A, i.e. if unique_pairs=False, then the existing performance of the main branch is preserved, and we would like to see approach C outperform the other two approaches.

Results

Click here to see the resulting plots

bbc-news
enron_spam
imdb
sst2
student-question-categories
subj
toxic_conversations
TREC-QC

Discussion

I am unable to see significant differences between the three approaches tested here, which is quite a shame. These results are quite surprising to me, as in my previous experiments [1] [2], I encountered improved performance when considering the same number of iterations.

I see three reasonable explanations:

  1. I have made a mistake either in this experiment or in both of my prior ones.
  2. Between acb27e8...3ce733f in this PR, the performance gain was somehow lost.
  3. Between the when I made my prior experiments, the main branch has been updated such that the baseline performance of the main branch has improved as well.

I'd like to further investigate option 2 this afternoon.

@danstan5 Please let me know if you spot a conceptual flaw in my experiments here. I may have overlooked something.

I'd like to notify @danielkorat and @MosheWasserb of these results, as I think they may benefit from the results of this experiment. I also think that @lewtun may be interested in this PR and its various experiments.

  • Tom Aarsen

@danstan5
Copy link
Contributor Author

danstan5 commented Feb 17, 2023

Hey @tomaarsen love these plots, can you share the script for them??

Your right there's nothing to statistically difference in these results, but this is because the default num_iterations is 20, the datasets are being oversampled anyway (so random sampling will look more suitably uniform). For example just checked on SST2 (8 samples) and unique_pairs -> 144 while num_iterations=20 -> 640.

The benefits of this PR only really only come at low iteration_num + unique_pairs provides a nice natural limit if not sure how to set num_iterations. So useful for very large datasets or in hyperparameter test/ active-training loops etc.

@tomaarsen
Copy link
Member

tomaarsen commented Feb 17, 2023

To add on to my previous comment, it does not seem like the performance degraded since earlier in this PR.

Experiment

I ran another experiment:

  • D: From this PR at acb27e8, using unique_pairs=True:

    python .\scripts\setfit\run_fewshot.py --is_dev_set=True --sample_sizes 2 4 8 16 --batch_size 32 --unique_pairs=True
    

    With the following changes in place:

    Click to see the diff
    diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py
    index 9e8a04d..7af18ee 100644
    --- a/src/setfit/modeling.py
    +++ b/src/setfit/modeling.py
    @@ -717,6 +717,24 @@ def negative_sentence_pairs_generate(
         return pairs
     
     
    +def bounded_cycle(iterable, N):
    +    # bounded_cycle('ABCD', 10) --> A B C D A B C D A B
    +    n_yielded = 0
    +    saved = []
    +    for element in iterable:
    +        yield element
    +        n_yielded += 1
    +        if n_yielded >= N:
    +            return None
    +        saved.append(element)
    +    while saved:
    +        for element in saved:
    +            yield element
    +            n_yielded += 1
    +            if n_yielded >= N:
    +                return None
    +
    +
     def sentence_pairs_generation(
         sentences: np.ndarray,
         labels: np.ndarray,
    @@ -742,7 +760,7 @@ def sentence_pairs_generation(
         max_neg_pairs = (num_iterations * len(sentences) * 2) - len(positive_pairs)
         negative_pairs = negative_sentence_pairs_generate(sentences, labels, max_neg_pairs, unique_pairs, multilabel)
     
    -    return positive_pairs + negative_pairs
    +    return list(bounded_cycle(positive_pairs + negative_pairs, num_iterations * len(sentences) * 2))
     
     
     def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix):

    As with C, these changes ensure the same number of iterations are used.

Results

Click to see the results of the experiment

bbc-news
enron_spam
sst2
subj

Discussion

As before, the performance seems nearly identical to the other results, despite now using the older version of this PR.

I'm curious to hear your thoughts,

  • Tom Aarsen

@tomaarsen
Copy link
Member

tomaarsen commented Feb 17, 2023

So useful for very large datasets or in hyperparameter test/ active-training loops etc.

Wonderful! I'll try to see if I can experiment a bit further in those cases. I was under the impression that this PR was most interesting in small dataset situations. Most of all I'm a bit saddened that I can't seem to reproduce my earlier results where this PR seemed to result in a ~3% accuracy jump from 45% to 48% on the emotion dataset.

As for the script I used, I'll attach it here now, but I'll also fix up the comments, imports etc. (e.g. some comments talk about create_summary_table.py). I'll try to push it to the main branch as well, soon, as I've found it fairly useful.

Click to see plotting script
import argparse
import json
import os
from pathlib import Path
from collections import defaultdict
from glob import glob
from typing import List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import string


def get_sample_sizes(path: str) -> List[str]:
    return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/*/train-*-0")}))


def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]:
    split_metrics = defaultdict(list)

    for sample_size in sample_sizes:
        result_jsons = sorted(glob(os.path.join(path, dataset, f"train-{sample_size}-*", "results.json")))
        for result_json in result_jsons:
            with open(result_json) as f:
                result_dict = json.load(f)

            metric_name = result_dict.get("measure", "N/A")
            split_metrics[sample_size].append(result_dict["score"])

    return metric_name, split_metrics


def plot_summary_comparison(paths: List[str]) -> None:
    dataset_to_df = defaultdict(pd.DataFrame)
    dataset_to_metric = {}
    for path_index, path in enumerate(paths):
        ds_to_metric, this_dataset_to_df = get_summary_df(path)
        for dataset, df in this_dataset_to_df.items():
            df["path_index"] = path_index
            dataset_to_df[dataset] = pd.concat((dataset_to_df[dataset], df))
        dataset_to_metric = dataset_to_metric | ds_to_metric

    # Prepare folder for storing figures
    image_dir = Path("scripts") / "images"
    image_dir.mkdir(exist_ok=True)
    new_version = max([int(path.name[2:]) for path in image_dir.glob("v_*/") if path.name[2:].isdigit()], default=0) + 1
    output_dir = image_dir / f"v_{new_version}"
    output_dir.mkdir()

    for dataset, df in dataset_to_df.items():
        columns = [column for column in df.columns if not column.startswith("path")]
        fig, axes = plt.subplots(ncols=len(columns), sharey=True)
        for column_index, column in enumerate(columns):
            ax = axes[column_index]

            # Set the y label only for the first column
            if column_index == 0:
                ax.set_ylabel(dataset_to_metric[dataset])

            # Set positions to 0, 0.25, ..., one position per boxplot
            # This places the boxplots closer together
            n_boxplots = len(df["path_index"].unique())
            allotted_box_width = 0.2
            positions = [allotted_box_width * i for i in range(n_boxplots)]
            ax.set_xlim(-allotted_box_width * 0.75, allotted_box_width * (n_boxplots - 0.25))

            # ax.set_xticks(range(n_boxplots), rotation="vertical")
            df[[column, "path_index"]].groupby("path_index", sort=True).boxplot(
                subplots=False, ax=ax, column=column, positions=positions
            )

            k_shot = column.split("-")[-1]
            ax.set_xlabel(f"{k_shot}-shot")
            if n_boxplots > 1:
                # If there are multiple boxplots, override the labels at the bottom generated by pandas
                if n_boxplots <= 26:
                    ax.set_xticklabels(string.ascii_uppercase[:n_boxplots])
                else:
                    ax.set_xticklabels(range(n_boxplots))
            else:
                # Otherwise, just remove the xticks
                ax.tick_params(labelbottom=False)

        if n_boxplots > 1:
            fig.suptitle(f"Comparison between various baselines on the {dataset}\ndataset under various $K$-shot conditions")
        else:
            fig.suptitle(f"Results on the {dataset} dataset under various $K$-shot conditions")
        fig.tight_layout()
        plt.savefig(str(output_dir / dataset))


def get_summary_df(path: str) -> None:
    """Given per-split results, creates a summary table of all datasets,
    with average metrics and standard deviations.

    Args:
        path: path to per-split results: either `scripts/{method_name}/{results}/{model_name}`,
            or `final_results/{method_name}/{model_name}.tar.gz`
    """

    sample_sizes = get_sample_sizes(path)
    header_row = ["dataset", "measure"]
    for sample_size in sample_sizes:
        header_row.append(f"{sample_size}_avg")
        header_row.append(f"{sample_size}_std")

    dataset_to_metric = {}
    dataset_to_df = {}
    for dataset in next(os.walk(path))[1]:
        metric_name, split_metrics = get_formatted_ds_metrics(path, dataset, sample_sizes)
        dataset_df = pd.DataFrame(split_metrics.values(), index=[f"{dataset}-{key}" for key in split_metrics]).T
        dataset_to_metric[dataset] = metric_name
        dataset_to_df[dataset] = dataset_df
    return dataset_to_metric, dataset_to_df


def main() -> None:
    parser = argparse.ArgumentParser()

    parser.add_argument("--paths", nargs="+", type=str)
    args = parser.parse_args()

    plot_summary_comparison(args.paths)


if __name__ == "__main__":
    main()

When calling this script, I provide it with the multiple paths like so:

python scripts/plot_summary_comparison.py --paths archived\results_main_16_02_78851287\paraphrase-mpnet-base-v2-CosineSimilarityLoss-logistic_regression-iterations_20-batch_32 scripts\setfit\results\paraphrase-mpnet-base-v2-CosineSimilarityLoss-logistic_regression-iterations_20-batch_32
  • Tom Aarsen

@danstan5
Copy link
Contributor Author

danstan5 commented Feb 17, 2023

More commits!

Furthermore, I'm very interested in "freshly" generating samples per epoch

These changes are a solution to the above:

  • Moves contrastive pairs sampler logic from .modelling into a .sampler module
  • Use a ConstrastiveDataset class that inherits from torch IterableDataset instead of functions. This matches exactly how sentence-transformers create training samples, see here calling the non setfit sampler uses same logic as a setfit sampler.
  • 47b20cb Having this in a class, the state of generated samples picked up again between epochs, so it continues generating new samples between epochs if needed.

Breaking changes:

Impact:

  • Expect for the shuffling ^(yet to be reviewed), this is no different when epoch =1.
    In cases where not all unique pairs were generated in epoch 1 they will now continue to be draw new samples in later epochs.
  • This is quite a major code change for quite a rare case. Again, the benefits are only seem at very low num_iterations and epochs > 1. I personally like the democratisation with the sentence-transformers "samplers" though, which of course SetFit is very closely linked to.

Thoughts as ever appreciated @tomaarsen on this. I can get some of the code a bit clearer, add Distillation sampler into similar class, update tests etc. but keen to get thoughts first

@tomaarsen
Copy link
Member

I'm a great fan of how this is looking! I don't have a lot of time to dive into it now (I'll look on Monday), but is my initial assumption correct that with these changes in place, increasing num_epochs would become essentially equivalent to increasing num_iterations?
I think we make the entire project a lot more intuitive if num_iterations is deprecated.

tomaarsen added a commit to tomaarsen/setfit that referenced this pull request Feb 20, 2023
tomaarsen added a commit that referenced this pull request Feb 24, 2023
* Add comparison plotting script

As used in comments in #268

* Apply automatic formatting

* Write the command used to plot the graphs to the output directory
@tomaarsen
Copy link
Member

tomaarsen commented Feb 24, 2023

I ran some tests with the new ContrastiveDataset: As expected, increasing num_epochs is now exactly equivalent to increasing num_iterations. To be concrete, the following situations sample exactly the same pairs:

  • num_epochs = 1, num_iterations = 20
  • num_epochs = 2, num_iterations = 10
  • num_epochs = 20, num_iterations = 1

This affords us to deprecate one of them (i.e. num_iterations), which may help simplify the code further, too.
An interesting implementation may be to create two generators: for positive pairs and for negative pairs. These generators can be "refreshed" once they "run out", and they mean that we don't need a max_pairs argument to the methods, as we can control that whenever we yield from the generator. I think we can then also simplify by only offering unique pairs per epoch, which means the unique_pairs parameter could perhaps be removed as well.
If beneficial, we could introduce an even_sampling or sample_evenly argument that decides whether the number of positive and negative pairs should be equal. Alternatively, we can provide e.g. sampling_strategy which accepts either "oversampling", "undersampling" or e.g. "unbalanced". I quite like that, actually.

I'd love to know your thoughts on this.

Edit: I'm unsure about the time efficiency consequences of using a generator, i.e. of having to do a bunch of CPU calls every iteration rather than once in bulk at the start of each epoch.

  • Tom Aarsen

@vahuja4
Copy link

vahuja4 commented Jul 20, 2023

The experiments here look very interesting and promising. @tomaarsen - would you please tell me if these changes are going to be merged soon?

@tomaarsen
Copy link
Member

This has been superseded by tomaarsen#5, which has been merged into v1.0.0-pre. I'll close this accordingly. Thanks for the great work @danstan5!

  • Tom Aarsen

@tomaarsen tomaarsen closed this Nov 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants