Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/setfit/run_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def parse_args():
parser.add_argument("--override_results", default=False, action="store_true")
parser.add_argument("--keep_body_frozen", default=False, action="store_true")
parser.add_argument("--add_data_augmentation", default=False)
parser.add_argument("--unique_pairs", type=bool, default=False)

args = parser.parse_args()

Expand Down Expand Up @@ -147,6 +148,7 @@ def main():
batch_size=args.batch_size,
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
unique_pairs=args.unique_pairs,
)
if args.classifier == "pytorch":
trainer.freeze()
Expand Down
8 changes: 3 additions & 5 deletions scripts/setfit/run_fewshot_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import LiteralString

from setfit.data import SAMPLE_SIZES
from setfit.modeling import SetFitBaseModel, SKLearnWrapper, sentence_pairs_generation_multilabel
from setfit.modeling import SetFitBaseModel, SKLearnWrapper, sentence_pairs_generation
from setfit.utils import DEV_DATASET_TO_METRIC, LOSS_NAME_TO_CLASS, TEST_DATASET_TO_METRIC, load_data_splits_multilabel


Expand All @@ -37,7 +37,7 @@ def parse_args():
default=["go_emotions"],
)
parser.add_argument("--sample_sizes", type=int, nargs="+", default=SAMPLE_SIZES)
parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument("--num_epochs", type=int, default=20) # should this be iterations?
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--max_seq_length", type=int, default=256)
parser.add_argument(
Expand Down Expand Up @@ -124,9 +124,7 @@ def train(self, data: Dataset) -> SKLearnWrapper:

# sentence-transformers adaptation
batch_size = self.args.batch_size
train_examples = []
for _ in range(self.args.num_epochs):
train_examples = sentence_pairs_generation_multilabel(np.array(x_train), y_train, train_examples)
train_examples = sentence_pairs_generation(np.array(x_train), y_train, self.args.num_epochs, multilabel=True)

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = self.loss_class(self.model)
Expand Down
51 changes: 0 additions & 51 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,57 +669,6 @@ def forward(self, sentence_features, labels=None, mask=None):
return loss


def sentence_pairs_generation(sentences, labels, pairs):
# Initialize two empty lists to hold the (sentence, sentence) pairs and
# labels to indicate if a pair is positive or negative

num_classes = np.unique(labels)
idx = [np.where(labels == i)[0] for i in num_classes]

for first_idx in range(len(sentences)):
current_sentence = sentences[first_idx]
label = labels[first_idx]
second_idx = np.random.choice(idx[np.where(num_classes == label)[0][0]])
positive_sentence = sentences[second_idx]
# Prepare a positive pair and update the sentences and labels
# lists, respectively
pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0))

negative_idx = np.where(labels != label)[0]
negative_sentence = sentences[np.random.choice(negative_idx)]
# Prepare a negative pair of sentences and update our lists
pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0))
# Return a 2-tuple of our sentence pairs and labels
return pairs


def sentence_pairs_generation_multilabel(sentences, labels, pairs):
# Initialize two empty lists to hold the (sentence, sentence) pairs and
# labels to indicate if a pair is positive or negative
for first_idx in range(len(sentences)):
current_sentence = sentences[first_idx]
sample_labels = np.where(labels[first_idx, :] == 1)[0]
if len(np.where(labels.dot(labels[first_idx, :].T) == 0)[0]) == 0:
continue
else:

for _label in sample_labels:
second_idx = np.random.choice(np.where(labels[:, _label] == 1)[0])
positive_sentence = sentences[second_idx]
# Prepare a positive pair and update the sentences and labels
# lists, respectively
pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0))

# Search for sample that don't have a label in common with current
# sentence
negative_idx = np.where(labels.dot(labels[first_idx, :].T) == 0)[0]
negative_sentence = sentences[np.random.choice(negative_idx)]
# Prepare a negative pair of sentences and update our lists
pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0))
# Return a 2-tuple of our sentence pairs and labels
return pairs


def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix):
# initialize two empty lists to hold the (sentence, sentence) pairs and
# labels to indicate if a pair is positive or negative
Expand Down
178 changes: 178 additions & 0 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from itertools import zip_longest
from typing import Generator, Iterable, List

import numpy as np
from sentence_transformers import InputExample
from torch.utils.data import IterableDataset

from . import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


def shuffle_combinations(iterable: Iterable, replacement: bool = False) -> Generator:
"""Generates shuffled pair combinations for any iterable data provided.

Args:
iterable: data to generate pair combinations from
replacement: enable to include combinations of same samples,
equivalent to itertools.combinations_with_replacement

Returns:
Generator of shuffled pairs as a tuple
"""
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1)
for i in np.random.RandomState(seed=42).permutation(len(idxs)):
_idx, idx = idxs[i, :]
yield iterable[_idx], iterable[idx]


