Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
08892f6
sampler for refactor WIP
danstan5 Sep 14, 2023
429de0f
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
173f084
Run formatters
tomaarsen Oct 17, 2023
c23959a
Remove tests from modeling.py
tomaarsen Oct 17, 2023
567f1c9
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
67ddedc
Merge branch 'refactor_v2' of https://github.com/tomaarsen/setfit int…
tomaarsen Oct 17, 2023
d37ee09
sampler logic fix "unique" strategy
danstan5 Oct 19, 2023
0ef8837
add sampler tests (not complete)
danstan5 Oct 19, 2023
131aa26
add sampling_strategy into TrainingArguments
danstan5 Oct 19, 2023
c6c6228
Merge branch 'refactor-sampling' of https://github.com/danstan5/setfi…
danstan5 Oct 19, 2023
7431005
num_iterations removed from TrainingArguments
danstan5 Oct 19, 2023
3bd2acc
run_fewshot compatible with <v.1.0.0
danstan5 Oct 20, 2023
3d07e6c
Run make style
tomaarsen Oct 25, 2023
978daee
Use "no" as the default evaluation_strategy
tomaarsen Oct 25, 2023
2802a3f
Move num_iterations back to TrainingArguments
tomaarsen Oct 25, 2023
391f991
Fix broken trainer tests due to new default sampling
tomaarsen Oct 25, 2023
f8b7253
Use the Contrastive Dataset for Distillation
tomaarsen Oct 25, 2023
38e9607
Set the default logging steps at 50
tomaarsen Oct 25, 2023
4ead15d
Add max_steps argument to TrainingArguments
tomaarsen Oct 25, 2023
eb70336
Change max_steps conditional
tomaarsen Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/setfit/run_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def parse_args():
parser.add_argument("--override_results", default=False, action="store_true")
parser.add_argument("--keep_body_frozen", default=False, action="store_true")
parser.add_argument("--add_data_augmentation", default=False)
parser.add_argument("--evaluation_strategy", default=False)

args = parser.parse_args()

Expand Down Expand Up @@ -148,6 +149,8 @@ def main():
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
)
if not args.evaluation_strategy:
trainer.args.evaluation_strategy = "no"
if args.classifier == "pytorch":
trainer.freeze()
trainer.train()
Expand Down
76 changes: 1 addition & 75 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import requests
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from sentence_transformers import InputExample, SentenceTransformer, models
from sentence_transformers import SentenceTransformer, models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
Expand Down Expand Up @@ -683,77 +683,3 @@ def _from_pretrained(
multi_target_strategy=multi_target_strategy,
normalize_embeddings=normalize_embeddings,
)


def sentence_pairs_generation(sentences, labels, pairs):
# Initialize two empty lists to hold the (sentence, sentence) pairs and
# labels to indicate if a pair is positive or negative

num_classes = np.unique(labels)
label_to_idx = {x: i for i, x in enumerate(num_classes)}
positive_idxs = [np.where(labels == i)[0] for i in num_classes]
negative_idxs = [np.where(labels != i)[0] for i in num_classes]

for first_idx in range(len(sentences)):
current_sentence = sentences[first_idx]
label = labels[first_idx]
second_idx = np.random.choice(positive_idxs[label_to_idx[label]])
positive_sentence = sentences[second_idx]
# Prepare a positive pair and update the sentences and labels
# lists, respectively
pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0))

third_idx = np.random.choice(negative_idxs[label_to_idx[label]])
negative_sentence = sentences[third_idx]
# Prepare a negative pair of sentences and update our lists
pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0))
# Return a 2-tuple of our sentence pairs and labels
return pairs


def sentence_pairs_generation_multilabel(sentences, labels, pairs):
# Initialize two empty lists to hold the (sentence, sentence) pairs and
# labels to indicate if a pair is positive or negative
for first_idx in range(len(sentences)):
current_sentence = sentences[first_idx]
sample_labels = np.where(labels[first_idx, :] == 1)[0]
if len(np.where(labels.dot(labels[first_idx, :].T) == 0)[0]) == 0:
continue
else:
for _label in sample_labels:
second_idx = np.random.choice(np.where(labels[:, _label] == 1)[0])
positive_sentence = sentences[second_idx]
# Prepare a positive pair and update the sentences and labels
# lists, respectively
pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0))

