This repository has been archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add cohere reranker * Add tests * Fix static * Fix comments * Add assert * Fix dict * Change name
- Loading branch information
Showing
12 changed files
with
245 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .reranker import TransparentReranker, Reranker | ||
from .reranker import Reranker | ||
from .transparent import TransparentReranker | ||
from .cohere import CohereReranker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
from typing import List, Optional | ||
|
||
|
||
from canopy.knowledge_base.models import KBQueryResult | ||
from canopy.knowledge_base.reranker import Reranker | ||
|
||
try: | ||
import cohere | ||
from cohere import CohereAPIError | ||
except (OSError, ImportError, ModuleNotFoundError): | ||
_cohere_installed = False | ||
else: | ||
_cohere_installed = True | ||
|
||
|
||
class CohereReranker(Reranker): | ||
""" | ||
Reranker that uses Cohere's text embedding to rerank documents. | ||
For each query and documents returned for that query, returns a list | ||
of documents ordered by their relevance to the provided query. | ||
""" | ||
|
||
def __init__(self, | ||
model_name: str = 'rerank-english-v2.0', | ||
*, | ||
top_n: int = 10, | ||
api_key: Optional[str] = None): | ||
""" | ||
Initializes the Cohere reranker. | ||
Args: | ||
model_name: The identifier of the model to use, one of : | ||
rerank-english-v2.0, rerank-multilingual-v2.0 | ||
top_n: The number of most relevant documents return, defaults to 10 | ||
api_key: API key for Cohere. If not passed `CO_API_KEY` environment | ||
variable will be used. | ||
""" | ||
|
||
if not _cohere_installed: | ||
raise ImportError( | ||
"Failed to import cohere. Make sure you install cohere extra " | ||
"dependencies by running: " | ||
"pip install canopy-sdk[cohere]" | ||
) | ||
cohere_api_key = api_key or os.environ.get("CO_API_KEY") | ||
if cohere_api_key is None: | ||
raise RuntimeError( | ||
"Cohere API key is required to use Cohere Reranker. " | ||
"Please provide it as an argument " | ||
"or set the CO_API_KEY environment variable." | ||
) | ||
self._client = cohere.Client(api_key=cohere_api_key) | ||
self._model_name = model_name | ||
self._top_n = top_n | ||
|
||
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
reranked_query_results: List[KBQueryResult] = [] | ||
for result in results: | ||
texts = [doc.text for doc in result.documents] | ||
try: | ||
response = self._client.rerank(query=result.query, | ||
documents=texts, | ||
top_n=self._top_n, | ||
model=self._model_name) | ||
except CohereAPIError as e: | ||
raise RuntimeError("Failed to rerank documents using Cohere." | ||
f" Underlying Error:\n{e.message}") | ||
|
||
reranked_docs = [] | ||
for rerank_result in response: | ||
doc = result.documents[rerank_result.index].copy( | ||
deep=True, | ||
update=dict(score=rerank_result.relevance_score) | ||
) | ||
reranked_docs.append(doc) | ||
|
||
reranked_query_results.append(KBQueryResult(query=result.query, | ||
documents=reranked_docs)) | ||
return reranked_query_results | ||
|
||
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import List | ||
|
||
from canopy.knowledge_base.models import KBQueryResult | ||
from canopy.knowledge_base.reranker import Reranker | ||
|
||
|
||
class TransparentReranker(Reranker): | ||
""" | ||
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker. | ||
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked. | ||
""" # noqa: E501 | ||
|
||
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
""" | ||
Returns the results as is. | ||
Args: | ||
results: A list of KBQueryResult to rerank. | ||
Returns: | ||
results: A list of KBQueryResult, same as the input. | ||
""" # noqa: E501 | ||
return results | ||
|
||
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]: | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import os | ||
|
||
CANOPY_DEBUG_INFO = os.getenv("CANOPY_DEBUG_INFO", "FALSE").lower() == "true" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBQueryResult, KBDocChunkWithScore | ||
from canopy.knowledge_base.reranker import CohereReranker | ||
|
||
|
||
@pytest.fixture | ||
def should_run_test(): | ||
if os.getenv("CO_API_KEY") is None: | ||
pytest.skip( | ||
"Couldn't find Cohere API key. Skipping Cohere tests." | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def cohere_reranker(should_run_test): | ||
return CohereReranker() | ||
|
||
|
||
@pytest.fixture | ||
def documents(): | ||
return [ | ||
KBDocChunkWithScore( | ||
id=f"doc_1_{i}", | ||
text=f"Sample chunk {i}", | ||
document_id="doc_1", | ||
source="doc_1", | ||
score=0.1 * i | ||
) for i in range(4) | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def query_result(documents): | ||
return KBQueryResult(query="Sample query 1", | ||
documents=documents) | ||
|
||
|
||
def test_rerank_empty(cohere_reranker): | ||
results = cohere_reranker.rerank([]) | ||
assert results == [] | ||
|
||
|
||
def test_rerank(cohere_reranker, query_result, documents): | ||
id_to_score = {d.id: d.score for d in query_result.documents} | ||
ranked_result = next(iter(cohere_reranker.rerank([query_result]))) | ||
reranked_scores = [doc.score for doc in ranked_result.documents] | ||
|
||
assert len(ranked_result.documents) == len(documents) | ||
assert reranked_scores == sorted(reranked_scores, reverse=True) | ||
|
||
# Make sure the scores are overriden by the reranker | ||
for doc in ranked_result.documents: | ||
assert doc.score != id_to_score[doc.id] | ||
|
||
|
||
def test_bad_api_key(should_run_test, query_result): | ||
with pytest.raises(RuntimeError, match="invalid api token"): | ||
CohereReranker(api_key="bad key").rerank([query_result]) | ||
|
||
|
||
def test_model_name_invalid(should_run_test, query_result): | ||
with pytest.raises(RuntimeError, match="model not found"): | ||
CohereReranker(model_name="my-madeup-model").rerank([query_result]) | ||
|
||
|
||
def test_top_n(should_run_test, query_result): | ||
results = CohereReranker(top_n=1).rerank([query_result]) | ||
assert len(results[0].documents) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBDocChunkWithScore, KBQueryResult | ||
from canopy.knowledge_base.reranker import TransparentReranker | ||
|
||
|
||
@pytest.fixture | ||
def documents(): | ||
return [ | ||
KBDocChunkWithScore( | ||
id=f"doc_1_{i}", | ||
text=f"Sample chunk {i}", | ||
document_id="doc_1", | ||
source="doc_1", | ||
score=0.1 * i | ||
) for i in range(1) | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def query_result(documents): | ||
return KBQueryResult(query="Sample query 1", | ||
documents=documents) | ||
|
||
|
||
def test_rerank(query_result): | ||
assert TransparentReranker().rerank([query_result]) == [query_result] |