Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
This example demonstrates how to use the CrossEncoder with instruction-tuned models like Qwen-reranker or BGE-reranker.
The new `prompt_template` and `prompt_template_kwargs` arguments in the `predict` and `rank` methods allow for
flexible and dynamic formatting of the input for such models.

This script covers three main scenarios:
1. Ranking without any template (baseline).
2. Ranking with a `prompt_template` provided at runtime.
3. Ranking with a dynamic `instruction` passed via `prompt_template_kwargs`.

Finally, it provides a guide on how to set a default prompt template in the model's `config.json`.
"""

from sentence_transformers.cross_encoder import CrossEncoder

# We use a Qwen Reranker model here. In a real-world scenario, this could also be
# an instruction-tuned model like 'BAAI/bge-reranker-large'.
model = CrossEncoder("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", trust_remote_code=True)
model.model.config.pad_token_id = model.tokenizer.pad_token_id

query = "What is the capital of China?"
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

# First, we create the sentence pairs for the query and all documents
sentence_pairs = [[query, doc] for doc in documents]

print("--- 1. Reranking without any template (Incorrect Usage of Qwen3 Reranker) ---")
# The model receives the plain query and document pairs.
baseline_scores = model.predict(sentence_pairs, convert_to_numpy=True)
scored_docs = sorted(zip(baseline_scores, documents), key=lambda x: x[0], reverse=True)

print("Query:", query)
for score, doc in scored_docs:
print(f"{score:.4f}\t{doc}")

print("\n\n--- 2. Reranking with a runtime prompt_template ---")
# The query and document are formatted using the template before being passed to the model.
# This changes the input text and thus the resulting scores.
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
instruction = "Given a web search query, retrieve relevant passages that answer the query"
query_template = f"{prefix}<Instruct>: {instruction}\n<Query>: {{query}}\n"
document_template = f"<Document>: {{document}}{suffix}"

template = query_template + document_template
template_scores = model.predict(sentence_pairs, prompt_template=template)
scored_docs_template = sorted(zip(template_scores, documents), key=lambda x: x[0], reverse=True)

print("Using template:", template)
print("Query:", query)
for score, doc in scored_docs_template:
print(f"{score:.4f}\t{doc}")
# The scores will be different from the baseline because the model processes a different text.

print("\n\n--- 3. Reranking with a dynamic instruction ---")
# This is useful for models that expect a specific instruction.
# The instruction can be changed at runtime.
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
instruct_template = f"{prefix}<Instruct>: {{instruction}}\n<Query>: {{query}}\n<Document>: {{document}}{suffix}"
instruct_kwargs_1 = {"instruction": "Given a query, find the most relevant document."}
instruct_kwargs_2 = {"instruction": "Given a question, find the incorrect answer."} # Misleading instruction

print(f"Using template: {instruct_template}")
print(f"With instruction 1: '{instruct_kwargs_1['instruction']}'")
instruction_scores_1 = model.predict(
sentence_pairs, prompt_template=instruct_template, prompt_template_kwargs=instruct_kwargs_1
)
scored_docs_instruct_1 = sorted(zip(instruction_scores_1, documents), key=lambda x: x[0], reverse=True)
for score, doc in scored_docs_instruct_1:
print(f"{score:.4f}\t{doc}")

print(f"\nWith instruction 2: '{instruct_kwargs_2['instruction']}'")
instruction_scores_2 = model.predict(
sentence_pairs, prompt_template=instruct_template, prompt_template_kwargs=instruct_kwargs_2
)
scored_docs_instruct_2 = sorted(zip(instruction_scores_2, documents), key=lambda x: x[0], reverse=True)
for score, doc in scored_docs_instruct_2:
print(f"{score:.4f}\t{doc}")
# The scores for instruction 1 and 2 will likely differ, as the instruction text changes the input.

# --- 4. Guide: Setting a Default Prompt Template in config.json ---
#
# If you are a model creator or want to use a specific prompt format consistently
# without passing it in every `rank` or `predict` call, you can set a default
# template in the model's `config.json` file.
#
# Step 1: Save your base model to a directory.
#
# from sentence_transformers import CrossEncoder
# import json
#
# model = CrossEncoder("your-base-model-name")
# save_path = "path/to/your-instruct-model"
# model.save(save_path)
#
# Step 2: Modify the `config.json` in the saved directory.
# Add the "prompt_template" and "prompt_template_kwargs" keys to the
# "sentence_transformers" dictionary.
#
# // path/to/your-instruct-model/config.json
# {
# ...
# "sentence_transformers": {
# "version": "3.0.0.dev0",
# "prompt_template": "Instruct: {instruction}\nQuery: {query}\nDocument: {document}",
# "prompt_template_kwargs": {
# "instruction": "Given a query, find the most relevant document."
# }
# },
# ...
# }
#
# Step 3: Load the model from the modified path.
# It will now use the default template automatically.
#
# instruct_model = CrossEncoder(save_path, trust_remote_code=True)
# sentence_pairs = [[query, doc] for doc in documents]
# scores = instruct_model.predict(sentence_pairs)
#
# # This call is now equivalent to calling the original model with the full template arguments:
# # original_model.predict(sentence_pairs,
# # prompt_template="Instruct: {instruction}\nQuery: {query}\nDocument: {document}",
# # prompt_template_kwargs={"instruction": "Given a query, find the most relevant document."})
#
# You can still override the default template by passing arguments at runtime:
#
# # This will use the new instruction, overriding the default one.
# scores_new_instruction = instruct_model.predict(
# sentence_pairs,
# prompt_template_kwargs={"instruction": "Find the answer to the question."}
# )
53 changes: 52 additions & 1 deletion sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ def __init__(
if config.architectures is not None:
classifier_trained = any([arch.endswith("ForSequenceClassification") for arch in config.architectures])

self.default_prompt_template: str | None = None
self.default_prompt_template_kwargs: dict[str, Any] | None = None
if hasattr(config, "sentence_transformers"):
self.default_prompt_template = config.sentence_transformers.get("prompt_template")
self.default_prompt_template_kwargs = config.sentence_transformers.get("prompt_template_kwargs")

if num_labels is None and not classifier_trained:
num_labels = 1

Expand Down Expand Up @@ -538,6 +544,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: Literal[False] = ...,
convert_to_tensor: Literal[False] = ...,
prompt_template: str | None = None,
prompt_template_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor: ...

@overload
Expand All @@ -550,6 +558,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: Literal[True] = True,
convert_to_tensor: Literal[False] = False,
prompt_template: str | None = None,
prompt_template_kwargs: dict[str, Any] | None = None,
) -> np.ndarray: ...

@overload
Expand All @@ -562,6 +572,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: bool = ...,
convert_to_tensor: Literal[True] = ...,
prompt_template: str | None = None,
prompt_template_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor: ...

@overload
Expand All @@ -574,6 +586,8 @@ def predict(
apply_softmax: bool | None = ...,
convert_to_numpy: Literal[False] = ...,
convert_to_tensor: Literal[False] = ...,
prompt_template: str | None = None,
prompt_template_kwargs: dict[str, Any] | None = None,
) -> list[torch.Tensor]: ...

@torch.inference_mode()
Expand All @@ -587,6 +601,8 @@ def predict(
apply_softmax: bool | None = False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
prompt_template: str | None = None,
prompt_template_kwargs: dict[str, Any] | None = None,
) -> list[torch.Tensor] | np.ndarray | torch.Tensor:
"""
Performs predictions with the CrossEncoder on the given sentence pairs.
Expand All @@ -599,13 +615,17 @@ def predict(
activation_fn (callable, optional): Activation function applied on the logits output of the CrossEncoder.
If None, the ``model.activation_fn`` will be used, which defaults to :class:`torch.nn.Sigmoid` if num_labels=1, else
:class:`torch.nn.Identity`. Defaults to None.
convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True.
apply_softmax (bool, optional): If set to True and `model.num_labels > 1`, applies softmax on the logits
output such that for each sample, the scores of each class sum to 1. Defaults to False.
convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, output
a list of PyTorch tensors. Defaults to True.
convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
Defaults to False.
prompt_template (str, optional): A template to format the input sentence pairs. The template should have placeholders
for `{query}` and `{document}`. For example: "Query: {query} Document: {document}".
prompt_template_kwargs (dict[str, Any], optional): A dictionary of keyword arguments to format the prompt template.
For example, you can provide an instruction: `{"instruction": "Determine the relevance."}` for a template like
"Instruct: {instruction} Query: {query} Document: {document}".

Returns:
Union[List[torch.Tensor], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs.
Expand Down Expand Up @@ -634,6 +654,28 @@ def predict(
logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
)

# If prompt_template is provided or default_prompt_template in config.json is set, use it to format the input sentence pairs
final_prompt_template = prompt_template if prompt_template is not None else self.default_prompt_template
if final_prompt_template:
final_kwargs = self.default_prompt_template_kwargs.copy() if self.default_prompt_template_kwargs else {}
if prompt_template_kwargs:
# Update final_kwargs with any additional keyword arguments provided in prompt_template_kwargs
final_kwargs.update(prompt_template_kwargs)

formatted_sentences = []
try:
for query, doc in sentences:
all_kwargs = {"query": query, "document": doc, **final_kwargs}
formatted_sentences.append(final_prompt_template.format(**all_kwargs))
except KeyError as e:
# If a placeholder in the prompt template is not valid, raise an error
available_keys = ["query", "document"] + list(final_kwargs.keys())
raise KeyError(
f"A placeholder in the prompt template is not valid. The placeholder {e} was not found. "
f"Available placeholders are: {', '.join(sorted(list(set(available_keys))))}."
) from e
sentences = formatted_sentences

if activation_fn is not None:
self.set_activation_fn(activation_fn, set_default=False)

Expand Down Expand Up @@ -681,6 +723,8 @@ def rank(
apply_softmax=False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
prompt_template: str | None = None,
prompt_template_kwargs: dict[str, Any] | None = None,
) -> list[dict[Literal["corpus_id", "score", "text"], int | float | str]]:
"""
Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
Expand All @@ -696,6 +740,11 @@ def rank(
convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True.
apply_softmax (bool, optional): If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output. Defaults to False.
convert_to_tensor (bool, optional): Convert the output to a tensor. Defaults to False.
prompt_template (str, optional): A template to format the input sentence pairs. The template should have placeholders
for `{query}` and `{document}`. For example: "Query: {query} Document: {document}".
prompt_template_kwargs (dict[str, Any], optional): A dictionary of keyword arguments to format the prompt template.
For example, you can provide an instruction: `{"instruction": "Determine the relevance."}` for a template like
"Instruct: {instruction} Query: {query} Document: {document}".

Returns:
List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]: A sorted list with the "corpus_id", "score", and optionally "text" of the documents.
Expand Down Expand Up @@ -750,6 +799,8 @@ def rank(
apply_softmax=apply_softmax,
convert_to_numpy=convert_to_numpy,
convert_to_tensor=convert_to_tensor,
prompt_template=prompt_template,
prompt_template_kwargs=prompt_template_kwargs,
)

results = []
Expand Down
75 changes: 75 additions & 0 deletions tests/cross_encoder/test_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,81 @@ def test_predict_softmax():
assert not torch.isclose(scores.sum(1), torch.ones(len(corpus), device=scores.device)).all()


def test_predict_with_prompt_template():
model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce")
query = "A man is eating pasta."
corpus = [
"A man is eating food.",
"A woman is playing violin.",
]
pairs = [[query, doc] for doc in corpus]
prompt_template = (
"Instruct: Given a query and a document, determine if they are relevant. Query: {query} Document: {document}"
)

# 1. Test with prompt_template in predict
scores_prompt = model.predict(pairs, prompt_template=prompt_template)

# 2. Test without prompt_template
scores_no_prompt = model.predict(pairs)

# The scores should be different as the input to the model is different
assert not np.allclose(scores_prompt, scores_no_prompt)

# 3. Test with prompt_template in rank
ranks_prompt = model.rank(query, corpus, prompt_template=prompt_template)
ranks_no_prompt = model.rank(query, corpus)

assert ranks_prompt[0]["score"] != ranks_no_prompt[0]["score"]
assert ranks_prompt[1]["score"] != ranks_no_prompt[1]["score"]


def test_predict_with_default_prompt_template(tmp_path: Path):
# 1. Create a base model and save it
model_name = "cross-encoder-testing/reranker-bert-tiny-gooaq-bce"
original_model = CrossEncoder(model_name)
save_path = tmp_path / "model_with_template"
original_model.save(str(save_path))

# 2. Modify the config.json to add a default prompt template
config_path = save_path / "config.json"
with open(config_path) as f:
config = json.load(f)

prompt_template = "Instruct: {instruction} Query: {query} Document: {document}"
prompt_kwargs = {"instruction": "Determine relevance."}
if "sentence_transformers" not in config:
config["sentence_transformers"] = {}
config["sentence_transformers"]["prompt_template"] = prompt_template
config["sentence_transformers"]["prompt_template_kwargs"] = prompt_kwargs

with open(config_path, "w") as f:
json.dump(config, f)

# 3. Load the model with the modified config
model_with_template = CrossEncoder(str(save_path))
assert model_with_template.default_prompt_template == prompt_template
assert model_with_template.default_prompt_template_kwargs == prompt_kwargs

# 4. Perform prediction and compare results
query = "A man is eating pasta."
doc = "A man is eating food."

# Prediction with the model that has a default template
scores_with_default_template = model_with_template.predict([[query, doc]])

# Prediction with the original model (no template)
scores_original = original_model.predict([[query, doc]])

# The scores should be different because one uses the template and the other doesn't.
assert not np.allclose(scores_with_default_template, scores_original)

# 5. Test that runtime arguments can overwrite the default template
runtime_template = "Query: {query} and Doc: {document}"
scores_runtime_template = model_with_template.predict([[query, doc]], prompt_template=runtime_template)
assert not np.allclose(scores_with_default_template, scores_runtime_template)


@pytest.mark.parametrize(
"model_name", ["cross-encoder-testing/reranker-bert-tiny-gooaq-bce", "cross-encoder/nli-MiniLM2-L6-H768"]
)
Expand Down
Loading