diff --git a/scripts/setfit/run_fewshot.py b/scripts/setfit/run_fewshot.py index 1248fddc..08f7023e 100644 --- a/scripts/setfit/run_fewshot.py +++ b/scripts/setfit/run_fewshot.py @@ -59,6 +59,7 @@ def parse_args(): parser.add_argument("--override_results", default=False, action="store_true") parser.add_argument("--keep_body_frozen", default=False, action="store_true") parser.add_argument("--add_data_augmentation", default=False) + parser.add_argument("--evaluation_strategy", default=False) args = parser.parse_args() @@ -148,6 +149,8 @@ def main(): num_epochs=args.num_epochs, num_iterations=args.num_iterations, ) + if not args.evaluation_strategy: + trainer.args.evaluation_strategy = "no" if args.classifier == "pytorch": trainer.freeze() trainer.train() diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 9df0ec4c..0662d2d3 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -17,7 +17,7 @@ import requests import torch from huggingface_hub import PyTorchModelHubMixin, hf_hub_download -from sentence_transformers import InputExample, SentenceTransformer, models +from sentence_transformers import SentenceTransformer, models from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier from sklearn.multioutput import ClassifierChain, MultiOutputClassifier @@ -683,77 +683,3 @@ def _from_pretrained( multi_target_strategy=multi_target_strategy, normalize_embeddings=normalize_embeddings, ) - - -def sentence_pairs_generation(sentences, labels, pairs): - # Initialize two empty lists to hold the (sentence, sentence) pairs and - # labels to indicate if a pair is positive or negative - - num_classes = np.unique(labels) - label_to_idx = {x: i for i, x in enumerate(num_classes)} - positive_idxs = [np.where(labels == i)[0] for i in num_classes] - negative_idxs = [np.where(labels != i)[0] for i in num_classes] - - for first_idx in range(len(sentences)): - current_sentence = sentences[first_idx] - label = labels[first_idx] - second_idx = np.random.choice(positive_idxs[label_to_idx[label]]) - positive_sentence = sentences[second_idx] - # Prepare a positive pair and update the sentences and labels - # lists, respectively - pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0)) - - third_idx = np.random.choice(negative_idxs[label_to_idx[label]]) - negative_sentence = sentences[third_idx] - # Prepare a negative pair of sentences and update our lists - pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0)) - # Return a 2-tuple of our sentence pairs and labels - return pairs - - -def sentence_pairs_generation_multilabel(sentences, labels, pairs): - # Initialize two empty lists to hold the (sentence, sentence) pairs and - # labels to indicate if a pair is positive or negative - for first_idx in range(len(sentences)): - current_sentence = sentences[first_idx] - sample_labels = np.where(labels[first_idx, :] == 1)[0] - if len(np.where(labels.dot(labels[first_idx, :].T) == 0)[0]) == 0: - continue - else: - for _label in sample_labels: - second_idx = np.random.choice(np.where(labels[:, _label] == 1)[0]) - positive_sentence = sentences[second_idx] - # Prepare a positive pair and update the sentences and labels - # lists, respectively - pairs.append(InputExample(texts=[current_sentence, positive_sentence], label=1.0)) - - # Search for sample that don't have a label in common with current - # sentence - negative_idx = np.where(labels.dot(labels[first_idx, :].T) == 0)[0] - negative_sentence = sentences[np.random.choice(negative_idx)] - # Prepare a negative pair of sentences and update our lists - pairs.append(InputExample(texts=[current_sentence, negative_sentence], label=0.0)) - # Return a 2-tuple of our sentence pairs and labels - return pairs - - -def sentence_pairs_generation_cos_sim(sentences, pairs, cos_sim_matrix): - # initialize two empty lists to hold the (sentence, sentence) pairs and - # labels to indicate if a pair is positive or negative - - idx = list(range(len(sentences))) - - for first_idx in range(len(sentences)): - current_sentence = sentences[first_idx] - second_idx = int(np.random.choice([x for x in idx if x != first_idx])) - - cos_sim = float(cos_sim_matrix[first_idx][second_idx]) - paired_sentence = sentences[second_idx] - pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim)) - - third_idx = np.random.choice([x for x in idx if x != first_idx]) - cos_sim = float(cos_sim_matrix[first_idx][third_idx]) - paired_sentence = sentences[third_idx] - pairs.append(InputExample(texts=[current_sentence, paired_sentence], label=cos_sim)) - - return pairs diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py new file mode 100644 index 00000000..1bea2e78 --- /dev/null +++ b/src/setfit/sampler.py @@ -0,0 +1,156 @@ +from itertools import zip_longest +from typing import Generator, Iterable, List, Optional + +import numpy as np +from sentence_transformers import InputExample +import torch +from torch.utils.data import IterableDataset + +from . import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: + """Generates shuffled pair combinations for any iterable data provided. + + Args: + iterable: data to generate pair combinations from + replacement: enable to include combinations of same samples, + equivalent to itertools.combinations_with_replacement + + Returns: + Generator of shuffled pairs as a tuple + """ + n = len(iterable) + k = 1 if not replacement else 0 + idxs = np.stack(np.triu_indices(n, k), axis=-1) + for i in np.random.RandomState(seed=42).permutation(len(idxs)): + _idx, idx = idxs[i, :] + yield iterable[_idx], iterable[idx] + + +class ContrastiveDataset(IterableDataset): + def __init__( + self, + examples: List[InputExample], + multilabel: bool, + num_iterations: Optional[None] = None, + sampling_strategy: str = "oversampling", + ) -> None: + """Generates positive and negative text pairs for contrastive learning. + + Args: + examples (InputExample): text and labels in a text transformer dataclass + multilabel: set to process "multilabel" labels array + sampling_strategy: "unique", "oversampling", or "undersampling" + num_iterations: if provided explicitly sets the number of pairs to be generated + where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs) + """ + super().__init__() + self.pos_index = 0 + self.neg_index = 0 + self.pos_pairs = [] + self.neg_pairs = [] + self.sentences = np.array([s.texts[0] for s in examples]) + self.labels = np.array([s.label for s in examples]) + self.sentence_labels = list(zip(self.sentences, self.labels)) + + if multilabel: + self.generate_multilabel_pairs() + else: + self.generate_pairs() + + if num_iterations is not None and num_iterations > 0: + self.len_pos_pairs = num_iterations * len(self.sentences) + self.len_neg_pairs = num_iterations * len(self.sentences) + + elif sampling_strategy == "unique": + self.len_pos_pairs = len(self.pos_pairs) + self.len_neg_pairs = len(self.neg_pairs) + + elif sampling_strategy == "undersampling": + self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) + self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) + + elif sampling_strategy == "oversampling": + self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) + self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) + + else: + raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") + + def generate_pairs(self) -> None: + for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): + if _label == label: + self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) + else: + self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) + + def generate_multilabel_pairs(self) -> None: + for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): + if any(np.logical_and(_label, label)): + # logical_and checks if labels are both set for each class + self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) + else: + self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) + + def get_positive_pairs(self) -> List[InputExample]: + pairs = [] + for _ in range(self.len_pos_pairs): + if self.pos_index >= len(self.pos_pairs): + self.pos_index = 0 + pairs.append(self.pos_pairs[self.pos_index]) + self.pos_index += 1 + return pairs + + def get_negative_pairs(self) -> List[InputExample]: + pairs = [] + for _ in range(self.len_neg_pairs): + if self.neg_index >= len(self.neg_pairs): + self.neg_index = 0 + pairs.append(self.neg_pairs[self.neg_index]) + self.neg_index += 1 + return pairs + + def __iter__(self): + for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()): + if pos_pair is not None: + yield pos_pair + if neg_pair is not None: + yield neg_pair + + def __len__(self) -> int: + return self.len_pos_pairs + self.len_neg_pairs + + +class ContrastiveDistillationDataset(ContrastiveDataset): + def __init__( + self, + examples: List[InputExample], + cos_sim_matrix: torch.Tensor, + num_iterations: Optional[None] = None, + sampling_strategy: str = "oversampling", + ) -> None: + self.cos_sim_matrix = cos_sim_matrix + super().__init__( + examples, + multilabel=False, + num_iterations=num_iterations, + sampling_strategy=sampling_strategy, + ) + # Internally we store all pairs in pos_pairs, regardless of sampling strategy. + # After all, without labels, there isn't much of a strategy. + self.sentence_labels = list(enumerate(self.sentences)) + + self.len_neg_pairs = 0 + if num_iterations is not None and num_iterations > 0: + self.len_pos_pairs = num_iterations * len(self.sentences) + else: + self.len_pos_pairs = len(self.pos_pairs) + + def generate_pairs(self) -> None: + for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels): + self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two])) diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index a4ec70f2..0abdd110 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -3,10 +3,9 @@ import time import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import evaluate -import numpy as np import torch from datasets import Dataset, DatasetDict from sentence_transformers import InputExample, SentenceTransformer, losses @@ -16,7 +15,7 @@ from torch import nn from torch.cuda.amp import autocast from torch.utils.data import DataLoader -from tqdm.autonotebook import tqdm, trange +from tqdm.autonotebook import tqdm from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer_callback import ( CallbackHandler, @@ -39,7 +38,7 @@ from . import logging from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna from .losses import SupConLoss -from .modeling import sentence_pairs_generation, sentence_pairs_generation_multilabel +from .sampler import ContrastiveDataset from .training_args import TrainingArguments from .utils import BestRun, default_hp_space_optuna @@ -368,17 +367,21 @@ def train( logger.info(f"Applying column mapping to {dataset_name} dataset") dataset = self._apply_column_mapping(dataset, self.column_mapping) - parameters.extend([dataset["text"], dataset["label"]]) + parameters.extend(self.dataset_to_parameters(dataset)) self.train_embeddings(*parameters, args=args) - self.train_classifier(*parameters[:2], args=args) + training_parameters = parameters[:len(parameters) // 2] if self.eval_dataset else parameters + self.train_classifier(*training_parameters, args=args) + + def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: + return [dataset["text"], dataset["label"]] def train_embeddings( self, x_train: List[str], - y_train: Union[List[int], List[List[int]]], - x_eval: List[str] = None, - y_eval: Union[List[int], List[List[int]]] = None, + y_train: Optional[Union[List[int], List[List[int]]]] = None, + x_eval: Optional[List[str]] = None, + y_eval: Optional[Union[List[int], List[List[int]]]] = None, args: Optional[TrainingArguments] = None, ) -> None: """ @@ -423,6 +426,8 @@ def get_dataloader( self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments ) -> Tuple[DataLoader, nn.Module, int]: # sentence-transformers adaptation + input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)] + if args.loss in [ losses.BatchAllTripletLoss, losses.BatchHardTripletLoss, @@ -430,9 +435,7 @@ def get_dataloader( losses.BatchHardSoftMarginTripletLoss, SupConLoss, ]: - examples = [InputExample(texts=[text], label=label) for text, label in zip(x, y)] - data_sampler = SentenceLabelDataset(examples, samples_per_label=args.samples_per_label) - + data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.samples_per_label) batch_size = min(args.embedding_batch_size, len(data_sampler)) dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True) @@ -450,17 +453,15 @@ def get_dataloader( margin=args.margin, ) else: - examples = [] - - for _ in trange(args.num_iterations, desc="Generating Training Pairs", disable=not args.show_progress_bar): - if self.model.multi_target_strategy is not None: - examples = sentence_pairs_generation_multilabel(np.array(x), np.array(y), examples) - else: - examples = sentence_pairs_generation(np.array(x), np.array(y), examples) - - batch_size = args.embedding_batch_size - dataloader = DataLoader(examples, shuffle=True, batch_size=batch_size) + data_sampler = ContrastiveDataset( + input_data, self.model.multi_target_strategy, args.num_iterations, args.sampling_strategy + ) + # shuffle_sampler = True can be dropped in for further 'randomising' + shuffle_sampler = True if args.sampling_strategy == "unique" else False + batch_size = min(args.embedding_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) loss = args.loss(self.model.model_body) + return dataloader, loss, batch_size def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None: @@ -505,8 +506,10 @@ def _train_sentence_transformer( self.state.epoch = 0 start_time = time.time() - # TODO: Add max_steps via args.max_steps here? - self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs + if args.max_steps > 0: + self.state.max_steps = args.max_steps + else: + self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.use_amp: diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index 5a27f585..2bc2d629 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -1,23 +1,19 @@ -import math import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union -import numpy as np +from datasets import Dataset import torch -from sentence_transformers import losses, util +from sentence_transformers import losses, util, InputExample +from torch import nn from torch.utils.data import DataLoader -from transformers.trainer_utils import set_seed from . import logging -from .modeling import sentence_pairs_generation_cos_sim +from .sampler import ContrastiveDistillationDataset from .trainer import Trainer from .training_args import TrainingArguments if TYPE_CHECKING: - import optuna - from datasets import Dataset - from .modeling import SetFitModel logging.set_verbosity_info() @@ -78,99 +74,27 @@ def __init__( self.teacher_model = teacher_model self.student_model = self.model - def train( - self, - args: Optional[TrainingArguments] = None, - trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, - **kwargs, - ) -> None: - """ - Main training entry point. - - Args: - args (`TrainingArguments`, *optional*): - Temporarily change the training arguments for this training call. - trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): - The trial run or the hyperparameter dictionary for hyperparameter search. - """ - if len(kwargs): - warnings.warn( - f"`{self.__class__.__name__}.train` does not accept keyword arguments anymore. " - f"Please provide training arguments via a `TrainingArguments` instance to the `{self.__class__.__name__}` " - f"initialisation or the `{self.__class__.__name__}.train` method.", - DeprecationWarning, - stacklevel=2, - ) - - args = args or self.args or TrainingArguments() - - set_seed(args.seed) # Seed must be set before instantiating the model when using model_init. - - if trial: # Trial and model initialization - self._hp_search_setup(trial) # sets trainer parameters and initializes model - - if self.train_dataset is None: - raise ValueError( - f"Training requires a `train_dataset` given to the `{self.__class__.__name__}` initialization." - ) - - self._validate_column_mapping(self.train_dataset) - train_dataset = self.train_dataset - if self.column_mapping is not None: - logger.info("Applying column mapping to training dataset") - train_dataset = self._apply_column_mapping(self.train_dataset, self.column_mapping) + def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: + return [dataset["text"]] - x_train: List[str] = train_dataset["text"] - - self.train_embeddings(x_train, args) - self.train_classifier(x_train, args) - - def train_embeddings( - self, - x_train: List[str], - args: Optional[TrainingArguments] = None, - ) -> None: - """ - Method to perform the embedding phase: finetuning the student its `SentenceTransformer` body. - - Args: - x_train (`List[str]`): A list of training sentences. - args (`TrainingArguments`, *optional*): - Temporarily change the training arguments for this training call. - """ - args = args or self.args or TrainingArguments() - - # **************** student training ********************* - x_train_embd_student = self.teacher_model.model_body.encode( - x_train, convert_to_tensor=self.teacher_model.has_differentiable_head + def get_dataloader( + self, x: List[str], y: Optional[Union[List[int], List[List[int]]]], args: TrainingArguments + ) -> Tuple[DataLoader, nn.Module, int]: + x_embd_student = self.teacher_model.model_body.encode( + x, convert_to_tensor=self.teacher_model.has_differentiable_head ) - cos_sim_matrix = util.cos_sim(x_train_embd_student, x_train_embd_student) - - train_examples = [] - for _ in range(args.num_iterations): - train_examples = sentence_pairs_generation_cos_sim(np.array(x_train), train_examples, cos_sim_matrix) - # **************** student training END ***************** - - batch_size = args.embedding_batch_size - train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) - train_loss = args.loss(self.student_model.model_body) - - total_train_steps = len(train_dataloader) * args.embedding_num_epochs - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_examples)}") - logger.info(f" Num epochs = {args.embedding_num_epochs}") - logger.info(f" Total optimization steps = {total_train_steps}") - logger.info(f" Total train batch size = {batch_size}") - - warmup_steps = math.ceil(total_train_steps * args.warmup_proportion) - self.student_model.model_body.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=args.embedding_num_epochs, - optimizer_params={"lr": args.body_embedding_learning_rate}, - warmup_steps=warmup_steps, - show_progress_bar=args.show_progress_bar, - use_amp=args.use_amp, + cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student) + + input_data = [InputExample(texts=[text]) for text in x] + data_sampler = ContrastiveDistillationDataset( + input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy ) + # shuffle_sampler = True can be dropped in for further 'randomising' + shuffle_sampler = True if args.sampling_strategy == "unique" else False + batch_size = min(args.embedding_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) + loss = args.loss(self.model.model_body) + return dataloader, loss, batch_size def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None: """ diff --git a/src/setfit/training_args.py b/src/setfit/training_args.py index 3ba751cb..73fd95f8 100644 --- a/src/setfit/training_args.py +++ b/src/setfit/training_args.py @@ -30,8 +30,26 @@ class TrainingArguments: Set the number of epochs the embedding and classifier training phases respectively, or set both if an integer is provided. Note that the number of epochs for the classifier is only used with a differentiable PyTorch head. - num_iterations (`int`, defaults to `20`): - The number of iterations to generate sentence pairs for. + max_steps (`int`, *optional*, defaults to `-1`): + If set to a positive number, the total number of training steps to perform. Overrides `num_epochs`. + The training may stop before reaching the set number of steps when all data is exhausted. + sampling_strategy (`str`, defaults to `"oversampling"`): + The sampling strategy of how to draw pairs in training. Possible values are: + + - `"oversampling"`: Draws even number of positive/ negative sentence pairs until every + sentence pair has been drawn. + - `"undersampling"`: Draws the minimum number of positive/ negative sentence pairs until + every sentence pair in the minority class has been drawn. + - `"unique"`: Draws every sentence pair combination (likely resulting in unbalanced + number of positive/ negative sentence pairs). + + The default is set to `"oversampling"`, ensuring all sentence pairs are drawn at least once. + Alternatively setting `num_iterations` will override this argument and determine the number + of generated sentence pairs. + num_iterations (`int`, *optional*): + If not set the `sampling_strategy` will determine the number of sentence pairs to generate. + This argument sets the number of iterations to generate sentence pairs for + and provides compatability with Setfit