From 68b71926017f29ef36b412070d45cfb5549e8003 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Thu, 20 Mar 2025 01:05:45 +0900 Subject: [PATCH 01/11] Add RankNetLoss and training script --- .../ms_marco/training_ms_marco_ranknet.py | 163 +++++++++++++ .../cross_encoder/losses/RankNetLoss.py | 214 ++++++++++++++++++ 2 files changed, 377 insertions(+) create mode 100644 examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py create mode 100644 sentence_transformers/cross_encoder/losses/RankNetLoss.py diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py new file mode 100644 index 000000000..f039a7ff2 --- /dev/null +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import logging +import traceback +from datetime import datetime + +from datasets import load_dataset + +from sentence_transformers.cross_encoder import CrossEncoder +from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator +from sentence_transformers.cross_encoder.losses import RankNetLoss +from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer +from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments + + +def main(): + model_name = "microsoft/MiniLM-L12-H384-uncased" + + # Set the log level to INFO to get more information + logging.basicConfig( + format="%(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + ) + # train_batch_size and eval_batch_size inform the size of the batches, while mini_batch_size is used by the loss + # to subdivide the batch into smaller parts. This mini_batch_size largely informs the training speed and memory usage. + # Keep in mind that the loss does not process `train_batch_size` pairs, but `train_batch_size * num_docs` pairs. + train_batch_size = 16 + eval_batch_size = 16 + mini_batch_size = 16 + num_epochs = 1 + max_docs = None + + + # 1. Define our CrossEncoder model + model = CrossEncoder(model_name, num_labels=1) + print("Model max length:", model.max_length) + print("Model num labels:", model.num_labels) + + # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco + logging.info("Read train dataset") + dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train") + + def listwise_mapper(batch, max_docs: int | None = 10): + processed_queries = [] + processed_docs = [] + processed_labels = [] + + for query, passages_info in zip(batch["query"], batch["passages"]): + # Extract passages and labels + passages = passages_info["passage_text"] + labels = passages_info["is_selected"] + + # Pair passages with labels and sort descending by label (positives first) + paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True) + + # Separate back to passages and labels + sorted_passages, sorted_labels = zip(*paired) if paired else ([], []) + + # Filter queries without any positive labels + if max(sorted_labels) < 1.0: + continue + + # Truncate to max_docs + if max_docs is not None: + sorted_passages = list(sorted_passages[:max_docs]) + sorted_labels = list(sorted_labels[:max_docs]) + + processed_queries.append(query) + processed_docs.append(sorted_passages) + processed_labels.append(sorted_labels) + + return { + "query": processed_queries, + "docs": processed_docs, + "labels": processed_labels, + } + + # Create a dataset with a "query" column with strings, a "docs" column with lists of strings, + # and a "labels" column with lists of floats + dataset = dataset.map( + lambda batch: listwise_mapper(batch=batch, max_docs=max_docs), + batched=True, + remove_columns=dataset.column_names, + desc="Processing listwise samples", + ) + + dataset = dataset.train_test_split(test_size=1_000, seed=12) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + logging.info(train_dataset) + + # 3. Define our training loss + loss = RankNetLoss( + model=model, + mini_batch_size=mini_batch_size, + ) + + # 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking + evaluator = CrossEncoderNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=eval_batch_size) + evaluator(model) + + # 5. Define the training arguments + short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] + run_name = f"reranker-msmarco-v1.1-{short_model_name}-ranknetloss" + args = CrossEncoderTrainingArguments( + # Required parameter: + output_dir=f"models/{run_name}", + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=eval_batch_size, + learning_rate=2e-5, + warmup_ratio=0.1, + fp16=False, # Set to False if you get an error that your GPU can't run on FP16 + bf16=True, # Set to True if you have a GPU that supports BF16 + # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + load_best_model_at_end=True, + metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10", + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + save_total_limit=2, + logging_steps=250, + logging_first_step=True, + run_name=run_name, # Will be used in W&B if `wandb` is installed + seed=12, + ) + + # 6. Create the trainer & start training + trainer = CrossEncoderTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=loss, + evaluator=evaluator, + ) + trainer.train() + + # 7. Evaluate the final model, useful to include these in the model card + evaluator(model) + + # 8. Save the final model + final_output_dir = f"models/{run_name}/final" + model.save_pretrained(final_output_dir) + + # 9. (Optional) save the model to the Hugging Face Hub! + # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first + try: + model.push_to_hub(run_name) + except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{run_name}')`." + ) + + +if __name__ == "__main__": + main() diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py new file mode 100644 index 000000000..29a446055 --- /dev/null +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import Literal + +import torch +from torch import Tensor, nn + +from sentence_transformers.cross_encoder import CrossEncoder +from sentence_transformers.util import fullname + + +class RankNetLoss(nn.Module): + def __init__( + self, + model: CrossEncoder, + sigma: float = 1.0, + eps: float = 1e-10, + activation_fct: nn.Module | None = nn.Identity(), + mini_batch_size: int | None = None, + ) -> None: + """ + RankNet loss implementation for learning to rank. This loss function implements the RankNet algorithm, + which learns a ranking function by optimizing pairwise document comparisons using a neural network. + The implementation is optimized to handle padded documents efficiently by only processing valid + documents during model inference. + + Args: + model (CrossEncoder): CrossEncoder model to be trained + sigma (float): Score difference weight used in sigmoid (default: 1.0) + eps (float): Small constant for numerical stability (default: 1e-10) + activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the + loss. Defaults to :class:`~torch.nn.Identity`. + mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant + impact on the memory consumption and speed of the training process. Three cases are possible: + - If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size. + - If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``. + - If ``mini_batch_size`` is <= 0, the entire batch is processed at once. + Defaults to None. + + References: + - Learning to Rank using Gradient Descent: https://icml.cc/Conferences/2015/wp-content/uploads/2015/06/icml_ranking.pdf + + Requirements: + 1. Query with multiple documents (pairwise approach) + 2. Documents must have relevance scores/labels. Both binary and continuous labels are supported. + + Inputs: + +----------------------------------------+--------------------------------+-------------------------------+ + | Texts | Labels | Number of Model Output Labels | + +========================================+================================+===============================+ + | (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 | + +----------------------------------------+--------------------------------+-------------------------------+ + + Example: + :: + + from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses + from datasets import Dataset + + model = CrossEncoder("microsoft/mpnet-base") + train_dataset = Dataset.from_dict({ + "query": ["What are pandas?", "What is the capital of France?"], + "docs": [ + ["Pandas are a kind of bear.", "Pandas are kind of like fish."], + ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."], + ], + "labels": [[1, 0], [1, 1, 0]], + }) + loss = losses.RankNetLoss(model) + + trainer = CrossEncoderTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() + """ + super().__init__() + self.model = model + self.sigma = sigma + self.eps = eps + self.activation_fct = activation_fct + self.mini_batch_size = mini_batch_size + + if self.model.num_labels != 1: + raise ValueError( + f"{self.__class__.__name__} supports a model with 1 output label, " + f"but got a model with {self.model.num_labels} output labels." + ) + + def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor: + """ + Compute RankNet loss for a batch of queries and their documents. + + Args: + inputs: List of (queries, documents_list) + labels: Ground truth relevance scores, shape (batch_size, num_documents) + + Returns: + Tensor: RankNet loss over the batch + """ + if isinstance(labels, Tensor): + raise ValueError( + "RankNetLoss expects a list of labels for each sample, but got a single value for each sample." + ) + if len(inputs) != 2: + raise ValueError(f"RankNetLoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs.") + + queries, docs_list = inputs + docs_per_query = [len(docs) for docs in docs_list] + max_docs = max(docs_per_query) + batch_size = len(queries) + + if docs_per_query != [len(labels) for labels in labels]: + raise ValueError( + f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})." + ) + + # Create input pairs for the model₩ + pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs] + + if not pairs: + # Handle edge case where all documents are padded + return torch.tensor(0.0, device=self.model.device, requires_grad=True) + + mini_batch_size = self.mini_batch_size or batch_size + if mini_batch_size <= 0: + mini_batch_size = len(pairs) + + logits_list = [] + for i in range(0, len(pairs), mini_batch_size): + mini_batch_pairs = pairs[i : i + mini_batch_size] + + tokens = self.model.tokenizer( + mini_batch_pairs, + padding=True, + truncation=True, + return_tensors="pt", + ) + tokens = tokens.to(self.model.device) + + logits = self.model(**tokens)[0].view(-1) + logits_list.append(logits) + + logits = torch.cat(logits_list, dim=0) + logits = self.activation_fct(logits) + + # Create output tensor filled with 0 (padded logits will be ignored via labels) + logits_matrix = torch.full((batch_size, max_docs), -1e16, device=self.model.device) + + # Place logits in the desired positions in the logit matrix + doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0) + batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query)) + logits_matrix[batch_indices, doc_indices] = logits + + # Create labels matrix + labels_matrix = torch.full_like(logits_matrix, float("-inf")) + labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float() + labels_matrix = labels_matrix.to(self.model.device) + + # Calculate pairwise differences for scores and labels + score_diffs = logits_matrix[:, :, None] - logits_matrix[:, None, :] + label_diffs = labels_matrix[:, :, None] - labels_matrix[:, None, :] + + # Create mask for valid pairs (where both documents are not padded) + valid_pairs = torch.isfinite(label_diffs) + + # Create mask for pairs where l_i > l_j + positive_pairs = label_diffs > 0 + + # Calculate probabilities and target probabilities + P_ij = torch.sigmoid(self.sigma * score_diffs) + P_ij = torch.clamp(P_ij, min=self.eps, max=1-self.eps) + + # Calculate loss only for pairs where l_i > l_j (positive_pairs) + # This follows the TensorFlow Ranking implementation more closely + losses = -torch.log(P_ij) + + # Apply masks and compute mean loss + masked_loss = losses[valid_pairs & positive_pairs] + + # Handle case when there are no positive pairs + if masked_loss.numel() == 0: + return torch.tensor(0.0, device=self.model.device, requires_grad=True) + + loss = torch.mean(masked_loss) + + return loss + + def get_config_dict(self) -> dict[str, float | int | str | None]: + """ + Get configuration parameters for this loss function. + + Returns: + Dictionary containing the configuration parameters + """ + return { + "sigma": self.sigma, + "eps": self.eps, + "activation_fct": fullname(self.activation_fct), + "mini_batch_size": self.mini_batch_size, + } + + @property + def citation(self) -> str: + return """ +@inproceedings{burges2005learning, + title={Learning to rank using gradient descent}, + author={Burges, Chris and Shaked, Tal and Renshaw, Erin and Lazier, Ari and Deeds, Matt and Hamilton, Nicole and Hullender, Greg}, + booktitle={Proceedings of the 22nd international conference on Machine learning}, + pages={89--96}, + year={2005} +} +""" \ No newline at end of file From 4cca438bcdd2ddade87e25755d7a4b9c84f8a88d Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Thu, 20 Mar 2025 10:50:43 +0900 Subject: [PATCH 02/11] Fix ListMLELoss documentation --- sentence_transformers/cross_encoder/losses/ListMLELoss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/cross_encoder/losses/ListMLELoss.py b/sentence_transformers/cross_encoder/losses/ListMLELoss.py index deb516809..961a00db6 100644 --- a/sentence_transformers/cross_encoder/losses/ListMLELoss.py +++ b/sentence_transformers/cross_encoder/losses/ListMLELoss.py @@ -15,7 +15,7 @@ def __init__( respect_input_order: bool = True, ) -> None: """ - This loss function implements the ListMLE learnin to rank algorithm, which uses a list-wise + This loss function implements the ListMLE learning to rank algorithm, which uses a list-wise approach based on maximum likelihood estimation of permutations. It maximizes the likelihood of the permutation induced by the ground truth labels. From 1eb6ad77609293b86e2716f9e75f2ae0892f7279 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Thu, 20 Mar 2025 10:59:16 +0900 Subject: [PATCH 03/11] Fix RankNet to class LambdaLoss --- .../cross_encoder/losses/RankNetLoss.py | 132 +++--------------- 1 file changed, 18 insertions(+), 114 deletions(-) diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py index 29a446055..1cf166980 100644 --- a/sentence_transformers/cross_encoder/losses/RankNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -2,19 +2,21 @@ from typing import Literal -import torch -from torch import Tensor, nn +from torch import nn from sentence_transformers.cross_encoder import CrossEncoder +from sentence_transformers.cross_encoder.losses import LambdaLoss, NoWeightingScheme from sentence_transformers.util import fullname -class RankNetLoss(nn.Module): +class RankNetLoss(LambdaLoss): def __init__( self, model: CrossEncoder, + k: int | None = None, sigma: float = 1.0, eps: float = 1e-10, + reduction_log: Literal["natural", "binary"] = "binary", activation_fct: nn.Module | None = nn.Identity(), mini_batch_size: int | None = None, ) -> None: @@ -75,117 +77,16 @@ def __init__( ) trainer.train() """ - super().__init__() - self.model = model - self.sigma = sigma - self.eps = eps - self.activation_fct = activation_fct - self.mini_batch_size = mini_batch_size - - if self.model.num_labels != 1: - raise ValueError( - f"{self.__class__.__name__} supports a model with 1 output label, " - f"but got a model with {self.model.num_labels} output labels." - ) - - def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor: - """ - Compute RankNet loss for a batch of queries and their documents. - - Args: - inputs: List of (queries, documents_list) - labels: Ground truth relevance scores, shape (batch_size, num_documents) - - Returns: - Tensor: RankNet loss over the batch - """ - if isinstance(labels, Tensor): - raise ValueError( - "RankNetLoss expects a list of labels for each sample, but got a single value for each sample." - ) - if len(inputs) != 2: - raise ValueError(f"RankNetLoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs.") - - queries, docs_list = inputs - docs_per_query = [len(docs) for docs in docs_list] - max_docs = max(docs_per_query) - batch_size = len(queries) - - if docs_per_query != [len(labels) for labels in labels]: - raise ValueError( - f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})." - ) - - # Create input pairs for the model₩ - pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs] - - if not pairs: - # Handle edge case where all documents are padded - return torch.tensor(0.0, device=self.model.device, requires_grad=True) - - mini_batch_size = self.mini_batch_size or batch_size - if mini_batch_size <= 0: - mini_batch_size = len(pairs) - - logits_list = [] - for i in range(0, len(pairs), mini_batch_size): - mini_batch_pairs = pairs[i : i + mini_batch_size] - - tokens = self.model.tokenizer( - mini_batch_pairs, - padding=True, - truncation=True, - return_tensors="pt", - ) - tokens = tokens.to(self.model.device) - - logits = self.model(**tokens)[0].view(-1) - logits_list.append(logits) - - logits = torch.cat(logits_list, dim=0) - logits = self.activation_fct(logits) - - # Create output tensor filled with 0 (padded logits will be ignored via labels) - logits_matrix = torch.full((batch_size, max_docs), -1e16, device=self.model.device) - - # Place logits in the desired positions in the logit matrix - doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0) - batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query)) - logits_matrix[batch_indices, doc_indices] = logits - - # Create labels matrix - labels_matrix = torch.full_like(logits_matrix, float("-inf")) - labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float() - labels_matrix = labels_matrix.to(self.model.device) - - # Calculate pairwise differences for scores and labels - score_diffs = logits_matrix[:, :, None] - logits_matrix[:, None, :] - label_diffs = labels_matrix[:, :, None] - labels_matrix[:, None, :] - - # Create mask for valid pairs (where both documents are not padded) - valid_pairs = torch.isfinite(label_diffs) - - # Create mask for pairs where l_i > l_j - positive_pairs = label_diffs > 0 - - # Calculate probabilities and target probabilities - P_ij = torch.sigmoid(self.sigma * score_diffs) - P_ij = torch.clamp(P_ij, min=self.eps, max=1-self.eps) - - # Calculate loss only for pairs where l_i > l_j (positive_pairs) - # This follows the TensorFlow Ranking implementation more closely - losses = -torch.log(P_ij) - - # Apply masks and compute mean loss - masked_loss = losses[valid_pairs & positive_pairs] - - # Handle case when there are no positive pairs - if masked_loss.numel() == 0: - return torch.tensor(0.0, device=self.model.device, requires_grad=True) - - loss = torch.mean(masked_loss) - - return loss + super().__init__( + model=model, + weighting_scheme=NoWeightingScheme(), + k=k, + sigma=sigma, + eps=eps, + reduction_log=reduction_log, + activation_fct=activation_fct, + mini_batch_size=mini_batch_size, + ) def get_config_dict(self) -> dict[str, float | int | str | None]: """ @@ -195,8 +96,11 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: Dictionary containing the configuration parameters """ return { + "weighting_scheme": fullname(self.weighting_scheme), + "k": self.k, "sigma": self.sigma, "eps": self.eps, + "reduction_log": self.reduction_log, "activation_fct": fullname(self.activation_fct), "mini_batch_size": self.mini_batch_size, } From 79f9e2236480c61395d3f974f7382903bba5b8b6 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Thu, 20 Mar 2025 11:04:56 +0900 Subject: [PATCH 04/11] Update training script for RankNetLoss --- .../training/ms_marco/training_ms_marco_ranknet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py index f039a7ff2..d97098abf 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py @@ -31,6 +31,7 @@ def main(): num_epochs = 1 max_docs = None + dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model model = CrossEncoder(model_name, num_labels=1) @@ -105,7 +106,7 @@ def listwise_mapper(batch, max_docs: int | None = 10): run_name = f"reranker-msmarco-v1.1-{short_model_name}-ranknetloss" args = CrossEncoderTrainingArguments( # Required parameter: - output_dir=f"models/{run_name}", + output_dir=f"models/{run_name}_{dt}", # Optional training parameters: num_train_epochs=num_epochs, per_device_train_batch_size=train_batch_size, @@ -144,7 +145,7 @@ def listwise_mapper(batch, max_docs: int | None = 10): evaluator(model) # 8. Save the final model - final_output_dir = f"models/{run_name}/final" + final_output_dir = f"models/{run_name}_{dt}/final" model.save_pretrained(final_output_dir) # 9. (Optional) save the model to the Hugging Face Hub! From 8364c2902a06472890a2160e3664c2d13cdf7497 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 14:04:37 +0100 Subject: [PATCH 05/11] Use super().get_config_dict() and remove weighting scheme It's a bit confusing to include the weighting scheme in the config if the RankNet loss doesn't have a notion of that --- .../cross_encoder/losses/RankNetLoss.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py index 1cf166980..3a1cc7f5b 100644 --- a/sentence_transformers/cross_encoder/losses/RankNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -6,7 +6,6 @@ from sentence_transformers.cross_encoder import CrossEncoder from sentence_transformers.cross_encoder.losses import LambdaLoss, NoWeightingScheme -from sentence_transformers.util import fullname class RankNetLoss(LambdaLoss): @@ -95,15 +94,9 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: Returns: Dictionary containing the configuration parameters """ - return { - "weighting_scheme": fullname(self.weighting_scheme), - "k": self.k, - "sigma": self.sigma, - "eps": self.eps, - "reduction_log": self.reduction_log, - "activation_fct": fullname(self.activation_fct), - "mini_batch_size": self.mini_batch_size, - } + config = super().get_config_dict() + del config["weighting_scheme"] + return config @property def citation(self) -> str: @@ -115,4 +108,4 @@ def citation(self) -> str: pages={89--96}, year={2005} } -""" \ No newline at end of file +""" From 5fe7a1cc14473b61c0c95e33c19ca05c31ce16a5 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 14:04:52 +0100 Subject: [PATCH 06/11] Add to __init__.py for easier import --- sentence_transformers/cross_encoder/losses/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sentence_transformers/cross_encoder/losses/__init__.py b/sentence_transformers/cross_encoder/losses/__init__.py index b77ad9bb0..a048d5265 100644 --- a/sentence_transformers/cross_encoder/losses/__init__.py +++ b/sentence_transformers/cross_encoder/losses/__init__.py @@ -17,6 +17,7 @@ from .MSELoss import MSELoss from .MultipleNegativesRankingLoss import MultipleNegativesRankingLoss from .PListMLELoss import PListMLELambdaWeight, PListMLELoss +from .RankNetLoss import RankNetLoss __all__ = [ "BinaryCrossEntropyLoss", @@ -35,4 +36,5 @@ "NDCGLoss2Scheme", "LambdaRankScheme", "NDCGLoss2PPScheme", + "RankNetLoss", ] From 73e3ad1cd36476a743cac8016f86eb7156c6a1be Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 14:05:15 +0100 Subject: [PATCH 07/11] Correctly capitalize citation titles --- sentence_transformers/cross_encoder/losses/LambdaLoss.py | 2 +- sentence_transformers/cross_encoder/losses/ListMLELoss.py | 4 ++-- sentence_transformers/cross_encoder/losses/ListNetLoss.py | 2 +- sentence_transformers/cross_encoder/losses/PListMLELoss.py | 2 +- sentence_transformers/cross_encoder/losses/RankNetLoss.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/cross_encoder/losses/LambdaLoss.py b/sentence_transformers/cross_encoder/losses/LambdaLoss.py index d7e78b76d..9257b70e6 100644 --- a/sentence_transformers/cross_encoder/losses/LambdaLoss.py +++ b/sentence_transformers/cross_encoder/losses/LambdaLoss.py @@ -351,7 +351,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{wang2018lambdaloss, - title={The lambdaloss framework for ranking metric optimization}, + title={The LambdaLoss Framework for Ranking Metric Optimization}, author={Wang, Xuanhui and Li, Cheng and Golbandi, Nadav and Bendersky, Michael and Najork, Marc}, booktitle={Proceedings of the 27th ACM international conference on information and knowledge management}, pages={1313--1322}, diff --git a/sentence_transformers/cross_encoder/losses/ListMLELoss.py b/sentence_transformers/cross_encoder/losses/ListMLELoss.py index 961a00db6..542dfe960 100644 --- a/sentence_transformers/cross_encoder/losses/ListMLELoss.py +++ b/sentence_transformers/cross_encoder/losses/ListMLELoss.py @@ -116,10 +116,10 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{10.1145/1390156.1390306, - title = {Listwise approach to learning to rank: theory and algorithm}, + title = {Listwise Approach to Learning to Rank - Theory and Algorithm}, author = {Xia, Fen and Liu, Tie-Yan and Wang, Jue and Zhang, Wensheng and Li, Hang}, booktitle = {Proceedings of the 25th International Conference on Machine Learning}, - pages = {1192–1199}, + pages = {1192-1199}, year = {2008}, url = {https://doi.org/10.1145/1390156.1390306}, } diff --git a/sentence_transformers/cross_encoder/losses/ListNetLoss.py b/sentence_transformers/cross_encoder/losses/ListNetLoss.py index 374cc7867..923d6bad5 100644 --- a/sentence_transformers/cross_encoder/losses/ListNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/ListNetLoss.py @@ -188,7 +188,7 @@ def get_config_dict(self) -> dict[str, float]: def citation(self) -> str: return """ @inproceedings{cao2007learning, - title={Learning to rank: from pairwise approach to listwise approach}, + title={Learning to Rank: From Pairwise Approach to Listwise Approach}, author={Cao, Zhe and Qin, Tao and Liu, Tie-Yan and Tsai, Ming-Feng and Li, Hang}, booktitle={Proceedings of the 24th international conference on Machine learning}, pages={129--136}, diff --git a/sentence_transformers/cross_encoder/losses/PListMLELoss.py b/sentence_transformers/cross_encoder/losses/PListMLELoss.py index fb722c9bb..bfa69ef63 100644 --- a/sentence_transformers/cross_encoder/losses/PListMLELoss.py +++ b/sentence_transformers/cross_encoder/losses/PListMLELoss.py @@ -284,7 +284,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{lan2014position, - title={Position-Aware ListMLE: A Sequential Learning Process for Ranking.}, + title={Position-Aware ListMLE: A Sequential Learning Process for Ranking}, author={Lan, Yanyan and Zhu, Yadong and Guo, Jiafeng and Niu, Shuzi and Cheng, Xueqi}, booktitle={UAI}, volume={14}, diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py index 3a1cc7f5b..177ff2ed0 100644 --- a/sentence_transformers/cross_encoder/losses/RankNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -102,7 +102,7 @@ def get_config_dict(self) -> dict[str, float | int | str | None]: def citation(self) -> str: return """ @inproceedings{burges2005learning, - title={Learning to rank using gradient descent}, + title={Learning to Rank using Gradient Descent}, author={Burges, Chris and Shaked, Tal and Renshaw, Erin and Lazier, Ari and Deeds, Matt and Hamilton, Nicole and Hullender, Greg}, booktitle={Proceedings of the 22nd international conference on Machine learning}, pages={89--96}, From 138164599160f186e43f20f1a886f96b32f693ae Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 14:07:39 +0100 Subject: [PATCH 08/11] Introduce reproducibility for the msmarco scripts --- .../cross_encoder/training/ms_marco/training_ms_marco_bce.py | 3 +++ .../training/ms_marco/training_ms_marco_bce_preprocessed.py | 3 +++ .../cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py | 3 +++ .../training/ms_marco/training_ms_marco_lambda.py | 3 +++ .../training/ms_marco/training_ms_marco_lambda_hard_neg.py | 3 +++ .../training/ms_marco/training_ms_marco_lambda_preprocessed.py | 3 +++ .../training/ms_marco/training_ms_marco_listmle.py | 3 +++ .../training/ms_marco/training_ms_marco_listnet.py | 3 +++ .../training/ms_marco/training_ms_marco_plistmle.py | 3 +++ .../training/ms_marco/training_ms_marco_ranknet.py | 3 +++ 10 files changed, 30 insertions(+) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py index 8be70e426..b01f768e1 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -21,6 +22,8 @@ def main(): num_epochs = 1 # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py index d87fbd67d..26c56d42a 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_bce_preprocessed.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset, load_from_disk from sentence_transformers.cross_encoder import CrossEncoder @@ -21,6 +22,8 @@ def main(): dataset_size = 2_000_000 # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py index 62bc75c92..d1a4a9985 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_cmnrl.py @@ -2,6 +2,7 @@ import traceback from collections import defaultdict +import torch from datasets import load_dataset from torch import nn @@ -25,6 +26,8 @@ def main(): num_epochs = 1 # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py index 3d13dede4..3712a1036 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda.py @@ -2,6 +2,7 @@ import traceback from datetime import datetime +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -32,6 +33,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py index c7af97ca3..8b7e93d6f 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_hard_neg.py @@ -2,6 +2,7 @@ import traceback from datetime import datetime +import torch from datasets import Dataset, concatenate_datasets, load_dataset from sentence_transformers import CrossEncoder, SentenceTransformer @@ -33,6 +34,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py index d2e728253..e2c7056e9 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_lambda_preprocessed.py @@ -2,6 +2,7 @@ import traceback from datetime import datetime +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -32,6 +33,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py index ba579a2e2..7f6e1c7c9 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_listmle.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -30,6 +31,8 @@ def main(): respect_input_order = True # Whether to respect the original order of documents # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py index 570a00e7f..bc920142d 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_listnet.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -29,6 +30,8 @@ def main(): max_docs = None # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py index a93c3d3f2..dede60052 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_plistmle.py @@ -1,6 +1,7 @@ import logging import traceback +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -30,6 +31,8 @@ def main(): respect_input_order = True # Whether to respect the original order of documents # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) diff --git a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py index d97098abf..8fed1dc47 100644 --- a/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py +++ b/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py @@ -4,6 +4,7 @@ import traceback from datetime import datetime +import torch from datasets import load_dataset from sentence_transformers.cross_encoder import CrossEncoder @@ -34,6 +35,8 @@ def main(): dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 1. Define our CrossEncoder model + # Set the seed so the new classifier weights are identical in subsequent runs + torch.manual_seed(12) model = CrossEncoder(model_name, num_labels=1) print("Model max length:", model.max_length) print("Model num labels:", model.num_labels) From 55e7f49e197e1b7c1e3c83741c1b54a26086e488 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 16:44:01 +0100 Subject: [PATCH 09/11] Add more docs for RankNetLoss --- .../cross_encoder/losses/RankNetLoss.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py index 177ff2ed0..e32c14f50 100644 --- a/sentence_transformers/cross_encoder/losses/RankNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -52,6 +52,17 @@ def __init__( | (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 | +----------------------------------------+--------------------------------+-------------------------------+ + Recommendations: + - Use :class:`~sentence_transformers.util.mine_hard_negatives` with ``output_format="labeled-list"`` + to convert question-answer pairs to the required input format with hard negatives. + + Relations: + - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` can be seen as an extension of this loss + where each score pair is weighted. Alternatively, this loss can be seen as a special case of the + :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` without a weighting scheme. + - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` with its default NDCGLoss2++ weighting + scheme anecdotally performs better than the other losses with the same input format. + Example: :: From 9e0121571cbe083dc08e5c7590488d7bc5198ada Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 17:02:37 +0100 Subject: [PATCH 10/11] Add RankNet to Loss Overview & API Reference --- docs/cross_encoder/loss_overview.md | 24 +++++++++---------- .../package_reference/cross_encoder/losses.md | 5 ++++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/cross_encoder/loss_overview.md b/docs/cross_encoder/loss_overview.md index e5258ba11..3e126e22f 100644 --- a/docs/cross_encoder/loss_overview.md +++ b/docs/cross_encoder/loss_overview.md @@ -17,15 +17,15 @@ Loss functions play a critical role in the performance of your fine-tuned Cross - ``(anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets`` with labels of 0 for negative and 1 for positive with ``output_format="labeled-list"``, ``` -| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions | -|---------------------------------------------------|------------------------------------------|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | `CrossEntropyLoss` | -| `(anchor, positive) pairs` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | -| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | `BinaryCrossEntropyLoss` | -| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | `BinaryCrossEntropyLoss` | -| `(anchor, positive, negative) triplets` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | -| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | -| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` | `LambdaLoss`
`ListNetLoss`
`ListMLELoss`
`PListMLELoss` | +| Inputs | Labels | Number of Model Output Labels | Appropriate Loss Functions | +|---------------------------------------------------|------------------------------------------|-------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `(sentence_A, sentence_B) pairs` | `class` | `num_classes` | `CrossEntropyLoss` | +| `(anchor, positive) pairs` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | +| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `1` | `BinaryCrossEntropyLoss` | +| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `1` | `BinaryCrossEntropyLoss` | +| `(anchor, positive, negative) triplets` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | +| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `1` | `MultipleNegativesRankingLoss`
`CachedMultipleNegativesRankingLoss` | +| `(query, [doc1, doc2, ..., docN])` | `[score1, score2, ..., scoreN]` | `1` |
  1. `LambdaLoss`
  2. `ListNetLoss`
  3. `PListMLELoss`
  4. `RankNetLoss`
  5. `ListMLELoss`
