Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Add Cohere Reranker #269

Merged
merged 7 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ module = [
'pinecone_text.*',
'pinecone_datasets',
'pinecone',
'transformers.*'
'transformers.*',
'cohere.*',
]
ignore_missing_imports = true

Expand Down
5 changes: 1 addition & 4 deletions src/canopy/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod
from typing import Iterable, Union, Optional, cast

Expand All @@ -13,9 +12,7 @@
StreamingChatResponse, )
from canopy.models.data_models import Context, Messages, SystemMessage
from canopy.utils.config import ConfigurableMixin

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"

from canopy.utils.debugging import CE_DEBUG_INFO

DEFAULT_SYSTEM_PROMPT = """Use the following pieces of context to answer the user question at the next messages. This context retrieved from a knowledge database and you should use only the facts from the context to answer. Always remember to include the source to the documents you used from their 'source' field in the format 'Source: $SOURCE_HERE'.
If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context.
Expand Down
8 changes: 4 additions & 4 deletions src/canopy/context_engine/context_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod
from typing import List, Optional

Expand All @@ -8,8 +7,7 @@
from canopy.knowledge_base.base import BaseKnowledgeBase
from canopy.models.data_models import Context, Query
from canopy.utils.config import ConfigurableMixin

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"
from canopy.utils.debugging import CE_DEBUG_INFO


class BaseContextEngine(ABC, ConfigurableMixin):
Expand Down Expand Up @@ -111,7 +109,9 @@ def query(self, queries: List[Query],
context = self.context_builder.build(query_results, max_context_tokens)

if CE_DEBUG_INFO:
context.debug_info["query_results"] = [qr.dict() for qr in query_results]
context.debug_info["query_results"] = [
{**qr.dict(), **qr.debug_info} for qr in query_results
]
return context

async def aquery(self, queries: List[Query], max_context_tokens: int,
Expand Down
25 changes: 19 additions & 6 deletions src/canopy/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pinecone import (ServerlessSpec, PodSpec,
Pinecone, PineconeApiException)

from canopy.utils.debugging import CE_DEBUG_INFO

try:
from pinecone import GRPCIndex as Index
except ImportError:
Expand Down Expand Up @@ -437,20 +439,31 @@ def query(self,
results = [self._query_index(q,
global_metadata_filter,
namespace) for q in queries]
results = self._reranker.rerank(results)
ranked_results = self._reranker.rerank(results)

return [
QueryResult(
query=r.query,
query=rr.query,
documents=[
DocumentWithScore(
**d.dict(exclude={
'values', 'sparse_values', 'document_id'
'document_id'
})
)
for d in r.documents
]
) for r in results
for d in rr.documents
],
debug_info={"db_result": QueryResult(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just relised that we want debug info to be only dicts with literals like str or int. This allows easier serialisation of this object

query=r.query,
documents=[
DocumentWithScore(
**d.dict(exclude={
'document_id'
})
)
for d in r.documents
]
)} if CE_DEBUG_INFO else {}
) for rr, r in zip(ranked_results, results)
izellevy marked this conversation as resolved.
Show resolved Hide resolved
]

def _query_index(self,
Expand Down
4 changes: 3 additions & 1 deletion src/canopy/knowledge_base/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .reranker import TransparentReranker, Reranker
from .reranker import Reranker
from .transparent import TransparentReranker
from .cohere import CohereReranker
84 changes: 84 additions & 0 deletions src/canopy/knowledge_base/reranker/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from typing import List, Optional


from canopy.knowledge_base.models import KBQueryResult
from canopy.knowledge_base.reranker import Reranker

try:
import cohere
from cohere import CohereAPIError
except (OSError, ImportError, ModuleNotFoundError):
_cohere_installed = False
else:
_cohere_installed = True


class CohereReranker(Reranker):
izellevy marked this conversation as resolved.
Show resolved Hide resolved
"""
Reranker that uses Cohere's text embedding to rerank documents.

For each query and documents returned for that query, returns a list
of documents ordered by their relevance to the provided query.
"""

def __init__(self,
model_name: str = 'rerank-english-v2.0',
*,
top_n: int = 10,
acatav marked this conversation as resolved.
Show resolved Hide resolved
api_key: Optional[str] = None):
"""
Initializes the Cohere reranker.

