Skip to content

Commit

Permalink
Added preamble and minimap as a default tool
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaytonSmith committed Aug 8, 2024
1 parent 76950f1 commit bd68cb9
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 49 deletions.
2 changes: 2 additions & 0 deletions src/backend/chat/custom/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def chat(self, chat_request: LangchainChatRequest, **kwargs: Any) -> Any:
verbose=True,
)

raise NotImplementedError("Langchain is not yet implemented")

return self.agent_executor.stream(
{
"input": chat_request.message,
Expand Down
28 changes: 14 additions & 14 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ class ModelDeploymentName(StrEnum):
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
# ModelDeploymentName.SageMaker: Deployment(
# name=ModelDeploymentName.SageMaker,
# deployment_class=SageMakerDeployment,
# models=SageMakerDeployment.list_models(),
# is_available=SageMakerDeployment.is_available(),
# env_vars=SAGE_MAKER_ENV_VARS,
# ),
# ModelDeploymentName.Azure: Deployment(
# name=ModelDeploymentName.Azure,
# deployment_class=AzureDeployment,
# models=AzureDeployment.list_models(),
# is_available=AzureDeployment.is_available(),
# env_vars=AZURE_ENV_VARS,
# ),
ModelDeploymentName.Bedrock: Deployment(
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
Expand Down
4 changes: 2 additions & 2 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ToolName(StrEnum):
implementation=LangChainMinimapRetriever,
parameter_definitions={
"query": {
"description": "Search API that takes a query or phrase. ",
"description": "Search API that takes a query or phrase. Results should be presented as an executive summary, grouped and summarized for the user with section headings and bullet points.",
"type": "str",
"required": True,
}
Expand All @@ -68,7 +68,7 @@ class ToolName(StrEnum):
is_available=LangChainMinimapRetriever.is_available(),
error_message="Minimap API not available.",
category=Category.DataLoader,
description="Fetches the most relevant news and content from Minimap.ai. Results should be presented as an executive summary, grouped and summarized for the user with section headings and bullet points.",
description="Fetches the most relevant news and content from Minimap.ai.",
),
# ToolName.Search_File: ManagedTool(
# name=ToolName.Search_File,
Expand Down
86 changes: 74 additions & 12 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import time
from typing import Any, Dict, Generator, List

import cohere
Expand All @@ -12,11 +13,57 @@
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.tools.minimap import LangChainMinimapRetriever

from cohere.types.tool import Tool

MINIMAP_TOOL = Tool(
name="Minimap",
description="Fetches the most relevant news and content from Minimap.ai.",
parameter_definitions={
"query": {
"description": "Search API that takes a query or phrase. Results should be presented as an executive summary, grouped and summarized for the user with section headings and bullet points.",
"type": "str",
"required": True,
}
},
)

COHERE_API_KEY_ENV_VAR = "COHERE_API_KEY"
COHERE_ENV_VARS = [COHERE_API_KEY_ENV_VAR]


MODELS = [
{
'name': 'command-r',
'endpoints': ['generate', 'chat', 'summarize'],
'finetuned': False,
'context_length': 128000,
'tokenizer_url': 'https://storage.googleapis.com/cohere-public/tokenizers/command-r.json',
'default_endpoints': []
},
{
'name': 'command-r-plus',
'endpoints': ['generate', 'chat', 'summarize'],
'finetuned': False,
'context_length': 128000,
'tokenizer_url': 'https://storage.googleapis.com/cohere-public/tokenizers/command-r-plus.json',
'default_endpoints': ['chat']
},
]

preamble = f"""
You are a news summarization assistant. You're equipped with Minimap.ai's news search tool.
You will be provided with large number of news articles. You're task is to provide users with salient summaries of major trends in the news. Summaries should be presented as a mini news brief, with topic headings and contain bullet points with key information.
You can elaborate on or use a more precise query than what the user provided to get more specific results.
Always ask the user if there's a specific point or topic they want to drill down on.
Today's date is {time.strftime("%Y-%m-%d")}.
"""

class CohereDeployment(BaseDeployment):
"""Cohere Platform Deployment."""

Expand All @@ -36,19 +83,20 @@ def list_models(cls) -> List[str]:
if not CohereDeployment.is_available():
return []

url = "https://api.cohere.ai/v1/models"
headers = {
"accept": "application/json",
"authorization": f"Bearer {cls.api_key}",
}
# url = "https://api.cohere.ai/v1/models"
# headers = {
# "accept": "application/json",
# "authorization": f"Bearer {cls.api_key}",
# }

response = requests.get(url, headers=headers)
# response = requests.get(url, headers=headers)
# logging.info(response.json())
# if not response.ok:
# logging.warning("Couldn't get models from Cohere API.")
# return []

if not response.ok:
logging.warning("Couldn't get models from Cohere API.")
return []
models = MODELS

models = response.json()["models"]
return [
model["name"]
for model in models
Expand All @@ -60,6 +108,7 @@ def is_available(cls) -> bool:
return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS])

def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
logging.info(f"Invoking chat with chat_request")
response = self.client.chat(
**chat_request.model_dump(exclude={"stream"}),
**kwargs,
Expand All @@ -69,10 +118,22 @@ def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Generator[StreamedChatResponse, None, None]:
stream = self.client.chat_stream(

chat_params = {
**chat_request.model_dump(exclude={"stream", "file_ids"}),
**kwargs,
**kwargs
}

# Append "Minimap" to the tools list
chat_params["tools"] = chat_params.get("tools", []) + [MINIMAP_TOOL]

# Set the preamble
chat_params["preamble"] = preamble

stream = self.client.chat_stream(
**chat_params,
)

for event in stream:
yield to_dict(event)

Expand Down Expand Up @@ -106,4 +167,5 @@ def invoke_tools(
chat_request: CohereChatRequest,
**kwargs: Any,
) -> Generator[StreamedChatResponse, None, None]:

yield from self.invoke_chat_stream(chat_request, **kwargs)
30 changes: 22 additions & 8 deletions src/backend/tools/minimap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pandas as pd
from typing import Any, Dict, List

# from langchain.text_splitter import CharacterTextSplitter
Expand Down Expand Up @@ -35,6 +36,16 @@

logging = logging.getLogger(__name__)


def hash_string(s: str) -> str:
"""
Hash a string using the djb2 algorithm.
"""
hash = 5381
for x in s:
hash = ((hash << 5) + hash) + ord(x)
return str(hash)

class MinimapAPIWrapper(BaseModel):
"""
Wrapper around Minimap.ai API.
Expand Down Expand Up @@ -64,7 +75,7 @@ class MinimapAPIWrapper(BaseModel):
max_retry: int = 5

# Default values for the parameters
top_k_results: int = 100
top_k_results: int = 50
MAX_QUERY_LENGTH: int = 300
doc_content_chars_max: int = 2000

Expand All @@ -90,16 +101,19 @@ def run(self, query: str) -> str:

results = response_json.get("results", [])


df = pd.DataFrame(results)

df['title_hash'] = df['title'].apply(hash_string)

# drop duplicates
df = df.drop_duplicates(subset=['title_hash'])

results = df.to_dict(orient='records')

# limit the number of results to top_k_results
results = results[:self.top_k_results]

# Rename id to doc_id
for result in results:
try:
result['doc_id'] = result.get('id')
except Exception as ex:
...

return results

except Exception as ex:
Expand Down
14 changes: 7 additions & 7 deletions src/community/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class ModelDeploymentName(StrEnum):


AVAILABLE_MODEL_DEPLOYMENTS = {
ModelDeploymentName.HuggingFace: Deployment(
name=ModelDeploymentName.HuggingFace,
deployment_class=HuggingFaceDeployment,
models=HuggingFaceDeployment.list_models(),
is_available=HuggingFaceDeployment.is_available(),
env_vars=[],
),
# ModelDeploymentName.HuggingFace: Deployment(
# name=ModelDeploymentName.HuggingFace,
# deployment_class=HuggingFaceDeployment,
# models=HuggingFaceDeployment.list_models(),
# is_available=HuggingFaceDeployment.is_available(),
# env_vars=[],
# ),
# Add the below for local model deployments
# ModelDeploymentName.LocalModel: Deployment(
# name=ModelDeploymentName.LocalModel,
Expand Down
12 changes: 6 additions & 6 deletions src/interfaces/coral_web/src/components/Configuration/Tools.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ const ToolSection = () => {
updateEnabledTools(updatedTools);
};

// useEffect to enable all tools by default
React.useEffect(() => {
if (tools.length > 0 && enabledTools.length === 0) {
updateEnabledTools(tools);
}
}, []);
// // useEffect to enable all tools by default
// React.useEffect(() => {
// if (tools.length > 0 && enabledTools.length === 0) {
// updateEnabledTools(tools);
// }
// }, []);

return (
<section className="relative flex flex-col gap-y-5 px-5">
Expand Down

0 comments on commit bd68cb9

Please sign in to comment.