diff --git a/platform/reworkd_platform/services/tokenizer/dependencies.py b/platform/reworkd_platform/services/tokenizer/dependencies.py index 6045a14a96..c076cf2171 100644 --- a/platform/reworkd_platform/services/tokenizer/dependencies.py +++ b/platform/reworkd_platform/services/tokenizer/dependencies.py @@ -3,5 +3,5 @@ from reworkd_platform.services.tokenizer.service import TokenService -def get_tokenizer(request: Request) -> TokenService: +def get_token_service(request: Request) -> TokenService: return TokenService(request.app.state.token_encoding) diff --git a/platform/reworkd_platform/services/tokenizer/lifetime.py b/platform/reworkd_platform/services/tokenizer/lifetime.py index db817da149..01e207f562 100644 --- a/platform/reworkd_platform/services/tokenizer/lifetime.py +++ b/platform/reworkd_platform/services/tokenizer/lifetime.py @@ -4,7 +4,7 @@ ENCODING_NAME = "cl100k_base" # gpt-4, gpt-3.5-turbo, text-embedding-ada-002 -async def init_tokenizer(app: FastAPI) -> None: # pragma: no cover +def init_tokenizer(app: FastAPI) -> None: # pragma: no cover """ Initialize tokenizer. diff --git a/platform/reworkd_platform/services/tokenizer/service.py b/platform/reworkd_platform/services/tokenizer/service.py index 89d6f2931c..d654dcf971 100644 --- a/platform/reworkd_platform/services/tokenizer/service.py +++ b/platform/reworkd_platform/services/tokenizer/service.py @@ -11,12 +11,5 @@ def tokenize(self, text: str) -> list[int]: def detokenize(self, tokens: list[int]) -> str: return self.encoding.decode(tokens) - def token_count(self, text: str) -> int: + def count(self, text: str) -> int: return len(self.tokenize(text)) - - def get_context_space( - self, prompt: str, max_tokens: int, reserved_response_tokens: int - ) -> int: - prompt_tokens = self.tokenize(prompt) - - return max_tokens - len(prompt_tokens) - reserved_response_tokens diff --git a/platform/reworkd_platform/tests/agent/agent/create_model_test.py b/platform/reworkd_platform/tests/agent/agent/create_model_test.py index c210128fae..94639b551b 100644 --- a/platform/reworkd_platform/tests/agent/agent/create_model_test.py +++ b/platform/reworkd_platform/tests/agent/agent/create_model_test.py @@ -1,27 +1,35 @@ +import itertools + import pytest -from reworkd_platform.schemas import ModelSettings -from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import ( - create_model, -) +from reworkd_platform.schemas import ModelSettings, UserBase +from reworkd_platform.web.api.agent.model_settings import create_model @pytest.mark.parametrize( - "settings", - [ - ModelSettings( - customTemperature=0.222, - customModelName="gpt-4", - maxTokens=1234, - ), - ModelSettings(), - ], + "settings, streaming", + list( + itertools.product( + [ + ModelSettings( + customTemperature=0.222, + customModelName="gpt-4", + maxTokens=1234, + ), + ModelSettings(), + ], + [True, False], + ) + ), ) -def test_create_model( - settings: ModelSettings, -): - model = create_model(settings) +def test_create_model(settings: ModelSettings, streaming: bool): + model = create_model( + settings, + UserBase(id="", email="test@example.com"), + streaming=streaming, + ) assert model.temperature == settings.temperature - assert model.model_name == settings.model + assert model.model_name.startswith(settings.model) assert model.max_tokens == settings.max_tokens + assert model.streaming == streaming diff --git a/platform/reworkd_platform/tests/test_reworkd_platform.py b/platform/reworkd_platform/tests/test_reworkd_platform.py index 1afa9082f0..474131a8ca 100644 --- a/platform/reworkd_platform/tests/test_reworkd_platform.py +++ b/platform/reworkd_platform/tests/test_reworkd_platform.py @@ -4,6 +4,7 @@ from starlette import status +@pytest.mark.skip(reason="Mysql needs to be mocked") @pytest.mark.anyio async def test_health(client: AsyncClient, fastapi_app: FastAPI) -> None: """ diff --git a/platform/reworkd_platform/tests/test_token_service.py b/platform/reworkd_platform/tests/test_token_service.py index e423d9bf0f..56d61acd26 100644 --- a/platform/reworkd_platform/tests/test_token_service.py +++ b/platform/reworkd_platform/tests/test_token_service.py @@ -19,17 +19,8 @@ def test_nothing(): validate_tokenize_and_detokenize(service, text, 0) -def test_context_space(): - prompt = "You're a wizard, Harry. Write a book based on the context below:" - max_tokens = 800 - - service = TokenService(encoding) - get_context_space = service.get_context_space(prompt, max_tokens, 500) - assert 0 < get_context_space < 800 - 500 - - def validate_tokenize_and_detokenize(service, text, expected_token_count): tokens = service.tokenize(text) assert text == service.detokenize(tokens) - assert len(tokens) == service.token_count(text) + assert len(tokens) == service.count(text) assert len(tokens) == expected_token_count diff --git a/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py b/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py index 577934c39c..475140455d 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py @@ -3,6 +3,8 @@ from fastapi import Depends from reworkd_platform.schemas import AgentRun, UserBase +from reworkd_platform.services.tokenizer.dependencies import get_token_service +from reworkd_platform.services.tokenizer.service import TokenService from reworkd_platform.settings import settings from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService from reworkd_platform.web.api.agent.agent_service.mock_agent_service import ( @@ -27,11 +29,14 @@ def func( run: AgentRun = Depends(validator), user: UserBase = Depends(get_current_user), agent_memory: AgentMemory = Depends(get_agent_memory), + token_service: TokenService = Depends(get_token_service), ) -> AgentService: if settings.ff_mock_mode_enabled: return MockAgentService() model = create_model(run.model_settings, user, streaming=streaming) - return OpenAIAgentService(model, run.model_settings.language, agent_memory) + return OpenAIAgentService( + model, run.model_settings.language, agent_memory, token_service + ) return func diff --git a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py index b832d9244b..85b596e6eb 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py @@ -1,12 +1,12 @@ from typing import List, Optional from lanarky.responses import StreamingResponse -from langchain.chat_models.base import BaseChatModel from langchain.output_parsers import PydanticOutputParser from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate from loguru import logger from pydantic import ValidationError +from reworkd_platform.services.tokenizer.service import TokenService from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments from reworkd_platform.web.api.agent.helpers import ( @@ -14,6 +14,7 @@ openai_error_handler, parse_with_handling, ) +from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI from reworkd_platform.web.api.agent.prompts import ( analyze_task_prompt, create_tasks_prompt, @@ -34,15 +35,28 @@ class OpenAIAgentService(AgentService): def __init__( self, - model: BaseChatModel, + model: WrappedChatOpenAI, language: str, agent_memory: AgentMemory, + token_service: TokenService, ): self.model = model self.agent_memory = agent_memory self.language = language + self.token_service = token_service async def start_goal_agent(self, *, goal: str) -> List[str]: + prompt = ChatPromptTemplate.from_messages( + [SystemMessagePromptTemplate(prompt=start_goal_prompt)] + ) + + self.model.max_tokens -= self.token_service.count( + prompt.format_prompt( + goal=goal, + language=self.language, + ).to_string(), + ) + completion = await call_model_with_handling( self.model, ChatPromptTemplate.from_messages( @@ -63,14 +77,22 @@ async def start_goal_agent(self, *, goal: str) -> List[str]: async def analyze_task_agent( self, *, goal: str, task: str, tool_names: List[str] ) -> Analysis: + functions = list(map(get_tool_function, get_user_tools(tool_names))) + prompt = analyze_task_prompt.format_prompt( + goal=goal, + task=task, + language=self.language, + ) + + self.model.max_tokens -= self.token_service.count(prompt.to_string()) + + # TODO: We are over counting tokens here + self.model.max_tokens -= self.token_service.count(str(functions)) + message = await openai_error_handler( func=self.model.apredict_messages, - messages=analyze_task_prompt.format_prompt( - goal=goal, - task=task, - language=self.language, - ).to_messages(), - functions=list(map(get_tool_function, get_user_tools(tool_names))), + messages=prompt.to_messages(), + functions=functions, ) function_call = message.additional_kwargs.get("function_call", {}) @@ -93,7 +115,9 @@ async def execute_task_agent( task: str, analysis: Analysis, ) -> StreamingResponse: - print("Execution analysis:", analysis) + # TODO: More mature way of calculating max_tokens + if self.model.max_tokens > 3000: + self.model.max_tokens = max(self.model.max_tokens - 1000, 3000) tool_class = get_tool_from_name(analysis.action) return await tool_class(self.model, self.language).call( @@ -109,18 +133,26 @@ async def create_tasks_agent( result: str, completed_tasks: Optional[List[str]] = None, ) -> List[str]: + prompt = ChatPromptTemplate.from_messages( + [SystemMessagePromptTemplate(prompt=create_tasks_prompt)] + ) + + args = { + "goal": goal, + "language": self.language, + "tasks": "\n".join(tasks), + "lastTask": last_task, + "result": result, + } + + self.model.max_tokens -= self.token_service.count( + prompt.format_prompt(**args).to_string(), + ) + completion = await call_model_with_handling( self.model, - ChatPromptTemplate.from_messages( - [SystemMessagePromptTemplate(prompt=create_tasks_prompt)] - ), - { - "goal": goal, - "language": self.language, - "tasks": "\n".join(tasks), - "lastTask": last_task, - "result": result, - }, + prompt, + args, ) previous_tasks = (completed_tasks or []) + tasks diff --git a/platform/reworkd_platform/web/api/agent/model_settings.py b/platform/reworkd_platform/web/api/agent/model_settings.py index 3d103be301..72e60fedd6 100644 --- a/platform/reworkd_platform/web/api/agent/model_settings.py +++ b/platform/reworkd_platform/web/api/agent/model_settings.py @@ -8,9 +8,13 @@ openai.api_base = settings.openai_api_base +class WrappedChatOpenAI(ChatOpenAI): + max_tokens: int + + def create_model( model_settings: ModelSettings, user: UserBase, streaming: bool = False -) -> ChatOpenAI: +) -> WrappedChatOpenAI: if model_settings.custom_api_key != "": api_key = model_settings.custom_api_key else: @@ -20,7 +24,7 @@ def create_model( model=model_settings.model, ) - return ChatOpenAI( + return WrappedChatOpenAI( client=None, # Meta private value but mypy will complain its missing openai_api_key=api_key, temperature=model_settings.temperature, diff --git a/platform/reworkd_platform/web/lifetime.py b/platform/reworkd_platform/web/lifetime.py index 1e635d6ac8..bc7a2604fc 100644 --- a/platform/reworkd_platform/web/lifetime.py +++ b/platform/reworkd_platform/web/lifetime.py @@ -7,6 +7,7 @@ from reworkd_platform.db.models import load_all_models from reworkd_platform.db.utils import create_engine from reworkd_platform.services.pinecone.lifetime import init_pinecone +from reworkd_platform.services.tokenizer.lifetime import init_tokenizer from reworkd_platform.services.vecs.lifetime import ( init_supabase_vecs, shutdown_supabase_vecs, @@ -59,6 +60,7 @@ def register_startup_event( async def _startup() -> None: # noqa: WPS430 _setup_db(app) init_pinecone() + init_tokenizer(app) init_supabase_vecs( app ) # create pg_connection connection pool at startup as its expensive