diff --git a/docs/cross_encoder/loss_overview.md b/docs/cross_encoder/loss_overview.md
index e5258ba11..3e126e22f 100644
--- a/docs/cross_encoder/loss_overview.md
+++ b/docs/cross_encoder/loss_overview.md
@@ -17,15 +17,15 @@ Loss functions play a critical role in the performance of your fine-tuned Cross
- ``(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets`` with labels of 0 for negative and 1 for positive with ``output_format="labeled-list"``,
```
-| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions |
-|---------------------------------------------------|------------------------------------------|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | `CrossEntropyLoss` |
-| `(anchor, positive) pairs` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` |
-| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | `BinaryCrossEntropyLoss` |
-| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | `BinaryCrossEntropyLoss` |
-| `(anchor, positive, negative) triplets` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` |
-| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` |
-| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` | `LambdaLoss`
`ListNetLoss`
`ListMLELoss`
`PListMLELoss` |
+| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions |
+|---------------------------------------------------|------------------------------------------|-------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | `CrossEntropyLoss` |
+| `(anchor, positive) pairs` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` |
+| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | `BinaryCrossEntropyLoss` |
+| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | `BinaryCrossEntropyLoss` |
+| `(anchor, positive, negative) triplets` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` |
+| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` |
+| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` |
- `LambdaLoss`
- `ListNetLoss`
- `PListMLELoss`
- `RankNetLoss`
- `ListMLELoss`
|
## Distillation
These loss functions are specifically designed to be used when distilling the knowledge from one model into another.
@@ -40,9 +40,9 @@ For example, when finetuning a small model to behave more like a larger & strong
In practice, not all loss functions get used equally often. The most common scenarios are:
* `(sentence_A, sentence_B) pairs` with `float similarity score` or `1 if positive, 0 if negative`: BinaryCrossEntropyLoss is a traditional option that remains very challenging to outperform.
-* `(anchor, positive) pairs` without any labels:
- * MultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train SentenceTransformer models, and the loss is also applicable for CrossEncoder models. This data is often relatively cheap to obtain, and mine_hard_negatives with output_format="n-tuple" or output_format="triplet" can easily be used to add hard negatives for this loss. CachedMultipleNegativesRankingLoss is often used to keep the memory usage in check.
- * Together with mine_hard_negatives with output_format="labeled-list", LambdaLoss is frequently used for learning-to-rank tasks.
+* `(anchor, positive) pairs` without any labels: combined with mine_hard_negatives
+ * with output_format="labeled-list", then LambdaLoss is frequently used for learning-to-rank tasks.
+ * with output_format="labeled-pair", then BinaryCrossEntropyLoss remains a strong option.
## Custom Loss Functions
diff --git a/docs/package_reference/cross_encoder/losses.md b/docs/package_reference/cross_encoder/losses.md
index 9c11024c1..3a5d008f1 100644
--- a/docs/package_reference/cross_encoder/losses.md
+++ b/docs/package_reference/cross_encoder/losses.md
@@ -61,3 +61,8 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui
```{eval-rst}
.. autoclass:: sentence_transformers.cross_encoder.losses.MarginMSELoss
```
+
+## RankNetLoss
+```{eval-rst}
+.. autoclass:: sentence_transformers.cross_encoder.losses.RankNetLoss
+```
\ No newline at end of file
diff --git a/examples/cross_encoder/training/ms_marco/README.md b/examples/cross_encoder/training/ms_marco/README.md
index 417cfba90..8f74237d7 100644
--- a/examples/cross_encoder/training/ms_marco/README.md
+++ b/examples/cross_encoder/training/ms_marco/README.md
@@ -56,8 +56,12 @@ In all scripts, the model is evaluated on subsets of `MS MARCO `PListMLELoss` > `ListNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary.
+Out of these training scripts, I suspect that **[training_ms_marco_lambda_preprocessed.py](training_ms_marco_lambda_preprocessed.py)**, **[training_ms_marco_lambda_hard_neg.py](training_ms_marco_lambda_hard_neg.py)** or **[training_ms_marco_bce_preprocessed.py](training_ms_marco_bce_preprocessed.py)** produces the strongest model, as anecdotally `LambdaLoss` and `BinaryCrossEntropyLoss` are quite strong. It seems that `LambdaLoss` > `ListNetLoss` > `PListMLELoss` > `RankNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary.
Additionally, you can also train with Distillation. See [Cross Encoder > Training Examples > Distillation](../distillation/README.md) for more details.
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py
index 8be70e426..b01f768e1 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py
@@ -1,6 +1,7 @@
import logging
import traceback
+import torch
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
@@ -21,6 +22,8 @@ def main():
num_epochs = 1
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py
index d87fbd67d..26c56d42a 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py
@@ -1,6 +1,7 @@
import logging
import traceback
+import torch
from datasets import load_dataset, load_from_disk
from sentence_transformers.cross_encoder import CrossEncoder
@@ -21,6 +22,8 @@ def main():
dataset_size = 2_000_000
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py
index 62bc75c92..d1a4a9985 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py
@@ -2,6 +2,7 @@
import traceback
from collections import defaultdict
+import torch
from datasets import load_dataset
from torch import nn
@@ -25,6 +26,8 @@ def main():
num_epochs = 1
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py
index 3d13dede4..3712a1036 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py
@@ -2,6 +2,7 @@
import traceback
from datetime import datetime
+import torch
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
@@ -32,6 +33,8 @@ def main():
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py
index c7af97ca3..8b7e93d6f 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py
@@ -2,6 +2,7 @@
import traceback
from datetime import datetime
+import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from sentence_transformers import CrossEncoder, SentenceTransformer
@@ -33,6 +34,8 @@ def main():
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py
index d2e728253..e2c7056e9 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py
@@ -2,6 +2,7 @@
import traceback
from datetime import datetime
+import torch
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
@@ -32,6 +33,8 @@ def main():
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py
index ba579a2e2..7f6e1c7c9 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py
@@ -1,6 +1,7 @@
import logging
import traceback
+import torch
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
@@ -30,6 +31,8 @@ def main():
respect_input_order = True # Whether to respect the original order of documents
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py
index 570a00e7f..bc920142d 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py
@@ -1,6 +1,7 @@
import logging
import traceback
+import torch
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
@@ -29,6 +30,8 @@ def main():
max_docs = None
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py
index a93c3d3f2..dede60052 100644
--- a/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py
@@ -1,6 +1,7 @@
import logging
import traceback
+import torch
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
@@ -30,6 +31,8 @@ def main():
respect_input_order = True # Whether to respect the original order of documents
# 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py
new file mode 100644
index 000000000..8fed1dc47
--- /dev/null
+++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py
@@ -0,0 +1,167 @@
+from __future__ import annotations
+
+import logging
+import traceback
+from datetime import datetime
+
+import torch
+from datasets import load_dataset
+
+from sentence_transformers.cross_encoder import CrossEncoder
+from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
+from sentence_transformers.cross_encoder.losses import RankNetLoss
+from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
+from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
+
+
+def main():
+ model_name = "microsoft/MiniLM-L12-H384-uncased"
+
+ # Set the log level to INFO to get more information
+ logging.basicConfig(
+ format="%(asctime)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO,
+ )
+ # train_batch_size and eval_batch_size inform the size of the batches, while mini_batch_size is used by the loss
+ # to subdivide the batch into smaller parts. This mini_batch_size largely informs the training speed and memory usage.
+ # Keep in mind that the loss does not process `train_batch_size` pairs, but `train_batch_size * num_docs` pairs.
+ train_batch_size = 16
+ eval_batch_size = 16
+ mini_batch_size = 16
+ num_epochs = 1
+ max_docs = None
+
+ dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+
+ # 1. Define our CrossEncoder model
+ # Set the seed so the new classifier weights are identical in subsequent runs
+ torch.manual_seed(12)
+ model = CrossEncoder(model_name, num_labels=1)
+ print("Model max length:", model.max_length)
+ print("Model num labels:", model.num_labels)
+
+ # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco
+ logging.info("Read train dataset")
+ dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")
+
+ def listwise_mapper(batch, max_docs: int | None = 10):
+ processed_queries = []
+ processed_docs = []
+ processed_labels = []
+
+ for query, passages_info in zip(batch["query"], batch["passages"]):
+ # Extract passages and labels
+ passages = passages_info["passage_text"]
+ labels = passages_info["is_selected"]
+
+ # Pair passages with labels and sort descending by label (positives first)
+ paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)
+
+ # Separate back to passages and labels
+ sorted_passages, sorted_labels = zip(*paired) if paired else ([], [])
+
+ # Filter queries without any positive labels
+ if max(sorted_labels) < 1.0:
+ continue
+
+ # Truncate to max_docs
+ if max_docs is not None:
+ sorted_passages = list(sorted_passages[:max_docs])
+ sorted_labels = list(sorted_labels[:max_docs])
+
+ processed_queries.append(query)
+ processed_docs.append(sorted_passages)
+ processed_labels.append(sorted_labels)
+
+ return {
+ "query": processed_queries,
+ "docs": processed_docs,
+ "labels": processed_labels,
+ }
+
+ # Create a dataset with a "query" column with strings, a "docs" column with lists of strings,
+ # and a "labels" column with lists of floats
+ dataset = dataset.map(
+ lambda batch: listwise_mapper(batch=batch, max_docs=max_docs),
+ batched=True,
+ remove_columns=dataset.column_names,
+ desc="Processing listwise samples",
+ )
+
+ dataset = dataset.train_test_split(test_size=1_000, seed=12)
+ train_dataset = dataset["train"]
+ eval_dataset = dataset["test"]
+ logging.info(train_dataset)
+
+ # 3. Define our training loss
+ loss = RankNetLoss(
+ model=model,
+ mini_batch_size=mini_batch_size,
+ )
+
+ # 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking
+ evaluator = CrossEncoderNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=eval_batch_size)
+ evaluator(model)
+
+ # 5. Define the training arguments
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
+ run_name = f"reranker-msmarco-v1.1-{short_model_name}-ranknetloss"
+ args = CrossEncoderTrainingArguments(
+ # Required parameter:
+ output_dir=f"models/{run_name}_{dt}",
+ # Optional training parameters:
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=train_batch_size,
+ per_device_eval_batch_size=eval_batch_size,
+ learning_rate=2e-5,
+ warmup_ratio=0.1,
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
+ bf16=True, # Set to True if you have a GPU that supports BF16
+ # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
+ load_best_model_at_end=True,
+ metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10",
+ # Optional tracking/debugging parameters:
+ eval_strategy="steps",
+ eval_steps=500,
+ save_strategy="steps",
+ save_steps=500,
+ save_total_limit=2,
+ logging_steps=250,
+ logging_first_step=True,
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
+ seed=12,
+ )
+
+ # 6. Create the trainer & start training
+ trainer = CrossEncoderTrainer(
+ model=model,
+ args=args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ loss=loss,
+ evaluator=evaluator,
+ )
+ trainer.train()
+
+ # 7. Evaluate the final model, useful to include these in the model card
+ evaluator(model)
+
+ # 8. Save the final model
+ final_output_dir = f"models/{run_name}_{dt}/final"
+ model.save_pretrained(final_output_dir)
+
+ # 9. (Optional) save the model to the Hugging Face Hub!
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
+ try:
+ model.push_to_hub(run_name)
+ except Exception:
+ logging.error(
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
+ f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
+ f"and saving it using `model.push_to_hub('{run_name}')`."
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sentence_transformers/cross_encoder/losses/LambdaLoss.py b/sentence_transformers/cross_encoder/losses/LambdaLoss.py
index d7e78b76d..9257b70e6 100644
--- a/sentence_transformers/cross_encoder/losses/LambdaLoss.py
+++ b/sentence_transformers/cross_encoder/losses/LambdaLoss.py
@@ -351,7 +351,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]:
def citation(self) -> str:
return """
@inproceedings{wang2018lambdaloss,
- title={The lambdaloss framework for ranking metric optimization},
+ title={The LambdaLoss Framework for Ranking Metric Optimization},
author={Wang, Xuanhui and Li, Cheng and Golbandi, Nadav and Bendersky, Michael and Najork, Marc},
booktitle={Proceedings of the 27th ACM international conference on information and knowledge management},
pages={1313--1322},
diff --git a/sentence_transformers/cross_encoder/losses/ListMLELoss.py b/sentence_transformers/cross_encoder/losses/ListMLELoss.py
index deb516809..542dfe960 100644
--- a/sentence_transformers/cross_encoder/losses/ListMLELoss.py
+++ b/sentence_transformers/cross_encoder/losses/ListMLELoss.py
@@ -15,7 +15,7 @@ def __init__(
respect_input_order: bool = True,
) -> None:
"""
- This loss function implements the ListMLE learnin to rank algorithm, which uses a list-wise
+ This loss function implements the ListMLE learning to rank algorithm, which uses a list-wise
approach based on maximum likelihood estimation of permutations. It maximizes the likelihood
of the permutation induced by the ground truth labels.
@@ -116,10 +116,10 @@ def get_config_dict(self) -> dict[str, float | int | str | None]:
def citation(self) -> str:
return """
@inproceedings{10.1145/1390156.1390306,
- title = {Listwise approach to learning to rank: theory and algorithm},
+ title = {Listwise Approach to Learning to Rank - Theory and Algorithm},
author = {Xia, Fen and Liu, Tie-Yan and Wang, Jue and Zhang, Wensheng and Li, Hang},
booktitle = {Proceedings of the 25th International Conference on Machine Learning},
- pages = {1192–1199},
+ pages = {1192-1199},
year = {2008},
url = {https://doi.org/10.1145/1390156.1390306},
}
diff --git a/sentence_transformers/cross_encoder/losses/ListNetLoss.py b/sentence_transformers/cross_encoder/losses/ListNetLoss.py
index 374cc7867..923d6bad5 100644
--- a/sentence_transformers/cross_encoder/losses/ListNetLoss.py
+++ b/sentence_transformers/cross_encoder/losses/ListNetLoss.py
@@ -188,7 +188,7 @@ def get_config_dict(self) -> dict[str, float]:
def citation(self) -> str:
return """
@inproceedings{cao2007learning,
- title={Learning to rank: from pairwise approach to listwise approach},
+ title={Learning to Rank: From Pairwise Approach to Listwise Approach},
author={Cao, Zhe and Qin, Tao and Liu, Tie-Yan and Tsai, Ming-Feng and Li, Hang},
booktitle={Proceedings of the 24th international conference on Machine learning},
pages={129--136},
diff --git a/sentence_transformers/cross_encoder/losses/PListMLELoss.py b/sentence_transformers/cross_encoder/losses/PListMLELoss.py
index fb722c9bb..bfa69ef63 100644
--- a/sentence_transformers/cross_encoder/losses/PListMLELoss.py
+++ b/sentence_transformers/cross_encoder/losses/PListMLELoss.py
@@ -284,7 +284,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]:
def citation(self) -> str:
return """
@inproceedings{lan2014position,
- title={Position-Aware ListMLE: A Sequential Learning Process for Ranking.},
+ title={Position-Aware ListMLE: A Sequential Learning Process for Ranking},
author={Lan, Yanyan and Zhu, Yadong and Guo, Jiafeng and Niu, Shuzi and Cheng, Xueqi},
booktitle={UAI},
volume={14},
diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py
new file mode 100644
index 000000000..fd24566ff
--- /dev/null
+++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py
@@ -0,0 +1,123 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from torch import nn
+
+from sentence_transformers.cross_encoder import CrossEncoder
+from sentence_transformers.cross_encoder.losses import LambdaLoss, NoWeightingScheme
+
+
+class RankNetLoss(LambdaLoss):
+ def __init__(
+ self,
+ model: CrossEncoder,
+ k: int | None = None,
+ sigma: float = 1.0,
+ eps: float = 1e-10,
+ reduction_log: Literal["natural", "binary"] = "binary",
+ activation_fct: nn.Module | None = nn.Identity(),
+ mini_batch_size: int | None = None,
+ ) -> None:
+ """
+ RankNet loss implementation for learning to rank. This loss function implements the RankNet algorithm,
+ which learns a ranking function by optimizing pairwise document comparisons using a neural network.
+ The implementation is optimized to handle padded documents efficiently by only processing valid
+ documents during model inference.
+
+ Args:
+ model (CrossEncoder): CrossEncoder model to be trained
+ sigma (float): Score difference weight used in sigmoid (default: 1.0)
+ eps (float): Small constant for numerical stability (default: 1e-10)
+ activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the
+ loss. Defaults to :class:`~torch.nn.Identity`.
+ mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant
+ impact on the memory consumption and speed of the training process. Three cases are possible:
+ - If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size.
+ - If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``.
+ - If ``mini_batch_size`` is <= 0, the entire batch is processed at once.
+ Defaults to None.
+
+ References:
+ - Learning to Rank using Gradient Descent: https://icml.cc/Conferences/2015/wp-content/uploads/2015/06/icml_ranking.pdf
+ - `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_
+
+ Requirements:
+ 1. Query with multiple documents (pairwise approach)
+ 2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.
+
+ Inputs:
+ +----------------------------------------+--------------------------------+-------------------------------+
+ | Texts | Labels | Number of Model Output Labels |
+ +========================================+================================+===============================+
+ | (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 |
+ +----------------------------------------+--------------------------------+-------------------------------+
+
+ Recommendations:
+ - Use :class:`~sentence_transformers.util.mine_hard_negatives` with ``output_format="labeled-list"``
+ to convert question-answer pairs to the required input format with hard negatives.
+
+ Relations:
+ - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` can be seen as an extension of this loss
+ where each score pair is weighted. Alternatively, this loss can be seen as a special case of the
+ :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` without a weighting scheme.
+ - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` with its default NDCGLoss2++ weighting
+ scheme anecdotally performs better than the other losses with the same input format.
+
+ Example:
+ ::
+
+ from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
+ from datasets import Dataset
+
+ model = CrossEncoder("microsoft/mpnet-base")
+ train_dataset = Dataset.from_dict({
+ "query": ["What are pandas?", "What is the capital of France?"],
+ "docs": [
+ ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
+ ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
+ ],
+ "labels": [[1, 0], [1, 1, 0]],
+ })
+ loss = losses.RankNetLoss(model)
+
+ trainer = CrossEncoderTrainer(
+ model=model,
+ train_dataset=train_dataset,
+ loss=loss,
+ )
+ trainer.train()
+ """
+ super().__init__(
+ model=model,
+ weighting_scheme=NoWeightingScheme(),
+ k=k,
+ sigma=sigma,
+ eps=eps,
+ reduction_log=reduction_log,
+ activation_fct=activation_fct,
+ mini_batch_size=mini_batch_size,
+ )
+
+ def get_config_dict(self) -> dict[str, float | int | str | None]:
+ """
+ Get configuration parameters for this loss function.
+
+ Returns:
+ Dictionary containing the configuration parameters
+ """
+ config = super().get_config_dict()
+ del config["weighting_scheme"]
+ return config
+
+ @property
+ def citation(self) -> str:
+ return """
+@inproceedings{burges2005learning,
+ title={Learning to Rank using Gradient Descent},
+ author={Burges, Chris and Shaked, Tal and Renshaw, Erin and Lazier, Ari and Deeds, Matt and Hamilton, Nicole and Hullender, Greg},
+ booktitle={Proceedings of the 22nd international conference on Machine learning},
+ pages={89--96},
+ year={2005}
+}
+"""
diff --git a/sentence_transformers/cross_encoder/losses/__init__.py b/sentence_transformers/cross_encoder/losses/__init__.py
index b77ad9bb0..a048d5265 100644
--- a/sentence_transformers/cross_encoder/losses/__init__.py
+++ b/sentence_transformers/cross_encoder/losses/__init__.py
@@ -17,6 +17,7 @@
from .MSELoss import MSELoss
from .MultipleNegativesRankingLoss import MultipleNegativesRankingLoss
from .PListMLELoss import PListMLELambdaWeight, PListMLELoss
+from .RankNetLoss import RankNetLoss
__all__ = [
"BinaryCrossEntropyLoss",
@@ -35,4 +36,5 @@
"NDCGLoss2Scheme",
"LambdaRankScheme",
"NDCGLoss2PPScheme",
+ "RankNetLoss",
]