# Search for sample that don't have a label in common with current
# sentence
negative_idx = np.where(labels.dot(labels[first_idx, :].T) == 0)[0]
negative_sentence = sentences[np.random.choice(negative_idx)]
# Prepare a negative pair of sentences and update our lists
pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0))
# Return a 2-tuple of our sentence pairs and labels
return pairs


def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix):
# initialize two empty lists to hold the (sentence, sentence) pairs and
# labels to indicate if a pair is positive or negative

idx = list(range(len(sentences)))

for first_idx in range(len(sentences)):
current_sentence = sentences[first_idx]
second_idx = int(np.random.choice([x for x in idx if x != first_idx]))

cos_sim = float(cos_sim_matrix[first_idx][second_idx])
paired_sentence = sentences[second_idx]
pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim))

third_idx = np.random.choice([x for x in idx if x != first_idx])
cos_sim = float(cos_sim_matrix[first_idx][third_idx])
paired_sentence = sentences[third_idx]
pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim))

return pairs
156 changes: 156 additions & 0 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from itertools import zip_longest
from typing import Generator, Iterable, List, Optional

import numpy as np
from sentence_transformers import InputExample
import torch
from torch.utils.data import IterableDataset

from . import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
"""Generates shuffled pair combinations for any iterable data provided.

Args:
iterable: data to generate pair combinations from
replacement: enable to include combinations of same samples,
equivalent to itertools.combinations_with_replacement

Returns:
Generator of shuffled pairs as a tuple
"""
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1)
for i in np.random.RandomState(seed=42).permutation(len(idxs)):
_idx, idx = idxs[i, :]
yield iterable[_idx], iterable[idx]


class ContrastiveDataset(IterableDataset):
def __init__(
self,
examples: List[InputExample],
multilabel: bool,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
) -> None:
"""Generates positive and negative text pairs for contrastive learning.

Args:
examples (InputExample): text and labels in a text transformer dataclass
multilabel: set to process "multilabel" labels array
sampling_strategy: "unique", "oversampling", or "undersampling"
num_iterations: if provided explicitly sets the number of pairs to be generated
where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs)
"""
super().__init__()
self.pos_index = 0
self.neg_index = 0
self.pos_pairs = []
self.neg_pairs = []
self.sentences = np.array([s.texts[0] for s in examples])
self.labels = np.array([s.label for s in examples])
self.sentence_labels = list(zip(self.sentences, self.labels))

if multilabel:
self.generate_multilabel_pairs()
else:
self.generate_pairs()

if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
self.len_neg_pairs = num_iterations * len(self.sentences)

elif sampling_strategy == "unique":
self.len_pos_pairs = len(self.pos_pairs)
self.len_neg_pairs = len(self.neg_pairs)

elif sampling_strategy == "undersampling":
self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs))

elif sampling_strategy == "oversampling":
self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs))

else:
raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.")

def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if _label == label:
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
else:
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))

def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if any(np.logical_and(_label, label)):
# logical_and checks if labels are both set for each class
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
else:
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))

def get_positive_pairs(self) -> List[InputExample]:
pairs = []
for _ in range(self.len_pos_pairs):
if self.pos_index >= len(self.pos_pairs):
self.pos_index = 0
pairs.append(self.pos_pairs[self.pos_index])
self.pos_index += 1
return pairs

def get_negative_pairs(self) -> List[InputExample]:
pairs = []
for _ in range(self.len_neg_pairs):
if self.neg_index >= len(self.neg_pairs):
self.neg_index = 0
pairs.append(self.neg_pairs[self.neg_index])
self.neg_index += 1
return pairs

def __iter__(self):
for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()):
if pos_pair is not None:
yield pos_pair
if neg_pair is not None:
yield neg_pair

def __len__(self) -> int:
return self.len_pos_pairs + self.len_neg_pairs


