This repository has been archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 121
Add Cohere Reranker #269
Merged
Merged
Add Cohere Reranker #269
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .reranker import TransparentReranker, Reranker | ||
from .reranker import Reranker | ||
from .transparent import TransparentReranker | ||
from .cohere import CohereReranker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
from typing import List, Optional | ||
|
||
|
||
from canopy.knowledge_base.models import KBQueryResult | ||
from canopy.knowledge_base.reranker import Reranker | ||
|
||
try: | ||
import cohere | ||
from cohere import CohereAPIError | ||
except (OSError, ImportError, ModuleNotFoundError): | ||
_cohere_installed = False | ||
else: | ||
_cohere_installed = True | ||
|
||
|
||
class CohereReranker(Reranker): | ||
izellevy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Reranker that uses Cohere's text embedding to rerank documents. | ||
|
||
For each query and documents returned for that query, returns a list | ||
of documents ordered by their relevance to the provided query. | ||
""" | ||
|
||
def __init__(self, | ||
model_name: str = 'rerank-english-v2.0', | ||
*, | ||
top_n: int = 10, | ||
acatav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
api_key: Optional[str] = None): | ||
""" | ||
Initializes the Cohere reranker. | ||
|
||
Args: | ||
model_name: The identifier of the model to use, one of : | ||
rerank-english-v2.0, rerank-multilingual-v2.0 | ||
top_n: The number of most relevant documents return, defaults to 10 | ||
api_key: API key for Cohere. If not passed `CO_API_KEY` environment | ||
variable will be used. | ||
""" | ||
|
||
if not _cohere_installed: | ||
raise ImportError( | ||
"Failed to import cohere. Make sure you install cohere extra " | ||
"dependencies by running: " | ||
"pip install canopy-sdk[cohere]" | ||
) | ||
cohere_api_key = api_key or os.environ.get("CO_API_KEY") | ||
if cohere_api_key is None: | ||
raise RuntimeError( | ||
"Cohere API key is required to use Cohere Reranker. " | ||
"Please provide it as an argument " | ||
"or set the CO_API_KEY environment variable." | ||
) | ||
self._client = cohere.Client(api_key=cohere_api_key) | ||
self._model_name = model_name | ||
self._top_n = top_n | ||
|
||
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
reranked_query_results: List[KBQueryResult] = [] | ||
for result in results: | ||
texts = [doc.text for doc in result.documents] | ||
try: | ||
response = self._client.rerank(query=result.query, | ||
documents=texts, | ||
top_n=self._top_n, | ||
model=self._model_name) | ||
except CohereAPIError as e: | ||
raise RuntimeError("Failed to rerank documents using Cohere." | ||
f" Underlying Error:\n{e.message}") | ||
|
||
reranked_docs = [] | ||
for rerank_result in response: | ||
doc = result.documents[rerank_result.index].copy( | ||
deep=True, | ||
update=dict(score=rerank_result.relevance_score) | ||
acatav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
reranked_docs.append(doc) | ||
|
||
reranked_query_results.append(KBQueryResult(query=result.query, | ||
documents=reranked_docs)) | ||
return reranked_query_results | ||
|
||
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import List | ||
|
||
from canopy.knowledge_base.models import KBQueryResult | ||
from canopy.knowledge_base.reranker import Reranker | ||
|
||
|
||
class TransparentReranker(Reranker): | ||
""" | ||
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker. | ||
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked. | ||
""" # noqa: E501 | ||
|
||
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
""" | ||
Returns the results as is. | ||
|
||
Args: | ||
results: A list of KBQueryResult to rerank. | ||
|
||
Returns: | ||
results: A list of KBQueryResult, same as the input. | ||
""" # noqa: E501 | ||
return results | ||
|
||
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import os | ||
|
||
CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBQueryResult, KBDocChunkWithScore | ||
from canopy.knowledge_base.reranker import CohereReranker | ||
|
||
|
||
@pytest.fixture | ||
def should_run_test(): | ||
if os.getenv("CO_API_KEY") is None: | ||
pytest.skip( | ||
"Couldn't find Cohere API key. Skipping Cohere tests." | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def cohere_reranker(should_run_test): | ||
return CohereReranker() | ||
|
||
|
||
@pytest.fixture | ||
def documents(): | ||
return [ | ||
KBDocChunkWithScore( | ||
id=f"doc_1_{i}", | ||
text=f"Sample chunk {i}", | ||
document_id="doc_1", | ||
source="doc_1", | ||
score=0.1 * i | ||
) for i in range(4) | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def query_result(documents): | ||
return KBQueryResult(query="Sample query 1", | ||
documents=documents) | ||
|
||
|
||
def test_rerank_empty(cohere_reranker): | ||
results = cohere_reranker.rerank([]) | ||
assert results == [] | ||
|
||
|
||
def test_rerank(cohere_reranker, query_result, documents): | ||
id_to_score = {d.id: d.score for d in query_result.documents} | ||
ranked_result = next(iter(cohere_reranker.rerank([query_result]))) | ||
reranked_scores = [doc.score for doc in ranked_result.documents] | ||
|
||
assert len(ranked_result.documents) == len(documents) | ||
acatav marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert reranked_scores == sorted(reranked_scores, reverse=True) | ||
|
||
# Make sure the scores are overriden by the reranker | ||
for doc in ranked_result.documents: | ||
assert doc.score != id_to_score[doc.id] | ||
|
||
|
||
def test_bad_api_key(should_run_test, query_result): | ||
with pytest.raises(RuntimeError, match="invalid api token"): | ||
CohereReranker(api_key="bad key").rerank([query_result]) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing more negative tests - wrong model name, bad input (e.g. not strings) etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added wrong model name, bad input is not possible since we validate our data with pydantic. |
||
|
||
def test_model_name_invalid(should_run_test, query_result): | ||
with pytest.raises(RuntimeError, match="model not found"): | ||
CohereReranker(model_name="my-madeup-model").rerank([query_result]) | ||
|
||
|
||
def test_top_n(should_run_test, query_result): | ||
results = CohereReranker(top_n=1).rerank([query_result]) | ||
assert len(results[0].documents) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBDocChunkWithScore, KBQueryResult | ||
from canopy.knowledge_base.reranker import TransparentReranker | ||
|
||
|
||
@pytest.fixture | ||
def documents(): | ||
return [ | ||
KBDocChunkWithScore( | ||
id=f"doc_1_{i}", | ||
text=f"Sample chunk {i}", | ||
document_id="doc_1", | ||
source="doc_1", | ||
score=0.1 * i | ||
) for i in range(1) | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def query_result(documents): | ||
return KBQueryResult(query="Sample query 1", | ||
documents=documents) | ||
|
||
|
||
def test_rerank(query_result): | ||
assert TransparentReranker().rerank([query_result]) == [query_result] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just relised that we want debug info to be only dicts with literals like str or int. This allows easier serialisation of this object