| ## Distillation These loss functions are specifically designed to be used when distilling the knowledge from one model into another. @@ -40,9 +40,9 @@ For example, when finetuning a small model to behave more like a larger & strong In practice, not all loss functions get used equally often. The most common scenarios are: * `(sentence_A, sentence_B) pairs` with `float similarity score` or `1 if positive, 0 if negative`: BinaryCrossEntropyLoss is a traditional option that remains very challenging to outperform. -* `(anchor, positive) pairs` without any labels: - * MultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train SentenceTransformer models, and the loss is also applicable for CrossEncoder models. This data is often relatively cheap to obtain, and mine_hard_negatives with output_format="n-tuple" or output_format="triplet" can easily be used to add hard negatives for this loss. CachedMultipleNegativesRankingLoss is often used to keep the memory usage in check. - * Together with mine_hard_negatives with output_format="labeled-list", LambdaLoss is frequently used for learning-to-rank tasks. +* `(anchor, positive) pairs` without any labels: combined with mine_hard_negatives + * with output_format="labeled-list", then LambdaLoss is frequently used for learning-to-rank tasks. + * with output_format="labeled-pair", then BinaryCrossEntropyLoss remains a strong option. ## Custom Loss Functions diff --git a/docs/package_reference/cross_encoder/losses.md b/docs/package_reference/cross_encoder/losses.md index 9c11024c1..3a5d008f1 100644 --- a/docs/package_reference/cross_encoder/losses.md +++ b/docs/package_reference/cross_encoder/losses.md @@ -61,3 +61,8 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui ```{eval-rst} .. autoclass:: sentence_transformers.cross_encoder.losses.MarginMSELoss ``` + +## RankNetLoss +```{eval-rst} +.. autoclass:: sentence_transformers.cross_encoder.losses.RankNetLoss +``` \ No newline at end of file From c347a333cefd2c0c5111c519042ce40bbd1f4292 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Mar 2025 17:11:09 +0100 Subject: [PATCH 11/11] Expand on RankNet docs slightly --- examples/cross_encoder/training/ms_marco/README.md | 6 +++++- sentence_transformers/cross_encoder/losses/RankNetLoss.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/cross_encoder/training/ms_marco/README.md b/examples/cross_encoder/training/ms_marco/README.md index 417cfba90..8f74237d7 100644 --- a/examples/cross_encoder/training/ms_marco/README.md +++ b/examples/cross_encoder/training/ms_marco/README.md @@ -56,8 +56,12 @@ In all scripts, the model is evaluated on subsets of `MS MARCO `PListMLELoss` > `ListNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary. +Out of these training scripts, I suspect that **[training_ms_marco_lambda_preprocessed.py](training_ms_marco_lambda_preprocessed.py)**, **[training_ms_marco_lambda_hard_neg.py](training_ms_marco_lambda_hard_neg.py)** or **[training_ms_marco_bce_preprocessed.py](training_ms_marco_bce_preprocessed.py)** produces the strongest model, as anecdotally `LambdaLoss` and `BinaryCrossEntropyLoss` are quite strong. It seems that `LambdaLoss` > `ListNetLoss` > `PListMLELoss` > `RankNetLoss` > `ListMLELoss` out of all learning to rank losses, but your milage may vary. Additionally, you can also train with Distillation. See [Cross Encoder > Training Examples > Distillation](../distillation/README.md) for more details. diff --git a/sentence_transformers/cross_encoder/losses/RankNetLoss.py b/sentence_transformers/cross_encoder/losses/RankNetLoss.py index e32c14f50..fd24566ff 100644 --- a/sentence_transformers/cross_encoder/losses/RankNetLoss.py +++ b/sentence_transformers/cross_encoder/losses/RankNetLoss.py @@ -40,6 +40,7 @@ def __init__( References: - Learning to Rank using Gradient Descent: https://icml.cc/Conferences/2015/wp-content/uploads/2015/06/icml_ranking.pdf + - `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_ Requirements: 1. Query with multiple documents (pairwise approach)