class ContrastiveDistillationDataset(ContrastiveDataset):
def __init__(
self,
examples: List[InputExample],
cos_sim_matrix: torch.Tensor,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
) -> None:
self.cos_sim_matrix = cos_sim_matrix
super().__init__(
examples,
multilabel=False,
num_iterations=num_iterations,
sampling_strategy=sampling_strategy,
)
# Internally we store all pairs in pos_pairs, regardless of sampling strategy.
# After all, without labels, there isn't much of a strategy.
self.sentence_labels = list(enumerate(self.sentences))

self.len_neg_pairs = 0
if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
else:
self.len_pos_pairs = len(self.pos_pairs)

def generate_pairs(self) -> None:
for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two]))
51 changes: 27 additions & 24 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import time
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import evaluate
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from sentence_transformers import InputExample, SentenceTransformer, losses
Expand All @@ -16,7 +15,7 @@
from torch import nn
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm, trange
from tqdm.autonotebook import tqdm
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer_callback import (
CallbackHandler,
Expand All @@ -39,7 +38,7 @@
from . import logging
from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna
from .losses import SupConLoss
from .modeling import sentence_pairs_generation, sentence_pairs_generation_multilabel
from .sampler import ContrastiveDataset
from .training_args import TrainingArguments
from .utils import BestRun, default_hp_space_optuna

Expand Down Expand Up @@ -368,17 +367,21 @@ def train(
logger.info(f"Applying column mapping to {dataset_name} dataset")
dataset = self._apply_column_mapping(dataset, self.column_mapping)

parameters.extend([dataset["text"], dataset["label"]])
parameters.extend(self.dataset_to_parameters(dataset))

self.train_embeddings(*parameters, args=args)
self.train_classifier(*parameters[:2], args=args)
training_parameters = parameters[:len(parameters) // 2] if self.eval_dataset else parameters
self.train_classifier(*training_parameters, args=args)

def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]:
return [dataset["text"], dataset["label"]]

def train_embeddings(
self,
x_train: List[str],
y_train: Union[List[int], List[List[int]]],
x_eval: List[str] = None,
y_eval: Union[List[int], List[List[int]]] = None,
y_train: Optional[Union[List[int], List[List[int]]]] = None,
x_eval: Optional[List[str]] = None,
y_eval: Optional[Union[List[int], List[List[int]]]] = None,
args: Optional[TrainingArguments] = None,
) -> None:
"""
Expand Down Expand Up @@ -423,16 +426,16 @@ def get_dataloader(
self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments
) -> Tuple[DataLoader, nn.Module, int]:
# sentence-transformers adaptation
input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)]

if args.loss in [
losses.BatchAllTripletLoss,
losses.BatchHardTripletLoss,
losses.BatchSemiHardTripletLoss,
losses.BatchHardSoftMarginTripletLoss,
SupConLoss,
]:
examples = [InputExample(texts=[text], label=label) for text, label in zip(x, y)]
data_sampler = SentenceLabelDataset(examples, samples_per_label=args.samples_per_label)

data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.samples_per_label)
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True)

Expand All @@ -450,17 +453,15 @@ def get_dataloader(
margin=args.margin,
)
else:
examples = []

for _ in trange(args.num_iterations, desc="Generating Training Pairs", disable=not args.show_progress_bar):
if self.model.multi_target_strategy is not None:
examples = sentence_pairs_generation_multilabel(np.array(x), np.array(y), examples)
else:
examples = sentence_pairs_generation(np.array(x), np.array(y), examples)

batch_size = args.embedding_batch_size
dataloader = DataLoader(examples, shuffle=True, batch_size=batch_size)
data_sampler = ContrastiveDataset(
input_data, self.model.multi_target_strategy, args.num_iterations, args.sampling_strategy
)
# shuffle_sampler = True can be dropped in for further 'randomising'
shuffle_sampler = True if args.sampling_strategy == "unique" else False
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
loss = args.loss(self.model.model_body)

return dataloader, loss, batch_size

def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None:
Expand Down Expand Up @@ -505,8 +506,10 @@ def _train_sentence_transformer(

self.state.epoch = 0
start_time = time.time()
# TODO: Add max_steps via args.max_steps here?
self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs
if args.max_steps > 0:
self.state.max_steps = args.max_steps
else:
self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

if args.use_amp:
Expand Down
Loading