Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Add Cohere Reranker #269

Merged
merged 7 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ module = [
'pinecone_text.*',
'pinecone_datasets',
'pinecone',
'transformers.*'
'transformers.*',
'cohere.*',
]
ignore_missing_imports = true

Expand Down
4 changes: 3 additions & 1 deletion src/canopy/knowledge_base/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .reranker import TransparentReranker, Reranker
from .reranker import Reranker
from .transparent import TransparentReranker
from .cohere import CohereReranker
69 changes: 69 additions & 0 deletions src/canopy/knowledge_base/reranker/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Optional

from canopy.knowledge_base.models import KBQueryResult
from canopy.knowledge_base.reranker import Reranker

try:
import cohere
except (OSError, ImportError, ModuleNotFoundError):
_cohere_installed = False
else:
_cohere_installed = True


class CohereReranker(Reranker):
izellevy marked this conversation as resolved.
Show resolved Hide resolved
"""
Reranker that uses Cohere's text embedding to rerank documents.

For each query and documents returned for that query, returns a list
of documents ordered by their relevance to the provided query.
"""

def __init__(self,
model_name: str = 'rerank-english-v2.0',
*,
top_n: int = 10,
acatav marked this conversation as resolved.
Show resolved Hide resolved
api_key: Optional[str] = None):
"""
Initializes the Cohere reranker.

Args:
model_name: The identifier of the model to use, one of :
rerank-english-v2.0, rerank-multilingual-v2.0
top_n: The number of most relevant documents return, defaults to 10
api_key: API key for Cohere. If not passed `CO_API_KEY` environment
variable will be used.
"""

if not _cohere_installed:
raise ImportError(
"Failed to import cohere. Make sure you install cohere extra "
"dependencies by running: "
"pip install canopy-sdk[cohere]"
)
self._client = cohere.Client(api_key=api_key)
izellevy marked this conversation as resolved.
Show resolved Hide resolved
self._model_name = model_name
self._top_n = top_n

def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
reranked_query_results: List[KBQueryResult] = []
for result in results:
texts = [doc.text for doc in result.documents]
response = self._client.rerank(query=result.query,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs to be wrapped in try clause.
Transient errors like rate limits etc should be retried (if the Cohere client itself doesn't do that for us already).

Errors that are caused by wrong configuration (like wrong model name or bad API key) need to be re-raised with an actionable error message

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cohere retries internally. Since Cohere does not return different error types it is hard to understand what the message is. For now I am raising a RuntimeError from the actual error.

documents=texts,
top_n=self._top_n,
model=self._model_name)
reranked_docs = []
for rerank_result in response:
doc = result.documents[rerank_result.index].copy(
deep=True,
update=dict(score=rerank_result.relevance_score)
acatav marked this conversation as resolved.
Show resolved Hide resolved
)
reranked_docs.append(doc)

reranked_query_results.append(KBQueryResult(query=result.query,
documents=reranked_docs))
return reranked_query_results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
raise NotImplementedError()
21 changes: 0 additions & 21 deletions src/canopy/knowledge_base/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,3 @@ def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
@abstractmethod
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
pass


class TransparentReranker(Reranker):
"""
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker.
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked.
""" # noqa: E501
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
"""
Returns the results as is.

Args:
results: A list of KBQueryResult to rerank.

Returns:
results: A list of KBQueryResult, same as the input.
""" # noqa: E501
return results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
return results
26 changes: 26 additions & 0 deletions src/canopy/knowledge_base/reranker/transparent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List

from canopy.knowledge_base.models import KBQueryResult
from canopy.knowledge_base.reranker import Reranker


class TransparentReranker(Reranker):
"""
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker.
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked.
""" # noqa: E501

def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
"""
Returns the results as is.

Args:
results: A list of KBQueryResult to rerank.

Returns:
results: A list of KBQueryResult, same as the input.
""" # noqa: E501
return results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
return results
Empty file.
62 changes: 62 additions & 0 deletions tests/system/reranker/test_cohere_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

import pytest

from canopy.knowledge_base.models import KBQueryResult, KBDocChunkWithScore
from canopy.knowledge_base.reranker import CohereReranker


@pytest.fixture
def should_run_test():
if os.getenv("CO_API_KEY") is None:
pytest.skip(
"Couldn't find Cohere API key. Skipping Cohere tests."
)


@pytest.fixture
def cohere_reranker(should_run_test):
return CohereReranker()


@pytest.fixture
def documents():
return [
KBDocChunkWithScore(
id=f"doc_1_{i}",
text=f"Sample chunk {i}",
document_id="doc_1",
source="doc_1",
score=0.1 * i
) for i in range(4)
]


@pytest.fixture
def query_result(documents):
return KBQueryResult(query="Sample query 1",
documents=documents)


def test_rerank_empty(cohere_reranker):
results = cohere_reranker.rerank([])
assert results == []


def test_rerank(cohere_reranker, query_result, documents):
ranked_result = next(iter(cohere_reranker.rerank([query_result])))
scores = [doc.score for doc in ranked_result.documents]

assert len(ranked_result.documents) == len(documents)
acatav marked this conversation as resolved.
Show resolved Hide resolved
assert scores == sorted(scores, reverse=True)


def test_bad_api_key(should_run_test, query_result):
from cohere import CohereAPIError
with pytest.raises(CohereAPIError, match="invalid api token"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to eliminate underlying service's errors like CohereAPIError or OpenAIError, and replace them with actionable error message (like something the user needs to change in the Canopy config, or the explicit env var to set).

In the future we will have our own error types like EncoderError, AuthenticationError etc. In the meantime simply re-raise RuntimeError for all of these cases (the CLI catches RuntimeError and prints them nicely)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the client, client does not return a specific error for different errors we always get a CohereAPIError. For now I am raising RuntimeError from that error, if they improve the client we can write actionable messages.

CohereReranker(api_key="bad key").rerank([query_result])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing more negative tests - wrong model name, bad input (e.g. not strings) etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added wrong model name, bad input is not possible since we validate our data with pydantic.


def test_top_n(should_run_test, query_result):
results = CohereReranker(top_n=1).rerank([query_result])
assert len(results[0].documents) == 1
27 changes: 27 additions & 0 deletions tests/system/reranker/test_transparent_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from canopy.knowledge_base.models import KBDocChunkWithScore, KBQueryResult
from canopy.knowledge_base.reranker import TransparentReranker


@pytest.fixture
def documents():
return [
KBDocChunkWithScore(
id=f"doc_1_{i}",
text=f"Sample chunk {i}",
document_id="doc_1",
source="doc_1",
score=0.1 * i
) for i in range(1)
]


@pytest.fixture
def query_result(documents):
return KBQueryResult(query="Sample query 1",
documents=documents)


def test_rerank(query_result):
assert TransparentReranker().rerank([query_result]) == [query_result]
Loading