Args:
model_name: The identifier of the model to use, one of :
rerank-english-v2.0, rerank-multilingual-v2.0
top_n: The number of most relevant documents return, defaults to 10
api_key: API key for Cohere. If not passed `CO_API_KEY` environment
variable will be used.
"""

if not _cohere_installed:
raise ImportError(
"Failed to import cohere. Make sure you install cohere extra "
"dependencies by running: "
"pip install canopy-sdk[cohere]"
)
cohere_api_key = api_key or os.environ.get("CO_API_KEY")
if cohere_api_key is None:
raise RuntimeError(
"Cohere API key is required to use Cohere Reranker. "
"Please provide it as an argument "
"or set the CO_API_KEY environment variable."
)
self._client = cohere.Client(api_key=cohere_api_key)
self._model_name = model_name
self._top_n = top_n

def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
reranked_query_results: List[KBQueryResult] = []
for result in results:
texts = [doc.text for doc in result.documents]
try:
response = self._client.rerank(query=result.query,
documents=texts,
top_n=self._top_n,
model=self._model_name)
except CohereAPIError as e:
raise RuntimeError("Failed to rerank documents using Cohere."
f" Underlying Error:\n{e.message}")

reranked_docs = []
for rerank_result in response:
doc = result.documents[rerank_result.index].copy(
deep=True,
update=dict(score=rerank_result.relevance_score)
acatav marked this conversation as resolved.
Show resolved Hide resolved
)
reranked_docs.append(doc)

reranked_query_results.append(KBQueryResult(query=result.query,
documents=reranked_docs))
return reranked_query_results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
raise NotImplementedError()
21 changes: 0 additions & 21 deletions src/canopy/knowledge_base/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,3 @@ def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
@abstractmethod
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
pass


class TransparentReranker(Reranker):
"""
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker.
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked.
""" # noqa: E501
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
"""
Returns the results as is.

Args:
results: A list of KBQueryResult to rerank.

Returns:
results: A list of KBQueryResult, same as the input.
""" # noqa: E501
return results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
return results
26 changes: 26 additions & 0 deletions src/canopy/knowledge_base/reranker/transparent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List

from canopy.knowledge_base.models import KBQueryResult
from canopy.knowledge_base.reranker import Reranker


class TransparentReranker(Reranker):
"""
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker.
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked.
""" # noqa: E501

def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
"""
Returns the results as is.

Args:
results: A list of KBQueryResult to rerank.

Returns:
results: A list of KBQueryResult, same as the input.
""" # noqa: E501
return results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
return results
3 changes: 3 additions & 0 deletions src/canopy/utils/debugging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"
Empty file.
71 changes: 71 additions & 0 deletions tests/system/reranker/test_cohere_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import pytest

from canopy.knowledge_base.models import KBQueryResult, KBDocChunkWithScore
from canopy.knowledge_base.reranker import CohereReranker


@pytest.fixture
def should_run_test():
if os.getenv("CO_API_KEY") is None:
pytest.skip(
"Couldn't find Cohere API key. Skipping Cohere tests."
)


@pytest.fixture
def cohere_reranker(should_run_test):
return CohereReranker()


@pytest.fixture
def documents():
return [
KBDocChunkWithScore(
id=f"doc_1_{i}",
text=f"Sample chunk {i}",
document_id="doc_1",
source="doc_1",
score=0.1 * i
) for i in range(4)
]


@pytest.fixture
def query_result(documents):
return KBQueryResult(query="Sample query 1",
documents=documents)


def test_rerank_empty(cohere_reranker):
results = cohere_reranker.rerank([])
assert results == []


def test_rerank(cohere_reranker, query_result, documents):
id_to_score = {d.id: d.score for d in query_result.documents}
ranked_result = next(iter(cohere_reranker.rerank([query_result])))
reranked_scores = [doc.score for doc in ranked_result.documents]

assert len(ranked_result.documents) == len(documents)
acatav marked this conversation as resolved.
Show resolved Hide resolved
assert reranked_scores == sorted(reranked_scores, reverse=True)

# Make sure the scores are overriden by the reranker
for doc in ranked_result.documents:
assert doc.score != id_to_score[doc.id]


def test_bad_api_key(should_run_test, query_result):
with pytest.raises(RuntimeError, match="invalid api token"):
CohereReranker(api_key="bad key").rerank([query_result])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing more negative tests - wrong model name, bad input (e.g. not strings) etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added wrong model name, bad input is not possible since we validate our data with pydantic.


def test_model_name_invalid(should_run_test, query_result):
with pytest.raises(RuntimeError, match="model not found"):
CohereReranker(model_name="my-madeup-model").rerank([query_result])


def test_top_n(should_run_test, query_result):
results = CohereReranker(top_n=1).rerank([query_result])
assert len(results[0].documents) == 1
27 changes: 27 additions & 0 deletions tests/system/reranker/test_transparent_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from canopy.knowledge_base.models import KBDocChunkWithScore, KBQueryResult
from canopy.knowledge_base.reranker import TransparentReranker


@pytest.fixture
def documents():
return [
KBDocChunkWithScore(
id=f"doc_1_{i}",
text=f"Sample chunk {i}",
document_id="doc_1",
source="doc_1",
score=0.1 * i
) for i in range(1)
]


@pytest.fixture
def query_result(documents):
return KBQueryResult(query="Sample query 1",
documents=documents)


def test_rerank(query_result):
assert TransparentReranker().rerank([query_result]) == [query_result]
Loading