class ConstrastiveDataset(IterableDataset):
def __init__(self, examples, num_iterations, unique_pairs, multilabel):
"""Generates positive and negative sentence pairs for contrastive learning.

Args:
examples (InputExample): text and labels in a sentence transformer dataclass
num_iterations: sets the number of contastive sample pairs to be generated
unique_pairs: when true will only return upto the number of unique sentence
pair combinations avaliable
multilabel: set to process "multilabel" labels array
"""
super().__init__()

self.pos_index = 0
self.neg_index = 0
self.multilabel = multilabel
self.unique_pairs = unique_pairs
self.sentences = np.array([s.texts[0] for s in examples])
self.labels = np.array([s.label for s in examples])
self.max_pairs = num_iterations * len(examples)

# generate dataset so __len__ method can be used
self.generate_sentence_pairs()

def generate_sentence_pairs(self) -> None:
"""Generates a new batch of positive and negative sentence pairs.

Note: pos_index/ neg_index keep the position of pairs being generated.
"""
positive_pairs = self.positive_sentence_pairs(self.max_pairs, self.unique_pairs)
negative_pairs = self.negative_sentence_pairs(self.max_pairs, self.unique_pairs)

if self.unique_pairs:
extra_pairs = abs(len(positive_pairs) - len(negative_pairs))

if len(positive_pairs) > len(negative_pairs):
logger.warning(
f"** Oversampling ({extra_pairs:,}) negative pairs to balance contrastive training samples."
)
negative_pairs += self.negative_sentence_pairs(max_pairs=extra_pairs, unique_pairs=False)

if len(negative_pairs) > len(positive_pairs):
logger.warning(
f"** Oversampling ({extra_pairs:,}) positive pairs to balance contrastive training samples."
)
positive_pairs += self.positive_sentence_pairs(max_pairs=extra_pairs, unique_pairs=False)

self._num_pairs = len(positive_pairs) + len(negative_pairs)
self.positive_pairs = positive_pairs
self.negative_pairs = negative_pairs

def positive_sentence_pairs(self, max_pairs: int, unique_pairs: bool = False) -> List[InputExample]:
"""Generates all unique or upto a max no. of combinations of positive sentence pairs.

Samples positive combinations of sentences (without replacement) and maximises
sampling of different classes in the pairs being generated.

Args:
max_pairs: returns when this many pairs are generated
unique_pairs: if true will return sentences if all unique combinations,
before max_pairs count is reached

Returns:
List of positive sentence pairs (upto the no. of unique_pairs or max_pairs)
"""
labels = self.labels
sentences = self.sentences
multilabel = self.multilabel
pairs = []

if multilabel:
label_ids = np.arange(labels.shape[1]) # based on index = 1,0
else:
label_ids = np.unique(labels) # based on class int
while True:
index = 0
positive_combinators = []
for _label in label_ids:
if multilabel:
label_sentences = sentences[np.where(labels[:, _label] == 1)[0]]
else:
label_sentences = sentences[np.where(labels == _label)]
positive_combinators.append(shuffle_combinations(label_sentences, replacement=True))

for pos_pairs in zip_longest(*positive_combinators):
for pos_pair in pos_pairs:
if pos_pair is not None:
index += 1
if index > self.pos_index:
pairs.append(InputExample(texts=[*pos_pair], label=1.0))
if len(pairs) == max_pairs:
self.pos_index = index
return pairs
self.pos_index = 0
if unique_pairs:
break
logger.warning(f"** All ({len(pairs):,}) positive unique pairs generated")
return pairs

def negative_sentence_pairs(self, max_pairs: int, unique_pairs: bool = False) -> List[InputExample]:
"""Generates all or upto a max sample no. of negative combinations.

Randomly samples negative combinations of sentences (without replacement).

Args:
max_pairs: returns when this many pairs are generated
unique_pairs: if true will return sentences if all unique combinations,
before max_pairs count is reached

Returns:
List of negative sentence pairs (upto the no. of unique_pairs or max_pairs)
"""
multilabel = self.multilabel
pairs = []

sentence_labels = list(zip(self.sentences, self.labels))
while True:
index = 0
for (_sentence, _label), (sentence, label) in shuffle_combinations(sentence_labels):
# logical_and checks if labels are both set for each class
if (multilabel and not any(np.logical_and(_label, label))) or (not multilabel and _label != label):
index += 1
if index > self.neg_index:
pairs.append(InputExample(texts=[_sentence, sentence], label=0.0))
if len(pairs) == max_pairs:
self.neg_index = index
return pairs
self.neg_index = 0
if unique_pairs:
break
logger.warning(f"** All ({len(pairs):,}) negative unique pairs generated")
return pairs

def __iter__(self):
for pos_pair, neg_pair in zip(self.positive_pairs, self.negative_pairs):
# generates one of each in turn
yield pos_pair
yield neg_pair

if self.pos_index or self.neg_index:
# not all pairs combinations sampled so continues from last index
self.generate_sentence_pairs()

def __len__(self):
return self._num_pairs
Loading