diff --git a/src/backend/chat/custom/langchain.py b/src/backend/chat/custom/langchain.py index 2ee432f68b..a902fd2ae7 100644 --- a/src/backend/chat/custom/langchain.py +++ b/src/backend/chat/custom/langchain.py @@ -59,6 +59,8 @@ def chat(self, chat_request: LangchainChatRequest, **kwargs: Any) -> Any: verbose=True, ) + raise NotImplementedError("Langchain is not yet implemented") + return self.agent_executor.stream( { "input": chat_request.message, diff --git a/src/backend/config/deployments.py b/src/backend/config/deployments.py index a00fb9f385..3cc664cd95 100644 --- a/src/backend/config/deployments.py +++ b/src/backend/config/deployments.py @@ -34,20 +34,20 @@ class ModelDeploymentName(StrEnum): is_available=CohereDeployment.is_available(), env_vars=COHERE_ENV_VARS, ), - ModelDeploymentName.SageMaker: Deployment( - name=ModelDeploymentName.SageMaker, - deployment_class=SageMakerDeployment, - models=SageMakerDeployment.list_models(), - is_available=SageMakerDeployment.is_available(), - env_vars=SAGE_MAKER_ENV_VARS, - ), - ModelDeploymentName.Azure: Deployment( - name=ModelDeploymentName.Azure, - deployment_class=AzureDeployment, - models=AzureDeployment.list_models(), - is_available=AzureDeployment.is_available(), - env_vars=AZURE_ENV_VARS, - ), + # ModelDeploymentName.SageMaker: Deployment( + # name=ModelDeploymentName.SageMaker, + # deployment_class=SageMakerDeployment, + # models=SageMakerDeployment.list_models(), + # is_available=SageMakerDeployment.is_available(), + # env_vars=SAGE_MAKER_ENV_VARS, + # ), + # ModelDeploymentName.Azure: Deployment( + # name=ModelDeploymentName.Azure, + # deployment_class=AzureDeployment, + # models=AzureDeployment.list_models(), + # is_available=AzureDeployment.is_available(), + # env_vars=AZURE_ENV_VARS, + # ), ModelDeploymentName.Bedrock: Deployment( name=ModelDeploymentName.Bedrock, deployment_class=BedrockDeployment, diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 523c6a8d77..6ba561be59 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -59,7 +59,7 @@ class ToolName(StrEnum): implementation=LangChainMinimapRetriever, parameter_definitions={ "query": { - "description": "Search API that takes a query or phrase. ", + "description": "Search API that takes a query or phrase. Results should be presented as an executive summary, grouped and summarized for the user with section headings and bullet points.", "type": "str", "required": True, } @@ -68,7 +68,7 @@ class ToolName(StrEnum): is_available=LangChainMinimapRetriever.is_available(), error_message="Minimap API not available.", category=Category.DataLoader, - description="Fetches the most relevant news and content from Minimap.ai. Results should be presented as an executive summary, grouped and summarized for the user with section headings and bullet points.", + description="Fetches the most relevant news and content from Minimap.ai.", ), # ToolName.Search_File: ManagedTool( # name=ToolName.Search_File, diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index 24fdfc7748..4679293022 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -1,6 +1,7 @@ import json import logging import os +import time from typing import Any, Dict, Generator, List import cohere @@ -12,11 +13,57 @@ from backend.model_deployments.base import BaseDeployment from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest +from backend.tools.minimap import LangChainMinimapRetriever + +from cohere.types.tool import Tool + +MINIMAP_TOOL = Tool( + name="Minimap", + description="Fetches the most relevant news and content from Minimap.ai.", + parameter_definitions={ + "query": { + "description": "Search API that takes a query or phrase. Results should be presented as an executive summary, grouped and summarized for the user with section headings and bullet points.", + "type": "str", + "required": True, + } + }, +) COHERE_API_KEY_ENV_VAR = "COHERE_API_KEY" COHERE_ENV_VARS = [COHERE_API_KEY_ENV_VAR] +MODELS = [ + { + 'name': 'command-r', + 'endpoints': ['generate', 'chat', 'summarize'], + 'finetuned': False, + 'context_length': 128000, + 'tokenizer_url': 'https://storage.googleapis.com/cohere-public/tokenizers/command-r.json', + 'default_endpoints': [] + }, + { + 'name': 'command-r-plus', + 'endpoints': ['generate', 'chat', 'summarize'], + 'finetuned': False, + 'context_length': 128000, + 'tokenizer_url': 'https://storage.googleapis.com/cohere-public/tokenizers/command-r-plus.json', + 'default_endpoints': ['chat'] + }, +] + +preamble = f""" +You are a news summarization assistant. You're equipped with Minimap.ai's news search tool. + +You will be provided with large number of news articles. You're task is to provide users with salient summaries of major trends in the news. Summaries should be presented as a mini news brief, with topic headings and contain bullet points with key information. + +You can elaborate on or use a more precise query than what the user provided to get more specific results. + +Always ask the user if there's a specific point or topic they want to drill down on. + +Today's date is {time.strftime("%Y-%m-%d")}. +""" + class CohereDeployment(BaseDeployment): """Cohere Platform Deployment.""" @@ -36,19 +83,20 @@ def list_models(cls) -> List[str]: if not CohereDeployment.is_available(): return [] - url = "https://api.cohere.ai/v1/models" - headers = { - "accept": "application/json", - "authorization": f"Bearer {cls.api_key}", - } + # url = "https://api.cohere.ai/v1/models" + # headers = { + # "accept": "application/json", + # "authorization": f"Bearer {cls.api_key}", + # } - response = requests.get(url, headers=headers) + # response = requests.get(url, headers=headers) + # logging.info(response.json()) + # if not response.ok: + # logging.warning("Couldn't get models from Cohere API.") + # return [] - if not response.ok: - logging.warning("Couldn't get models from Cohere API.") - return [] + models = MODELS - models = response.json()["models"] return [ model["name"] for model in models @@ -60,6 +108,7 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS]) def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: + logging.info(f"Invoking chat with chat_request") response = self.client.chat( **chat_request.model_dump(exclude={"stream"}), **kwargs, @@ -69,10 +118,22 @@ def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: def invoke_chat_stream( self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: - stream = self.client.chat_stream( + + chat_params = { **chat_request.model_dump(exclude={"stream", "file_ids"}), - **kwargs, + **kwargs + } + + # Append "Minimap" to the tools list + chat_params["tools"] = chat_params.get("tools", []) + [MINIMAP_TOOL] + + # Set the preamble + chat_params["preamble"] = preamble + + stream = self.client.chat_stream( + **chat_params, ) + for event in stream: yield to_dict(event) @@ -106,4 +167,5 @@ def invoke_tools( chat_request: CohereChatRequest, **kwargs: Any, ) -> Generator[StreamedChatResponse, None, None]: + yield from self.invoke_chat_stream(chat_request, **kwargs) diff --git a/src/backend/tools/minimap.py b/src/backend/tools/minimap.py index fac41b42b9..6fc2d6d1f2 100644 --- a/src/backend/tools/minimap.py +++ b/src/backend/tools/minimap.py @@ -1,4 +1,5 @@ import os +import pandas as pd from typing import Any, Dict, List # from langchain.text_splitter import CharacterTextSplitter @@ -35,6 +36,16 @@ logging = logging.getLogger(__name__) + +def hash_string(s: str) -> str: + """ + Hash a string using the djb2 algorithm. + """ + hash = 5381 + for x in s: + hash = ((hash << 5) + hash) + ord(x) + return str(hash) + class MinimapAPIWrapper(BaseModel): """ Wrapper around Minimap.ai API. @@ -64,7 +75,7 @@ class MinimapAPIWrapper(BaseModel): max_retry: int = 5 # Default values for the parameters - top_k_results: int = 100 + top_k_results: int = 50 MAX_QUERY_LENGTH: int = 300 doc_content_chars_max: int = 2000 @@ -90,16 +101,19 @@ def run(self, query: str) -> str: results = response_json.get("results", []) + + df = pd.DataFrame(results) + + df['title_hash'] = df['title'].apply(hash_string) + + # drop duplicates + df = df.drop_duplicates(subset=['title_hash']) + + results = df.to_dict(orient='records') + # limit the number of results to top_k_results results = results[:self.top_k_results] - # Rename id to doc_id - for result in results: - try: - result['doc_id'] = result.get('id') - except Exception as ex: - ... - return results except Exception as ex: diff --git a/src/community/config/deployments.py b/src/community/config/deployments.py index 932d55a07f..ffe6d31d2b 100644 --- a/src/community/config/deployments.py +++ b/src/community/config/deployments.py @@ -13,13 +13,13 @@ class ModelDeploymentName(StrEnum): AVAILABLE_MODEL_DEPLOYMENTS = { - ModelDeploymentName.HuggingFace: Deployment( - name=ModelDeploymentName.HuggingFace, - deployment_class=HuggingFaceDeployment, - models=HuggingFaceDeployment.list_models(), - is_available=HuggingFaceDeployment.is_available(), - env_vars=[], - ), + # ModelDeploymentName.HuggingFace: Deployment( + # name=ModelDeploymentName.HuggingFace, + # deployment_class=HuggingFaceDeployment, + # models=HuggingFaceDeployment.list_models(), + # is_available=HuggingFaceDeployment.is_available(), + # env_vars=[], + # ), # Add the below for local model deployments # ModelDeploymentName.LocalModel: Deployment( # name=ModelDeploymentName.LocalModel, diff --git a/src/interfaces/coral_web/src/components/Configuration/Tools.tsx b/src/interfaces/coral_web/src/components/Configuration/Tools.tsx index 4485cae25b..01b01acf77 100644 --- a/src/interfaces/coral_web/src/components/Configuration/Tools.tsx +++ b/src/interfaces/coral_web/src/components/Configuration/Tools.tsx @@ -72,12 +72,12 @@ const ToolSection = () => { updateEnabledTools(updatedTools); }; - // useEffect to enable all tools by default - React.useEffect(() => { - if (tools.length > 0 && enabledTools.length === 0) { - updateEnabledTools(tools); - } - }, []); + // // useEffect to enable all tools by default + // React.useEffect(() => { + // if (tools.length > 0 && enabledTools.length === 0) { + // updateEnabledTools(tools); + // } + // }, []); return (