Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Add Cohere Reranker (#269)
Browse files Browse the repository at this point in the history
* Add cohere reranker

* Add tests

* Fix static

* Fix comments

* Add assert

* Fix dict

* Change name
  • Loading branch information
izellevy authored Jan 31, 2024
1 parent 4b8ee29 commit 95c7b24
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 39 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ module = [
'pinecone_text.*',
'pinecone_datasets',
'pinecone',
'transformers.*'
'transformers.*',
'cohere.*',
]
ignore_missing_imports = true

Expand Down
7 changes: 2 additions & 5 deletions src/canopy/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod
from typing import Iterable, Union, Optional, cast

Expand All @@ -13,9 +12,7 @@
StreamingChatResponse, )
from canopy.models.data_models import Context, Messages, SystemMessage
from canopy.utils.config import ConfigurableMixin

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"

from canopy.utils.debugging import CANOPY_DEBUG_INFO

DEFAULT_SYSTEM_PROMPT = """Use the following pieces of context to answer the user question at the next messages. This context retrieved from a knowledge database and you should use only the facts from the context to answer. Always remember to include the source to the documents you used from their 'source' field in the format 'Source: $SOURCE_HERE'.
If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context.
Expand Down Expand Up @@ -223,7 +220,7 @@ def chat(self,
stream=stream,
model_params=model_params_dict)
debug_info = {}
if CE_DEBUG_INFO:
if CANOPY_DEBUG_INFO:
debug_info['context'] = context.dict()
debug_info['context'].update(context.debug_info)

Expand Down
10 changes: 5 additions & 5 deletions src/canopy/context_engine/context_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod
from typing import List, Optional

Expand All @@ -8,8 +7,7 @@
from canopy.knowledge_base.base import BaseKnowledgeBase
from canopy.models.data_models import Context, Query
from canopy.utils.config import ConfigurableMixin

CE_DEBUG_INFO = os.getenv("CE_DEBUG_INFO", "FALSE").lower() == "true"
from canopy.utils.debugging import CANOPY_DEBUG_INFO


class BaseContextEngine(ABC, ConfigurableMixin):
Expand Down Expand Up @@ -110,8 +108,10 @@ def query(self, queries: List[Query],
namespace=namespace)
context = self.context_builder.build(query_results, max_context_tokens)

if CE_DEBUG_INFO:
context.debug_info["query_results"] = [qr.dict() for qr in query_results]
if CANOPY_DEBUG_INFO:
context.debug_info["query_results"] = [
{**qr.dict(), **qr.debug_info} for qr in query_results
]
return context

async def aquery(self, queries: List[Query], max_context_tokens: int,
Expand Down
28 changes: 22 additions & 6 deletions src/canopy/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pinecone import (ServerlessSpec, PodSpec,
Pinecone, PineconeApiException)

from canopy.utils.debugging import CANOPY_DEBUG_INFO

try:
from pinecone import GRPCIndex as Index
except ImportError:
Expand Down Expand Up @@ -437,20 +439,34 @@ def query(self,
results = [self._query_index(q,
global_metadata_filter,
namespace) for q in queries]
results = self._reranker.rerank(results)
ranked_results = self._reranker.rerank(results)

assert len(results) == len(ranked_results), ("Reranker returned a different"
" number of results "
"than the number of queries")
return [
QueryResult(
query=r.query,
query=rr.query,
documents=[
DocumentWithScore(
**d.dict(exclude={
'values', 'sparse_values', 'document_id'
'document_id'
})
)
for d in r.documents
]
) for r in results
for d in rr.documents
],
debug_info={"db_result": QueryResult(
query=r.query,
documents=[
DocumentWithScore(
**d.dict(exclude={
'document_id'
})
)
for d in r.documents
]
).dict()} if CANOPY_DEBUG_INFO else {}
) for rr, r in zip(ranked_results, results)
]

def _query_index(self,
Expand Down
4 changes: 3 additions & 1 deletion src/canopy/knowledge_base/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .reranker import TransparentReranker, Reranker
from .reranker import Reranker
from .transparent import TransparentReranker
from .cohere import CohereReranker
84 changes: 84 additions & 0 deletions src/canopy/knowledge_base/reranker/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from typing import List, Optional


from canopy.knowledge_base.models import KBQueryResult
from canopy.knowledge_base.reranker import Reranker

try:
import cohere
from cohere import CohereAPIError
except (OSError, ImportError, ModuleNotFoundError):
_cohere_installed = False
else:
_cohere_installed = True


class CohereReranker(Reranker):
"""
Reranker that uses Cohere's text embedding to rerank documents.
For each query and documents returned for that query, returns a list
of documents ordered by their relevance to the provided query.
"""

def __init__(self,
model_name: str = 'rerank-english-v2.0',
*,
top_n: int = 10,
api_key: Optional[str] = None):
"""
Initializes the Cohere reranker.
Args:
model_name: The identifier of the model to use, one of :
rerank-english-v2.0, rerank-multilingual-v2.0
top_n: The number of most relevant documents return, defaults to 10
api_key: API key for Cohere. If not passed `CO_API_KEY` environment
variable will be used.
"""

if not _cohere_installed:
raise ImportError(
"Failed to import cohere. Make sure you install cohere extra "
"dependencies by running: "
"pip install canopy-sdk[cohere]"
)
cohere_api_key = api_key or os.environ.get("CO_API_KEY")
if cohere_api_key is None:
raise RuntimeError(
"Cohere API key is required to use Cohere Reranker. "
"Please provide it as an argument "
"or set the CO_API_KEY environment variable."
)
self._client = cohere.Client(api_key=cohere_api_key)
self._model_name = model_name
self._top_n = top_n

def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
reranked_query_results: List[KBQueryResult] = []
for result in results:
texts = [doc.text for doc in result.documents]
try:
response = self._client.rerank(query=result.query,
documents=texts,
top_n=self._top_n,
model=self._model_name)
except CohereAPIError as e:
raise RuntimeError("Failed to rerank documents using Cohere."
f" Underlying Error:\n{e.message}")

reranked_docs = []
for rerank_result in response:
doc = result.documents[rerank_result.index].copy(
deep=True,
update=dict(score=rerank_result.relevance_score)
)
reranked_docs.append(doc)

reranked_query_results.append(KBQueryResult(query=result.query,
documents=reranked_docs))
return reranked_query_results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
raise NotImplementedError()
21 changes: 0 additions & 21 deletions src/canopy/knowledge_base/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,3 @@ def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
@abstractmethod
async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
pass


class TransparentReranker(Reranker):
"""
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker.
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked.
""" # noqa: E501
def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
"""
Returns the results as is.
Args:
results: A list of KBQueryResult to rerank.
Returns:
results: A list of KBQueryResult, same as the input.
""" # noqa: E501
return results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
return results
26 changes: 26 additions & 0 deletions src/canopy/knowledge_base/reranker/transparent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List

from canopy.knowledge_base.models import KBQueryResult
from canopy.knowledge_base.reranker import Reranker


class TransparentReranker(Reranker):
"""
Transparent reranker that does nothing, it just returns the results as is. This is the default reranker.
The TransparentReranker is used as a placeholder for future development "forcing" every result set to be reranked.
""" # noqa: E501

def rerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
"""
Returns the results as is.
Args:
results: A list of KBQueryResult to rerank.
Returns:
results: A list of KBQueryResult, same as the input.
""" # noqa: E501
return results

async def arerank(self, results: List[KBQueryResult]) -> List[KBQueryResult]:
return results
3 changes: 3 additions & 0 deletions src/canopy/utils/debugging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

CANOPY_DEBUG_INFO = os.getenv("CANOPY_DEBUG_INFO", "FALSE").lower() == "true"
Empty file.
71 changes: 71 additions & 0 deletions tests/system/reranker/test_cohere_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import pytest

from canopy.knowledge_base.models import KBQueryResult, KBDocChunkWithScore
from canopy.knowledge_base.reranker import CohereReranker


@pytest.fixture
def should_run_test():
if os.getenv("CO_API_KEY") is None:
pytest.skip(
"Couldn't find Cohere API key. Skipping Cohere tests."
)


@pytest.fixture
def cohere_reranker(should_run_test):
return CohereReranker()


@pytest.fixture
def documents():
return [
KBDocChunkWithScore(
id=f"doc_1_{i}",
text=f"Sample chunk {i}",
document_id="doc_1",
source="doc_1",
score=0.1 * i
) for i in range(4)
]


@pytest.fixture
def query_result(documents):
return KBQueryResult(query="Sample query 1",
documents=documents)


def test_rerank_empty(cohere_reranker):
results = cohere_reranker.rerank([])
assert results == []


def test_rerank(cohere_reranker, query_result, documents):
id_to_score = {d.id: d.score for d in query_result.documents}
ranked_result = next(iter(cohere_reranker.rerank([query_result])))
reranked_scores = [doc.score for doc in ranked_result.documents]

assert len(ranked_result.documents) == len(documents)
assert reranked_scores == sorted(reranked_scores, reverse=True)

# Make sure the scores are overriden by the reranker
for doc in ranked_result.documents:
assert doc.score != id_to_score[doc.id]


def test_bad_api_key(should_run_test, query_result):
with pytest.raises(RuntimeError, match="invalid api token"):
CohereReranker(api_key="bad key").rerank([query_result])


def test_model_name_invalid(should_run_test, query_result):
with pytest.raises(RuntimeError, match="model not found"):
CohereReranker(model_name="my-madeup-model").rerank([query_result])


def test_top_n(should_run_test, query_result):
results = CohereReranker(top_n=1).rerank([query_result])
assert len(results[0].documents) == 1
27 changes: 27 additions & 0 deletions tests/system/reranker/test_transparent_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from canopy.knowledge_base.models import KBDocChunkWithScore, KBQueryResult
from canopy.knowledge_base.reranker import TransparentReranker


@pytest.fixture
def documents():
return [
KBDocChunkWithScore(
id=f"doc_1_{i}",
text=f"Sample chunk {i}",
document_id="doc_1",
source="doc_1",
score=0.1 * i
) for i in range(1)
]


@pytest.fixture
def query_result(documents):
return KBQueryResult(query="Sample query 1",
documents=documents)


def test_rerank(query_result):
assert TransparentReranker().rerank([query_result]) == [query_result]

0 comments on commit 95c7b24

Please sign in to comment.