Skip to content
24 changes: 12 additions & 12 deletions docs/cross_encoder/loss_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ Loss functions play a critical role in the performance of your fine-tuned Cross
- ``(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets`` with labels of 0 for negative and 1 for positive with ``output_format="labeled-list"``,
```

| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions |
|---------------------------------------------------|------------------------------------------|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | <a href="../package_reference/cross_encoder/losses.html#crossentropyloss">`CrossEntropyLoss`</a> |
| `(anchor, positive) pairs` | `none` | `1` | <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a> |
| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | <a href="../package_reference/cross_encoder/losses.html#binarycrossentropyloss">`BinaryCrossEntropyLoss`</a> |
| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | <a href="../package_reference/cross_encoder/losses.html#binarycrossentropyloss">`BinaryCrossEntropyLoss`</a> |
| `(anchor, positive, negative) triplets` | `none` | `1` | <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a> |
| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a> |
| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` | <a href="../package_reference/cross_encoder/losses.html#lambdaloss">`LambdaLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#listnetloss">`ListNetLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#listmleloss">`ListMLELoss`</a><br><a href="../package_reference/cross_encoder/losses.html#plistmleloss">`PListMLELoss`</a> |
| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions |
|---------------------------------------------------|------------------------------------------|-------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | <a href="../package_reference/cross_encoder/losses.html#crossentropyloss">`CrossEntropyLoss`</a> |
| `(anchor, positive) pairs` | `none` | `1` | <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a> |
| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | <a href="../package_reference/cross_encoder/losses.html#binarycrossentropyloss">`BinaryCrossEntropyLoss`</a> |
| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | <a href="../package_reference/cross_encoder/losses.html#binarycrossentropyloss">`BinaryCrossEntropyLoss`</a> |
| `(anchor, positive, negative) triplets` | `none` | `1` | <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a> |
| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a> |
| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` | <ol style="margin-bottom: 0;line-height: inherit;"><li><a href="../package_reference/cross_encoder/losses.html#lambdaloss">`LambdaLoss`</a></li><li><a href="../package_reference/cross_encoder/losses.html#listnetloss">`ListNetLoss`</a></li><li><a href="../package_reference/cross_encoder/losses.html#plistmleloss">`PListMLELoss`</a></li><li><a href="../package_reference/cross_encoder/losses.html#ranknetloss">`RankNetLoss`</a></li><li><a href="../package_reference/cross_encoder/losses.html#listmleloss">`ListMLELoss`</a></li></ol> |

## Distillation
These loss functions are specifically designed to be used when distilling the knowledge from one model into another.
Expand All @@ -40,9 +40,9 @@ For example, when finetuning a small model to behave more like a larger & strong
In practice, not all loss functions get used equally often. The most common scenarios are:

* `(sentence_A, sentence_B) pairs` with `float similarity score` or `1 if positive, 0 if negative`: <a href="../package_reference/cross_encoder/losses.html#binarycrossentropyloss"><code>BinaryCrossEntropyLoss</code></a> is a traditional option that remains very challenging to outperform.
* `(anchor, positive) pairs` without any labels:
* <a href="../package_reference/cross_encoder/losses.html#multiplenegativesrankingloss"><code>MultipleNegativesRankingLoss</code></a> (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train <a href="../package_reference/sentence_transformer/SentenceTransformer.html#sentence_transformers.SentenceTransformer"><code>SentenceTransformer</code></a> models, and the loss is also applicable for <a href="../package_reference/cross_encoder/cross_encoder.html#sentence_transformers.cross_encoder.CrossEncoder"><code>CrossEncoder</code></a> models. This data is often relatively cheap to obtain, and <a href="../package_reference/util.html#sentence_transformers.util.mine_hard_negatives"><code>mine_hard_negatives</code></a> with <code>output_format="n-tuple"</code> or <code>output_format="triplet"</code> can easily be used to add hard negatives for this loss. <a href="../package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss"><code>CachedMultipleNegativesRankingLoss</code></a></a> is often used to keep the memory usage in check.
* Together with <a href="../package_reference/util.html#sentence_transformers.util.mine_hard_negatives"><code>mine_hard_negatives</code></a> with <code>output_format="labeled-list"</code>, <a href="../package_reference/cross_encoder/losses.html#lambdaloss"><code>LambdaLoss</code></a> is frequently used for learning-to-rank tasks.
* `(anchor, positive) pairs` without any labels: combined with <a href="../package_reference/util.html#sentence_transformers.util.mine_hard_negatives"><code>mine_hard_negatives</code></a>
* with <code>output_format="labeled-list"</code>, then <a href="../package_reference/cross_encoder/losses.html#lambdaloss"><code>LambdaLoss</code></a> is frequently used for learning-to-rank tasks.
* with <code>output_format="labeled-pair"</code>, then <a href="../package_reference/cross_encoder/losses.html#binarycrossentropyloss"><code>BinaryCrossEntropyLoss</code></a> remains a strong option.

## Custom Loss Functions

Expand Down
5 changes: 5 additions & 0 deletions docs/package_reference/cross_encoder/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui
```{eval-rst}
.. autoclass:: sentence_transformers.cross_encoder.losses.MarginMSELoss
```

## RankNetLoss
```{eval-rst}
.. autoclass:: sentence_transformers.cross_encoder.losses.RankNetLoss
```
6 changes: 5 additions & 1 deletion examples/cross_encoder/training/ms_marco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ In all scripts, the model is evaluated on subsets of `MS MARCO <https://huggingf
```{eval-rst}
This example uses the :class:`~sentence_transformers.cross_encoder.losses.PListMLELoss` with the default :class:`~sentence_transformers.cross_encoder.losses.PListMLELambdaWeight` position weighting. The script applies dataset pre-processing into ``(query, [doc1, doc2, ..., docN])`` with ``labels`` as ``[score1, score2, ..., scoreN]``.
```
* **[training_ms_marco_ranknet.py](training_ms_marco_ranknet.py)**:
```{eval-rst}
This example uses the :class:`~sentence_transformers.cross_encoder.losses.RankNetLoss`. The script applies dataset pre-processing into ``(query, [doc1, doc2, ..., docN])`` with ``labels`` as ``[score1, score2, ..., scoreN]``.
```

Out of these training scripts, I suspect that **[training_ms_marco_lambda_preprocessed.py](training_ms_marco_lambda_preprocessed.py)**, **[training_ms_marco_lambda_hard_neg.py](training_ms_marco_lambda_hard_neg.py)** or **[training_ms_marco_bce_preprocessed.py](training_ms_marco_bce_preprocessed.py)** produces the strongest model, as anecdotally `LambdaLoss` and `BinaryCrossEntropyLoss` are quite strong. It seems that `LambdaLoss` > `PListMLELoss` > `ListNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary.
Out of these training scripts, I suspect that **[training_ms_marco_lambda_preprocessed.py](training_ms_marco_lambda_preprocessed.py)**, **[training_ms_marco_lambda_hard_neg.py](training_ms_marco_lambda_hard_neg.py)** or **[training_ms_marco_bce_preprocessed.py](training_ms_marco_bce_preprocessed.py)** produces the strongest model, as anecdotally `LambdaLoss` and `BinaryCrossEntropyLoss` are quite strong. It seems that `LambdaLoss` > `ListNetLoss` > `PListMLELoss` > `RankNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary.

Additionally, you can also train with Distillation. See [Cross Encoder > Training Examples > Distillation](../distillation/README.md) for more details.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback

import torch
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
Expand All @@ -21,6 +22,8 @@ def main():
num_epochs = 1

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback

import torch
from datasets import load_dataset, load_from_disk

from sentence_transformers.cross_encoder import CrossEncoder
Expand All @@ -21,6 +22,8 @@ def main():
dataset_size = 2_000_000

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback
from collections import defaultdict

import torch
from datasets import load_dataset
from torch import nn

Expand All @@ -25,6 +26,8 @@ def main():
num_epochs = 1

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback
from datetime import datetime

import torch
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
Expand Down Expand Up @@ -32,6 +33,8 @@ def main():
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback
from datetime import datetime

import torch
from datasets import Dataset, concatenate_datasets, load_dataset

from sentence_transformers import CrossEncoder, SentenceTransformer
Expand Down Expand Up @@ -33,6 +34,8 @@ def main():
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback
from datetime import datetime

import torch
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
Expand Down Expand Up @@ -32,6 +33,8 @@ def main():
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback

import torch
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
Expand Down Expand Up @@ -30,6 +31,8 @@ def main():
respect_input_order = True # Whether to respect the original order of documents

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback

import torch
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
Expand Down Expand Up @@ -29,6 +30,8 @@ def main():
max_docs = None

# 1. Define our CrossEncoder model
# Set the seed so the new classifier weights are identical in subsequent runs
torch.manual_seed(12)
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
Expand Down
Loading