diff --git a/docs/cross_encoder/loss_overview.md b/docs/cross_encoder/loss_overview.md index e5258ba11..3e126e22f 100644 --- a/docs/cross_encoder/loss_overview.md +++ b/docs/cross_encoder/loss_overview.md @@ -17,15 +17,15 @@ Loss functions play a critical role in the performance of your fine-tuned Cross - ``(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets`` with labels of 0 for negative and 1 for positive with ``output_format="labeled-list"``, ``` -| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions | -|---------------------------------------------------|------------------------------------------|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | `CrossEntropyLoss` | -| `(anchor, positive) pairs` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | -| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | `BinaryCrossEntropyLoss` | -| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | `BinaryCrossEntropyLoss` | -| `(anchor, positive, negative) triplets` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | -| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | -| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` | `LambdaLoss`
`ListNetLoss`
`ListMLELoss`
`PListMLELoss` | +| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions | +|---------------------------------------------------|------------------------------------------|-------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | `CrossEntropyLoss` | +| `(anchor, positive) pairs` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | +| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | `BinaryCrossEntropyLoss` | +| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | `BinaryCrossEntropyLoss` | +| `(anchor, positive, negative) triplets` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | +| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | +| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` |
  1. `LambdaLoss`
  2. `ListNetLoss`
  3. `PListMLELoss`
  4. `RankNetLoss`
  5. `ListMLELoss`
| ## Distillation These loss functions are specifically designed to be used when distilling the knowledge from one model into another. @@ -40,9 +40,9 @@ For example, when finetuning a small model to behave more like a larger & strong In practice, not all loss functions get used equally often. The most common scenarios are: * `(sentence_A, sentence_B) pairs` with `float similarity score` or `1 if positive, 0 if negative`: BinaryCrossEntropyLoss is a traditional option that remains very challenging to outperform. -* `(anchor, positive) pairs` without any labels: - * MultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train SentenceTransformer models, and the loss is also applicable for CrossEncoder models. This data is often relatively cheap to obtain, and mine_hard_negatives with output_format="n-tuple" or output_format="triplet" can easily be used to add hard negatives for this loss. CachedMultipleNegativesRankingLoss is often used to keep the memory usage in check. - * Together with mine_hard_negatives with output_format="labeled-list", LambdaLoss is frequently used for learning-to-rank tasks. +* `(anchor, positive) pairs` without any labels: combined with mine_hard_negatives + * with output_format="labeled-list", then LambdaLoss is frequently used for learning-to-rank tasks. + * with output_format="labeled-pair", then BinaryCrossEntropyLoss remains a strong option. ## Custom Loss Functions diff --git a/docs/package_reference/cross_encoder/losses.md b/docs/package_reference/cross_encoder/losses.md index 9c11024c1..3a5d008f1 100644 --- a/docs/package_reference/cross_encoder/losses.md +++ b/docs/package_reference/cross_encoder/losses.md @@ -61,3 +61,8 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui ```{eval-rst} .. autoclass:: sentence_transformers.cross_encoder.losses.MarginMSELoss ``` + +## RankNetLoss +```{eval-rst} +.. autoclass:: sentence_transformers.cross_encoder.losses.RankNetLoss +``` \ No newline at end of file diff --git a/examples/cross_encoder/training/ms_marco/README.md b/examples/cross_encoder/training/ms_marco/README.md index 417cfba90..8f74237d7 100644 --- a/examples/cross_encoder/training/ms_marco/README.md +++ b/examples/cross_encoder/training/ms_marco/README.md @@ -56,8 +56,12 @@ In all scripts, the model is evaluated on subsets of `MS MARCO `PListMLELoss` > `ListNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary. +Out of these training scripts, I suspect that **[training_ms_marco_lambda_preprocessed.py](training_ms_marco_lambda_preprocessed.py)**, **[training_ms_marco_lambda_hard_neg.py](training_ms_marco_lambda_hard_neg.py)** or **[training_ms_marco_bce_preprocessed.py](training_ms_marco_bce_preprocessed.py)** produces the strongest model, as anecdotally `LambdaLoss` and `BinaryCrossEntropyLoss` are quite strong. It seems that `LambdaLoss` > `ListNetLoss` > `PListMLELoss` > `RankNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary. Additionally, you can also train with Distillation. See [Cross Encoder > Training Examples > Distillation](../distillation/README.md) for more details. diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py index 8be70e426..b01f768e1 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -21,6 +22,8 @@ def main(): num_epochs = 1 # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py index d87fbd67d..26c56d42a 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset, load_from_disk from sentence_transformers.cross_encoder import CrossEncoder @@ -21,6 +22,8 @@ def main(): dataset_size = 2_000_000 # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py index 62bc75c92..d1a4a9985 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py @@ -2,6 +2,7 @@ import traceback from collections import defaultdict +import torch from datasets import load_dataset from torch import nn @@ -25,6 +26,8 @@ def main(): num_epochs = 1 # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py index 3d13dede4..3712a1036 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py @@ -2,6 +2,7 @@ import traceback from datetime import datetime +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -32,6 +33,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py index c7af97ca3..8b7e93d6f 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py @@ -2,6 +2,7 @@ import traceback from datetime import datetime +import torch from datasets import Dataset, concatenate_datasets, load_dataset from sentence_transformers import CrossEncoder, SentenceTransformer @@ -33,6 +34,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py index d2e728253..e2c7056e9 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py @@ -2,6 +2,7 @@ import traceback from datetime import datetime +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -32,6 +33,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py index ba579a2e2..7f6e1c7c9 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -30,6 +31,8 @@ def main(): respect_input_order = True # Whether to respect the original order of documents # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py index 570a00e7f..bc920142d 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -29,6 +30,8 @@ def main(): max_docs = None # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py index a93c3d3f2..dede60052 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -30,6 +31,8 @@ def main(): respect_input_order = True # Whether to respect the original order of documents # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py new file mode 100644 index 000000000..8fed1dc47 --- /dev/null +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import logging +import traceback +from datetime import datetime + +import torch +from datasets import load_dataset + +from sentence_transformers.cross_encoder import CrossEncoder +from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator +from sentence_transformers.cross_encoder.losses import RankNetLoss +from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer +from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments + + +def main(): + model_name = "microsoft/MiniLM-L12-H384-uncased" + + # Set the log level to INFO to get more information + logging.basicConfig( + format="%(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + ) + # train_batch_size and eval_batch_size inform the size of the batches, while mini_batch_size is used by the loss + # to subdivide the batch into smaller parts. This mini_batch_size largely informs the training speed and memory usage. + # Keep in mind that the loss does not process `train_batch_size` pairs, but `train_batch_size * num_docs` pairs. + train_batch_size = 16 + eval_batch_size = 16 + mini_batch_size = 16 + num_epochs = 1 + max_docs = None + + dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) + model = CrossEncoder(model_name, num_labels=1) + print("Model max length:", model.max_length) + print("Model num labels:", model.num_labels) + + # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco + logging.info("Read train dataset") + dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train") + + def listwise_mapper(batch, max_docs: int | None = 10): + processed_queries = [] + processed_docs = [] + processed_labels = [] + + for query, passages_info in zip(batch["query"], batch["passages"]): + # Extract passages and labels + passages = passages_info["passage_text"] + labels = passages_info["is_selected"] + + # Pair passages with labels and sort descending by label (positives first) + paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True) + + # Separate back to passages and labels + sorted_passages, sorted_labels = zip(*paired) if paired else ([], []) + + # Filter queries without any positive labels + if max(sorted_labels) < 1.0: + continue + + # Truncate to max_docs + if max_docs is not None: + sorted_passages = list(sorted_passages[:max_docs]) + sorted_labels = list(sorted_labels[:max_docs]) + + processed_queries.append(query) + processed_docs.append(sorted_passages) + processed_labels.append(sorted_labels) + + return { + "query": processed_queries, + "docs": processed_docs, + "labels": processed_labels, + } + + # Create a dataset with a "query" column with strings, a "docs" column with lists of strings, + # and a "labels" column with lists of floats + dataset = dataset.map( + lambda batch: listwise_mapper(batch=batch, max_docs=max_docs), + batched=True, + remove_columns=dataset.column_names, + desc="Processing listwise samples", + ) + + dataset = dataset.train_test_split(test_size=1_000, seed=12) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + logging.info(train_dataset) + + # 3. Define our training loss + loss = RankNetLoss( + model=model, + mini_batch_size=mini_batch_size, + ) + + # 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking + evaluator = CrossEncoderNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=eval_batch_size) + evaluator(model) + + # 5. Define the training arguments + short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] + run_name = f"reranker-msmarco-v1.1-{short_model_name}-ranknetloss" + args = CrossEncoderTrainingArguments( + # Required parameter: + output_dir=f"models/{run_name}_{dt}", + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=eval_batch_size, + learning_rate=2e-5, + warmup_ratio=0.1, + fp16=False, # Set to False if you get an error that your GPU can't run on FP16 + bf16=True, # Set to True if you have a GPU that supports BF16 + # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + load_best_model_at_end=True, + metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10", + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + save_total_limit=2, + logging_steps=250, + logging_first_step=True, + run_name=run_name, # Will be used in W&B if `wandb` is installed + seed=12, + ) + + # 6. Create the trainer & start training + trainer = CrossEncoderTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=loss, + evaluator=evaluator, + ) + trainer.train() + + # 7. Evaluate the final model, useful to include these in the model card + evaluator(model) + + # 8. Save the final model + final_output_dir = f"models/{run_name}_{dt}/final" + model.save_pretrained(final_output_dir) + + # 9. (Optional) save the model to the Hugging Face Hub! + # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first + try: + model.push_to_hub(run_name) + except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{run_name}')`." + ) + + +if __name__ == "__main__": + main() diff --git a/sentence_transformers/cross_encoder/losses/LambdaLoss.py b/sentence_transformers/cross_encoder/losses/LambdaLoss.py index d7e78b76d..9257b70e6 100644 --- a/sentence_transformers/cross_encoder/losses/LambdaLoss.py +++ b/sentence_transformers/cross_encoder/losses/LambdaLoss.py @@ -351,7 +351,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{wang2018lambdaloss, - title={The lambdaloss framework for ranking metric optimization}, + title={The LambdaLoss Framework for Ranking Metric Optimization}, author={Wang, Xuanhui and Li, Cheng and Golbandi, Nadav and Bendersky, Michael and Najork, Marc}, booktitle={Proceedings of the 27th ACM international conference on information and knowledge management}, pages={1313--1322}, diff --git a/sentence_transformers/cross_encoder/losses/ListMLELoss.py b/sentence_transformers/cross_encoder/losses/ListMLELoss.py index deb516809..542dfe960 100644 --- a/sentence_transformers/cross_encoder/losses/ListMLELoss.py +++ b/sentence_transformers/cross_encoder/losses/ListMLELoss.py @@ -15,7 +15,7 @@ def __init__( respect_input_order: bool = True, ) -> None: """ - This loss function implements the ListMLE learnin to rank algorithm, which uses a list-wise + This loss function implements the ListMLE learning to rank algorithm, which uses a list-wise approach based on maximum likelihood estimation of permutations. It maximizes the likelihood of the permutation induced by the ground truth labels. @@ -116,10 +116,10 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{10.1145/1390156.1390306, - title = {Listwise approach to learning to rank: theory and algorithm}, + title = {Listwise Approach to Learning to Rank - Theory and Algorithm}, author = {Xia, Fen and Liu, Tie-Yan and Wang, Jue and Zhang, Wensheng and Li, Hang}, booktitle = {Proceedings of the 25th International Conference on Machine Learning}, - pages = {1192–1199}, + pages = {1192-1199}, year = {2008}, url = {https://doi.org/10.1145/1390156.1390306}, } diff --git a/sentence_transformers/cross_encoder/losses/ListNetLoss.py b/sentence_transformers/cross_encoder/losses/ListNetLoss.py index 374cc7867..923d6bad5 100644 --- a/sentence_transformers/cross_encoder/losses/ListNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/ListNetLoss.py @@ -188,7 +188,7 @@ def get_config_dict(self) -> dict[str, float]: def citation(self) -> str: return """ @inproceedings{cao2007learning, - title={Learning to rank: from pairwise approach to listwise approach}, + title={Learning to Rank: From Pairwise Approach to Listwise Approach}, author={Cao, Zhe and Qin, Tao and Liu, Tie-Yan and Tsai, Ming-Feng and Li, Hang}, booktitle={Proceedings of the 24th international conference on Machine learning}, pages={129--136}, diff --git a/sentence_transformers/cross_encoder/losses/PListMLELoss.py b/sentence_transformers/cross_encoder/losses/PListMLELoss.py index fb722c9bb..bfa69ef63 100644 --- a/sentence_transformers/cross_encoder/losses/PListMLELoss.py +++ b/sentence_transformers/cross_encoder/losses/PListMLELoss.py @@ -284,7 +284,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{lan2014position, - title={Position-Aware ListMLE: A Sequential Learning Process for Ranking.}, + title={Position-Aware ListMLE: A Sequential Learning Process for Ranking}, author={Lan, Yanyan and Zhu, Yadong and Guo, Jiafeng and Niu, Shuzi and Cheng, Xueqi}, booktitle={UAI}, volume={14}, diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py new file mode 100644 index 000000000..fd24566ff --- /dev/null +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import Literal + +from torch import nn + +from sentence_transformers.cross_encoder import CrossEncoder +from sentence_transformers.cross_encoder.losses import LambdaLoss, NoWeightingScheme + + +class RankNetLoss(LambdaLoss): + def __init__( + self, + model: CrossEncoder, + k: int | None = None, + sigma: float = 1.0, + eps: float = 1e-10, + reduction_log: Literal["natural", "binary"] = "binary", + activation_fct: nn.Module | None = nn.Identity(), + mini_batch_size: int | None = None, + ) -> None: + """ + RankNet loss implementation for learning to rank. This loss function implements the RankNet algorithm, + which learns a ranking function by optimizing pairwise document comparisons using a neural network. + The implementation is optimized to handle padded documents efficiently by only processing valid + documents during model inference. + + Args: + model (CrossEncoder): CrossEncoder model to be trained + sigma (float): Score difference weight used in sigmoid (default: 1.0) + eps (float): Small constant for numerical stability (default: 1e-10) + activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the + loss. Defaults to :class:`~torch.nn.Identity`. + mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant + impact on the memory consumption and speed of the training process. Three cases are possible: + - If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size. + - If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``. + - If ``mini_batch_size`` is <= 0, the entire batch is processed at once. + Defaults to None. + + References: + - Learning to Rank using Gradient Descent: https://icml.cc/Conferences/2015/wp-content/uploads/2015/06/icml_ranking.pdf + - `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_ + + Requirements: + 1. Query with multiple documents (pairwise approach) + 2. Documents must have relevance scores/labels. Both binary and continuous labels are supported. + + Inputs: + +----------------------------------------+--------------------------------+-------------------------------+ + | Texts | Labels | Number of Model Output Labels | + +========================================+================================+===============================+ + | (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 | + +----------------------------------------+--------------------------------+-------------------------------+ + + Recommendations: + - Use :class:`~sentence_transformers.util.mine_hard_negatives` with ``output_format="labeled-list"`` + to convert question-answer pairs to the required input format with hard negatives. + + Relations: + - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` can be seen as an extension of this loss + where each score pair is weighted. Alternatively, this loss can be seen as a special case of the + :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` without a weighting scheme. + - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` with its default NDCGLoss2++ weighting + scheme anecdotally performs better than the other losses with the same input format. + + Example: + :: + + from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses + from datasets import Dataset + + model = CrossEncoder("microsoft/mpnet-base") + train_dataset = Dataset.from_dict({ + "query": ["What are pandas?", "What is the capital of France?"], + "docs": [ + ["Pandas are a kind of bear.", "Pandas are kind of like fish."], + ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."], + ], + "labels": [[1, 0], [1, 1, 0]], + }) + loss = losses.RankNetLoss(model) + + trainer = CrossEncoderTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() + """ + super().__init__( + model=model, + weighting_scheme=NoWeightingScheme(), + k=k, + sigma=sigma, + eps=eps, + reduction_log=reduction_log, + activation_fct=activation_fct, + mini_batch_size=mini_batch_size, + ) + + def get_config_dict(self) -> dict[str, float | int | str | None]: + """ + Get configuration parameters for this loss function. + + Returns: + Dictionary containing the configuration parameters + """ + config = super().get_config_dict() + del config["weighting_scheme"] + return config + + @property + def citation(self) -> str: + return """ +@inproceedings{burges2005learning, + title={Learning to Rank using Gradient Descent}, + author={Burges, Chris and Shaked, Tal and Renshaw, Erin and Lazier, Ari and Deeds, Matt and Hamilton, Nicole and Hullender, Greg}, + booktitle={Proceedings of the 22nd international conference on Machine learning}, + pages={89--96}, + year={2005} +} +""" diff --git a/sentence_transformers/cross_encoder/losses/__init__.py b/sentence_transformers/cross_encoder/losses/__init__.py index b77ad9bb0..a048d5265 100644 --- a/sentence_transformers/cross_encoder/losses/__init__.py +++ b/sentence_transformers/cross_encoder/losses/__init__.py @@ -17,6 +17,7 @@ from .MSELoss import MSELoss from .MultipleNegativesRankingLoss import MultipleNegativesRankingLoss from .PListMLELoss import PListMLELambdaWeight, PListMLELoss +from .RankNetLoss import RankNetLoss __all__ = [ "BinaryCrossEntropyLoss", @@ -35,4 +36,5 @@ "NDCGLoss2Scheme", "LambdaRankScheme", "NDCGLoss2PPScheme", + "RankNetLoss", ]